Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def _configure_basic_optimizer(self, model_parameters):
param_groups = []
if muon_params:
accepted_parameters = dict()
for key in ["lr", "momentum", "weight_decay", "muon_lr"]:
for key in ["lr", "momentum", "weight_decay", "muon_lr", "ns_method"]:
if key in optimizer_parameters:
if key == "muon_lr": # muon_lr will override lr
accepted_parameters['lr'] = optimizer_parameters[key]
Expand Down
129 changes: 116 additions & 13 deletions deepspeed/runtime/zero/muon/original_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
from deepspeed.runtime import compiler
from deepspeed.accelerator import get_accelerator


@compiler.compile()
Expand All @@ -45,7 +46,9 @@ def zeropower_via_newtonschulz5(G, steps: int):
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
# Use bf16 when hardware supports it; fp32 otherwise
compute_dtype = torch.bfloat16 if get_accelerator().is_bf16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

Expand All @@ -63,13 +66,93 @@ def zeropower_via_newtonschulz5(G, steps: int):


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
def zeropower_via_gram_newtonschulz(G, steps: int):
"""
Gram Newton-Schulz iteration for orthogonalization.

Mathematically equivalent to standard Newton-Schulz but iterates on the
small square Gram matrix R = X @ X.T (n x n) instead of the full rectangular
X (n x m). This reduces FLOPs significantly when m >> n (typical for
transformer weight matrices with aspect ratio ~5).

Uses fp16 instead of bf16 for better numerical precision at the same
compute cost. Includes a restart at iteration 2 to maintain stability
in half-precision.

Falls back to standard Newton-Schulz for square matrices (n == m)
where there is no FLOP advantage.

Reference: https://tridao.me/blog/2026/gram-newton-schulz/
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
# Use fp16 for better precision than bf16 when hardware supports it; fp32 otherwise
compute_dtype = torch.float16 if get_accelerator().is_fp16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

n, m = X.size(-2), X.size(-1)

X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# For square matrices, no FLOP advantage; use standard iteration
if m <= n:
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X

# Gram NS: iterate on R = X @ X.T (n x n) instead of X (n x m)
R = X @ X.mT
Q = None
restart_at = 2

for i in range(steps):
if i == restart_at and i != 0:
X = Q @ X
R = X @ X.mT
Q = None

Z = b * R + c * R @ R

if Q is None:
Q = Z.clone()
Q.diagonal().add_(a)
else:
Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0)

if i < steps - 1 and (i + 1) != restart_at:
RZ = torch.addmm(R, Z, R, beta=a, alpha=1.0)
R = torch.addmm(RZ, Z, RZ, beta=a, alpha=1.0)

if G.size(-2) > G.size(-1):
X = X.mT @ Q.mT
else:
X = Q @ X
return X


NS_METHODS = {"standard", "gram"}


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"):
orig_dtype = grad.dtype
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
if ns_method == "gram":
update = zeropower_via_gram_newtonschulz(update, steps=ns_steps)
else:
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
if update.dtype != orig_dtype:
update = update.to(orig_dtype)
return update


Expand All @@ -93,10 +176,12 @@ class Muon(torch.optim.Optimizer):
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
ns_method: Newton-Schulz method. "gram" (default) uses Gram NS for ~2x speedup
on rectangular matrices. "standard" uses the original iteration.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)
Expand All @@ -122,7 +207,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand All @@ -136,8 +224,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
Muon variant for usage in non-distributed settings.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
super().__init__(params, defaults)

@torch.no_grad()
Expand All @@ -156,7 +244,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])

Expand Down Expand Up @@ -208,7 +299,10 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group[
"ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}"
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand Down Expand Up @@ -240,7 +334,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand Down Expand Up @@ -277,7 +374,10 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group[
"ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}"
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand All @@ -304,7 +404,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
if self.use_muon:
self.sub_groups_using_muon = []
self.muon_beta = None
self.muon_ns_method = None
for idx, param_group in enumerate(fp16_param_groups):
if getattr(param_group['params'][0], 'use_muon', False):
self.sub_groups_using_muon.extend([True] * len(param_groups[idx]))
Expand All @@ -799,6 +800,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve per-group ns_method in ZeRO-3 Muon updates

The ZeRO-3 setup stores ns_method in a single optimizer-wide field that is overwritten for each Muon param group, and _apply_distributed_muon_update later uses that single value for all Muon subgroups. If a user configures multiple use_muon=True groups with different ns_method values, earlier groups silently run with the last group's method, producing incorrect optimizer behavior and invalid experiment comparisons. This should either enforce one shared ns_method (like momentum) or track/apply ns_method per subgroup.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ns_method actually is decided by ns_method field in json and cannot diverge.

else:
self.sub_groups_using_muon.extend([False] * len(param_groups[idx]))
# bookkeeping related to param groups
Expand Down Expand Up @@ -1547,7 +1549,7 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
param = params[base_i + rank]
g = param.grad
m = gathered_momentums_pad[base_i + rank]
update = muon_update(g, m, beta=self.muon_beta)
update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram'))
g.data.copy_(update, non_blocking=False)
grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz],
grads_pad[base_i + rank],
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,11 @@ def get_flat_partition(self,
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size())
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram')
grad_accum = muon_update(grad_accum,
buffer,
self.optimizer.param_groups[param_group_idx]['momentum'],
ns_method=ns_method)
tensor = grad_accum
num_elements = tensor.numel()
buffer_idx += num_elements
Expand Down
14 changes: 13 additions & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ toc_label: "Contents"

Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.

Muon supports the following params:

| "params" key | Description | Default |
| -------------- | -------------------------------------------------------------------------------------------------------------------- | --------- |
| lr | Learning rate for all parameters. Overridden by `muon_lr` / `adam_lr` if set. | 0.001 |
| momentum | Momentum coefficient for the Muon update. | 0.95 |
| weight\_decay | Weight decay (AdamW-style). | 0.0 |
| muon\_lr | Learning rate override for Muon parameters. Defaults to `lr` if not set. | - |
| adam\_lr | Learning rate override for non-Muon (Adam) parameters. Defaults to `lr` if not set. | - |
| ns\_method | Newton-Schulz orthogonalization method: `"gram"` for Gram NS (~2x faster on rectangular matrices), `"standard"` for the original iteration. Use `"standard"` to fall back if you encounter convergence issues. | `"gram"` |

Example of <i>**optimizer**</i> with Adam

```json
Expand Down Expand Up @@ -73,7 +84,8 @@ If not set, muon_lr will default to lr.
"lr": 0.001,
"momentum": 0.9,
"weight_decay": 0.0,
"muon_lr": 0.001
"muon_lr": 0.001,
"ns_method": "gram"
}
},
"zero_optimization": {
Expand Down
91 changes: 91 additions & 0 deletions tests/unit/ops/muon/test_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,94 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optim
after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training"


class TestGramNewtonSchulz(DistributedTest):
"""Test Gram Newton-Schulz integration with Muon optimizer."""

world_size = 2
reuse_dist_env = True

@pytest.mark.parametrize('ns_method', ['gram', 'standard'])
@pytest.mark.parametrize('zero_stage', [1, 2])
def test_ns_method_training(self, ns_method, zero_stage):
"""Verify both ns_method values work end-to-end with DeepSpeed."""
hidden_dim = 64
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01,
"ns_method": ns_method,
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
},
"zero_optimization": {
"stage": zero_stage,
"reduce_scatter": False,
},
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=3)
initial_params = [p.clone().cpu() for p in model.parameters()]
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()

after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial, final), "Parameters should have been updated"

@pytest.mark.parametrize('ns_method', ['gram', 'standard'])
def test_ns_method_stage3(self, ns_method):
"""Verify ns_method works with ZeRO Stage 3."""
hidden_dim = 64
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01,
"ns_method": ns_method,
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
},
"zero_optimization": {
"stage": 3,
"reduce_scatter": False,
},
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=3)
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()
Loading