Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def _configure_basic_optimizer(self, model_parameters):
param_groups = []
if muon_params:
accepted_parameters = dict()
for key in ["lr", "momentum", "weight_decay", "muon_lr"]:
for key in ["lr", "momentum", "weight_decay", "muon_lr", "ns_method"]:
if key in optimizer_parameters:
if key == "muon_lr": # muon_lr will override lr
accepted_parameters['lr'] = optimizer_parameters[key]
Expand Down
153 changes: 136 additions & 17 deletions deepspeed/runtime/zero/muon/original_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@

import torch
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
from deepspeed.runtime import compiler
from deepspeed.accelerator import get_accelerator


@compiler.compile()
def zeropower_via_newtonschulz5(G, steps: int):
def _zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
Expand All @@ -45,7 +44,8 @@ def zeropower_via_newtonschulz5(G, steps: int):
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
compute_dtype = torch.bfloat16 if get_accelerator().is_bf16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

Expand All @@ -62,14 +62,91 @@ def zeropower_via_newtonschulz5(G, steps: int):
return X


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
def _zeropower_via_gram_newtonschulz(G, steps: int):
"""
Gram Newton-Schulz iteration for orthogonalization.

Mathematically equivalent to standard Newton-Schulz but iterates on the
small square Gram matrix R = X @ X.T (n x n) instead of the full rectangular
X (n x m). This reduces FLOPs significantly when m >> n (typical for
transformer weight matrices with aspect ratio ~5).

Uses fp16 instead of bf16 for better numerical precision at the same
compute cost. Includes a restart at iteration 2 to maintain stability
in half-precision.

Falls back to standard Newton-Schulz for square matrices (n == m)
where there is no FLOP advantage.

Reference: https://tridao.me/blog/2026/gram-newton-schulz/
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
compute_dtype = torch.float16 if get_accelerator().is_fp16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

n, m = X.size(-2), X.size(-1)

X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# For square matrices, no FLOP advantage; use standard iteration
if m <= n:
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X

# Gram NS: iterate on R = X @ X.T (n x n) instead of X (n x m)
R = X @ X.mT
Q = None
restart_at = 2

for i in range(steps):
if i == restart_at and i != 0:
X = Q @ X
R = X @ X.mT
Q = None

Z = b * R + c * R @ R

if Q is None:
Q = Z.clone()
Q.diagonal().add_(a)
else:
Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Support batched tensors in Gram Newton-Schulz update

The new default ns_method="gram" regresses Muon's stated batched input support (grad.ndim >= 2): this path uses torch.addmm, which only accepts 2D inputs, so a Muon parameter with shape like (B, N, M) will now fail at runtime in _zeropower_via_gram_newtonschulz. Previously, the standard Newton-Schulz implementation used batched matmuls and handled these shapes, so this commit introduces a crash for valid prior inputs unless users manually switch to ns_method="standard".

Useful? React with 👍 / 👎.


if i < steps - 1 and (i + 1) != restart_at:
RZ = torch.addmm(R, Z, R, beta=a, alpha=1.0)
R = torch.addmm(RZ, Z, RZ, beta=a, alpha=1.0)

if G.size(-2) > G.size(-1):
X = X.mT @ Q.mT
else:
X = Q @ X
return X


NS_METHODS = {"standard", "gram"}


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"):
orig_dtype = grad.dtype
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
if ns_method == "gram":
update = _zeropower_via_gram_newtonschulz(update, steps=ns_steps)
else:
update = _zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
if update.dtype != orig_dtype:
update = update.to(orig_dtype)
return update


Expand All @@ -93,10 +170,12 @@ class Muon(torch.optim.Optimizer):
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
ns_method: Newton-Schulz method. "gram" (default) uses Gram NS for ~2x speedup
on rectangular matrices. "standard" uses the original iteration.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)
Expand All @@ -122,7 +201,12 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"),
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand All @@ -136,8 +220,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
Muon variant for usage in non-distributed settings.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
super().__init__(params, defaults)

@torch.no_grad()
Expand All @@ -156,7 +240,12 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"),
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])

Expand Down Expand Up @@ -208,7 +297,17 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group["ns_method"] in NS_METHODS, (
f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}")
assert set(group.keys()) == set([
"params",
"lr",
"momentum",
"weight_decay",
"use_muon",
"ns_method",
])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand Down Expand Up @@ -240,7 +339,12 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"),
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand Down Expand Up @@ -277,7 +381,17 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group["ns_method"] in NS_METHODS, (
f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}")
assert set(group.keys()) == set([
"params",
"lr",
"momentum",
"weight_decay",
"use_muon",
"ns_method",
])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand All @@ -304,7 +418,12 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"),
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
if self.use_muon:
self.sub_groups_using_muon = []
self.muon_beta = None
self.muon_ns_method = None
for idx, param_group in enumerate(fp16_param_groups):
if getattr(param_group['params'][0], 'use_muon', False):
self.sub_groups_using_muon.extend([True] * len(param_groups[idx]))
Expand All @@ -767,6 +768,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve ns_method per Muon param group in ZeRO-3

ZeRO-3 stores ns_method in a single self.muon_ns_method while iterating all Muon param groups, so later groups overwrite earlier values. _muon_update_grads_in_place then applies that one method to every Muon subgroup, which silently ignores per-group configuration when users provide multiple Muon groups (a pattern already handled for momentum via explicit consistency checks).

Useful? React with 👍 / 👎.

else:
self.sub_groups_using_muon.extend([False] * len(param_groups[idx]))
# bookkeeping related to param groups
Expand Down Expand Up @@ -1515,7 +1517,7 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
param = params[base_i + rank]
g = param.grad
m = gathered_momentums_pad[base_i + rank]
update = muon_update(g, m, beta=self.muon_beta)
update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram'))
g.data.copy_(update, non_blocking=False)
grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz],
grads_pad[base_i + rank],
Expand Down
82 changes: 80 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,79 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
return torch.tensor(total_norm, device=self.device, dtype=torch.float)

############################################################################################
def _apply_muon_update_for_cpu_offload(self, param):
"""Apply muon_update for a parameter in the CPU offload path.

