-
Notifications
You must be signed in to change notification settings - Fork 343
feat: Add pinned memory optimizer offload for Megatron policy worker #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1255,22 +1255,92 @@ def move_optimizer(self, device: str): | |
| optimizer_state = self.optimizer.state | ||
| else: | ||
| optimizer_state = self.optimizer._get_state() | ||
|
|
||
| use_pinned = self.cfg.get("use_pinned_optimizer_offload", False) | ||
|
|
||
| if device == "cpu": | ||
| if use_pinned: | ||
| self._coalesced_optimizer_to_cpu(optimizer_state) | ||
| else: | ||
| self._optimizer_to_cpu(optimizer_state) | ||
| elif device == "cuda": | ||
| if use_pinned: | ||
| self._coalesced_optimizer_to_cuda(optimizer_state) | ||
| else: | ||
| self._optimizer_to_cuda(optimizer_state) | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid device: {device}. Only strings 'cpu' and 'cuda' are supported." | ||
| ) | ||
|
|
||
| def _optimizer_to_cpu(self, optimizer_state): | ||
| """Offload optimizer state tensors to CPU using default pageable memory.""" | ||
| for _, state in optimizer_state.items(): | ||
| # Iterate through the state items (e.g., momentum, variance) for a parameter | ||
| for k, v in state.items(): | ||
| # Check if the item is a tensor | ||
| if torch.is_tensor(v): | ||
| # Move the tensor to device and update the state dictionary | ||
| if device == "cpu": | ||
| if v.is_cuda: | ||
| state[k] = v.to("cpu") | ||
| elif device == "cuda": | ||
| if not v.is_cuda: | ||
| state[k] = v.to("cuda") | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid device: {device}. Only strings 'cpu' and 'cuda' are supported." | ||
| ) | ||
| if torch.is_tensor(v) and v.is_cuda: | ||
| state[k] = v.to("cpu") | ||
|
|
||
| def _optimizer_to_cuda(self, optimizer_state): | ||
| """Reload optimizer state tensors to CUDA.""" | ||
| for _, state in optimizer_state.items(): | ||
| for k, v in state.items(): | ||
| if torch.is_tensor(v) and not v.is_cuda: | ||
| state[k] = v.to("cuda") | ||
|
|
||
| def _get_or_alloc_pinned_buf( | ||
| self, attr_name: str, total_bytes: int | ||
| ) -> torch.Tensor: | ||
| """Return a cached pinned CPU buffer, allocating only on first use or resize.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend having a So either we add a deletion api here too and show people how to use them in pairs, or we make this allocation API specific to optimizer offloading (in naming or just hardcode it in the _optimizer_to functions), so that folks don't use it for other purposes. |
||
| buf = getattr(self, attr_name, None) | ||
| if buf is None or buf.numel() < total_bytes: | ||
| buf = torch.empty( | ||
| total_bytes, device="cpu", dtype=torch.uint8, pin_memory=True | ||
| ) | ||
| setattr(self, attr_name, buf) | ||
| return buf | ||
|
|
||
| def _coalesced_optimizer_to_cpu(self, optimizer_state): | ||
| """Offload all optimizer state tensors to CPU via a cached pinned buffer. | ||
|
|
||
| Packs all CUDA tensors into a single pre-allocated pinned CPU buffer, | ||
| eliminating per-tensor cudaHostAlloc overhead. The pinned buffer is | ||
| allocated once on first call and reused across iterations. | ||
| """ | ||
| ALIGN = 512 | ||
| entries = [] | ||
| total_bytes = 0 | ||
|
|
||
| for _, state in optimizer_state.items(): | ||
| for k, v in state.items(): | ||
| if not torch.is_tensor(v) or not v.is_cuda: | ||
| continue | ||
| if v.dim() == 0: | ||
| state[k] = v.cpu() | ||
| continue | ||
| offset = (total_bytes + ALIGN - 1) // ALIGN * ALIGN | ||
| nbytes = v.numel() * v.element_size() | ||
| entries.append((state, k, v, offset, nbytes)) | ||
| total_bytes = offset + nbytes | ||
|
|
||
| if not entries: | ||
| return | ||
|
|
||
| cpu_buf = self._get_or_alloc_pinned_buf("_optimizer_pinned_buf", total_bytes) | ||
|
|
||
| for state, k, v, offset, nbytes in entries: | ||
| dst = cpu_buf[offset : offset + nbytes].view(v.dtype).reshape(v.shape) | ||
| dst.copy_(v, non_blocking=True) | ||
| state[k] = dst | ||
|
|
||
| torch.cuda.synchronize() | ||
|
|
||
| def _coalesced_optimizer_to_cuda(self, optimizer_state): | ||
| """Reload all optimizer state tensors back to CUDA.""" | ||
| for _, state in optimizer_state.items(): | ||
| for k, v in state.items(): | ||
| if torch.is_tensor(v) and not v.is_cuda: | ||
| state[k] = v.to("cuda", non_blocking=True) | ||
| torch.cuda.synchronize() | ||
|
|
||
| def save_checkpoint( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not ideal, we should use
self.cfg['use_pinned_optimizer_offload']; we should add this field to all necessary base configs underexamples/configsto make sure the read doesn't fail.