Skip to content

[Feature] Enable AutoEP Compatibility with ZeRO-3#7928

Open
nathon-lee wants to merge 11 commits intodeepspeedai:masterfrom
nathon-lee:feat_autoEP_zero3
Open

[Feature] Enable AutoEP Compatibility with ZeRO-3#7928
nathon-lee wants to merge 11 commits intodeepspeedai:masterfrom
nathon-lee:feat_autoEP_zero3

Conversation

@nathon-lee
Copy link
Copy Markdown
Contributor

@nathon-lee nathon-lee commented Mar 28, 2026

[Feature] Enable AutoEP Compatibility with ZeRO-3


📌 Summary

This PR introduces compatibility between AutoEP (Expert Parallelism) and ZeRO-3.

AutoEP has historically relied on ZeRO-2 due to inherent conflicts between expert-parallel parameter partitioning and ZeRO-3’s data-parallel sharding. This PR resolves those conflicts through a minimal and targeted decoupling strategy, allowing:

  • Expert parameters to follow AutoEP semantics
  • Non-expert parameters (e.g., attention, embeddings) to fully benefit from ZeRO-3 sharding

This preserves AutoEP’s high-throughput execution while unlocking the memory efficiency of ZeRO-3 where applicable.


🔍 Design Overview

Instead of modifying core ZeRO-3 logic, this PR selectively bypasses ZeRO-3 mechanisms for expert parameters, while keeping the default behavior unchanged for all other parameters.

The implementation consists of four focused components:

1. Parameter Partition Bypass

Expert parameters are tagged (_autoep_expert=True) and excluded from ZeRO-3 partitioning and gathering logic.

2. Gradient Reduction Isolation

Expert gradients bypass ZeRO-3 reduce-scatter and instead use all_reduce within the EP data-parallel group, matching AutoEP semantics.

3. Optimizer State Isolation

A dedicated optimizer is introduced for expert parameters, along with FP32 master weights to ensure numerical stability during updates.

4. Checkpoint Compatibility

Expert parameters and their optimizer states are explicitly integrated into checkpoint save/load paths to ensure correct training resumption.


✅ Benefits

  • Enables AutoEP + ZeRO-3 co-existence
  • Reduces memory footprint for non-expert parameters via ZeRO-3
  • Preserves AutoEP’s performance characteristics (Grouped-GEMM + AllToAll)
  • Keeps changes localized without impacting standard ZeRO-3 workflows

⚠️ Trade-offs

  • Expert parameters are not sharded by ZeRO-3
  • Their memory footprint remains similar to AutoEP + ZeRO-2

🧪 Testing

  • Verified end-to-end training correctness

  • Added unit tests for:

    • Gradient reduction isolation
    • Optimizer state handling
    • Checkpoint save/load consistency

Due to limited GPU resources, validation has been performed on 2 GPUs.

If additional resources (e.g., 8 GPUs) are available, I would be very happy to further validate scalability and robustness. The additional verification should only require a few hours.


🙏 Notes

Feedback and suggestions are very welcome.
If possible, I would greatly appreciate access to larger-scale testing resources to further strengthen validation.


References


Signed-off-by: nathon-lee [leejianwoo@gmail.com]

Copilot AI and others added 10 commits February 27, 2026 06:30
This reverts commit ff88670.

Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com>
Revert "fix: update 1 file reformatted." (ff88670)
Revert accidental Muon optimizer code re-introduction from copilot PRs
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: move torch.distributed as dist

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: update docs _tutorials autoep.md .

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d231f6b3bc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +197 to +200
if ep_group is not None:
self.ep_group = ep_group
self.ep_group_name = f"ep_group_{id(ep_group)}"
for param in self.experts.parameters():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Preserve process-group handle in AutoEP parallelism setup

DeepSpeedEngine._configure_distributed_model passes a boolean (use_data_before_expert_parallel_) into every module's set_deepspeed_parallelism, but this method treats any non-None value as an EP process group and overwrites self.ep_group with True/False. Once that happens, AutoEP forward calls (all_gather/all_to_all_single) run with a boolean instead of a process group and fail at runtime as soon as the layer executes.

