-
Notifications
You must be signed in to change notification settings - Fork 4.8k
feat(zero2): add CPU offload support for Muon optimizer #7939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -759,6 +759,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): | |
| if self.use_muon: | ||
| self.sub_groups_using_muon = [] | ||
| self.muon_beta = None | ||
| self.muon_ns_method = None | ||
| for idx, param_group in enumerate(fp16_param_groups): | ||
| if getattr(param_group['params'][0], 'use_muon', False): | ||
| self.sub_groups_using_muon.extend([True] * len(param_groups[idx])) | ||
|
|
@@ -767,6 +768,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): | |
| raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " | ||
| f"Found {self.muon_beta} and {group_beta}.") | ||
| self.muon_beta = group_beta | ||
| self.muon_ns_method = param_group.get('ns_method', 'gram') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
ZeRO-3 stores Useful? React with 👍 / 👎. |
||
| else: | ||
| self.sub_groups_using_muon.extend([False] * len(param_groups[idx])) | ||
| # bookkeeping related to param groups | ||
|
|
@@ -1515,7 +1517,7 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b | |
| param = params[base_i + rank] | ||
| g = param.grad | ||
| m = gathered_momentums_pad[base_i + rank] | ||
| update = muon_update(g, m, beta=self.muon_beta) | ||
| update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram')) | ||
| g.data.copy_(update, non_blocking=False) | ||
| grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], | ||
| grads_pad[base_i + rank], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1502,6 +1502,79 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): | |
| return torch.tensor(total_norm, device=self.device, dtype=torch.float) | ||
|
|
||
| ############################################################################################ | ||
| def _apply_muon_update_for_cpu_offload(self, param): | ||
| """Apply muon_update for a parameter in the CPU offload path. | ||
|
|
||
| For Muon parameters (use_muon=True), runs Newton-Schulz | ||
| orthogonalization on GPU (momentum is temporarily copied from | ||
| CPU to GPU) and writes only the partition slice back to the | ||
| CPU FP32 grad buffer. Cross-boundary parameters are | ||
| redundantly processed by each involved rank with the full | ||
| gradient, matching the non-offload path behavior in | ||
| get_flat_partition. | ||
|
|
||
| Returns True if muon_update was applied (caller should skip | ||
| the normal copy for this param). | ||
| """ | ||
| if not getattr(param, 'use_muon', False): | ||
| return False | ||
| if 'muon' not in self.optimizer.__class__.__name__.lower(): | ||
| return False | ||
|
|
||
| param_id = self.get_param_id(param) | ||
| [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] | ||
|
|
||
| grad_accum = self.get_param_gradient_attribute(param) | ||
| if grad_accum is None: | ||
| return False | ||
|
|
||
| flatten_copy = self.optimizer.param_groups[i]['params'][0] | ||
| if "momentum_buffer" not in self.optimizer.state[flatten_copy]: | ||
| total_size = sum(p.numel() for p in self.params_in_partition[i]) | ||
| self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size, | ||
| dtype=torch.float32, | ||
| device=self.device) | ||
|
|
||
| momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"] | ||
|
|
||
| muon_offset = 0 | ||
| for p in self.params_in_partition[i]: | ||
| if p is param: | ||
| break | ||
| muon_offset += p.numel() | ||
|
|
||
| momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size()) | ||
|
|
||
| beta = self.optimizer.param_groups[i].get('momentum', 0.95) | ||
| ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram') | ||
|
|
||
| # Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU | ||
| gpu_device = grad_accum.device | ||
| grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32) | ||
| momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32) | ||
| update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method) | ||
| momentum_cpu.copy_(momentum_gpu.to(device='cpu')) | ||
| update_cpu = update.to(device='cpu') | ||
| del grad_gpu, momentum_gpu | ||
|
|
||
| momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1) | ||
|
|
||
| # Write only the partition slice of the update to CPU FP32 grad buffer | ||
| tensor_offset = 0 | ||
| actual_num_elements = param.numel() | ||
| if source_offset > 0: | ||
| tensor_offset = source_offset | ||
| actual_num_elements = param.numel() - tensor_offset | ||
| if actual_num_elements > num_elements: | ||
| actual_num_elements = num_elements | ||
|
|
||
| dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements) | ||
| update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements) | ||
| dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype)) | ||
|
|
||
| self.clear_grad_attribute(param) | ||
| return True | ||
|
|
||
| def copy_grads_in_partition(self, param): | ||
| if self.cpu_offload: | ||
|
|
||
|
|
@@ -1513,7 +1586,8 @@ def copy_grads_in_partition(self, param): | |
|
|
||
| self.update_offload_overflow_tracker_for_param_grad(param) | ||
|
|
||
| self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) | ||
| if not self._apply_muon_update_for_cpu_offload(param): | ||
| self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) | ||
|
Comment on lines
1587
to
+1590
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the CPU-offload path, overflow tracking is performed before Useful? React with 👍 / 👎. |
||
|
|
||
| return | ||
| #print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}") | ||
|
|
@@ -1996,7 +2070,11 @@ def get_flat_partition(self, | |
| assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" | ||
| buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, | ||
| tensor.numel()).view(tensor.size()) | ||
| grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum']) | ||
| ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram') | ||
| grad_accum = muon_update(grad_accum, | ||
| buffer, | ||
| self.optimizer.param_groups[param_group_idx]['momentum'], | ||
| ns_method=ns_method) | ||
| tensor = grad_accum | ||
| num_elements = tensor.numel() | ||
| buffer_idx += num_elements | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new default
ns_method="gram"regresses Muon's stated batched input support (grad.ndim >= 2): this path usestorch.addmm, which only accepts 2D inputs, so a Muon parameter with shape like(B, N, M)will now fail at runtime in_zeropower_via_gram_newtonschulz. Previously, the standard Newton-Schulz implementation used batched matmuls and handled these shapes, so this commit introduces a crash for valid prior inputs unless users manually switch tons_method="standard".Useful? React with 👍 / 👎.