Add Gram Newton-Schulz orthogonalization for Muon optimizer #7953
Add Gram Newton-Schulz orthogonalization for Muon optimizer #7953
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4fe6a0b4cf
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " | ||
| f"Found {self.muon_beta} and {group_beta}.") | ||
| self.muon_beta = group_beta | ||
| self.muon_ns_method = param_group.get('ns_method', 'gram') |
There was a problem hiding this comment.
Preserve per-group ns_method in ZeRO-3 Muon updates
The ZeRO-3 setup stores ns_method in a single optimizer-wide field that is overwritten for each Muon param group, and _apply_distributed_muon_update later uses that single value for all Muon subgroups. If a user configures multiple use_muon=True groups with different ns_method values, earlier groups silently run with the last group's method, producing incorrect optimizer behavior and invalid experiment comparisons. This should either enforce one shared ns_method (like momentum) or track/apply ns_method per subgroup.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
ns_method actually is decided by ns_method field in json and cannot diverge.
Integrate Gram Newton-Schulz (Gram NS) as the default orthogonalization
method for Muon, with a configurable ns_method switch to fall back to
standard NS when needed (e.g., for debugging convergence issues).
Gram NS iterates on the small square Gram matrix R = X @ X.T (n x n)
instead of the full rectangular X (n x m), reducing FLOPs by ~50% for
typical transformer weight matrices (aspect ratio ~5). It uses fp16
instead of bf16 for better numerical precision at the same compute cost,
with a restart at iteration 2 for half-precision stability.
Benchmark results on A100:
- (2048, 11059): 2.25x GPU speedup, 1.85x CPU speedup
- (3584, 19353): 2.07x GPU speedup, 1.35x CPU speedup
- Falls back to standard NS for square matrices (no FLOP advantage)
Usage: set ns_method in DeepSpeed config:
{"optimizer": {"type": "muon", "params": {"ns_method": "gram"}}}
Use "standard" to disable Gram NS and revert to original behavior.
Reference: https://arxiv.org/abs/2503.02022
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Both NS functions now query the accelerator to choose compute dtype instead of hardcoding. Standard NS uses is_bf16_supported() to select bf16 vs fp32; Gram NS uses is_fp16_supported() to select fp16 vs fp32. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Gram Newton-Schulz produces non-contiguous tensors via .mT for tall weight matrices (e.g., gate_proj/up_proj in LLaMA). This caused downstream grad norm computation (g.data.double()) to be ~1.8x slower due to strided memory access, adding ~75ms to optimizer step time. Add .contiguous() to the Gram NS return path for tall matrices, and ensure muon_update casts back to the original gradient dtype (Gram NS uses fp16 internally while gradients are bf16). Benchmark (Qwen2.5-3B, 2xA100, ZeRO-2, 3 runs avg): Before fix: 945.1ms/step (optimizer: 229.9ms) After fix: 936.6ms/step (optimizer: 204.3ms) Standard NS baseline: 1054.5ms/step Gram NS speedup: 10.4% -> 11.2% Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Replace (Q @ X).mT.contiguous() with X.mT @ Q.mT which produces a contiguous result directly. cuBLAS handles transposed inputs natively via transpose flags, so the matmul cost is identical but the extra memcpy from .contiguous() is eliminated. Benchmark (Qwen2.5-3B, 2xA100, ZeRO-2, 3 runs avg): Before: 936.6ms/step (backward: 628.4ms) After: 931.5ms/step (backward: 612.8ms) Speedup vs standard NS: 11.2% -> 11.7% Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Replace separate scalar-multiply + matmul + add operations with single torch.addmm calls for Q and R updates, reducing kernel launch overhead. Remove torch.eye allocation by using diagonal().add_() instead. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>

Author: @delock and @PKUWZP
Summary
Integrate Gram Newton-Schulz (Gram NS) as the default orthogonalization method for the Muon optimizer, with a configurable
ns_methodswitch to fall back to the original iteration when needed.Based on the Gram Newton-Schulz method from https://tridao.me/blog/2026/gram-newton-schulz/
Motivation
Standard Newton-Schulz iterates on the full rectangular matrix X (n × m). Gram NS iterates on the much smaller Gram matrix R = X @ X.T (n × n), which is significantly cheaper when m >> n — the common case for transformer weight matrices (typical aspect ratio α ≈ 5).
Changes
zeropower_via_gram_newtonschulzinoriginal_muon.pywith fp16 compute (better precision than bf16 at the same cost)and a restart at iteration 2 for half-precision stability
ns_methodparameter ("gram"|"standard") tomuon_updateand all Muon optimizer classesns_methodthrough ZeRO Stage 1/2/3 call sites and DeepSpeed JSON configUsage
Performance improvement:
