Skip to content

Add Gram Newton-Schulz orthogonalization for Muon optimizer #7953

Open
delock wants to merge 7 commits intomasterfrom
gma/gram_muon
Open

Add Gram Newton-Schulz orthogonalization for Muon optimizer #7953
delock wants to merge 7 commits intomasterfrom
gma/gram_muon

Conversation

@delock
Copy link
Copy Markdown
Collaborator

@delock delock commented Apr 3, 2026

Author: @delock and @PKUWZP

Summary

Integrate Gram Newton-Schulz (Gram NS) as the default orthogonalization method for the Muon optimizer, with a configurable ns_method switch to fall back to the original iteration when needed.

Based on the Gram Newton-Schulz method from https://tridao.me/blog/2026/gram-newton-schulz/

Motivation

Standard Newton-Schulz iterates on the full rectangular matrix X (n × m). Gram NS iterates on the much smaller Gram matrix R = X @ X.T (n × n), which is significantly cheaper when m >> n — the common case for transformer weight matrices (typical aspect ratio α ≈ 5).

Changes

  • Add zeropower_via_gram_newtonschulz in original_muon.py with fp16 compute (better precision than bf16 at the same cost)
    and a restart at iteration 2 for half-precision stability
  • Add ns_method parameter ("gram" | "standard") to muon_update and all Muon optimizer classes
  • Thread ns_method through ZeRO Stage 1/2/3 call sites and DeepSpeed JSON config
  • Automatic fallback to standard NS for square matrices (m ≤ n) where Gram NS has no FLOP advantage
  • Documentation and unit tests for both methods across ZeRO Stage 1, 2, and 3

Usage

"optimizer": {                                                                                                               
    "type": "Muon",                                                                                                          
    "params": {                                                                                                              
        "ns_method": "gram"                                                                                                  
    }                                                                                                                        
}                                                                                                                            
                                                                                                                             
Set "ns_method": "standard" to disable Gram NS and revert to original behavior (e.g., for debugging convergence issues).     

Performance improvement:
image

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4fe6a0b4cf

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve per-group ns_method in ZeRO-3 Muon updates

The ZeRO-3 setup stores ns_method in a single optimizer-wide field that is overwritten for each Muon param group, and _apply_distributed_muon_update later uses that single value for all Muon subgroups. If a user configures multiple use_muon=True groups with different ns_method values, earlier groups silently run with the last group's method, producing incorrect optimizer behavior and invalid experiment comparisons. This should either enforce one shared ns_method (like momentum) or track/apply ns_method per subgroup.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ns_method actually is decided by ns_method field in json and cannot diverge.

delock added 7 commits April 2, 2026 23:58
Integrate Gram Newton-Schulz (Gram NS) as the default orthogonalization
method for Muon, with a configurable ns_method switch to fall back to
standard NS when needed (e.g., for debugging convergence issues).

Gram NS iterates on the small square Gram matrix R = X @ X.T (n x n)
instead of the full rectangular X (n x m), reducing FLOPs by ~50% for
typical transformer weight matrices (aspect ratio ~5). It uses fp16
instead of bf16 for better numerical precision at the same compute cost,
with a restart at iteration 2 for half-precision stability.

Benchmark results on A100:
- (2048, 11059): 2.25x GPU speedup, 1.85x CPU speedup
- (3584, 19353): 2.07x GPU speedup, 1.35x CPU speedup
- Falls back to standard NS for square matrices (no FLOP advantage)

Usage: set ns_method in DeepSpeed config:
  {"optimizer": {"type": "muon", "params": {"ns_method": "gram"}}}
  Use "standard" to disable Gram NS and revert to original behavior.

Reference: https://arxiv.org/abs/2503.02022
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Both NS functions now query the accelerator to choose compute dtype
instead of hardcoding. Standard NS uses is_bf16_supported() to select
bf16 vs fp32; Gram NS uses is_fp16_supported() to select fp16 vs fp32.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Gram Newton-Schulz produces non-contiguous tensors via .mT for tall
weight matrices (e.g., gate_proj/up_proj in LLaMA). This caused
downstream grad norm computation (g.data.double()) to be ~1.8x slower
due to strided memory access, adding ~75ms to optimizer step time.

Add .contiguous() to the Gram NS return path for tall matrices, and
ensure muon_update casts back to the original gradient dtype (Gram NS
uses fp16 internally while gradients are bf16).

Benchmark (Qwen2.5-3B, 2xA100, ZeRO-2, 3 runs avg):
  Before fix: 945.1ms/step (optimizer: 229.9ms)
  After fix:  936.6ms/step (optimizer: 204.3ms)
  Standard NS baseline: 1054.5ms/step
  Gram NS speedup: 10.4% -> 11.2%

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Replace (Q @ X).mT.contiguous() with X.mT @ Q.mT which produces a
contiguous result directly. cuBLAS handles transposed inputs natively
via transpose flags, so the matmul cost is identical but the extra
memcpy from .contiguous() is eliminated.

Benchmark (Qwen2.5-3B, 2xA100, ZeRO-2, 3 runs avg):
  Before: 936.6ms/step (backward: 628.4ms)
  After:  931.5ms/step (backward: 612.8ms)
  Speedup vs standard NS: 11.2% -> 11.7%

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Replace separate scalar-multiply + matmul + add operations with single
torch.addmm calls for Q and R updates, reducing kernel launch overhead.
Remove torch.eye allocation by using diagonal().add_() instead.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
@PKUWZP PKUWZP self-requested a review April 3, 2026 17:45
@delock
Copy link
Copy Markdown
Collaborator Author

delock commented Apr 8, 2026

Convergence is almost identical between standard NS and gram NS, using same learning rate
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant