Skip to content
Open

Fixes #870

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
108 changes: 106 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,108 @@
# openpi
# openpi with CUDA inference!

Making pi05 go 🚀🚀🚀🚀.

Tested on
- ubuntu 22.04
- cuda 12.6
- python 3.11.14
- A100 40GB GPU

## Installation

```
export CUDA_HOME=/usr/local/cuda-12.6
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}

GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
```

## Setup

```
# Convert JAX weights -> pytorch weights
uv run python examples/convert_jax_model_to_pytorch.py --checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch --config_name pi05_droid --precision float32
cp -r ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid/assets/ ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch/

# PyTorch Hacks
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
```

## Benchmarking
```
uv run scripts/benchmark.py
```
We also provide individual benchmarking and testing scripts in the `tests/` folder.

## Benchmark Results
### JAX Results
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **inference_ms** | 92.9 | 50.9 | 79.8 | 80.0 | 81.0 | 90.9 | 102.2 | 272.0 |
| **policy_infer_ms** | 57.2 | 2.8 | 56.2 | 56.3 | 56.4 | 57.4 | 65.5 | 65.5 |

### PyTorch Results
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **inference_ms** | 322.0 | 0.9 | 321.5 | 321.8 | 322.6 | 323.3 | 323.4 | 323.7 |
| **policy_infer_ms** | 317.3 | 0.5 | 317.0 | 317.3 | 317.6 | 317.9 | 318.2 | 318.2 |

### CUDA Results
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **inference_ms** | 303.5 | 1.3 | 303.1 | 303.6 | 304.5 | 304.7 | 305.5 | 305.9 |
| **policy_infer_ms** | 298.7 | 1.2 | 298.6 | 298.8 | 299.2 | 299.7 | 300.0 | 301.0 |

Comparing JAX vs PyTorch:
```
JAX vs PyTorch
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min │ -0.724857 │
│ Action Max │ 0.966833 │
│ Action Mean │ 0.046649 │
│ │ │
│ Mean Absolute Diff │ 0.005507 │
│ Max Absolute Diff │ 0.183711 │
│ Median Absolute Diff │ 0.002901 │
│ │ │
│ % within 0.001 │ 24.83% │
│ % within 0.01 │ 86.12% │
│ % within 0.1 │ 99.96% │
│ % within 1.0 │ 100.00% │
└──────────────────────┴───────────┘
```

Comparing JAX vs CUDA:
```
JAX vs CUDA
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min │ -0.724857 │
│ Action Max │ 0.966833 │
│ Action Mean │ 0.046649 │
│ │ │
│ Mean Absolute Diff │ 0.004960 │
│ Max Absolute Diff │ 0.192997 │
│ Median Absolute Diff │ 0.002596 │
│ │ │
│ % within 0.001 │ 27.29% │
│ % within 0.01 │ 87.54% │
│ % within 0.1 │ 99.96% │
│ % within 1.0 │ 100.00% │
└──────────────────────┴───────────┘
```

## TODO
- [] still trying to figure out how to pin cuda 12.6 in uv. Getting cmake resolution errors currently. If we don't pin it, uv sync installs 12.6 on my machine by default.



<!-- # openpi

openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).

Expand Down Expand Up @@ -320,4 +424,4 @@ We will collect common issues and their solutions here. If you encounter an issu
| CUDA/GPU errors | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |
| Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. | -->
6 changes: 3 additions & 3 deletions examples/droid/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def handler(signum, frame):

def main(args: Args):
# Make sure external camera is specified by user -- we only use one external camera for the policy
assert (
args.external_camera is not None and args.external_camera in ["left", "right"]
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
assert args.external_camera is not None and args.external_camera in ["left", "right"], (
f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
)

# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
Expand Down
8 changes: 4 additions & 4 deletions examples/libero/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def eval_libero(args: Args) -> None:
t = 0
replay_images = []

logging.info(f"Starting episode {task_episodes+1}...")
logging.info(f"Starting episode {task_episodes + 1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
Expand Down Expand Up @@ -142,9 +142,9 @@ def eval_libero(args: Args) -> None:

# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
assert len(action_chunk) >= args.replan_steps, (
f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
)
action_plan.extend(action_chunk[: args.replan_steps])

action = action_plan.popleft()
Expand Down
14 changes: 12 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"transformers==4.53.2",
"rich>=14.0.0",
"polars>=1.30.0",
"gpustat>=1.1.1",
]


Expand Down Expand Up @@ -66,6 +67,11 @@ override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
openpi-client = { workspace = true }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
# torch = { index = "pytorch-cu126" }

# [[tool.uv.index]]
# name = "pytorch-cu126"
# url = "https://download.pytorch.org/whl/cu126"

[tool.uv.workspace]
members = ["packages/*"]
Expand Down Expand Up @@ -128,9 +134,13 @@ force-sort-within-sections = true
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
known-third-party = ["wandb"]

# [build-system]
# requires = ["hatchling"]
# build-backend = "hatchling.build"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
requires = ["setuptools", "torch==2.7.1"]
build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
markers = ["manual: should be run manually."]
Expand Down
Loading