Useful? React with 👍 / 👎.

Comment on lines +156 to +158
self.reorderer = TokenReorderer(
num_experts=self.num_local_experts,
top_k=spec.top_k,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use global expert cardinality for token reordering

The router emits expert IDs in the global range [0, num_experts), but TokenReorderer is initialized with num_local_experts. For ep_size > 1, expert IDs outside the local range are mis-bucketed/dropped by the histogram logic, so token counts no longer match the sorted token stream; this corrupts dispatch metadata and can trigger incorrect routing or downstream shape/index failures in multi-rank EP runs.

Useful? React with 👍 / 👎.

Comment on lines +2622 to +2626
if not hasattr(self, '_autoep_expert_optimizer'):
optimizer_cls = type(self.optimizer)
base_group = self.optimizer.param_groups[0]
expert_group = {k: v for k, v in base_group.items() if k != 'params'}
expert_group['params'] = expert_params
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep expert optimizer hyperparameters in schedule sync

The dedicated AutoEP expert optimizer is created once from self.optimizer.param_groups[0] and then reused without any hyperparameter refresh. If a scheduler (or manual LR/WD update) changes the main optimizer during training, expert params keep stale hyperparameters while non-expert params follow the new values, causing silent optimization drift between parameter sets.

Useful? React with 👍 / 👎.

@PKUWZP PKUWZP self-requested a review March 28, 2026 15:33
@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Mar 29, 2026

Thank you, @nathon-lee! This is amazing.
AutoEP is not officially released yet. Did you use my fork?

Maybe we should focus on merging the branch first? I have left it for a while, but I will prioritize it if you can help me.

@nathon-lee
Copy link
Copy Markdown
Contributor Author

nathon-lee commented Mar 30, 2026

Thank you, @nathon-lee! This is amazing. AutoEP is not officially released yet. Did you use my fork?

Maybe we should focus on merging the branch first? I have left it for a while, but I will prioritize it if you can help me.

@tohtana Thanks for pointing this out — you’re right. I did use tohtana/DeepSpeedExamples/training/expert_parallel as a reference, and I should have acknowledged that more clearly.

I picked this up because ZeRO-3 compatibility for AutoEP did seem to be covered in the 2026 roadmap. This wasn’t meant as a direct port of tohtana/add_autoep, but it was definitely informed by your earlier work. I’ll update the PR description and address the review comments first. I’d really appreciate your guidance on how best to align it, and I’d be very happy to collaborate and revise it accordingly.

@tohtana tohtana mentioned this pull request Mar 31, 2026
@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Mar 31, 2026

Hi @nathon-lee,
I found this PR is missing some features (universal checkpoint support, metadata saving/loading, some EP implementation, tests) in my branch. I opened a new PR (#7938) based on my branch.
So, how about merging this PR to #7938?

@nathon-lee
Copy link
Copy Markdown
Contributor Author

Hi @nathon-lee, I found this PR is missing some features (universal checkpoint support, metadata saving/loading, some EP implementation, tests) in my branch. I opened a new PR based on my branch. So, how about merging this PR to #7938?

Hi @tohtana, thanks for the heads-up and for adding the missing features on top of your branch.
I’m fine with proceeding with #7938 as the main PR. Once we confirm everything is covered there, we can close this PR.

@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Mar 31, 2026

@nathon-lee #7938 is missing Z3 support. Do you think you can add it? What about creating a new PR focusing on Z3 support and merge it to #7938.

@nathon-lee
Copy link
Copy Markdown
Contributor Author

@nathon-lee #7938 is missing Z3 support. Do you think you can add it? What about creating a new PR focusing on Z3 support and merge it to #7938.

@tohtana ok

@nathon-lee
Copy link
Copy Markdown
Contributor Author

@nathon-lee #7938缺少 Z3 支持。您认为您可以添加吗?不如创建一个新的 PR,专门用于添加 Z3 支持,然后将其合并到#7938中。

@tohtana 好的

I’ll probably wait until your AutoEP branch is merged into main before opening my PR, since my changes depend on your branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants