diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5fc07b3a7238..1fc71f1a8e9e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1754,7 +1754,7 @@ def _configure_basic_optimizer(self, model_parameters): param_groups = [] if muon_params: accepted_parameters = dict() - for key in ["lr", "momentum", "weight_decay", "muon_lr"]: + for key in ["lr", "momentum", "weight_decay", "muon_lr", "ns_method"]: if key in optimizer_parameters: if key == "muon_lr": # muon_lr will override lr accepted_parameters['lr'] = optimizer_parameters[key] diff --git a/deepspeed/runtime/zero/muon/original_muon.py b/deepspeed/runtime/zero/muon/original_muon.py index f4dc7a0909bb..e4f4e955d22e 100644 --- a/deepspeed/runtime/zero/muon/original_muon.py +++ b/deepspeed/runtime/zero/muon/original_muon.py @@ -29,11 +29,10 @@ import torch import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check -from deepspeed.runtime import compiler +from deepspeed.accelerator import get_accelerator -@compiler.compile() -def zeropower_via_newtonschulz5(G, steps: int): +def _zeropower_via_newtonschulz5(G, steps: int): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -45,7 +44,8 @@ def zeropower_via_newtonschulz5(G, steps: int): """ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() + compute_dtype = torch.bfloat16 if get_accelerator().is_bf16_supported() else torch.float32 + X = G.to(compute_dtype) if G.size(-2) > G.size(-1): X = X.mT @@ -62,14 +62,91 @@ def zeropower_via_newtonschulz5(G, steps: int): return X -@compiler.compile() -def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): +def _zeropower_via_gram_newtonschulz(G, steps: int): + """ + Gram Newton-Schulz iteration for orthogonalization. + + Mathematically equivalent to standard Newton-Schulz but iterates on the + small square Gram matrix R = X @ X.T (n x n) instead of the full rectangular + X (n x m). This reduces FLOPs significantly when m >> n (typical for + transformer weight matrices with aspect ratio ~5). + + Uses fp16 instead of bf16 for better numerical precision at the same + compute cost. Includes a restart at iteration 2 to maintain stability + in half-precision. + + Falls back to standard Newton-Schulz for square matrices (n == m) + where there is no FLOP advantage. + + Reference: https://tridao.me/blog/2026/gram-newton-schulz/ + """ + assert G.ndim >= 2 + a, b, c = (3.4445, -4.7750, 2.0315) + compute_dtype = torch.float16 if get_accelerator().is_fp16_supported() else torch.float32 + X = G.to(compute_dtype) + if G.size(-2) > G.size(-1): + X = X.mT + + n, m = X.size(-2), X.size(-1) + + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # For square matrices, no FLOP advantage; use standard iteration + if m <= n: + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(-2) > G.size(-1): + X = X.mT + return X + + # Gram NS: iterate on R = X @ X.T (n x n) instead of X (n x m) + R = X @ X.mT + Q = None + restart_at = 2 + + for i in range(steps): + if i == restart_at and i != 0: + X = Q @ X + R = X @ X.mT + Q = None + + Z = b * R + c * R @ R + + if Q is None: + Q = Z.clone() + Q.diagonal().add_(a) + else: + Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0) + + if i < steps - 1 and (i + 1) != restart_at: + RZ = torch.addmm(R, Z, R, beta=a, alpha=1.0) + R = torch.addmm(RZ, Z, RZ, beta=a, alpha=1.0) + + if G.size(-2) > G.size(-1): + X = X.mT @ Q.mT + else: + X = Q @ X + return X + + +NS_METHODS = {"standard", "gram"} + + +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"): + orig_dtype = grad.dtype momentum.lerp_(grad, 1 - beta) update = grad.lerp_(momentum, beta) if nesterov else momentum if update.ndim == 4: # for the case of conv filters update = update.view(len(update), -1) - update = zeropower_via_newtonschulz5(update, steps=ns_steps) + if ns_method == "gram": + update = _zeropower_via_gram_newtonschulz(update, steps=ns_steps) + else: + update = _zeropower_via_newtonschulz5(update, steps=ns_steps) update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + if update.dtype != orig_dtype: + update = update.to(orig_dtype) return update @@ -93,10 +170,12 @@ class Muon(torch.optim.Optimizer): lr: The learning rate, in units of spectral norm per update. weight_decay: The AdamW-style weight decay. momentum: The momentum. A value of 0.95 here is usually fine. + ns_method: Newton-Schulz method. "gram" (default) uses Gram NS for ~2x speedup + on rectangular matrices. "standard" uses the original iteration. """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method) assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) params = sorted(params, key=lambda x: x.size(), reverse=True) super().__init__(params, defaults) @@ -122,7 +201,12 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram"), + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], @@ -136,8 +220,8 @@ class SingleDeviceMuon(torch.optim.Optimizer): Muon variant for usage in non-distributed settings. """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method) super().__init__(params, defaults) @torch.no_grad() @@ -156,7 +240,12 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram"), + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) @@ -208,7 +297,17 @@ def __init__(self, param_groups): group["lr"] = group.get("lr", 0.02) group["momentum"] = group.get("momentum", 0.95) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + group["ns_method"] = group.get("ns_method", "gram") + assert group["ns_method"] in NS_METHODS, ( + f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}") + assert set(group.keys()) == set([ + "params", + "lr", + "momentum", + "weight_decay", + "use_muon", + "ns_method", + ]) else: # defaults group["lr"] = group.get("lr", 3e-4) @@ -240,7 +339,12 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram"), + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], @@ -277,7 +381,17 @@ def __init__(self, param_groups): group["lr"] = group.get("lr", 0.02) group["momentum"] = group.get("momentum", 0.95) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + group["ns_method"] = group.get("ns_method", "gram") + assert group["ns_method"] in NS_METHODS, ( + f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}") + assert set(group.keys()) == set([ + "params", + "lr", + "momentum", + "weight_decay", + "use_muon", + "ns_method", + ]) else: # defaults group["lr"] = group.get("lr", 3e-4) @@ -304,7 +418,12 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram"), + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) else: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8f28ee4f8685..3175e3165b11 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -759,6 +759,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): if self.use_muon: self.sub_groups_using_muon = [] self.muon_beta = None + self.muon_ns_method = None for idx, param_group in enumerate(fp16_param_groups): if getattr(param_group['params'][0], 'use_muon', False): self.sub_groups_using_muon.extend([True] * len(param_groups[idx])) @@ -767,6 +768,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " f"Found {self.muon_beta} and {group_beta}.") self.muon_beta = group_beta + self.muon_ns_method = param_group.get('ns_method', 'gram') else: self.sub_groups_using_muon.extend([False] * len(param_groups[idx])) # bookkeeping related to param groups @@ -1515,7 +1517,7 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b param = params[base_i + rank] g = param.grad m = gathered_momentums_pad[base_i + rank] - update = muon_update(g, m, beta=self.muon_beta) + update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram')) g.data.copy_(update, non_blocking=False) grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], grads_pad[base_i + rank], diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0bfb18877f2d..a702e25fdbc5 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1502,6 +1502,79 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): return torch.tensor(total_norm, device=self.device, dtype=torch.float) ############################################################################################ + def _apply_muon_update_for_cpu_offload(self, param): + """Apply muon_update for a parameter in the CPU offload path. + + For Muon parameters (use_muon=True), runs Newton-Schulz + orthogonalization on GPU (momentum is temporarily copied from + CPU to GPU) and writes only the partition slice back to the + CPU FP32 grad buffer. Cross-boundary parameters are + redundantly processed by each involved rank with the full + gradient, matching the non-offload path behavior in + get_flat_partition. + + Returns True if muon_update was applied (caller should skip + the normal copy for this param). + """ + if not getattr(param, 'use_muon', False): + return False + if 'muon' not in self.optimizer.__class__.__name__.lower(): + return False + + param_id = self.get_param_id(param) + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + return False + + flatten_copy = self.optimizer.param_groups[i]['params'][0] + if "momentum_buffer" not in self.optimizer.state[flatten_copy]: + total_size = sum(p.numel() for p in self.params_in_partition[i]) + self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size, + dtype=torch.float32, + device=self.device) + + momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"] + + muon_offset = 0 + for p in self.params_in_partition[i]: + if p is param: + break + muon_offset += p.numel() + + momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size()) + + beta = self.optimizer.param_groups[i].get('momentum', 0.95) + ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram') + + # Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU + gpu_device = grad_accum.device + grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32) + momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32) + update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method) + momentum_cpu.copy_(momentum_gpu.to(device='cpu')) + update_cpu = update.to(device='cpu') + del grad_gpu, momentum_gpu + + momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1) + + # Write only the partition slice of the update to CPU FP32 grad buffer + tensor_offset = 0 + actual_num_elements = param.numel() + if source_offset > 0: + tensor_offset = source_offset + actual_num_elements = param.numel() - tensor_offset + if actual_num_elements > num_elements: + actual_num_elements = num_elements + + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements) + update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements) + dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype)) + + self.clear_grad_attribute(param) + return True + def copy_grads_in_partition(self, param): if self.cpu_offload: @@ -1513,7 +1586,8 @@ def copy_grads_in_partition(self, param): self.update_offload_overflow_tracker_for_param_grad(param) - self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) + if not self._apply_muon_update_for_cpu_offload(param): + self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) return #print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}") @@ -1996,7 +2070,11 @@ def get_flat_partition(self, assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, tensor.numel()).view(tensor.size()) - grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum']) + ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram') + grad_accum = muon_update(grad_accum, + buffer, + self.optimizer.param_groups[param_group_idx]['momentum'], + ns_method=ns_method) tensor = grad_accum num_elements = tensor.numel() buffer_idx += num_elements diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index f8209c8d8068..c80be6342cc4 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -39,7 +39,18 @@ toc_label: "Contents" | type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | -Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe. +Muon optimizer is supported with ZeRO Stage 1, 2, and 3, including CPU offload (`offload_optimizer`) for all stages. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe. + +Muon supports the following params: + +| "params" key | Description | Default | +| -------------- | -------------------------------------------------------------------------------------------------------------------- | --------- | +| lr | Learning rate for all parameters. Overridden by `muon_lr` / `adam_lr` if set. | 0.001 | +| momentum | Momentum coefficient for the Muon update. | 0.95 | +| weight\_decay | Weight decay (AdamW-style). | 0.0 | +| muon\_lr | Learning rate override for Muon parameters. Defaults to `lr` if not set. | - | +| adam\_lr | Learning rate override for non-Muon (Adam) parameters. Defaults to `lr` if not set. | - | +| ns\_method | Newton-Schulz orthogonalization method: `"gram"` for Gram NS (~2x faster on rectangular matrices), `"standard"` for the original iteration. Use `"standard"` to fall back if you encounter convergence issues. | `"gram"` | Example of **optimizer** with Adam @@ -73,7 +84,8 @@ If not set, muon_lr will default to lr. "lr": 0.001, "momentum": 0.9, "weight_decay": 0.0, - "muon_lr": 0.001 + "muon_lr": 0.001, + "ns_method": "gram" } }, "zero_optimization": { diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index 02594941cef0..84b06dd96265 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -86,3 +86,94 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optim after_training = [p.clone().cpu() for p in model.parameters()] for initial, final in zip(initial_params, after_training): assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training" + + +class TestGramNewtonSchulz(DistributedTest): + """Test Gram Newton-Schulz integration with Muon optimizer.""" + + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('ns_method', ['gram', 'standard']) + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_ns_method_training(self, ns_method, zero_stage): + """Verify both ns_method values work end-to-end with DeepSpeed.""" + hidden_dim = 64 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01, + "ns_method": ns_method, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=3) + initial_params = [p.clone().cpu() for p in model.parameters()] + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + after_training = [p.clone().cpu() for p in model.parameters()] + for initial, final in zip(initial_params, after_training): + assert not torch.equal(initial, final), "Parameters should have been updated" + + @pytest.mark.parametrize('ns_method', ['gram', 'standard']) + def test_ns_method_stage3(self, ns_method): + """Verify ns_method works with ZeRO Stage 3.""" + hidden_dim = 64 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01, + "ns_method": ns_method, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 3, + "reduce_scatter": False, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=3) + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() diff --git a/tests/unit/ops/muon/test_muon_cpu_offload.py b/tests/unit/ops/muon/test_muon_cpu_offload.py new file mode 100644 index 000000000000..083de623d2f9 --- /dev/null +++ b/tests/unit/ops/muon/test_muon_cpu_offload.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator + +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported", allow_module_level=True) + + +@pytest.mark.parametrize('zero_stage', [2]) +class TestMuonCPUOffload(DistributedTest): + + def test_momentum_buffer_on_cpu(self, zero_stage): + """Verify Muon CPU offload creates momentum buffer on CPU. + + This is the key invariant: after a training step with CPU offload, + the Muon momentum buffer must reside on CPU (not GPU), confirming + that muon_update ran on CPU and no GPU memory is wasted. + """ + hidden_dim = 32 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + }, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + # Muon momentum buffer must exist and be on CPU. + # If muon_update was silently skipped, momentum_buffer would not be created. + flatten_copy = optimizer.optimizer.param_groups[0]['params'][0] + state = optimizer.optimizer.state[flatten_copy] + assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. " + "muon_update was not called in the CPU offload path.") + assert state['momentum_buffer'].device.type == 'cpu', ( + f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU") + + +@pytest.mark.parametrize('zero_stage', [2]) +class TestMuonCPUOffloadCosim(DistributedTest): + + def test_cosim_offload_vs_no_offload(self, zero_stage): + """Verify CPU offload produces results consistent with GPU path. + + With the same random seed, offload and non-offload should produce + close parameters. If muon_update is skipped or wrong in either path, + the results diverge significantly. + """ + hidden_dim = 32 + batch_size = 8 + + def train(offload): + torch.manual_seed(42) + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + }, + } + if offload: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) + engine, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()} + + params_offload = train(offload=True) + params_no_offload = train(offload=False) + + for name in params_offload: + p_off = params_offload[name] + p_no = params_no_offload[name] + # Both paths should produce the same NaN pattern + nan_mask = p_off.isnan() | p_no.isnan() + assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. " + "muon_update produced different results.") + # On non-NaN elements, cosine similarity should be very high + valid = ~nan_mask + if valid.sum() > 0: + cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0), + p_no[valid].unsqueeze(0)).item() + assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and " + f"non-offload is too low, indicating muon_update results diverge.")