diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index 632c0b8782..d90c7cd5e9 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -9,11 +9,9 @@ Usage: # Just inspect keys: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid @@ -44,7 +42,6 @@ import openpi.models.pi0_config import openpi.models_pytorch.pi0_pytorch from openpi.training import utils -import openpi.training.config as _config def slice_paligemma_state_dict(state_dict, config): @@ -419,9 +416,7 @@ def load_jax_model_and_print_keys(checkpoint_dir: str): print(utils.array_tree_to_info(metadata)) -def convert_pi0_checkpoint( - checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config -): +def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str): """ Convert PI0 JAX checkpoint to PyTorch format. @@ -429,14 +424,25 @@ def convert_pi0_checkpoint( checkpoint_dir: Path to the JAX checkpoint precision: Model precision (float32, bfloat16, float16) output_path: Path to save the converted PyTorch model - model_config: Model config """ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") - print(f"Model config: {model_config}") # Break down orbax ckpts by restoring via JAX to respect dtype initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") + # Detect pi05 from checkpoint contents: pi05 uses time_mlp_in/out instead of state_proj + is_pi05 = "time_mlp_in" in initial_params["projection_params"] + print(f"Auto-detected model type: {'Pi0.5' if is_pi05 else 'Pi0'}") + + # Infer action_dim from action_in_proj kernel shape: [action_dim, expert_hidden_size] + action_in_proj_kernel = initial_params["projection_params"]["action_in_proj"]["kernel"] + if isinstance(action_in_proj_kernel, dict): + action_in_proj_kernel = action_in_proj_kernel["value"] + action_dim = action_in_proj_kernel.shape[0] + print(f"Auto-detected action_dim: {action_dim}") + + model_config = openpi.models.pi0_config.Pi0Config(pi05=is_pi05, action_dim=action_dim) + # Process projection params if model_config.pi05: keys = [ @@ -557,7 +563,6 @@ def __init__(self): def main( checkpoint_dir: str, - config_name: str, output_path: str | None = None, precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", *, @@ -571,16 +576,13 @@ def main( precision: Precision for model conversion inspect_only: Only inspect parameter keys, don't convert """ - model_config = _config.get_config(config_name).model - if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): - raise ValueError(f"Config {config_name} is not a Pi0Config") if inspect_only: load_jax_model_and_print_keys(checkpoint_dir) else: if not output_path: print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") return - convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) + convert_pi0_checkpoint(checkpoint_dir, precision, output_path) if __name__ == "__main__":