Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions examples/convert_jax_model_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
Usage:
# Just inspect keys:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only

# Convert to PyTorch:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output

Example:
# pi0_droid
Expand Down Expand Up @@ -44,7 +42,6 @@
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
from openpi.training import utils
import openpi.training.config as _config


def slice_paligemma_state_dict(state_dict, config):
Expand Down Expand Up @@ -419,24 +416,33 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
print(utils.array_tree_to_info(metadata))


def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
):
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str):
"""
Convert PI0 JAX checkpoint to PyTorch format.

Args:
checkpoint_dir: Path to the JAX checkpoint
precision: Model precision (float32, bfloat16, float16)
output_path: Path to save the converted PyTorch model
model_config: Model config
"""
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
print(f"Model config: {model_config}")

# Break down orbax ckpts by restoring via JAX to respect dtype
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")

# Detect pi05 from checkpoint contents: pi05 uses time_mlp_in/out instead of state_proj
is_pi05 = "time_mlp_in" in initial_params["projection_params"]
print(f"Auto-detected model type: {'Pi0.5' if is_pi05 else 'Pi0'}")

# Infer action_dim from action_in_proj kernel shape: [action_dim, expert_hidden_size]
action_in_proj_kernel = initial_params["projection_params"]["action_in_proj"]["kernel"]
if isinstance(action_in_proj_kernel, dict):
action_in_proj_kernel = action_in_proj_kernel["value"]
action_dim = action_in_proj_kernel.shape[0]
print(f"Auto-detected action_dim: {action_dim}")

model_config = openpi.models.pi0_config.Pi0Config(pi05=is_pi05, action_dim=action_dim)

# Process projection params
if model_config.pi05:
keys = [
Expand Down Expand Up @@ -557,7 +563,6 @@ def __init__(self):

def main(
checkpoint_dir: str,
config_name: str,
output_path: str | None = None,
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
*,
Expand All @@ -571,16 +576,13 @@ def main(
precision: Precision for model conversion
inspect_only: Only inspect parameter keys, don't convert
"""
model_config = _config.get_config(config_name).model
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
raise ValueError(f"Config {config_name} is not a Pi0Config")
if inspect_only:
load_jax_model_and_print_keys(checkpoint_dir)
else:
if not output_path:
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
return
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
convert_pi0_checkpoint(checkpoint_dir, precision, output_path)


if __name__ == "__main__":
Expand Down