Add torch.compile support for actor and critic models#198
Draft
adenzler-nvidia wants to merge 8 commits intoleggedrobotics:mainfrom
Draft
Add torch.compile support for actor and critic models#198adenzler-nvidia wants to merge 8 commits intoleggedrobotics:mainfrom
adenzler-nvidia wants to merge 8 commits intoleggedrobotics:mainfrom
Conversation
Wrap the actor with torch.compile in OnPolicyRunner, defaulting to mode="default" for automatic Triton kernel fusion. Export methods (JIT, ONNX, get_inference_policy) unwrap the compiled model so scripting and tracing work unchanged. A one-time message after the first iteration reminds users that compile overhead is amortized.
- Compile both actor and critic (was actor-only) - Reorder act() to consume all actor distribution state before critic call - Add cudagraph_mark_step_begin() in learning loop for future CUDA graph compat - Swap uncompiled models for save/load to keep clean state_dict keys - Reject CUDA-graph modes (reduce-overhead, max-autotune) with clear error: critic's graph replay invalidates actor graph output buffers - Supported modes: default, max-autotune-no-cudagraphs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove cudagraph_mark_step_begin() call and duplicate comments from the CUDA graph investigation — not needed since those modes are rejected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reset ppo.py to main and re-apply only the torch-compile changes: - Reorder act() to consume actor distribution state before critic - Reorder learning loop similarly Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The actor/critic call reorder is not required for the supported compile modes (default, max-autotune-no-cudagraphs). Remove to minimize diff. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…lback Replace manual _uncompiled_actor/_uncompiled_critic reference tracking with a stateless _unwrap_compiled() helper that detects compiled models via state_dict key prefixes and unwraps them for serialization/export. - Add _unwrap_compiled() that detects _orig_mod. prefix in state_dict keys and returns the inner module, with a clear error if PyTorch changes the internal API - Remove _uncompiled_actor/_uncompiled_critic mutable state - Add try/except around torch.compile with fallback to eager mode - Simplify save/load/export to use _unwrap_compiled() uniformly Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Allow CNNModel to work when only 2D (image) observation groups are provided, with no 1D (state) groups. This enables perception-only actor/critic configurations. - CNNModel.get_latent: skip 1D concatenation when self.obs_groups is empty - MLPModel.update_normalization: skip when self.obs_groups is empty - _TorchCNNModel/_OnnxCNNModel: same fix for JIT/ONNX export paths Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This reverts commit d465de3.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch.compilesupport for actor and critic models inOnPolicyRunner, controlled viatorch_compile_modeconfig key (default:"default", set tonullto disable)CNNModel); MLP-only policies see no benefit as models are too small for compilation overhead to pay offtorch.compilefails_unwrap_compiled()helper that detects compiled models bystate_dictkey prefixes — no reliance on private class namesreduce-overhead,max-autotune) are blocked as they're incompatible with the two-model actor/critic patternCNNModelwith zero 1D observation groups (perception-only configurations)Closes #196
Test plan
torch_compile_mode: nulldisables compilation🤖 Generated with Claude Code