Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def fused_recurrent_gated_delta_rule_main(
T.annotate_layout({h_smem: tilelang.layout.make_swizzled_layout(h_smem)})
for i, j in T.Parallel(K, v_per_cta):
v_idx = v_start * v_per_cta + j
if v_idx < V:
if v_idx < V and state_id >= 0:
h_smem[i, j] = State[state_id, state_seq_id, hv_id, i, v_idx]
else:
h_smem[i, j] = 0.0
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def _reduce_split_kernel(
if sinks_ptr is not None:
sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype)
l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max)
acc = acc / l_sum
acc = acc / (l_sum + 1e-10)
Comment thread
grimoire marked this conversation as resolved.

out_offs = (cur_batch * stride_obs + cur_head * stride_oh + offs_dv * stride_od)
tl.store(out_ptr + out_offs, acc, mask=mask_dv)
Expand Down Expand Up @@ -716,7 +716,7 @@ def _fused_reduce_hadamard_kernel(
sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype)
l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max)

acc = acc / l_sum
acc = acc / (l_sum + 1e-10)

# Walsh-Hadamard butterfly via acc buffer as float32 scratch
scratch_base = cur_batch * stride_abs + cur_head * stride_ah
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/models/qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Iterable

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -83,12 +82,14 @@ def __init__(self,
prefix=add_prefix('experts', prefix),
)

self.moe_all_reduce = self.experts.build_moe_all_reduce()

self.shared_expert = Qwen3_5MLP(
config=config,
intermediate_size=config.shared_expert_intermediate_size,
dtype=dtype,
device=device,
is_tp=is_tp,
is_tp=self.moe_all_reduce.enable_shared_tp(),
all_reduce=False,
prefix=add_prefix('shared_expert', prefix),
)
Expand Down Expand Up @@ -122,8 +123,7 @@ def forward(self, hidden_states: torch.Tensor, all_routed_experts: torch.Tensor
out_states += shared_states
out_states = out_states.reshape(batch_size, sequence_length, -1)

if self._all_reduce:
dist.all_reduce(out_states)
out_states = self.moe_all_reduce(out_states)
return out_states


Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/utils/cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p

qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens))
input_buffers['qkv_lens'].zero_()
input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs)
# initialize q_seqlens and kv_seqlens to max_tokens // max_batchs
# to avoid out of bound in flash attention kernels
# padding kv should be the same as padding q so q-kv=0
input_buffers['qkv_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs)
input_buffers['qkv_lens'][:, :batch_size] = qkv
input_buffers['cu_seqlens'][:, 1:] = input_buffers['qkv_seqlens'].cumsum(1)
if inputs_embeds is not None:
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/nn/gated_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tens
self.conv_state_indices = state_ids.to(torch.int32)
# we assume 0 is dummy state, shared by all invalid states.
self.valid_state = state_ids >= 0
# keep state_ids < 0 so we can ignore invalid init state
self.origin_state_ids = state_ids
self.state_ids = state_ids.clamp(0)


Expand Down Expand Up @@ -241,7 +243,7 @@ def __call__(
initial_state=recurrent_state,
output_final_state=True,
use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,
state_indices=state_ids,
state_indices=gated_delta_meta.origin_state_ids,
cache_seqlens=cache_seqlens,
)
# out (seqlen, B, ...) -> (1, seqlen * B, ...)
Expand Down
43 changes: 43 additions & 0 deletions lmdeploy/pytorch/nn/moe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,46 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_
def renormalize(self, topk_weights):
"""renormalize."""
return _renormalize(topk_weights, self.do_renormalize)

def build_moe_all_reduce(self):
"""Build moe all reduce.

This is only used when dp==1 and tp>1, and fused moe module does not perform all_reduce
"""
dist_ctx = get_dist_manager().current_context()
dp = dist_ctx.dist_config.dp
enable = (dp == 1) and (not self.all_reduce)
return MoEAllReduce(enable, self.tp, self.tp_mode)


class MoEAllReduce(nn.Module):

def __init__(self, enable: bool, moe_tp: int, tp_mode: TPMode):
super().__init__()
enable_moe_tp = moe_tp > 1
if tp_mode == TPMode.DEFAULT and enable_moe_tp:
# else, shared expert should has same tp as moe layer
#
self._enable_shared_tp = enable_moe_tp
self._all_reduce = enable and enable_moe_tp
else:
# do not support shared layer to perform tp
# do not perform all reduce here
self._enable_shared_tp = False
self._all_reduce = False

if self._all_reduce:
dist_ctx = get_dist_manager().current_context()
self.group = dist_ctx.moe_tp_group.gpu_group
else:
self.group = None

def enable_shared_tp(self):
"""Shared tp."""
return self._enable_shared_tp

def forward(self, x: torch.Tensor):
"""forward."""
if self._all_reduce:
dist.all_reduce(x, group=self.group)
return x
Comment on lines +360 to +364
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoEAllReduce.forward() calls dist.all_reduce(x) without specifying a group. In lmdeploy.pytorch.distributed.all_reduce, the default group is 'tp' which maps to the attention TP group, not the MoE TP group, so this can reduce across the wrong ranks when attn_tp != moe_tp. Please pass the correct MoE TP process group (e.g., from DistContext.moe_tp_group.gpu_group) into MoEAllReduce and use it in the all-reduce call.

Copilot uses AI. Check for mistakes.
Loading