Skip to content

Add torch.compile support for actor and critic models#198

Draft
adenzler-nvidia wants to merge 8 commits intoleggedrobotics:mainfrom
adenzler-nvidia:feature/torch-compile
Draft

Add torch.compile support for actor and critic models#198
adenzler-nvidia wants to merge 8 commits intoleggedrobotics:mainfrom
adenzler-nvidia:feature/torch-compile

Conversation

@adenzler-nvidia
Copy link
Copy Markdown

Summary

  • Add opt-in torch.compile support for actor and critic models in OnPolicyRunner, controlled via torch_compile_mode config key (default: "default", set to null to disable)
  • ~1.2-1.4x total iteration speedup on CNN-based policies (Kuka-Allegro dexsuite with CNNModel); MLP-only policies see no benefit as models are too small for compilation overhead to pay off
  • Graceful fallback to eager mode if torch.compile fails
  • Clean serialization via _unwrap_compiled() helper that detects compiled models by state_dict key prefixes — no reliance on private class names
  • CUDA-graph-based compile modes (reduce-overhead, max-autotune) are blocked as they're incompatible with the two-model actor/critic pattern
  • Support CNNModel with zero 1D observation groups (perception-only configurations)

Closes #196

Test plan

  • Verify torch.compile speedup on CNN-based policy (dexsuite single_camera)
  • Verify MLP-only policy is unaffected
  • Verify save/load round-trip with compiled models produces clean state_dict keys
  • Verify JIT and ONNX export work with compiled models
  • Verify torch_compile_mode: null disables compilation
  • Verify graceful fallback when compile fails

🤖 Generated with Claude Code

adenzler-nvidia and others added 7 commits April 2, 2026 15:19
Wrap the actor with torch.compile in OnPolicyRunner, defaulting to
mode="default" for automatic Triton kernel fusion. Export methods
(JIT, ONNX, get_inference_policy) unwrap the compiled model so
scripting and tracing work unchanged. A one-time message after the
first iteration reminds users that compile overhead is amortized.
- Compile both actor and critic (was actor-only)
- Reorder act() to consume all actor distribution state before critic call
- Add cudagraph_mark_step_begin() in learning loop for future CUDA graph compat
- Swap uncompiled models for save/load to keep clean state_dict keys
- Reject CUDA-graph modes (reduce-overhead, max-autotune) with clear error:
  critic's graph replay invalidates actor graph output buffers
- Supported modes: default, max-autotune-no-cudagraphs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove cudagraph_mark_step_begin() call and duplicate comments from
the CUDA graph investigation — not needed since those modes are rejected.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reset ppo.py to main and re-apply only the torch-compile changes:
- Reorder act() to consume actor distribution state before critic
- Reorder learning loop similarly

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The actor/critic call reorder is not required for the supported compile
modes (default, max-autotune-no-cudagraphs). Remove to minimize diff.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…lback

Replace manual _uncompiled_actor/_uncompiled_critic reference tracking
with a stateless _unwrap_compiled() helper that detects compiled models
via state_dict key prefixes and unwraps them for serialization/export.

- Add _unwrap_compiled() that detects _orig_mod. prefix in state_dict
  keys and returns the inner module, with a clear error if PyTorch
  changes the internal API
- Remove _uncompiled_actor/_uncompiled_critic mutable state
- Add try/except around torch.compile with fallback to eager mode
- Simplify save/load/export to use _unwrap_compiled() uniformly

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Allow CNNModel to work when only 2D (image) observation groups are
provided, with no 1D (state) groups. This enables perception-only
actor/critic configurations.

- CNNModel.get_latent: skip 1D concatenation when self.obs_groups is empty
- MLPModel.update_normalization: skip when self.obs_groups is empty
- _TorchCNNModel/_OnnxCNNModel: same fix for JIT/ONNX export paths

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile support

1 participant