For Muon parameters (use_muon=True), runs Newton-Schulz
orthogonalization on GPU (momentum is temporarily copied from
CPU to GPU) and writes only the partition slice back to the
CPU FP32 grad buffer. Cross-boundary parameters are
redundantly processed by each involved rank with the full
gradient, matching the non-offload path behavior in
get_flat_partition.

Returns True if muon_update was applied (caller should skip
the normal copy for this param).
"""
if not getattr(param, 'use_muon', False):
return False
if 'muon' not in self.optimizer.__class__.__name__.lower():
return False

param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
return False

flatten_copy = self.optimizer.param_groups[i]['params'][0]
if "momentum_buffer" not in self.optimizer.state[flatten_copy]:
total_size = sum(p.numel() for p in self.params_in_partition[i])
self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size,
dtype=torch.float32,
device=self.device)

momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"]

muon_offset = 0
for p in self.params_in_partition[i]:
if p is param:
break
muon_offset += p.numel()

momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size())

beta = self.optimizer.param_groups[i].get('momentum', 0.95)
ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram')

# Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU
gpu_device = grad_accum.device
grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32)
momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32)
update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method)
momentum_cpu.copy_(momentum_gpu.to(device='cpu'))
update_cpu = update.to(device='cpu')
del grad_gpu, momentum_gpu

momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1)

# Write only the partition slice of the update to CPU FP32 grad buffer
tensor_offset = 0
actual_num_elements = param.numel()
if source_offset > 0:
tensor_offset = source_offset
actual_num_elements = param.numel() - tensor_offset
if actual_num_elements > num_elements:
actual_num_elements = num_elements

dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements)
update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements)
dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype))

self.clear_grad_attribute(param)
return True

def copy_grads_in_partition(self, param):
if self.cpu_offload:

Expand All @@ -1513,7 +1586,8 @@ def copy_grads_in_partition(self, param):

self.update_offload_overflow_tracker_for_param_grad(param)

self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
if not self._apply_muon_update_for_cpu_offload(param):
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
Comment on lines 1587 to +1590
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Track Muon overflow after CPU-offload update

In the CPU-offload path, overflow tracking is performed before _apply_muon_update_for_cpu_offload(param), but has_overflow() for offload relies only on self.local_overflow. Because no post-muon_update inf/nan check is done, numerical failures introduced by the Newton-Schulz step can bypass loss-scaling overflow handling and still be applied to optimizer state. This diverges from the non-offload path, where overflow is checked on the Muon-transformed gradients.

Useful? React with 👍 / 👎.


return
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
Expand Down Expand Up @@ -1996,7 +2070,11 @@ def get_flat_partition(self,
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size())
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram')
grad_accum = muon_update(grad_accum,
buffer,
self.optimizer.param_groups[param_group_idx]['momentum'],
ns_method=ns_method)
tensor = grad_accum
num_elements = tensor.numel()
buffer_idx += num_elements
Expand Down
Loading
Loading