diff --git a/lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py b/lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py index 1774416328..0720b0958e 100644 --- a/lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py +++ b/lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py @@ -203,7 +203,7 @@ def fused_recurrent_gated_delta_rule_main( T.annotate_layout({h_smem: tilelang.layout.make_swizzled_layout(h_smem)}) for i, j in T.Parallel(K, v_per_cta): v_idx = v_start * v_per_cta + j - if v_idx < V: + if v_idx < V and state_id >= 0: h_smem[i, j] = State[state_id, state_seq_id, hv_id, i, v_idx] else: h_smem[i, j] = 0.0 diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 1a279b96cb..a4f0bcef99 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -602,7 +602,7 @@ def _reduce_split_kernel( if sinks_ptr is not None: sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype) l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max) - acc = acc / l_sum + acc = acc / (l_sum + 1e-10) out_offs = (cur_batch * stride_obs + cur_head * stride_oh + offs_dv * stride_od) tl.store(out_ptr + out_offs, acc, mask=mask_dv) @@ -716,7 +716,7 @@ def _fused_reduce_hadamard_kernel( sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype) l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max) - acc = acc / l_sum + acc = acc / (l_sum + 1e-10) # Walsh-Hadamard butterfly via acc buffer as float32 scratch scratch_base = cur_batch * stride_abs + cur_head * stride_ah diff --git a/lmdeploy/pytorch/models/qwen3_5_moe.py b/lmdeploy/pytorch/models/qwen3_5_moe.py index 112b9e116b..4fee786d07 100644 --- a/lmdeploy/pytorch/models/qwen3_5_moe.py +++ b/lmdeploy/pytorch/models/qwen3_5_moe.py @@ -3,7 +3,6 @@ from collections.abc import Iterable import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig @@ -83,12 +82,14 @@ def __init__(self, prefix=add_prefix('experts', prefix), ) + self.moe_all_reduce = self.experts.build_moe_all_reduce() + self.shared_expert = Qwen3_5MLP( config=config, intermediate_size=config.shared_expert_intermediate_size, dtype=dtype, device=device, - is_tp=is_tp, + is_tp=self.moe_all_reduce.enable_shared_tp(), all_reduce=False, prefix=add_prefix('shared_expert', prefix), ) @@ -122,8 +123,7 @@ def forward(self, hidden_states: torch.Tensor, all_routed_experts: torch.Tensor out_states += shared_states out_states = out_states.reshape(batch_size, sequence_length, -1) - if self._all_reduce: - dist.all_reduce(out_states) + out_states = self.moe_all_reduce(out_states) return out_states diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index c7c59edb6d..3082dc9075 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -229,7 +229,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) input_buffers['qkv_lens'].zero_() - input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs) + # initialize q_seqlens and kv_seqlens to max_tokens // max_batchs + # to avoid out of bound in flash attention kernels + # padding kv should be the same as padding q so q-kv=0 + input_buffers['qkv_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs) input_buffers['qkv_lens'][:, :batch_size] = qkv input_buffers['cu_seqlens'][:, 1:] = input_buffers['qkv_seqlens'].cumsum(1) if inputs_embeds is not None: diff --git a/lmdeploy/pytorch/nn/gated_delta.py b/lmdeploy/pytorch/nn/gated_delta.py index b57ac6cd95..2ddbcd7f75 100644 --- a/lmdeploy/pytorch/nn/gated_delta.py +++ b/lmdeploy/pytorch/nn/gated_delta.py @@ -73,6 +73,8 @@ def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tens self.conv_state_indices = state_ids.to(torch.int32) # we assume 0 is dummy state, shared by all invalid states. self.valid_state = state_ids >= 0 + # keep state_ids < 0 so we can ignore invalid init state + self.origin_state_ids = state_ids self.state_ids = state_ids.clamp(0) @@ -241,7 +243,7 @@ def __call__( initial_state=recurrent_state, output_final_state=True, use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, - state_indices=state_ids, + state_indices=gated_delta_meta.origin_state_ids, cache_seqlens=cache_seqlens, ) # out (seqlen, B, ...) -> (1, seqlen * B, ...) diff --git a/lmdeploy/pytorch/nn/moe/base.py b/lmdeploy/pytorch/nn/moe/base.py index 76f1927d46..c47b6854de 100644 --- a/lmdeploy/pytorch/nn/moe/base.py +++ b/lmdeploy/pytorch/nn/moe/base.py @@ -319,3 +319,46 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ def renormalize(self, topk_weights): """renormalize.""" return _renormalize(topk_weights, self.do_renormalize) + + def build_moe_all_reduce(self): + """Build moe all reduce. + + This is only used when dp==1 and tp>1, and fused moe module does not perform all_reduce + """ + dist_ctx = get_dist_manager().current_context() + dp = dist_ctx.dist_config.dp + enable = (dp == 1) and (not self.all_reduce) + return MoEAllReduce(enable, self.tp, self.tp_mode) + + +class MoEAllReduce(nn.Module): + + def __init__(self, enable: bool, moe_tp: int, tp_mode: TPMode): + super().__init__() + enable_moe_tp = moe_tp > 1 + if tp_mode == TPMode.DEFAULT and enable_moe_tp: + # else, shared expert should has same tp as moe layer + # + self._enable_shared_tp = enable_moe_tp + self._all_reduce = enable and enable_moe_tp + else: + # do not support shared layer to perform tp + # do not perform all reduce here + self._enable_shared_tp = False + self._all_reduce = False + + if self._all_reduce: + dist_ctx = get_dist_manager().current_context() + self.group = dist_ctx.moe_tp_group.gpu_group + else: + self.group = None + + def enable_shared_tp(self): + """Shared tp.""" + return self._enable_shared_tp + + def forward(self, x: torch.Tensor): + """forward.""" + if self._all_reduce: + dist.all_reduce(x, group=self.group) + return x