Skip to content

Fix qwen35 dp#4535

Open
grimoire wants to merge 3 commits intoInternLM:mainfrom
grimoire:fix-qwen35-tp
Open

Fix qwen35 dp#4535
grimoire wants to merge 3 commits intoInternLM:mainfrom
grimoire:fix-qwen35-tp

Conversation

@grimoire
Copy link
Copy Markdown
Collaborator

dp/cudagraph might padding state_ids with -1, which would be clamp to 0 in model.

0 is reserved state for dummy inputs, multiple dummy inputs might write to the same state, leads to invalid output (nan/inf).

@grimoire grimoire requested a review from yao-fengchen April 18, 2026 08:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses NaN/Inf issues seen with Qwen3.5 under DP + CUDA graph execution by preventing padded/invalid state_ids from colliding with the reserved dummy state, and by hardening a few related kernels/buffer initializations.

Changes:

  • Preserve negative state_ids for gated-delta decoding and update the CUDA kernel to ignore invalid (<0) states.
  • Adjust CUDA graph buffer initialization to fill both Q/KV seqlens padding consistently to avoid flash-attn OOB.
  • Introduce a MoE all-reduce wrapper and route Qwen3.5 MoE output reduction through it; add a small div-by-zero guard in split-K reduce.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
lmdeploy/pytorch/nn/moe/base.py Adds build_moe_all_reduce() and MoEAllReduce helper for post-MoE reduction/shared-TP enablement.
lmdeploy/pytorch/nn/gated_delta.py Keeps original (possibly negative) state_ids and uses them for decoding-state indices.
lmdeploy/pytorch/models/utils/cudagraph.py Initializes both q/kv seqlens padding via qkv_seqlens to prevent kernel OOB.
lmdeploy/pytorch/models/qwen3_5_moe.py Uses new MoE all-reduce helper and updates shared expert TP selection.
lmdeploy/pytorch/kernels/cuda/pagedattention.py Adds epsilon to avoid div-by-zero in _reduce_split_kernel.
lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py Skips loading initial state when state_id < 0.
Comments suppressed due to low confidence (1)

lmdeploy/pytorch/models/qwen3_5_moe.py:106

  • After switching to self.moe_all_reduce(...), the old DP/TP gating block that sets self._all_reduce (and the dp/world_size locals) is now unused, and the new reduction path is unconditional. Once MoEAllReduce is fixed to use the right group/enable conditions, please either delete the dead _all_reduce logic or wire its condition into build_moe_all_reduce() so behavior matches the previous (dp==1 and moe_tp>1) guard.
        self.moe_all_reduce = self.experts.build_moe_all_reduce()

        self.shared_expert = Qwen3_5MLP(
            config=config,
            intermediate_size=config.shared_expert_intermediate_size,
            dtype=dtype,
            device=device,
            is_tp=self.moe_all_reduce.enable_shared_tp(),
            all_reduce=False,
            prefix=add_prefix('shared_expert', prefix),
        )
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)

        # get all reduce
        dist_ctx = get_dist_manager().current_context()
        dp = dist_ctx.dist_config.dp
        world_size = dist_ctx.dist_config.moe_tp
        if dp == 1 and world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +351 to +355
def forward(self, x: torch.Tensor):
"""forward."""
if self._all_reduce:
dist.all_reduce(x)
return x
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoEAllReduce.forward() calls dist.all_reduce(x) without specifying a group. In lmdeploy.pytorch.distributed.all_reduce, the default group is 'tp' which maps to the attention TP group, not the MoE TP group, so this can reduce across the wrong ranks when attn_tp != moe_tp. Please pass the correct MoE TP process group (e.g., from DistContext.moe_tp_group.gpu_group) into MoEAllReduce and use it in the all-reduce call.

Copilot uses AI. Check for mistakes.
Comment thread lmdeploy/pytorch/nn/moe/base.py Outdated

This is only used when dp==1 and tp>1, and fused moe module does not perform all_reduce
"""
return MoEAllReduce(not self.all_reduce, self.tp, self.tp_mode)
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_moe_all_reduce()/MoEAllReduce currently enables all-reduce based on not self.all_reduce + tp_mode, but it does not enforce the documented dp==1 condition. In DP runs, an extra all-reduce here would incorrectly mix data-parallel replicas. Please include a DP guard (e.g., check get_dist_manager().current_context().dist_config.dp) in the enable logic, or require the caller to pass an explicit enable flag derived from (dp==1 and moe_tp>1 and fused_moe_not_reducing).

Suggested change
return MoEAllReduce(not self.all_reduce, self.tp, self.tp_mode)
dist_ctx = get_dist_manager().current_context()
dp = dist_ctx.dist_config.dp
enable = (dp == 1) and (not self.all_reduce)
return MoEAllReduce(enable, self.tp, self.tp_mode)

Copilot uses AI. Check for mistakes.
Comment on lines 602 to +605
if sinks_ptr is not None:
sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype)
l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max)
acc = acc / l_sum
acc = acc / (l_sum + 1e-10)
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The div-by-zero guard was added in _reduce_split_kernel, but _fused_reduce_hadamard_kernel in the same file still does acc = acc / l_sum and can hit the same l_sum==0 case (e.g., when all splits are masked out). Consider applying a consistent safeguard (e.g., clamp l_sum to an epsilon or conditional divide) to the fused kernel as well to avoid NaN/Inf in TURBO_QUANT paths.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants