diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 55ffaf6f52b..d65d437fbfb 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -61,7 +61,7 @@ def add_llm_args(parser): default='AUTO', choices=[ 'AUTO', 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM', - 'CUTEDSL', 'TRITON' + 'CUTEDSL', 'TRITON', 'DENSEGEMM' ], help= 'MoE backend to use. AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases.' @@ -203,6 +203,19 @@ def add_llm_args(parser): parser.add_argument('--relaxed_topk', type=int, default=1) parser.add_argument('--relaxed_delta', type=float, default=0.) + # CuTe DSL + parser.add_argument( + '--use_cute_dsl_bf16_bmm', + default=False, + action='store_true', + help='Use CuTe DSL bf16 persistent GEMM for BMM on Blackwell.') + parser.add_argument( + '--use_cute_dsl_bf16_gemm', + default=False, + action='store_true', + help='Use CuTe DSL bf16 persistent GEMM for Linear layers on Blackwell.' + ) + # HF parser.add_argument('--trust_remote_code', default=False, @@ -331,6 +344,8 @@ def setup_llm(args, **kwargs): gather_generation_logits=args.return_generation_logits, max_beam_width=args.max_beam_width, orchestrator_type=args.orchestrator_type, + use_cute_dsl_bf16_bmm=args.use_cute_dsl_bf16_bmm, + use_cute_dsl_bf16_gemm=args.use_cute_dsl_bf16_gemm, **kwargs) use_beam_search = args.max_beam_width > 1 diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index f91d331dbe5..2f647f66a0c 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -336,6 +336,8 @@ def get_dense_gemm_approximate_cta_nums( Sm100BlockwiseGemmKernel from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \ Sm100BlockScaledPersistentDenseGemmKernel + from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import \ + PersistentDenseGemmKernel from ..cute_dsl_kernels.blackwell.moe_as_dense_gemm.fc1 import \ Sm100BlockScaledPersistentDenseGemmKernel as DenseGemmSwigluKernel from ..cute_dsl_kernels.blackwell.top_k.filtered_top_k_decode_varlen import \ @@ -715,7 +717,7 @@ def forward( max_active_clusters, stream, swap_ab, - options=f"--opt-level 2 --enable-tvm-ffi" + options="--opt-level 2 --enable-tvm-ffi" if self.use_tvm_ffi else "--opt-level 2", ) @@ -2597,7 +2599,7 @@ def forward( c_cute_tensor, max_active_clusters=max_active_clusters, stream=stream, - options=f"--opt-level 2 --enable-tvm-ffi" + options="--opt-level 2 --enable-tvm-ffi" if self.use_tvm_ffi else "--opt-level 2", ) self.__class__.kernel_cache[cache_key] = compiled_gemm @@ -2911,7 +2913,7 @@ def forward( c_cute_tensor, max_active_clusters=max_active_clusters, stream=stream, - options=f"--opt-level 2 --enable-tvm-ffi" + options="--opt-level 2 --enable-tvm-ffi" if self.use_tvm_ffi else "--opt-level 2", ) self.__class__.kernel_cache[cache_key] = compiled_gemm @@ -3042,14 +3044,21 @@ def __init__( weight_per_expert: int, output_dtype: torch.dtype, scaling_vector_size: int = 16, + sm_budget: int = -1, ): super().__init__() self.expert_count = expert_count self.weight_per_expert = weight_per_expert self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size + self.sm_budget = sm_budget def unique_id(self): + # sm_budget is intentionally excluded: inner tuning is performed once + # with the full SM budget (sm_budget=0) and the resulting tactic is + # reused across all GreenContext SM splits. The actual sm_budget is + # still applied at kernel-execution time (max_active_clusters), so + # runtime throughput correctly reflects the partition size. return ( self.expert_count, self.weight_per_expert, @@ -3234,15 +3243,24 @@ def forward( self.scaling_vector_size, self.expert_count, alpha_post is not None, # Whether alpha_post is enabled - self. - output_dtype, # Include output dtype to avoid cache collision + self.sm_budget, + c_cutlass_dtype, # output dtype affects compiled kernel ) if cache_key not in self.__class__.kernel_cache: - # Get max active clusters only when compiling kernel + # Get max active clusters only when compiling kernel. + # self.sm_budget is in SM units; convert to cluster count by + # dividing by the number of SMs per cluster. + sms_per_cluster = cluster_shape_mn[0] * cluster_shape_mn[1] hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1]) + max_active_clusters_hw = hardware_info.get_max_active_clusters( + sms_per_cluster) + if self.sm_budget > 0: + constrained = max(1, self.sm_budget // sms_per_cluster) + max_active_clusters = min(max_active_clusters_hw, + constrained) + else: + max_active_clusters = max_active_clusters_hw kernel = self.kernel_class( sf_vec_size=self.scaling_vector_size, @@ -3313,6 +3331,7 @@ def cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( weight_per_expert: int, output_dtype: torch.dtype, scaling_vector_size: int = 16, + sm_budget: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Dense GEMM with SwiGLU fusion for MoE FC1 layer. @@ -3332,15 +3351,50 @@ def cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( weight_per_expert: Number of weight columns per expert output_dtype: Output data type (bfloat16 or float16) scaling_vector_size: Block scaling vector size (default: 16) + sm_budget: Number of physical SMs available for this kernel (-1 = unconstrained). + The actual max_active_clusters is derived from the best tactic's cluster_shape_mn: + max_active_clusters = sm_budget // (cluster_m * cluster_n) + where cluster_m * cluster_n is the number of SMs per software cluster. Returns: Tuple of (output, output_scale_factor) """ - runner = CuteDSLNVFP4DenseGemmSwigluRunner( + # Auto-detect SM budget from the current stream's GreenContext when not + # provided explicitly. If the op is dispatched inside + # ``torch.cuda.stream(gc_stream)`` where *gc_stream* was created by + # ``cuGreenCtxStreamCreate``, the stream carries a CUgreenCtx whose SM + # partition size is used as the budget. Falls back to -1 (unconstrained) + # for plain streams or when the Driver API is unavailable. + if sm_budget == -1: + from tensorrt_llm._torch.modules.fused_moe.green_context import \ + get_current_stream_gc_sm_count + detected = get_current_stream_gc_sm_count() + if detected > 0: + sm_budget = detected + # print( + # f"Auto-detected SM budget from GreenContext: {sm_budget} SMs" + # ) + + # unique_id() intentionally excludes sm_budget, so all GC splits share + # the same autotuner cache entry. The cache is pre-warmed once with + # sm_budget=0 (full SM) by DenseGEMMGCSMRunner.do_preparation before + # the outer profiling loop. Subsequent calls with any sm_budget find a + # cache hit and skip profiling; the actual sm_budget is still applied + # inside forward() when deriving max_active_clusters for execution. + tuning_runner = CuteDSLNVFP4DenseGemmSwigluRunner( expert_count=expert_count, weight_per_expert=weight_per_expert, output_dtype=output_dtype, scaling_vector_size=scaling_vector_size, + sm_budget=-1, + ) + + exec_runner = CuteDSLNVFP4DenseGemmSwigluRunner( + expert_count=expert_count, + weight_per_expert=weight_per_expert, + output_dtype=output_dtype, + scaling_vector_size=scaling_vector_size, + sm_budget=sm_budget, ) inputs = [ @@ -3351,12 +3405,16 @@ def cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( tuner = AutoTuner.get() _, best_tactic = tuner.choose_one( "trtllm::cute_dsl_nvfp4_dense_gemm_swiglu_blackwell", - [runner], - runner.get_tuning_config(), + [tuning_runner], + tuning_runner.get_tuning_config(), inputs, ) - output, output_sf = runner(inputs, tactic=best_tactic) + # print(f"{best_tactic=}") + + # Phase 2: execute with the SM budget reflecting the actual runtime + # constraint (Green Context partition or unconstrained). + output, output_sf = exec_runner(inputs, tactic=best_tactic) return output, output_sf @torch.library.register_fake( @@ -3373,6 +3431,7 @@ def _( weight_per_expert: int, output_dtype: torch.dtype, scaling_vector_size: int = 16, + sm_budget: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: # weight: [num_expert, weight_per_expert, k//2] (fp4 packed) m = input.shape[0] @@ -3455,8 +3514,8 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs, - ) -> List[Tuple[Tuple[int, int], Tuple[int, int]]]: - """Return valid (mma_tiler_mn, cluster_shape_mn) combinations.""" + ) -> List[Tuple[Tuple[int, int], Tuple[int, int], int]]: + """Return valid (mma_tiler_mn, cluster_shape_mn, split_k) combinations.""" # Check SM version - only supports SM 100 and SM 103 major, minor = torch.cuda.get_device_capability() if not (major == 10 and minor in [0, 3]): @@ -3470,9 +3529,11 @@ def get_valid_tactics( n = b.shape[0] l = 1 # dense GEMM - # Define candidates together - mma_tiler_mn_candidates = [(128, 64), (128, 128), (128, 256)] - cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4)] + # Define candidates + mma_tiler_mn_candidates = [(128, 64), (128, 128), (128, 256), + (256, 128)] + cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4), (2, 1)] + split_k_candidates = [1, 2, 4] # Map torch dtype to cutlass dtype if self.output_dtype not in self._CUTLASS_DTYPE_MAP: @@ -3481,6 +3542,9 @@ def get_valid_tactics( ) c_cutlass_dtype = self._CUTLASS_DTYPE_MAP[self.output_dtype] + # MMA tile K size for split-K divisibility check + _MMA_TILE_K = 256 + tactics = [] for mma_tiler_mn, cluster_shape_mn in itertools.product( mma_tiler_mn_candidates, cluster_shape_mn_candidates): @@ -3501,7 +3565,15 @@ def get_valid_tactics( self.expert_count, self.weight_per_expert, ): - tactics.append((mma_tiler_mn, cluster_shape_mn)) + for split_k in split_k_candidates: + # K-tiles must be evenly divisible by split_k, + # and each split must contain whole experts. + k_tiles = k // _MMA_TILE_K + tiles_per_expert = self.weight_per_expert // _MMA_TILE_K + if (k_tiles % split_k == 0 and + (k_tiles // split_k) % tiles_per_expert == 0): + tactics.append( + (mma_tiler_mn, cluster_shape_mn, split_k)) return tactics @@ -3510,14 +3582,13 @@ def get_tuning_config(self) -> TuningConfig: if key not in self.tuning_config_cache: self.tuning_config_cache[key] = TuningConfig( dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2), ), + 0, 0, deep_gemm_gen_tuning_buckets), ), constraint_specs=( ConstraintSpec(2, 0, fp4_scale_infer_shape), ConstraintSpec(4, 0, lambda shapes: shapes[0][0]), ), use_cold_l2_cache=True, - tune_max_num_tokens=256, + tune_max_num_tokens=512, distributed_tuning_strategy=DistributedTuningStrategy. PARALLEL, ) @@ -3526,13 +3597,13 @@ def get_tuning_config(self) -> TuningConfig: def forward( self, inputs: List[torch.Tensor], - tactic: Optional[Tuple[Tuple[int, int], Tuple[int, int]]], + tactic: Optional[Tuple[Tuple[int, int], Tuple[int, int], int]], ) -> torch.Tensor: """Execute the dense GEMM FC2. Args: inputs: [a, b, a_sf, b_sf, alpha_scale] - tactic: ((mma_m, mma_n), (cluster_m, cluster_n)) + tactic: ((mma_m, mma_n), (cluster_m, cluster_n), split_k) Returns: Output tensor @@ -3547,14 +3618,21 @@ def forward( l = 1 # dense GEMM # Default tactic if not provided - if isinstance(tactic, tuple): + if isinstance(tactic, tuple) and len(tactic) == 3: + mma_tiler_mn, cluster_shape_mn, split_k = tactic + elif isinstance(tactic, tuple) and len(tactic) == 2: mma_tiler_mn, cluster_shape_mn = tactic + split_k = 1 else: - mma_tiler_mn, cluster_shape_mn = (128, 128), (1, 1) + mma_tiler_mn, cluster_shape_mn, split_k = (128, 128), (1, 1), 1 # Allocate output tensor c_dtype = self.output_dtype - c = torch.empty((m, n), dtype=c_dtype, device=a.device) + if split_k > 1: + # Atomic reduction accumulates onto C; must be zero-initialized + c = torch.zeros((m, n), dtype=c_dtype, device=a.device) + else: + c = torch.empty((m, n), dtype=c_dtype, device=a.device) # Get CUDA stream torch_stream = torch.cuda.current_stream() @@ -3599,6 +3677,7 @@ def forward( self.weight_per_expert, mma_tiler_mn, cluster_shape_mn, + split_k, self.scaling_vector_size, self. output_dtype, # Include output dtype to avoid cache collision @@ -3616,6 +3695,7 @@ def forward( cluster_shape_mn=cluster_shape_mn, expert_count=self.expert_count, weight_per_expert=self.weight_per_expert, + split_k=split_k, ) # Compile the kernel and cache it @@ -5127,3 +5207,627 @@ def warmup_cute_dsl_indexer_topk( f"Warmed up CuTE DSL indexer top-k kernels: dtype={dtype}, " f"SingleCTA bucketed_num_cols=[2^{min_seq_len_log2}..2^{max_seq_len_log2}], " f"{multi_cta_info}, top_k={top_k}, next_n={next_n}") + + # ====================================================================== + # BF16 Dense Persistent BMM (CuTe DSL) for Blackwell + # ====================================================================== + + class CuteDSLBf16BlackwellBmmRunner(TunableRunner): + kernel_class = PersistentDenseGemmKernel + kernel_cache = dict() + + tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( + 0, 1, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), ) + + def __init__(self, use_tvm_ffi: bool = True): + super().__init__() + self.use_tvm_ffi = use_tvm_ffi + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL BF16 BMM only supports SM 100 family. Skipping all tactics." + ) + return [] + # [b, m, k] + batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[ + 0].shape[2] + # [b, n, k] + n = inputs[1].shape[1] + # m,k + a_major = "k" + # n, k + b_major = "k" + # m, n + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if self.__class__.kernel_class.can_implement( + cutlass.BFloat16, # ab_dtype + cutlass.Float32, # acc_dtype + cutlass.BFloat16, # c_dtype + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> None: + """ + Performs bf16 dense persistent batched gemm using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (batch_size, m, k), dtype: bf16. + inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: bf16. + inputs[2]: Output tensor of shape (batch_size, m, n), dtype: bf16. + tactic: Tiling and cluster strategy, typically a tuple + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + + a_tensor, b_tensor, c_tensor = inputs + + # Permute C from [B, M, N] to [M, N, B] for CuTe layout. + # from_dlpack captures the actual strides, so non-contiguous + # views (e.g. from .transpose(0,1)) are handled natively by + # TMA without an extra copy. + c_tmp = c_tensor.permute(1, 2, 0) + + batch_size = a_tensor.shape[0] + m = a_tensor.shape[1] + k = a_tensor.shape[2] + n = b_tensor.shape[1] + + # Compute A strides so the kernel can handle non-contiguous + # views (e.g. [M,B,K].transpose(0,1) → [B,M,K] with + # non-standard strides) without a .contiguous() copy. + # CuTe tensor is (M, K, B) so strides map as: + # M stride = a_tensor.stride(1) + # K stride = 1 (always innermost) + # B stride = a_tensor.stride(0) + a_stride_m = a_tensor.stride(1) + a_stride_batch = a_tensor.stride(0) + + if not self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.BFloat16, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.BFloat16, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + cache_key = ( + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + self.use_tvm_ffi, + ) + if cache_key not in self.__class__.kernel_cache: + if self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.BFloat16, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.BFloat16, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + stream = cute.runtime.make_fake_stream( + use_tvm_ffi_env_stream=True) + + gemm = self.__class__.kernel_class( + cutlass.Float32, # acc_dtype + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + compiled_gemm = cute.compile( + gemm.wrapper_strided, + m, + n, + k, + batch_size, + a_ptr, + b_ptr, + c_cute_tensor, + a_stride_m, + a_stride_batch, + max_active_clusters=max_active_clusters, + stream=stream, + options="--opt-level 2 --enable-tvm-ffi" + if self.use_tvm_ffi else "--opt-level 2", + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + if self.use_tvm_ffi: + compiled_gemm( + m, + n, + k, + batch_size, + a_tensor.data_ptr(), + b_tensor.data_ptr(), + c_tmp, + a_stride_m, + a_stride_batch, + ) + else: + compiled_gemm( + m, + n, + k, + batch_size, + a_ptr, + b_ptr, + c_cute_tensor, + a_stride_m, + a_stride_batch, + stream=stream, + ) + + # a/b: bf16, output: bf16 + @torch.library.custom_op("trtllm::cute_dsl_bf16_bmm_blackwell", + mutates_args=("output", ), + device_types="cuda") + def cute_dsl_bf16_bmm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + use_tvm_ffi: bool = True, + ) -> None: + if not is_sm_100f(): + raise ValueError( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL BF16 BMM only supports SM 100 family.") + + tuner = AutoTuner.get() + + runner = CuteDSLBf16BlackwellBmmRunner(use_tvm_ffi=use_tvm_ffi) + + inputs = [input, weight, output] + + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_bf16_bmm_blackwell::gemm", + [runner], + runner.__class__.tuning_config, + inputs, + ) + runner(inputs, tactic=best_tactic) + + @torch.library.register_fake("trtllm::cute_dsl_bf16_bmm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + output: torch.Tensor, + use_tvm_ffi: bool = True, + ) -> None: + batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2] + n = mat_b.shape[1] + assert output.dtype == torch.bfloat16, "CuTe DSL bf16 bmm output dtype must be bf16" + assert output.shape == ( + batch_size, m, n), "CuTe DSL bf16 bmm output shape is incorrect" + + # ====================================================================== + # BF16 Dense Persistent GEMM (CuTe DSL) for Blackwell - Linear layers + # ====================================================================== + + class CuteDSLBf16BlackwellGemmRunner(TunableRunner): + """ + CuTe DSL BF16 GEMM runner for Linear layers. + + Unlike BMM which operates on [B, M, K] @ [B, N, K] -> [B, M, N], + GEMM operates on [M, K] @ [N, K]^T -> [M, N] (standard Linear). + + We reuse PersistentDenseGemmKernel with batch_size=1. + """ + kernel_class = PersistentDenseGemmKernel + kernel_cache = dict() + + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=( + # Output rows (tensor 2, dim 0) equal input rows (tensor 0, dim 0). + # Without this, when the autotuner sweeps sub-profiles with smaller + # num_tokens the output tensor stays at its original size, producing a + # distinct cache key for every (num_tokens_bucket, sub_profile) pair + # and causing O(N^2) redundant router-GEMM re-tunings. + ConstraintSpec(2, 0, lambda shapes: shapes[0][0]), ), + ) + + def __init__(self, + use_tvm_ffi: bool = True, + max_active_clusters: Optional[int] = None, + sm_budget: int = -1): + super().__init__() + self.use_tvm_ffi = use_tvm_ffi + self.max_active_clusters = max_active_clusters + self.sm_budget = sm_budget + + def unique_id(self): + # Exclude SM-count fields so that the autotuner cache is shared + # across different GreenContext partitions. Tuning is always done + # with an unconstrained runner (all SMs); the SM budget is applied + # only at execution time. + return (self.use_tvm_ffi, ) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL BF16 GEMM only supports SM 100 family. Skipping all tactics." + ) + return [] + + # input: [M, K], weight: [N, K], output: [M, N] + m, k = inputs[0].shape[0], inputs[0].shape[1] + n = inputs[1].shape[0] + batch_size = 1 + + # Detect output dtype from the output tensor (supports BF16 and FP32) + c_dtype_cutlass = _TORCH_TO_CUTLASS_DTYPE[inputs[2].dtype] + + # Layouts: A is [M, K] K-major, B is [N, K] K-major + a_major = "k" + b_major = "k" + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if self.__class__.kernel_class.can_implement( + cutlass.BFloat16, # ab_dtype + cutlass.Float32, # acc_dtype + c_dtype_cutlass, # c_dtype + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> None: + """ + Performs bf16 dense persistent GEMM using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (m, k), dtype: bf16. + inputs[1]: Weight tensor of shape (n, k), dtype: bf16. + inputs[2]: Output tensor of shape (m, n), dtype: bf16 or fp32. + tactic: Tiling and cluster strategy, typically a tuple + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + + a_tensor, b_tensor, c_tensor = inputs + + # Input: [M, K], Weight: [N, K], Output: [M, N] + m, k = a_tensor.shape[0], a_tensor.shape[1] + n = b_tensor.shape[0] + batch_size = 1 + + # Ensure inputs are contiguous + a_tensor = a_tensor.contiguous() + b_tensor = b_tensor.contiguous() + + # For output, use contiguous buffer if needed + c_needs_copy = not c_tensor.is_contiguous() + if c_needs_copy: + c_buf = torch.empty_like(c_tensor) + else: + c_buf = c_tensor + + # Reshape to [1, M, K], [1, N, K], [1, M, N] for the batched kernel + a_batched = a_tensor.unsqueeze(0) # [1, M, K] + b_batched = b_tensor.unsqueeze(0) # [1, N, K] + # c_buf is [M, N], permute to [M, N, 1] for cute layout + c_tmp = c_buf.unsqueeze(-1) # [M, N, 1] + + # Detect output dtype (supports BF16 and FP32) + c_dtype_cutlass = _TORCH_TO_CUTLASS_DTYPE[c_tensor.dtype] + + if not self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.BFloat16, + a_batched.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.BFloat16, + b_batched.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters_hw = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + if self.sm_budget > 0: + # sm_budget is in SM units; convert to cluster count based on + # the best tactic's actual cluster size (no hardcoded assumption). + cluster_sms = cluster_shape_mn[0] * cluster_shape_mn[1] + constrained = max(1, self.sm_budget // cluster_sms) + max_active_clusters = min(max_active_clusters_hw, constrained) + elif self.max_active_clusters is not None: + max_active_clusters = min(max_active_clusters_hw, + self.max_active_clusters) + else: + max_active_clusters = max_active_clusters_hw + + cache_key = ( + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + self.use_tvm_ffi, + c_dtype_cutlass, + max_active_clusters, + ) + if cache_key not in self.__class__.kernel_cache: + if self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.BFloat16, + a_batched.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.BFloat16, + b_batched.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + stream = cute.runtime.make_fake_stream( + use_tvm_ffi_env_stream=True) + + gemm = self.__class__.kernel_class( + cutlass.Float32, # acc_dtype + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + + compiled_gemm = cute.compile( + gemm.wrapper, + m, + n, + k, + batch_size, + a_ptr, + b_ptr, + c_cute_tensor, + max_active_clusters=max_active_clusters, + stream=stream, + options="--opt-level 2 --enable-tvm-ffi" + if self.use_tvm_ffi else "--opt-level 2", + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + if self.use_tvm_ffi: + compiled_gemm( + m, + n, + k, + batch_size, + a_batched.data_ptr(), + b_batched.data_ptr(), + c_tmp, + ) + else: + compiled_gemm( + m, + n, + k, + batch_size, + a_ptr, + b_ptr, + c_cute_tensor, + stream=stream, + ) + + # Copy result back if original output was non-contiguous + if c_needs_copy: + c_tensor.copy_(c_buf) + + # input: [M, K], weight: [N, K], output: [M, N] + @torch.library.custom_op("trtllm::cute_dsl_bf16_gemm_blackwell", + mutates_args=("output", ), + device_types="cuda") + def cute_dsl_bf16_gemm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + use_tvm_ffi: bool = True, + max_active_clusters: int = 0, + sm_budget: int = -1, + ) -> None: + """ + CuTe DSL BF16 GEMM for Linear layers on Blackwell. + + Computes: output = input @ weight^T + - input: [M, K] (num_tokens, in_features) + - weight: [N, K] (out_features, in_features) + - output: [M, N] (num_tokens, out_features) + + Args: + max_active_clusters: Direct cluster count limit (0 = unconstrained). + sm_budget: Number of physical SMs available (-1 = unconstrained). + When > 0, a two-phase tune is performed: first the best cluster_shape_mn + is determined without constraint, then max_active_clusters is derived as + sm_budget // (cluster_m * cluster_n). Takes precedence over max_active_clusters. + """ + if not is_sm_100f(): + raise ValueError( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL BF16 GEMM only supports SM 100 family.") + + tuner = AutoTuner.get() + + # Auto-detect SM budget from the current stream's GreenContext. + if sm_budget == -1: + from tensorrt_llm._torch.modules.fused_moe.green_context import \ + get_current_stream_gc_sm_count + detected = get_current_stream_gc_sm_count() + if detected > 0: + sm_budget = detected + # print( + # f"Auto-detected SM budget from GreenContext: {sm_budget} SMs" + # ) + + # Tuning runner: unconstrained (all SMs). unique_id() intentionally + # excludes sm_budget so the autotuner cache is shared across all + # GreenContext partitions — the best tile/cluster shape is the same + # regardless of how many SMs are available. + tune_runner = CuteDSLBf16BlackwellGemmRunner(use_tvm_ffi=use_tvm_ffi) + + # Execution runner: applies the actual SM budget so that + # max_active_clusters is correctly constrained at launch time. + exec_runner = CuteDSLBf16BlackwellGemmRunner(use_tvm_ffi=use_tvm_ffi, + sm_budget=sm_budget) + + inputs = [input, weight, output] + + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_bf16_gemm_blackwell::gemm", + [tune_runner], + tune_runner.__class__.tuning_config, + inputs, + ) + + # print( + # f"Chosen tactic for CuTe DSL BF16 GEMM: {best_tactic}, sm_budget={exec_runner.sm_budget}" + # ) + + exec_runner(inputs, tactic=best_tactic) + + @torch.library.register_fake("trtllm::cute_dsl_bf16_gemm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + output: torch.Tensor, + use_tvm_ffi: bool = True, + max_active_clusters: int = 0, + sm_budget: int = -1, + ) -> None: + m, k = mat_a.shape[0], mat_a.shape[1] + n = mat_b.shape[0] + assert output.dtype in (torch.bfloat16, torch.float32), \ + "CuTe DSL bf16 gemm output dtype must be bf16 or fp32" + assert output.shape == ( + m, n), "CuTe DSL bf16 gemm output shape is incorrect" diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py new file mode 100644 index 00000000000..67e035a1d51 --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py @@ -0,0 +1,1061 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file is copied and modified from cutlass example https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py + +from typing import Literal, Optional, Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from .custom_pipeline import PipelineTmaUmma, PipelineUmmaAsync +from .utils import ( + TRTLLM_ENABLE_PDL, + griddepcontrol_launch_dependents, + griddepcontrol_wait, + is_power_of_2, +) + + +def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + c_smem_layout: Union[cute.Layout, None], +) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics.""" + num_acc_stage = 2 + num_c_stage = 2 if use_tma_store else 0 + + a_smem_layout_stage_one = utils.sm100.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1) + b_smem_layout_staged_one = utils.sm100.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1) + + ab_bytes_per_stage = cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + cute.size_in_bytes( + b_dtype, b_smem_layout_staged_one + ) + mbar_helpers_bytes = 1024 + + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout) + c_bytes = c_bytes_per_stage * num_c_stage + + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + if use_tma_store: + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + +class PersistentDenseGemmKernel: + """Persistent batched dense GEMM (C = A x B) for Blackwell SM100 using CuTe DSL. + + Supports BF16/FP16 inputs with FP32 accumulator and BF16/FP16 output. + + Notes: + - A and B tensor must have the same data type. + - Supported A/B data types: Float16, BFloat16, TFloat32, Float8E4M3FN, + Float8E5M2, Int8, Uint8 + - Supported accumulator: Float32, Float16, Int32 + - MMA tiler M: 64/128 (1CTA) or 128/256 (2CTA) + - MMA tiler N: 32-256, step 32 + - Cluster M must be multiple of 2 if 2CTA + - Cluster M*N <= 16, positive powers of 2 + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool = True, + swizzle_size: int = 1, + raster_along: Literal["m", "n"] = "m", + ): + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + self.swizzle_size = swizzle_size + self.raster_along = raster_along + self.mma_tiler_mn = mma_tiler_mn + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + self.arch = "sm_100" + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilogue_warp_id) + ) + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + + def _create_tiled_mma(self): + return utils.sm100.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs.""" + tiled_mma = self._create_tiled_mma() + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + c_smem_layout = None + if cutlass.const_expr(self.use_tma_store): + c_smem_layout = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, 1 + ) + + self.smem_capacity = utils.get_smem_capacity_in_bytes() + + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = _compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.c_dtype, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + c_smem_layout, + ) + + self.a_smem_layout_staged = utils.sm100.make_smem_layout_a( + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage + ) + self.b_smem_layout_staged = utils.sm100.make_smem_layout_b( + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage + ) + + self.c_smem_layout_staged = None + if self.use_tma_store: + self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage + ) + + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation.""" + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + tiled_mma = self._create_tiled_mma() + self._setup_attributes() + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), + ) + + # Setup TMA load for B + b_op = utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + self.swizzle_size, + self.raster_along, + max_active_clusters, + ) + + # Launch the kernel + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + use_pdl=TRTLLM_ENABLE_PDL, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """GPU device kernel performing the Persistent batched GEMM computation.""" + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch tma desc + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # Setup cta/thread coordinates + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + tidx, _, _ = cute.arch.thread_idx() + + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc_mbar: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ).make_participants() + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + _tmem_dealloc_barrier = None + if cutlass.const_expr(not self.use_tma_store): + _tmem_dealloc_barrier = pipeline.NamedBarrier( # noqa: F841 + barrier_id=self.tmem_dealloc_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # Setup smem tensor A/B + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + + # Compute multicast mask for A/B buffer full + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # Local_tile partition global tensors + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # Partition global tensor for TiledMMA_A/B/C + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + tCgC = thr_mma.partition_C(gC_mnl) + + # Partition global/shared tensor for TMA load A/B + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # Cluster wait before tensor memory alloc + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # Construct the scheduler + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + ) + work_tile = tile_sched.initial_work_tile_info() + + # PDL: Wait for previous grid to finish + griddepcontrol_wait() + + # Specialized TMA load warp + if warp_idx == self.tma_warp_id: + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) + + cute.copy( + tma_atom_a, + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + mcast_mask=b_full_mcast_mask, + ) + + peek_ab_empty_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + ab_producer.tail() + + # Specialized MMA warp + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if is_leader_cta: + peek_ab_full_status = ab_consumer.try_wait() + + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # Reset ACCUMULATE for each new output tile + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in range(k_tile_cnt): + if is_leader_cta: + handle = ab_consumer.wait_and_advance(peek_ab_full_status) + + # Inner loop over kblocks within each K tile. + # Set ACCUMULATE=True after first gemm call to + # avoid clearing the accumulator on each sub-MMA. + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_crd = (None, None, kblock_idx, handle.index) + cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_crd], tCrB[kblock_crd], tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + handle.release() + + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + acc_pipeline.producer_tail(acc_producer_state) + + sC = None + if cutlass.const_expr(self.use_tma_store): + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # Specialized epilogue warps + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # -- Epilogue partition setup (TMA store path) -- + assert cutlass.const_expr(self.use_tma_store) + assert tma_atom_c is not None and sC is not None + + # TMEM -> RMEM copy setup + copy_atom_t2r = utils.sm100.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + tAcc_epi = cute.flat_divide(tCtAcc_base[((None, None), 0, 0, None)], epi_tile) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_tAcc_base = thr_copy_t2r.partition_S(tAcc_epi) + + gC_mnl_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None, None)], epi_tile) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + # RMEM -> SMEM copy setup + copy_atom_r2s = utils.sm100.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + + # SMEM -> GMEM TMA store setup + sC_for_tma = cute.group_modes(sC, 0, 2) + gC_for_tma = cute.group_modes(gC_mnl_epi, 0, 2) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, 0, cute.make_layout(1), sC_for_tma, gC_for_tma + ) + + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + # -- Epilogue tile scheduling loop -- + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + num_tiles_executed = tile_sched.num_tiles_executed + + # Slice to per mma tile + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + acc_stage_index = acc_consumer_state.index + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Store accumulator to global memory in sub-tiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = (num_tiles_executed - 1) * subtile_cnt + + for subtile_idx in cutlass.range(subtile_cnt): + # Load accumulator from TMEM to RMEM + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Convert to output type and apply epilogue op + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # Store to SMEM + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + + # Fence and barrier + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilogue_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # TMA store from SMEM to GMEM + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)]) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # Release accumulator buffer + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Wait for C store complete and deallocate TMEM + c_pipeline.producer_tail() + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + # PDL: Launch dependent kernels + griddepcontrol_launch_dependents() + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + swizzle_size: int, + raster_along: Literal["m", "n"], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid size using persistent tile scheduler.""" + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl, swizzle_size, raster_along == "m" + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """Compute the number of tensor memory allocation columns.""" + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + return num_tmem_alloc_cols + + @staticmethod + def check_supported_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """Check if the dtypes are valid.""" + valid_ab_dtypes = { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + } + if a_dtype not in valid_ab_dtypes or b_dtype not in valid_ab_dtypes: + return False + if a_dtype != b_dtype: + return False + if acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: + return False + + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Float16: {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + if a_dtype not in acc_ab_compatibility.get(acc_dtype, set()): + return False + + acc_c_compatibility = { + cutlass.Float32: { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + cutlass.Float16: {cutlass.BFloat16, cutlass.Float16}, + cutlass.Int32: { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + } + if c_dtype not in acc_c_compatibility.get(acc_dtype, set()): + return False + + return True + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """Check if the mma tiler and cluster shape are valid.""" + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + return False + if mma_tiler_mn[1] not in range(32, 257, 32): + return False + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + return False + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + return False + return True + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + batch_size: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """Check if the tensor alignment is valid (contiguous dim 16-byte aligned).""" + + def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contiguous_16B_alignment(ab_dtype, a_major == "m", (m, k, batch_size)) + or not check_contiguous_16B_alignment(ab_dtype, b_major == "n", (n, k, batch_size)) + or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, batch_size)) + ): + return False + return True + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + batch_size: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """Check if the gemm can be implemented.""" + if not PersistentDenseGemmKernel.check_supported_dtypes( + ab_dtype, ab_dtype, acc_dtype, c_dtype + ): + return False + if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + return False + if not PersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, batch_size, ab_dtype, c_dtype, a_major, b_major, c_major + ): + return False + # Check epilogue store alignment for non-TMA store + # (TMA store handles OOB; we always use TMA store) + return True + + @cute.jit + def wrapper( + self, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + batch_size: cutlass.Int32, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Executes the wrapped GEMM kernel with dynamically shaped tensors. + + Args: + m: The M dimension of the GEMM problem. + n: The N dimension of the GEMM problem. + k: The K dimension of the GEMM problem. + batch_size: The batch dimension. + a_ptr: Pointer to the A tensor (M x K x batch_size, row-major K). + b_ptr: Pointer to the B tensor (N x K x batch_size, row-major K). + c_tensor: Output tensor as cute.Tensor for TVM FFI stream detection. + max_active_clusters: Maximum number of active clusters. + stream: CUDA stream for the operation. + """ + # m, k, batch_size with inner most dimension as k + a_tensor = cute.make_tensor( + a_ptr, + layout=cute.make_ordered_layout((m, k, batch_size), order=(1, 0, 2)), + ) + # n, k, batch_size with inner most dimension as k + b_tensor = cute.make_tensor( + b_ptr, + layout=cute.make_ordered_layout( + (n, k, batch_size), + order=(1, 0, 2), + ), + ) + + self( + a_tensor, + b_tensor, + c_tensor, + max_active_clusters, + stream, + ) + + @cute.jit + def wrapper_strided( + self, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + batch_size: cutlass.Int32, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_tensor: cute.Tensor, + a_stride_m: cutlass.Int32, + a_stride_batch: cutlass.Int32, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Executes the GEMM kernel with explicit A tensor strides. + + Like ``wrapper`` but allows non-contiguous A tensors by accepting + the M and batch strides directly. The K stride is assumed to be 1 + (row-major in K). B is always contiguous. + + Args: + m: The M dimension of the GEMM problem. + n: The N dimension of the GEMM problem. + k: The K dimension of the GEMM problem. + batch_size: The batch dimension. + a_ptr: Pointer to the A tensor data. + b_ptr: Pointer to the B tensor data. + c_tensor: Output tensor as cute.Tensor. + a_stride_m: Stride of A along the M dimension (in elements). + a_stride_batch: Stride of A along the batch dimension (in elements). + max_active_clusters: Maximum number of active clusters. + stream: CUDA stream for the operation. + """ + # A with explicit strides: (M, K, batch_size), K stride = 1 + a_tensor = cute.make_tensor( + a_ptr, + layout=cute.make_layout( + (m, k, batch_size), + stride=(a_stride_m, 1, a_stride_batch), + ), + ) + # B is always contiguous: (N, K, batch_size) with K innermost + b_tensor = cute.make_tensor( + b_ptr, + layout=cute.make_ordered_layout( + (n, k, batch_size), + order=(1, 0, 2), + ), + ) + + self( + a_tensor, + b_tensor, + c_tensor, + max_active_clusters, + stream, + ) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py index 8a26748c4e8..fe35021064b 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py @@ -24,6 +24,13 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.cute.nvgpu import cpasync, tcgen05 +from ..utils import ( + atomic_add_func, + vectorized_atomic_add_bf16x8, + vectorized_atomic_add_fp16x8, + vectorized_atomic_add_fp32x2, +) + """ This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel @@ -159,6 +166,7 @@ def __init__( weight_per_expert: int, use_prefetch: bool = False, prefetch_dist: int = 3, + split_k: int = 1, ): """Initializes the configuration for a Blackwell dense GEMM kernel. @@ -207,15 +215,11 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 - self.alpha_scale_load_warp_id = 6 - self.dummy_warp_id = 7 self.threads_per_cta = 32 * len( ( self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id, - self.alpha_scale_load_warp_id, - self.dummy_warp_id, ) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync @@ -240,6 +244,7 @@ def __init__( self.use_prefetch = use_prefetch self.prefetch_dist = prefetch_dist + self.split_k = split_k def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -332,23 +337,43 @@ def _setup_attributes(self): self.c_dtype, ) + # Atomic add parameters for split-K epilogue + if self.split_k > 1: + epi_tile_m = cute.size(self.epi_tile[0]) + epi_tile_n = cute.size(self.epi_tile[1]) + num_epilogue_threads = 32 * len(self.epilog_warp_id) + self.ttr_racc_size = (epi_tile_m * epi_tile_n) // num_epilogue_threads + if self.c_dtype in (cutlass.Float16, cutlass.BFloat16): + self.epi_layout_atomic = cute.make_layout( + shape=(self.ttr_racc_size // 8, 4, 2), stride=(8, 2, 1) + ) + self.epi_loop_size_atomic = self.ttr_racc_size // 8 + self.element_offset_atomic = 8 + elif self.c_dtype == cutlass.Float32: + self.epi_layout_atomic = cute.make_layout( + shape=(self.ttr_racc_size // 2, 2), stride=(2, 1) + ) + self.epi_loop_size_atomic = self.ttr_racc_size // 2 + self.element_offset_atomic = 2 + else: + self.epi_layout_atomic = cute.make_layout(shape=(self.ttr_racc_size,), stride=(1,)) + self.epi_loop_size_atomic = self.ttr_racc_size + self.element_offset_atomic = 1 + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory - self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_alpha_scale_stage = ( - self._compute_stages( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.b_dtype, - self.epi_tile, - self.c_dtype, - self.c_layout, - self.sf_dtype, - self.sf_vec_size, - self.smem_capacity, - self.occupancy, - ) + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, ) - # Compute A/B/SFA/SFB/C shared memory layout self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( tiled_mma, @@ -381,17 +406,6 @@ def _setup_attributes(self): self.num_c_stage, ) - self.alpha_scale_smem_layout_staged = cute.make_layout( - ( - self.cta_tile_shape_mnk[0], - self.num_alpha_scale_stage, - ), - stride=( - self.num_alpha_scale_stage, - 1, - ), - ) - @cute.jit def __call__( self, @@ -546,12 +560,13 @@ def __call__( self.epi_tile, ) - # Compute grid size + # Compute grid size (inflated by split_k for split-K decomposition) self.tile_sched_params, grid = self._compute_grid( c_tensor, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters, + self.split_k, ) self.buffer_align_bytes = 1024 @@ -563,9 +578,6 @@ class SharedStorage: ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - alpha_scale_load_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.num_alpha_scale_stage * 2 - ] tmem_dealloc_mbar_ptr: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) @@ -596,13 +608,6 @@ class SharedStorage: cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], self.buffer_align_bytes, ] - # Alpha scale shared memory - sAlphaScale: cute.struct.Align[ - cute.struct.MemRange[ - cutlass.Float32, cute.cosize(self.alpha_scale_smem_layout_staged) - ], - 16, - ] self.shared_storage = SharedStorage @@ -627,11 +632,12 @@ class SharedStorage: self.b_smem_layout_staged, self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, - self.alpha_scale_smem_layout_staged, self.c_smem_layout_staged, self.epi_tile, self.tile_sched_params, epilogue_op, + c_tensor, + self.epi_layout_atomic if self.split_k > 1 else cute.make_layout(1), ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], @@ -664,11 +670,12 @@ def kernel( b_smem_layout_staged: cute.ComposedLayout, sfa_smem_layout_staged: cute.Layout, sfb_smem_layout_staged: cute.Layout, - alpha_scale_smem_layout_staged: cute.Layout, c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], epi_tile: cute.Tile, tile_sched_params: utils.PersistentTileSchedulerParams, epilogue_op: cutlass.Constexpr, + mC_raw: cute.Tensor, + epi_layout_atomic: cute.Layout, ): """ GPU device kernel performing the Persistent batched GEMM computation. @@ -738,24 +745,6 @@ def kernel( cta_layout_vmnk=cluster_layout_vmnk, ) - # Initialize alpha_scale_pipeline (barrier) and states - alpha_scale_pipeline_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - 32 * 1, # alpha_scale_load_warp_id threads - 32 * 1, - ) - alpha_scale_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - 32 * len(self.epilog_warp_id), # epilogue warps - 32 * len(self.epilog_warp_id), - ) - alpha_scale_pipeline = pipeline.PipelineCpAsync.create( - barrier_storage=storage.alpha_scale_load_mbar_ptr.data_ptr(), - num_stages=self.num_alpha_scale_stage, - producer_group=alpha_scale_pipeline_producer_group, - consumer_group=alpha_scale_pipeline_consumer_group, - ) - # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( storage.tmem_holding_buf, @@ -782,8 +771,6 @@ def kernel( sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) # (MMA, MMA_N, MMA_K, STAGE) sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) - # Alpha scale shared memory tensor - sAlphaScale = storage.sAlphaScale.get_tensor(alpha_scale_smem_layout_staged) # # Compute multicast mask for A/B/SFA/SFB buffer full @@ -839,7 +826,10 @@ def kernel( gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) - k_tile_cnt = cutlass.Int32(cute.size(gA_mkl, mode=[3])) + k_tile_total = cute.size(gA_mkl, mode=[3]) + # For split-K: each CTA processes k_tile_total // split_k K-tiles + k_tiles_per_split = k_tile_total // self.split_k + k_tile_cnt_local = cutlass.Int32(k_tiles_per_split) # # Partition global tensor for TiledMMA_A/B/C @@ -952,10 +942,17 @@ def kernel( while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx + + # Split-K: decompose L coord into batch_idx and split_k_idx + # Input tensors use batch_idx for L; K-tiles offset by k_start + # For split_k=1: batch_idx = coord[2], k_start = 0 + batch_idx = cur_tile_coord[2] // self.split_k + k_start = (cur_tile_coord[2] % self.split_k) * k_tile_cnt_local + mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], - cur_tile_coord[2], + batch_idx, ) # @@ -979,92 +976,100 @@ def kernel( # Prefetch logic: use_prefetch for both A&B, or explicit A-only/B-only if self.use_prefetch: # Prefetch both A and B (default behavior) - for k_tile in cutlass.range(0, min(prefetch_dist, k_tile_cnt), unroll=1): + for k_tile in cutlass.range(0, min(prefetch_dist, k_tile_cnt_local), unroll=1): # Prefetch both A and B (default behavior) cute.prefetch( tma_atom_a, - tAgA_slice[(None, k_tile)], + tAgA_slice[(None, k_tile + k_start)], ) cute.prefetch( tma_atom_b, - tBgB_slice[(None, k_tile)], + tBgB_slice[(None, k_tile + k_start)], ) cute.prefetch( tma_atom_sfa, - tAgSFA_slice[(None, k_tile)], + tAgSFA_slice[(None, k_tile + k_start)], ) cute.prefetch( tma_atom_sfb, - tBgSFB_slice[(None, k_tile)], + tBgSFB_slice[(None, k_tile + k_start)], ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: + if ab_producer_state.count < k_tile_cnt_local: peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Tma load loop # - for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + for k_tile in cutlass.range(0, k_tile_cnt_local, 1, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) - # TMA load A/B/SFA/SFB + # TMA load A/B/SFA/SFB (offset by k_start for split-K) cute.copy( tma_atom_a, - tAgA_slice[(None, ab_producer_state.count)], + tAgA_slice[(None, ab_producer_state.count + k_start)], tAsA[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB_slice[(None, ab_producer_state.count)], + tBgB_slice[(None, ab_producer_state.count + k_start)], tBsB[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=b_full_mcast_mask, ) cute.copy( tma_atom_sfa, - tAgSFA_slice[(None, ab_producer_state.count)], + tAgSFA_slice[(None, ab_producer_state.count + k_start)], tAsSFA[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=sfa_full_mcast_mask, ) cute.copy( tma_atom_sfb, - tBgSFB_slice[(None, ab_producer_state.count)], + tBgSFB_slice[(None, ab_producer_state.count + k_start)], tBsSFB[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=sfb_full_mcast_mask, ) # Prefetch logic in the loop: use_prefetch for both A&B, or explicit A-only/B-only - if k_tile < k_tile_cnt - prefetch_dist: + if k_tile < k_tile_cnt_local - prefetch_dist: if self.use_prefetch: # Prefetch both A and B (default behavior) cute.prefetch( tma_atom_a, - tAgA_slice[(None, ab_producer_state.count + prefetch_dist)], + tAgA_slice[ + (None, ab_producer_state.count + k_start + prefetch_dist) + ], ) cute.prefetch( tma_atom_b, - tBgB_slice[(None, ab_producer_state.count + prefetch_dist)], + tBgB_slice[ + (None, ab_producer_state.count + k_start + prefetch_dist) + ], ) cute.prefetch( tma_atom_sfa, - tAgSFA_slice[(None, ab_producer_state.count + prefetch_dist)], + tAgSFA_slice[ + (None, ab_producer_state.count + k_start + prefetch_dist) + ], ) cute.prefetch( tma_atom_sfb, - tBgSFB_slice[(None, ab_producer_state.count + prefetch_dist)], + tBgSFB_slice[ + (None, ab_producer_state.count + k_start + prefetch_dist) + ], ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: + if ab_producer_state.count < k_tile_cnt_local: peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # @@ -1078,101 +1083,6 @@ def kernel( # ab_pipeline.producer_tail(ab_producer_state) - # - # Specialized Alpha Scale Load warp - # - if warp_idx == self.alpha_scale_load_warp_id: - # - # Persistent tile scheduling loop - # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - - alpha_scale_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_alpha_scale_stage - ) - - # Setup copy atom for alpha scale loading - atom_copy = cute.make_copy_atom( - cute.nvgpu.cpasync.CopyG2SOp(), - malpha_scale_mnl.element_type, - num_bits_per_copy=malpha_scale_mnl.element_type.width, - ) - tiled_copy_alpha_scale = cute.make_tiled_copy_tv( - atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) - ) - thr_copy_alpha_scale = tiled_copy_alpha_scale.get_slice(cute.arch.lane_idx()) - - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler - cur_tile_coord = work_tile.tile_idx - - # Reset producer state for this tile - alpha_scale_producer_state.reset_count() - peek_alpha_scale_empty_status = cutlass.Boolean(1) - if alpha_scale_producer_state.count < k_tile_cnt: - peek_alpha_scale_empty_status = alpha_scale_pipeline.producer_try_acquire( - alpha_scale_producer_state - ) - - galpha_scale_mnl_current_tile = galpha_scale_mnl[ - None, cur_tile_coord[0], None, cur_tile_coord[2] - ] - - tAgAlphaScale = thr_copy_alpha_scale.partition_S(galpha_scale_mnl_current_tile) - tAsAlphaScale = thr_copy_alpha_scale.partition_D(sAlphaScale) - - # Load alpha scale for each k tile - for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): - # Calculate expert index for this k_tile - expert_idx = k_tile // (self.weight_per_expert // self.mma_tiler[2]) - - # Slice alpha scale for current tile and expert - tAgAlphaScale_slice = tAgAlphaScale[(None, None, expert_idx)] - tAsAlphaScale_slice = tAsAlphaScale[ - (None, None, alpha_scale_producer_state.index) - ] - - # Wait for alpha scale buffer empty - alpha_scale_pipeline.producer_acquire( - alpha_scale_producer_state, peek_alpha_scale_empty_status - ) - - num_iters = cute.size(tAgAlphaScale_slice, mode=[1]) - - # Load alpha scale from global to shared memory - for iter_idx in cutlass.range(num_iters, unroll_full=True): - iter_coord = (None, iter_idx) - pred = cutlass.Boolean( - 32 * iter_idx + cute.arch.lane_idx() < malpha_scale_mnl.shape[0] - ) - if pred: - cute.copy( - tiled_copy_alpha_scale, - tAgAlphaScale_slice[iter_coord], - tAsAlphaScale_slice[iter_coord], - ) - - # Commit and advance - alpha_scale_pipeline.producer_commit(alpha_scale_producer_state) - alpha_scale_producer_state.advance() - - # Peek next - peek_alpha_scale_empty_status = cutlass.Boolean(1) - if alpha_scale_producer_state.count < k_tile_cnt: - peek_alpha_scale_empty_status = alpha_scale_pipeline.producer_try_acquire( - alpha_scale_producer_state - ) - - # Advance to next tile - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - # Wait for alpha scale buffer empty - alpha_scale_pipeline.producer_tail(alpha_scale_producer_state) - # # Specialized MMA warp # @@ -1273,25 +1183,30 @@ def kernel( # Peek (try_wait) AB buffer full for k_tile = 0 ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + if ab_consumer_state.count < k_tile_cnt_local and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # - # Mma mainloop + # Mma mainloop with expert grouping + # Accumulate tiles_per_expert k-tiles per acc buffer to halve + # acc pipeline traffic (512 → 256 round-trips). + # For split-K: only process k_tiles_per_split K-tiles (a subset of experts). # - for k_tile in range(k_tile_cnt): - # Set tensor memory buffer for current tile - # (MMA, MMA_M, MMA_N) - tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + tiles_per_expert_mma = self.weight_per_expert // self.mma_tiler[2] + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + for k_tile in range(k_tiles_per_split): + is_first_of_expert = k_tile % tiles_per_expert_mma == 0 + is_last_of_expert = k_tile % tiles_per_expert_mma == tiles_per_expert_mma - 1 + + if is_first_of_expert: + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] if is_leader_cta: - # Wait for accumulator buffer empty(each kblock) - acc_pipeline.producer_acquire(acc_producer_state) + if is_first_of_expert: + acc_pipeline.producer_acquire(acc_producer_state) - # Conditionally wait for AB buffer full ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) - # Copy SFA/SFB from smem to tmem s2t_stage_coord = ( None, None, @@ -1312,12 +1227,9 @@ def kernel( tCtSFB_compact_s2t, ) - # - # Reset the ACCUMULATE field for each tile - # - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + if is_first_of_expert: + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB num_kblocks = cute.size(tCrA, mode=[2]) for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): kblock_coord = ( @@ -1327,7 +1239,6 @@ def kernel( ab_consumer_state.index, ) - # Set SFA/SFB tensor to tiled_mma sf_kblock_coord = (None, None, kblock_idx) tiled_mma.set( tcgen05.Field.SFA, @@ -1346,25 +1257,20 @@ def kernel( tCtAcc, ) - # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) - # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt: + if ab_consumer_state.count < k_tile_cnt_local: if is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) - # - # Async arrive accumulator buffer full - # - if is_leader_cta: - acc_pipeline.producer_commit(acc_producer_state) - acc_producer_state.advance() + if is_last_of_expert: + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() # # Advance to next tile @@ -1432,11 +1338,6 @@ def kernel( pipeline.PipelineUserType.Consumer, self.num_acc_stage ) - # Alpha scale consumer state - alpha_scale_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_alpha_scale_stage - ) - # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, @@ -1448,30 +1349,29 @@ def kernel( producer_group=c_producer_group, ) - # Create copy atom for loading alpha scale from smem - alpha_scale_copy_atom_s2r = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - cutlass.Float32, - num_bits_per_copy=32, - ) - tiled_copy_alpha_scale_s2r = cute.make_tiled_copy_tv( - alpha_scale_copy_atom_s2r, - cute.make_layout((32 * len(self.epilog_warp_id),)), - cute.make_layout((1,)), - ) - thr_copy_alpha_scale_s2r = tiled_copy_alpha_scale_s2r.get_slice(tidx) + tiles_per_expert = self.weight_per_expert // self.mma_tiler[2] + # For split-K: each CTA handles experts_per_split experts + experts_per_split = k_tiles_per_split // tiles_per_expert + m_total = malpha_scale_mnl.shape[0] while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx + + # Split-K: decompose L coord into batch_idx and split_k_idx + # For split_k=1: batch_idx = coord[2], expert_offset = 0 + batch_idx = cur_tile_coord[2] // self.split_k + expert_offset = (cur_tile_coord[2] % self.split_k) * experts_per_split + + # Use batch_idx for output L coordinate (correct for both split_k=1 and >1) mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], - cur_tile_coord[2], + batch_idx, ) # - # Slice to per mma tile index + # Slice to per mma tile index (for TMA store path) # # ((ATOM_V, REST_V), EPI_M, EPI_N) bSG_gC = bSG_gC_partitioned[ @@ -1486,54 +1386,33 @@ def kernel( # initialize the final accumulator tTR_rAcc_final.fill(0.0) - # Initialize alpha_scale consumer state for this tile - alpha_scale_consumer_state.reset_count() - peek_alpha_scale_full_status = cutlass.Boolean(1) - if alpha_scale_consumer_state.count < k_tile_cnt: - peek_alpha_scale_full_status = alpha_scale_pipeline.consumer_try_wait( - alpha_scale_consumer_state - ) + # Epilogue iterates over experts_per_split experts (not full expert_count) + num_experts_k = cutlass.Int32(experts_per_split) acc_consumer_state.reset_count() peek_acc_full_status = cutlass.Boolean(1) - if acc_consumer_state.count < k_tile_cnt: + if acc_consumer_state.count < num_experts_k: peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) - for k_tile in cutlass.range(k_tile_cnt): - # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + m_start = cur_tile_coord[0] * self.cta_tile_shape_mnk[0] + m_in_bounds = m_start < m_total + thread_in_bounds = epi_tidx < (m_total - m_start) if m_in_bounds else False + + # Prologue: prefetch alpha for first expert in this split + prefetched_alpha = cutlass.Float32(0.0) + if thread_in_bounds: + prefetched_alpha = galpha_scale_mnl[ + epi_tidx, cur_tile_coord[0], expert_offset, batch_idx + ] + + for expert_idx in cutlass.range(num_experts_k): tTR_tAcc = tTR_tAcc_base[ (None, None, None, None, None, acc_consumer_state.index) ] tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - # - # Update accumulator by scale factor in subtiles - # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - # for subtile_idx in cutlass.range_dynamic(subtile_cnt): - - # - # Wait for alpha scale buffer full - # - alpha_scale_pipeline.consumer_wait( - alpha_scale_consumer_state, peek_alpha_scale_full_status - ) - - # Load alpha scale from shared memory for current expert - alpha_scale_smem_slice = sAlphaScale[(None, alpha_scale_consumer_state.index)] - - tAsAlphaScale_slice = thr_copy_alpha_scale_s2r.partition_S( - alpha_scale_smem_slice - ) - current_alpha_scale_reg = cute.make_rmem_tensor( - tAsAlphaScale_slice.shape, cutlass.Float32 - ) - - cute.copy( - alpha_scale_copy_atom_s2r, tAsAlphaScale_slice, current_alpha_scale_reg - ) - current_alpha_scale = current_alpha_scale_reg[0] + current_alpha_scale = prefetched_alpha # # Wait for accumulator buffer full @@ -1559,16 +1438,6 @@ def kernel( final_vec = acc_vec * current_alpha_scale + final_vec tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) - # Release alpha scale buffer - alpha_scale_pipeline.consumer_release(alpha_scale_consumer_state) - alpha_scale_consumer_state.advance() - - # Peek next alpha scale - peek_alpha_scale_full_status = cutlass.Boolean(1) - if alpha_scale_consumer_state.count < k_tile_cnt: - peek_alpha_scale_full_status = alpha_scale_pipeline.consumer_try_wait( - alpha_scale_consumer_state - ) # # Async arrive accumulator buffer empty # @@ -1576,59 +1445,97 @@ def kernel( acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() + # Prefetch next expert's alpha after releasing acc buffer + next_expert = expert_idx + 1 + clamped_expert = ( + next_expert if next_expert < num_experts_k else num_experts_k - 1 + ) + prefetched_alpha = cutlass.Float32(0.0) + if thread_in_bounds: + prefetched_alpha = galpha_scale_mnl[ + epi_tidx, cur_tile_coord[0], expert_offset + clamped_expert, batch_idx + ] + peek_acc_full_status = cutlass.Boolean(1) - if acc_consumer_state.count < k_tile_cnt: + if acc_consumer_state.count < num_experts_k: peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) # # Store accumulator to global memory in subtiles # - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - subtile_cnt = cute.size(tTR_rAcc_final.shape, mode=[3]) - num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt - for subtile_idx in cutlass.range(subtile_cnt): - # - # Store accumulator to shared memory or global memory - # - tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] + if cutlass.const_expr(self.split_k > 1): # - # Convert to C type + # Split-K atomic add path: each CTA atomically adds its + # partial result directly to the output C tensor. + # Guard with M bounds check to avoid OOB writes when + # M is not a multiple of tile_M. # - acc_vec = tiled_copy_r2s.retile(tTR_rAcc_subtile).load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tRS_rC.store(acc_vec) - + rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout_atomic) + m_coord = cur_tile_coord[0] * self.cta_tile_shape_mnk[0] + epi_tidx + scatter_base = cute.domain_offset((m_coord, 0, batch_idx), mC_raw) + + if thread_in_bounds: + for subtile_idx in cutlass.range(subtile_cnt): + tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] + acc_vec = tTR_rAcc_subtile.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + base_coord_n = cur_tile_coord[1] * self.cta_tile_shape_mnk[ + 1 + ] + subtile_idx * cute.size(tTR_rC) + + for index in cutlass.range(self.epi_loop_size_atomic, unroll_full=True): + coord_n = base_coord_n + index * self.element_offset_atomic + scatter_out = cute.domain_offset((0, coord_n, 0), scatter_base) + if cutlass.const_expr(self.c_dtype == cutlass.Float16): + vectorized_atomic_add_fp16x8( + rOut_epi[index, None, None], scatter_out + ) + elif cutlass.const_expr(self.c_dtype == cutlass.BFloat16): + vectorized_atomic_add_bf16x8( + rOut_epi[index, None, None], scatter_out + ) + elif cutlass.const_expr(self.c_dtype == cutlass.Float32): + vectorized_atomic_add_fp32x2(rOut_epi[index, None], scatter_out) + else: + atomic_add_func(rOut_epi[index], scatter_out) + else: # - # Store C to shared memory + # Standard TMA store path (split_k=1) # - c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage - cute.copy( - tiled_copy_r2s, - tRS_rC, - tRS_sC[(None, None, None, c_buffer)], - ) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) - self.epilog_sync_barrier.arrive_and_wait() + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] - # - # TMA store C to global memory - # - if warp_idx == self.epilog_warp_id[0]: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc_subtile).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage cute.copy( - tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], ) - # Fence and barrier to make sure shared memory store is visible to TMA store - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - self.epilog_sync_barrier.arrive_and_wait() + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() # # Advance to next tile # @@ -1642,9 +1549,10 @@ def kernel( self.epilog_sync_barrier.arrive_and_wait() tmem.free(acc_tmem_ptr) # - # Wait for C store complete + # Wait for C store complete (only needed for TMA store path) # - c_pipeline.producer_tail() + if cutlass.const_expr(self.split_k <= 1): + c_pipeline.producer_tail() def mainloop_s2t_copy_and_partition( self, @@ -1856,8 +1764,8 @@ def _compute_stages( sf_vec_size: int, smem_capacity: int, occupancy: int, - ) -> Tuple[int, int, int, int]: - """Computes the number of stages for A/B/C/Alpha_Scale operands based on heuristics. + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. :param tiled_mma: The tiled MMA object defining the core computation. :type tiled_mma: cute.TiledMma @@ -1883,22 +1791,17 @@ def _compute_stages( :type occupancy: int :return: A tuple containing the computed number of stages for: - (ACC stages, A/B operand stages, C stages, Alpha_Scale stages) - :rtype: tuple[int, int, int, int] + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] """ - # ACC stages - num_acc_stage = 2 + # ACC stages: 3 for N=128 to allow MMA to run ahead of epilogue + num_acc_stage = 3 if mma_tiler_mnk[1] == 256: num_acc_stage = 1 - # else if mma_tiler_mnk[1] == 64: - # num_acc_stage = 6 # Default C stages num_c_stage = 2 - # Default Alpha scale stages - num_alpha_scale_stage = 10 - # Calculate smem layout and size for one stage of A, B, SFA, SFB and C a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( tiled_mma, @@ -1942,21 +1845,12 @@ def _compute_stages( c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) c_bytes = c_bytes_per_stage * num_c_stage - # Alpha scale shared memory - # hardcode cta tile shape - alpha_bytes = cute.size_in_bytes( - cutlass.Float32, - cute.make_layout( - (mma_tiler_mnk[0] // tiled_mma.thr_id.shape, num_alpha_scale_stage), - stride=(num_alpha_scale_stage, 1), - ), - ) # Calculate A/B/SFA/SFB stages: # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B/SFA/SFB stage num_ab_stage = ( - smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes + alpha_bytes) + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage # Refine epilogue stages: @@ -1965,10 +1859,10 @@ def _compute_stages( num_c_stage += ( smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - - occupancy * (mbar_helpers_bytes + c_bytes + alpha_bytes) + - occupancy * (mbar_helpers_bytes + c_bytes) ) // (occupancy * c_bytes_per_stage) - return num_acc_stage, num_ab_stage, num_c_stage, num_alpha_scale_stage + return num_acc_stage, num_ab_stage, num_c_stage @staticmethod def _compute_grid( @@ -1976,6 +1870,7 @@ def _compute_grid( cta_tile_shape_mnk: Tuple[int, int, int], cluster_shape_mn: Tuple[int, int], max_active_clusters: cutlass.Constexpr, + split_k: int = 1, ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: """Use persistent tile scheduler to compute the grid size for the output tensor C. @@ -1987,6 +1882,8 @@ def _compute_grid( :type cluster_shape_mn: tuple[int, int] :param max_active_clusters: Maximum number of active clusters. :type max_active_clusters: cutlass.Constexpr + :param split_k: Split-K factor to inflate L dimension. + :type split_k: int :return: A tuple containing: - tile_sched_params: Parameters for the persistent tile scheduler. @@ -1996,6 +1893,8 @@ def _compute_grid( c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) gc = cute.zipped_divide(c, tiler=c_shape) num_ctas_mnl = gc[(0, (None, None, None))].shape + if split_k > 1: + num_ctas_mnl = (num_ctas_mnl[0], num_ctas_mnl[1], num_ctas_mnl[2] * split_k) cluster_shape_mnl = (*cluster_shape_mn, 1) tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) @@ -2246,6 +2145,13 @@ def wrapper( :param stream: CUDA stream :type stream: cuda.CUstream """ + # Cast Int64 → Int32 so all derived tensor shapes and cute.size() calls stay in 32-bit, + # which is required by cutlass.range(). The wrapper accepts Int64 for API compatibility + # but practical tensor dimensions always fit in Int32. + m = cutlass.Int32(m) + n = cutlass.Int32(n) + k = cutlass.Int32(k) + l = cutlass.Int32(l) # noqa: E741 scale_k = k // scaling_vector_size # Create A tensor (M, K, L) - K-major diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py index 98f9294d1dc..760be6d7d14 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py @@ -256,6 +256,28 @@ def vectorized_atomic_add_bf16x8(rOut_epi_packed, ) +@dsl_user_op +def vectorized_atomic_add_fp16x8(rOut_epi_packed, + scatter_out_offset, + loc=None, + ip=None): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()), + ], + "red.global.v4.f16x2.add.noftz [$0], {$1, $2, $3, $4};", + "l,r,r,r,r", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + @dsl_user_op def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, @@ -304,6 +326,19 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): loc=loc, ip=ip, ) + elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float16): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()), + ], + "red.add.noftz.f16 [$0], $1;", + "l,h", + has_side_effects=True, + loc=loc, + ip=ip, + ) @dsl_user_op diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f9988fc9442..83ec6820a42 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -126,6 +126,8 @@ class ModelConfig(Generic[TConfig]): # cute dsl op configs use_cute_dsl_blockscaling_mm: bool = False use_cute_dsl_blockscaling_bmm: bool = False + use_cute_dsl_bf16_bmm: bool = False + use_cute_dsl_bf16_gemm: bool = False _frozen: bool = field(default=False, init=False, repr=False) @@ -294,7 +296,7 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir, json_extended_quant_configs = json.load(fm) except Exception: logger.info( - f"No quant_cfg.json found for layer quant info, using hf_quant_config.json." + "No quant_cfg.json found for layer quant info, using hf_quant_config.json." ) json_quant_configs.update(json_extended_quant_configs) # kv_cache_quant_algo is global regardless of MIXED_PRECISION diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 138c6b8af6d..3153b4ee30e 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -41,7 +41,7 @@ from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ ConsumableWeightsDict -from tensorrt_llm._utils import get_sm_version +from tensorrt_llm._utils import get_sm_version, is_sm_100f from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -577,6 +577,25 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, "up_proj": "w3", "gate_proj": "w1", }) + # For DenseGEMM with fused shared expert: inject shared expert + # weights as the last expert (index num_experts). + grandparent_name = '.'.join( + names[:-2]) # e.g. "model.layers.3.mlp" + parent_mlp = all_named_modules.get(grandparent_name) + if getattr(parent_mlp, '_fuse_shared_expert', False): + shared_prefix = grandparent_name + ".shared_experts" + shared_raw = filter_weights(shared_prefix, weights) + shared_renamed = rename_moe_weight( + shared_raw, { + "down_proj": "w2", + "up_proj": "w3", + "gate_proj": "w1", + }) + shared_idx = module.num_experts - 1 + for k, v in shared_renamed.items(): + module_weights[f"{shared_idx}.{k}"] = v + if can_mark_consumed: + weights.mark_consumed(shared_prefix) module.load_weights(weights=[module_weights]) # Mark consumed MoE weights using parent name if can_mark_consumed: @@ -820,8 +839,10 @@ def __init__( fuse_routing_kernel: bool = True, apply_routing: bool = False, moe_backend: str = 'CUTLASS', + use_cute_dsl_bf16_gemm: bool = False, ): super().__init__() + self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), dtype=dtype), requires_grad=False) @@ -846,10 +867,24 @@ def __init__( is_fused=fuse_routing_kernel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states, - self.weight.t(), - bias=None, - out_dtype=torch.float32) + if (self.use_cute_dsl_bf16_gemm and is_sm_100f() + and self.weight.dtype == torch.bfloat16): + input_2d = hidden_states.view(-1, hidden_states.shape[-1]) + m, k = input_2d.shape + n = self.weight.shape[0] + output = torch.empty(m, + n, + dtype=torch.float32, + device=hidden_states.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + input_2d.contiguous(), self.weight, output) + logits = output.view(*hidden_states.shape[:-1], n) + else: + logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states, + self.weight.t(), + bias=None, + out_dtype=torch.float32) return logits def load_weights(self, weights: List[Dict]): @@ -906,18 +941,29 @@ def __init__(self, gate_cls = DeepseekV3Gate if hasattr(model_config.pretrained_config, "gate_cls"): gate_cls = model_config.pretrained_config.gate_cls - self.gate = gate_cls(hidden_size, - num_experts, - top_k=top_k, - n_group=config.n_group, - topk_group=config.topk_group, - routed_scaling_factor=config.routed_scaling_factor, - dtype=dtype, - fuse_routing_kernel=True, - apply_routing=False, - moe_backend=model_config.moe_backend) + self.gate = gate_cls( + hidden_size, + num_experts, + top_k=top_k, + n_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + dtype=dtype, + fuse_routing_kernel=True, + apply_routing=False, + moe_backend=model_config.moe_backend, + use_cute_dsl_bf16_gemm=model_config.use_cute_dsl_bf16_gemm) + + # For DenseGEMM, fuse the shared expert as the last expert in the dense + # GEMM kernel (index num_experts), always activated with scale=1.0. + # This avoids a separate GatedMLP forward and add operation. + self._fuse_shared_expert = model_config.moe_backend.upper( + ) == 'DENSEGEMM' + self.num_routed_experts = num_experts + fused_num_experts = num_experts + 1 if self._fuse_shared_expert else num_experts + self.experts = create_moe( - num_experts=num_experts, + num_experts=fused_num_experts, routing_method=self.gate.routing_method, hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -940,36 +986,42 @@ def __init__(self, self.mapping = model_config.mapping - shared_quant_config = self._get_shared_experts_quant_config( - model_config, layer_idx) - shared_model_config = model_config - if shared_quant_config is not model_config.quant_config: - shared_model_config = copy.copy(model_config) - shared_model_config.quant_config = shared_quant_config - - # For shared experts, use the block size implied by their quant config. - block_size = 1 - if (shared_quant_config is not None - and shared_quant_config.quant_algo is not None - and shared_quant_config.group_size is not None): - block_size = shared_quant_config.group_size - - shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( - shared_expert_intermediate_size, block_size) - - self.shared_experts = GatedMLP( - hidden_size=hidden_size, - intermediate_size=shared_expert_intermediate_size, - bias=False, - dtype=dtype, - config=shared_model_config, - overridden_tp_size=shared_tp_size, - reduce_output=False, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, - ) - self.shared_experts_use_fp4 = ( - shared_quant_config is not None - and shared_quant_config.layer_quant_mode.has_nvfp4()) + if self._fuse_shared_expert: + # Shared expert is fused into self.experts; no separate module needed. + self.shared_experts = None + self.shared_output_scale = None + self.shared_experts_use_fp4 = False + else: + shared_quant_config = self._get_shared_experts_quant_config( + model_config, layer_idx) + shared_model_config = model_config + if shared_quant_config is not model_config.quant_config: + shared_model_config = copy.copy(model_config) + shared_model_config.quant_config = shared_quant_config + + # For shared experts, use the block size implied by their quant config. + block_size = 1 + if (shared_quant_config is not None + and shared_quant_config.quant_algo is not None + and shared_quant_config.group_size is not None): + block_size = shared_quant_config.group_size + + shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( + shared_expert_intermediate_size, block_size) + + self.shared_experts = GatedMLP( + hidden_size=hidden_size, + intermediate_size=shared_expert_intermediate_size, + bias=False, + dtype=dtype, + config=shared_model_config, + overridden_tp_size=shared_tp_size, + reduce_output=False, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + ) + self.shared_experts_use_fp4 = ( + shared_quant_config is not None + and shared_quant_config.layer_quant_mode.has_nvfp4()) self.allreduce = None if not self.use_dp and self.mapping.tp_size > 1: @@ -1113,9 +1165,20 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, num_experts=num_experts, device=hidden_states.device) + is_densegemm = self.gate.moe_backend.upper() == 'DENSEGEMM' + if is_densegemm: + # DenseGEMM handles routing and FP4 quantization internally. + # Pass BF16 hidden_states so quantize_input() works correctly, and + # provide router_weight_t for the internal routing GEMM. + expert_input = hidden_states + extra_kwargs = {"router_weight_t": self.gate.weight.t()} + else: + expert_input = (hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states) + extra_kwargs = {} + routed_output = self.experts( - hidden_states_fp4 - if hidden_states_fp4 is not None else hidden_states, + expert_input, router_logits, do_finalize=do_finalize, output_dtype=hidden_states.dtype, @@ -1124,6 +1187,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, **({ "alltoall_result_do_sum": False } if isinstance(self.experts, WideEPMoE) else {}), + **extra_kwargs, ) return routed_output @@ -1139,6 +1203,19 @@ def forward( if not do_finalize: assert not self.use_dp + if self._fuse_shared_expert: + # Shared expert is fused into self.experts (DenseGEMM always-active + # expert at index num_routed_experts). No separate forward needed. + final_hidden_states = self.compute_routed_output( + hidden_states, hidden_states_fp4, all_rank_num_tokens, + do_finalize) + if not self.use_dp and self.mapping.tp_size > 1: + final_hidden_states = self.allreduce( + final_hidden_states, + all_reduce_params=final_all_reduce_params) + + return final_hidden_states + def _compute_shared_output(): shared_input = (hidden_states_fp4 if (hidden_states_fp4 is not None @@ -1799,7 +1876,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): # at the end of __init__. if model_config.mapping.has_cp_helix(): print( - f"[DeepseekV3ForCausalLM::__init__] Repurposing KVP ranks to TP while keeping other details the same." + "[DeepseekV3ForCausalLM::__init__] Repurposing KVP ranks to TP while keeping other details the same." ) self.mapping_with_cp = copy.deepcopy(model_config.mapping) # Repurpose KVP ranks to TP while keeping other details the same. @@ -1862,8 +1939,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): # Undo any manipulations done to mapping. if self.mapping_with_cp is not None: print( - f"[DeepseekV3ForCausalLM::__init__] Restoring original mapping." - ) + "[DeepseekV3ForCausalLM::__init__] Restoring original mapping.") model_config._frozen = False model_config.mapping = self.mapping_with_cp model_config._frozen = True diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 6946294259f..3dc1c585476 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -462,6 +462,8 @@ def __init__( self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + self.use_cute_dsl_bf16_bmm = config.use_cute_dsl_bf16_bmm + self.use_cute_dsl_bf16_gemm = config.use_cute_dsl_bf16_gemm qkv_shard_indices_mapping = { "q": (0, self.q_size * (2 if self.attn_output_gate else 1)), @@ -1219,6 +1221,8 @@ def __init__( self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + self.use_cute_dsl_bf16_bmm = config.use_cute_dsl_bf16_bmm + self.use_cute_dsl_bf16_gemm = config.use_cute_dsl_bf16_gemm if not self.is_lite: self.kv_a_proj_with_mqa = Linear( @@ -1230,7 +1234,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm) self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank, eps=rms_norm_eps, @@ -1247,7 +1252,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm) else: self.kv_a_proj_with_mqa = Linear( hidden_size, @@ -1258,7 +1264,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm) self.q_proj = Linear( self.q_lora_rank, @@ -1271,7 +1278,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, @@ -1289,7 +1297,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_absorption only @@ -2381,9 +2390,14 @@ def forward_absorption_generation( # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] # The output of bmm is written directly into fused_q + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + bmm_fn = lambda: torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + q_nope_t, self.k_b_proj_trans, q_nope_out) + else: + bmm_fn = lambda: torch.ops.trtllm.bmm_out( + q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out) maybe_execute_in_parallel( - lambda: torch.ops.trtllm.bmm_out( - q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out), + bmm_fn, lambda: self.mqa.mla_rope_generation( fused_q, q_pe, @@ -2480,9 +2494,14 @@ def forward_absorption_generation( if self.v_b_proj.dtype == torch.bfloat16: # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim] # -> [num_heads, seq, v_head_dim] - torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), - self.v_b_proj.transpose(1, 2), - attn_output.transpose(0, 1)) + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + attn_out_latent.transpose(0, 1), self.v_b_proj, + attn_output.transpose(0, 1)) + else: + torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), + self.v_b_proj.transpose(1, 2), + attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: fp8_block_scaling_bmm_out( attn_out_latent, @@ -2533,9 +2552,13 @@ def forward_absorption_context( # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] # The output of bmm is written directly into fused_q - torch.ops.trtllm.bmm_out(q_nope_t, - self.k_b_proj_trans.transpose(1, 2), - q_nope_out) + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + q_nope_t, self.k_b_proj_trans, q_nope_out) + else: + torch.ops.trtllm.bmm_out(q_nope_t, + self.k_b_proj_trans.transpose(1, 2), + q_nope_out) elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) @@ -2592,9 +2615,14 @@ def forward_absorption_context( if self.v_b_proj.dtype == torch.bfloat16: # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim] # -> [num_heads, seq, v_head_dim] - torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), - self.v_b_proj.transpose(1, 2), - attn_output.transpose(0, 1)) + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + attn_out_latent.transpose(0, 1), self.v_b_proj, + attn_output.transpose(0, 1)) + else: + torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), + self.v_b_proj.transpose(1, 2), + attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: fp8_block_scaling_bmm_out( attn_out_latent, @@ -2657,9 +2685,13 @@ def forward_sparse_mla_kvcache_bf16( # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] # The output of bmm is written directly into fused_q - torch.ops.trtllm.bmm_out(q_nope_t, - self.k_b_proj_trans.transpose(1, 2), - q_nope_out) + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + q_nope_t, self.k_b_proj_trans, q_nope_out) + else: + torch.ops.trtllm.bmm_out(q_nope_t, + self.k_b_proj_trans.transpose(1, 2), + q_nope_out) elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = q_nope_out.transpose(0, 1) @@ -2737,9 +2769,14 @@ def forward_sparse_mla_kvcache_bf16( if self.v_b_proj.dtype == torch.bfloat16: # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim] # -> [num_heads, seq, v_head_dim] - torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), - self.v_b_proj.transpose(1, 2), - attn_output.transpose(0, 1)) + if self.use_cute_dsl_bf16_bmm and is_sm_100f(): + torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( + attn_out_latent.transpose(0, 1), self.v_b_proj, + attn_output.transpose(0, 1)) + else: + torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), + self.v_b_proj.transpose(1, 2), + attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: fp8_block_scaling_bmm_out( attn_out_latent, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index beff454807f..647ba9c56e1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -452,7 +452,7 @@ def _create_comm_strategy_auto(self) -> Communication: def forward_impl( self, x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, + router_logits: Optional[torch.Tensor], *, do_finalize: bool = True, output_dtype: Optional[torch.dtype] = None, @@ -492,6 +492,10 @@ def forward_impl( self.determine_communication_method(all_rank_num_tokens_padded, num_chunks) # ========== Step 3: Execute MoE computation ========== + router_weight_t = kwargs.get("router_weight_t") + if router_weight_t is None and kwargs.get("router_weight") is not None: + router_weight_t = kwargs["router_weight"].t() + if num_chunks == 1: # Single chunk case outputs = self._forward_single_chunk( @@ -501,6 +505,7 @@ def forward_impl( all_rank_num_tokens_padded, use_dp_padding, do_finalize, + router_weight_t=router_weight_t, ) else: # Multiple chunks case @@ -512,6 +517,7 @@ def forward_impl( all_rank_num_tokens_padded, use_dp_padding, do_finalize, + router_weight_t=router_weight_t, ) # DWDP: record compute and trigger next prefetch (per-layer, not per-chunk) @@ -565,11 +571,12 @@ def _prepare_workspace_deepgemm( def _forward_single_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, + router_logits: Optional[torch.Tensor], output_dtype: Optional[torch.dtype], all_rank_num_tokens: List[int], use_dp_padding: Optional[bool], do_finalize: bool = True, + router_weight_t: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Single chunk execution path @@ -593,6 +600,7 @@ def _forward_single_chunk( is_last_call, do_finalize, workspace=workspace, + router_weight_t=router_weight_t, ) return outputs @@ -600,7 +608,7 @@ def _forward_single_chunk( def _forward_chunk_impl( self, x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, + router_logits: Optional[torch.Tensor], output_dtype: Optional[torch.dtype], all_rank_num_tokens: List[int], use_dp_padding: bool, @@ -608,6 +616,7 @@ def _forward_chunk_impl( is_last_call: bool, do_finalize: bool = True, workspace: Optional[dict] = None, + router_weight_t: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unified execution flow for all backends @@ -632,36 +641,58 @@ def _forward_chunk_impl( # ========== Step 2: Apply routing (only if backend supports load balancer) ========== if self.backend._supports_load_balancer(): - # Separated routing: ConfigurableMoE calls routing_method - token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) - - # Convert to standard dtypes for consistency with other MoE implementations - token_selected_experts = token_selected_experts.to(torch.int32) - - assert token_selected_experts.shape[1] == self.routing_method.experts_per_token - assert token_selected_experts.shape == token_final_scales.shape - # CutlassFusedMoE and DenseGEMMFusedMoE expect float32, while TRTLLMGenFusedMoE uses bfloat16 - if isinstance(self.backend, (CutlassFusedMoE, DenseGEMMFusedMoE)): - assert token_final_scales.dtype == torch.float32 - assert token_selected_experts.dtype == torch.int32 - - # Convert token_final_scales to bfloat16 if needed (TRTLLMGen backend requires it) - if token_final_scales is not None and isinstance(self.backend, TRTLLMGenFusedMoE): - token_final_scales = token_final_scales.to(torch.bfloat16) - - # Apply router weight on input if enabled - if self.apply_router_weight_on_input: - assert x.dtype != torch.float8_e4m3fn, ( - "Current workaround for apply_router_weight_on_input does not support fp8 input" + if isinstance(self.backend, DenseGEMMFusedMoE): + # DenseGEMM always uses internal routing path. + assert isinstance(self.backend, DenseGEMMFusedMoE), ( + "router_logits=None is only supported by DenseGEMMFusedMoE." ) - x = x * token_final_scales.to(x.dtype) - # TODO: remove this once we have correct fusedmoe kernel ready - # Check if using DeepEP strategies (they don't support token_final_scales=None) - if isinstance(self.comm, (DeepEP, DeepEPLowLatency)): - # DeepEP doesn't support token_final_scales is None - token_final_scales = torch.ones_like(token_final_scales) - else: - token_final_scales = None + assert router_weight_t is not None, ( + "router_weight_t (or router_weight) is required for DenseGEMMFusedMoE." + ) + assert self.comm is None, ( + "DenseGEMM internal routing with router_weight_t currently only supports non-communication path." + ) + assert not self.apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported with DenseGEMM internal routing." + ) + token_selected_experts = None + token_final_scales = None + else: + assert router_logits is not None, ( + f"router_logits must be provided for backend {self.backend.__class__.__name__}." + ) + # Separated routing: ConfigurableMoE calls routing_method + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits + ) + + # Convert to standard dtypes for consistency with other MoE implementations + token_selected_experts = token_selected_experts.to(torch.int32) + + assert token_selected_experts.shape[1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + # CutlassFusedMoE expects float32, while TRTLLMGenFusedMoE uses bfloat16 + if isinstance(self.backend, CutlassFusedMoE): + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + # Convert token_final_scales to bfloat16 if needed (TRTLLMGen backend requires it) + if token_final_scales is not None and isinstance(self.backend, TRTLLMGenFusedMoE): + token_final_scales = token_final_scales.to(torch.bfloat16) + + # Apply router weight on input if enabled + if self.apply_router_weight_on_input: + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + ) + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + # Check if using DeepEP strategies (they don't support token_final_scales=None) + if isinstance(self.comm, (DeepEP, DeepEPLowLatency)): + # DeepEP doesn't support token_final_scales is None + token_final_scales = torch.ones_like(token_final_scales) + else: + token_final_scales = None else: # Fused routing: Backend handles routing internally @@ -739,6 +770,9 @@ def _forward_chunk_impl( eplb_dispatch_kwargs["eplb_local_stats"] = local_statistic_tensor_for_dispatch should_update_eplb_after_dispatch = True + # Keep pre-quant hidden states for DenseGEMM internal routing. + densegemm_router_input = x if self.backend.__class__ == DenseGEMMFusedMoE else None + # ========== Step 4 & 5: Quantization and Communication Dispatch ========== # Order depends on whether strategy supports post-quant dispatch if self.comm is not None: @@ -810,7 +844,14 @@ def _forward_chunk_impl( token_final_scales=token_final_scales, x_sf=x_sf, **self._get_backend_kwargs( - router_logits, do_finalize, all_rank_num_tokens, output_dtype, x, workspace + router_logits, + do_finalize, + all_rank_num_tokens, + output_dtype, + x, + workspace, + router_weight_t=router_weight_t, + router_input=densegemm_router_input, ), ) @@ -884,12 +925,13 @@ def _prepare_workspaces_for_chunk( def _forward_multiple_chunks( self, x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, + router_logits: Optional[torch.Tensor], num_chunks: int, output_dtype: Optional[torch.dtype], all_rank_num_tokens: List[int], use_dp_padding: Optional[bool], do_finalize: bool = True, + router_weight_t: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Multiple chunks execution path with auxiliary stream for overlapping @@ -921,7 +963,11 @@ def _forward_multiple_chunks( chunk_size_list = self.split_chunk(x.shape[0], num_chunks) x_list = x.split(chunk_size_list) - router_logits_list = router_logits.split(chunk_size_list) + router_logits_list = ( + router_logits.split(chunk_size_list) + if router_logits is not None + else tuple([None] * num_chunks) + ) # Determine if we need multiple streams for overlapped execution use_multi_stream = not self.enable_alltoall and self.aux_stream is not None @@ -980,6 +1026,7 @@ def _forward_multiple_chunks( is_last_call, do_finalize, workspace=workspace_0, + router_weight_t=router_weight_t, ) else: # Odd chunk: execute on main stream @@ -993,6 +1040,7 @@ def _forward_multiple_chunks( is_last_call, do_finalize, workspace=workspace_1, + router_weight_t=router_weight_t, ) else: # No overlap @@ -1006,6 +1054,7 @@ def _forward_multiple_chunks( is_last_call, do_finalize, workspace=workspace_0, + router_weight_t=router_weight_t, ) if chunked_used[idx_chunk]: @@ -1119,6 +1168,8 @@ def _get_backend_kwargs( output_dtype: Optional[torch.dtype] = None, x: Optional[torch.Tensor] = None, workspace: Optional[dict] = None, + router_weight_t: Optional[torch.Tensor] = None, + router_input: Optional[torch.Tensor] = None, ) -> Dict: """ Get backend-specific keyword arguments for run_moe @@ -1207,6 +1258,13 @@ def _get_backend_kwargs( if workspace is not None: kwargs["workspace"] = workspace + # DenseGEMM-specific parameters + elif self.backend.__class__ == DenseGEMMFusedMoE: + if router_weight_t is not None: + kwargs["router_weight_t"] = router_weight_t + if router_input is not None: + kwargs["router_input"] = router_input + # TRTLLMGen-specific parameters elif self.backend.__class__ == TRTLLMGenFusedMoE: # Determine router_logits based on whether routing has been done diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py index 43d457389b1..e6150d888e3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py @@ -1,603 +1,37 @@ # SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import inspect -import os -from typing import Dict, List, Optional, Union - -import torch - -from tensorrt_llm.models.modeling_utils import QuantAlgo -from tensorrt_llm.quantization.utils import fp4_utils - -from ...distributed import allgather -from ...memory_buffer_utils import get_memory_buffers -from ...model_config import ModelConfig -from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf -from .interface import MoE, MoEWeightLoadingMode -from .quantization import NVFP4CuteDslFusedMoEMethod -from .routing import BaseMoeRoutingMethod - - -@torch.compile(options={"max-autotune": True}) -def gen_fc2_alpha_fused( - token_selected_experts: torch.Tensor, - token_final_scales: torch.Tensor, - alpha: Optional[torch.Tensor], - alpha_max: Optional[torch.Tensor] = None, - output: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Generate fc2 alpha values, optionally normalized for FC1 alpha_post fusion. - - Instead of: - 1. zeros() -> scatter_() -> multiply with alpha (operates on large [N, E] tensor) - - We do: - 1. Gather alpha values for selected experts (small [N, top_k] tensor) - 2. Multiply scales with gathered alpha (small tensor operation) - 3. Optionally normalize by alpha_max for FC1 alpha_post fusion - 4. Scatter to output (single write to large tensor) - - This reduces memory bandwidth by avoiding read-modify-write on the large tensor. - - Args: - token_selected_experts: Expert indices for each token [num_tokens, top_k] - token_final_scales: Final scaling factors [num_tokens, top_k] - alpha: Per-expert alpha values [expert_size] - alpha_max: Max alpha value for normalization (optional) - output: Pre-allocated output buffer [num_tokens, expert_size] (optional). - If None, a new tensor will be allocated (not compatible with CUDA graph). - """ - # Pre-compute scaled values on small tensor [num_tokens, top_k] - if alpha is not None: - # Gather alpha for selected experts: alpha[expert_idx] for each selection - gathered_alpha = alpha[token_selected_experts.long()] # [num_tokens, top_k] - scaled_values = token_final_scales * gathered_alpha - else: - scaled_values = token_final_scales - - # Normalize by alpha_max for FC1 alpha_post fusion - if alpha_max is not None: - scaled_values = scaled_values / alpha_max - - # Use pre-allocated output or create new tensor - if output is not None: - output.zero_() - fc2_alpha = output - else: - assert alpha is not None, ( - "alpha must be provided when output buffer is not pre-allocated, " - "since expert_size cannot be inferred from token_final_scales alone" - ) - num_tokens = token_selected_experts.shape[0] - expert_size = alpha.shape[0] - fc2_alpha = torch.zeros( - [num_tokens, expert_size], - dtype=torch.float32, - device=token_selected_experts.device, - ) - - return fc2_alpha.scatter_(1, token_selected_experts.long(), scaled_values) - - -class DenseGEMMFusedMoE(MoE): - """CuteDSL DenseGEMM flow of fused mixture of experts (MoE) Layer. - - This backend uses CuTe DSL dense GEMM kernels with fused SwiGLU for MoE - computation. It supports NVFP4 quantization only and is restricted to - SM100/SM103 (Blackwell) architectures. - - Unlike CutlassFusedMoE which uses per-expert scattered GEMM, DenseGEMM - packs all experts into a single dense matrix and uses standard GEMM operations, - which can be more efficient for small token counts (min-latency scenarios). - - Args: - num_experts (int): Number of experts in the MoE layer. - top_k (int): Number of top experts to select for each input token. - hidden_size (int): Size of the hidden state. - intermediate_size (int): Size of the intermediate state. - aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. - dtype (Optional[torch.dtype]): Data type for the weights. - reduce_results (bool): Whether to reduce the results across devices. - model_config (ModelConfig): Configuration object for the model. - """ - - # Memory buffer pool for CUDA graph compatibility - buffers = get_memory_buffers() - - # DenseGEMM only supports SM100 and SM103 (Blackwell CuTe DSL kernels). - _SUPPORTED_SM_VERSIONS = (100, 103) - - @classmethod - def can_implement( - cls, - quant_algo: Optional[QuantAlgo], - dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, - ) -> tuple: - """Check if DenseGEMMFusedMoE can implement the given configuration. - - DenseGEMMFusedMoE supports: - - NVFP4 quantization only - - SM100/SM103 (Blackwell) only - - SwiGLU activation only (swiglu_gptoss_style not supported) - """ - from tensorrt_llm._utils import get_sm_version - - from .interface import _warn_and_return - - sm_version = get_sm_version() - if sm_version not in cls._SUPPORTED_SM_VERSIONS: - return _warn_and_return( - f"DenseGEMMFusedMoE requires SM {cls._SUPPORTED_SM_VERSIONS}, got SM{sm_version}" - ) - - if quant_algo != QuantAlgo.NVFP4: - return _warn_and_return( - f"DenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={quant_algo})" - ) - - if swiglu_gptoss_style: - return _warn_and_return("DenseGEMMFusedMoE does not support swiglu_gptoss_style") - - return (True, None) - - def __init__( - self, - *, - routing_method: BaseMoeRoutingMethod, - num_experts: int, - hidden_size: int, - intermediate_size: int, - dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - model_config: ModelConfig = ModelConfig(), - aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, - weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, - apply_router_weight_on_input: bool = False, - layer_idx: Optional[int] = None, - init_load_balancer: bool = True, - without_comm: bool = False, - activation_type=None, - ): - # DenseGEMM CuTe DSL kernels only support SM100 and SM103. - from tensorrt_llm._utils import get_sm_version - - from ...utils import ActivationType - - sm_version = get_sm_version() - assert sm_version in self._SUPPORTED_SM_VERSIONS, ( - f"DenseGEMMFusedMoE only supports SM {self._SUPPORTED_SM_VERSIONS} " - f"(got SM {sm_version}). The CuTe DSL kernels require Blackwell architecture." - ) - - # DenseGEMM kernel hardcodes SwiGLU fusion — reject other activation types - # before calling super().__init__() to fail fast with a clear message. - if activation_type is None: - activation_type = ActivationType.Swiglu - assert activation_type == ActivationType.Swiglu, ( - f"DenseGEMMFusedMoE only supports SwiGLU activation " - f"(got activation_type={activation_type}). " - f"The FC1 kernel fuses SwiGLU into the GEMM epilogue." - ) - - # FC2 DenseGEMM kernel tiles K dimension with MMA tile size 256. - # weight_per_expert (= intermediate_size) must be 256-aligned so that - # expert boundaries align with MMA tile boundaries. - _MMA_TILE_K = 256 - assert intermediate_size % _MMA_TILE_K == 0, ( - f"DenseGEMMFusedMoE requires intermediate_size to be a multiple of " - f"{_MMA_TILE_K} (got intermediate_size={intermediate_size}). " - f"FC2 kernel cannot correctly split alpha_scale at expert boundaries " - f"when weight_per_expert is not MMA tile-K aligned." - ) - - # Call MoE base class directly (not CutlassFusedMoE). - # Note: `without_comm` and `apply_router_weight_on_input` are accepted - # for API compatibility with create_moe_backend() but are not passed to - # MoE.__init__() since DenseGEMM does not use alltoall communication. - super().__init__( - routing_method=routing_method, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, - reduce_results=reduce_results, - model_config=model_config, - aux_stream_dict=aux_stream_dict, - weight_loading_mode=weight_loading_mode, - layer_idx=layer_idx, - init_load_balancer=init_load_balancer, - activation_type=activation_type, - ) - - # Environment variable to control fc2_alpha fusion into FC1's alpha_post. - # Default: disabled (0). Set to "1" to enable fusion (known accuracy issue under TP). - self.use_fused_fc2_alpha = os.environ.get("TRTLLM_MOE_FUSED_FC2_ALPHA", "0") == "1" - - # Pre-register fc2_alpha_max buffer for fused fc2_alpha optimization. - # Populated in load_weights() with max(fc2_alpha). - self.register_buffer("fc2_alpha_max", torch.zeros(1, dtype=torch.float32)) - - # Initialize auxiliary stream and events for gen_fc2_alpha_fused overlap with fc1 - if self.aux_stream_dict is None: - self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} - if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict: - self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = torch.cuda.Stream() - self.event_dict = {} - for key in [EventType.Main, EventType.MoeFc2Alpha]: - self.event_dict[key] = torch.cuda.Event() - - # Weight creation - self._weights_created = False - if not model_config.skip_create_weights_in_init: - self.create_weights() - - def _supports_load_balancer(self) -> bool: - """DenseGEMMFusedMoE supports load balancer.""" - return True - - def _get_quant_method(self): - if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( - exclude_kv_cache=True - ): - if self.quant_config.layer_quant_mode.has_nvfp4(): - return NVFP4CuteDslFusedMoEMethod() - raise ValueError( - f"{self.__class__.__name__} only supports NVFP4 quantization, " - f"got {self.quant_config.quant_mode}." - ) - raise ValueError( - f"{self.__class__.__name__} requires quantization (NVFP4), " - f"but no quantization config was provided." - ) +"""Selector for DenseGEMM MoE SM-partition strategy. - def create_weights(self): - if self._weights_created: - return +Controls which parallel strategy is used when ``TRTLLM_MOE_FUSED_FC2_ALPHA=0`` +(i.e. the non-fused path where FC1 and Router GEMM run concurrently): - self.quant_method = self._get_quant_method() - self.quant_method.create_weights(self) +- ``gc`` (default) – Green Context: hardware-level SM pool isolation via + ``torch.cuda.GreenContext``. Tuned by ``TRTLLM_MOE_FC1_SM_NUNBER`` (default 0.5). +- ``smp`` – SM Partition: soft SM limit via the ``sm_budget`` kernel launch + parameter. Tuned by ``TRTLLM_MOE_FC1_SM_NUNBER`` (default 0.5). +- ``no_overlap`` – Single stream, no SM partitioning. All kernels use all available + SMs and run sequentially. Useful as a performance baseline. - self._weights_created = True +Set ``TRTLLM_MOE_SM_SPLIT_MODE`` to one of the above values to select the backend. +""" - def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False): - assert self._weights_created - assert len(weights) == 1 - weights = weights[0] - - kargs = {} - if "allow_partial_loading" in inspect.getfullargspec(self.quant_method.load_weights).args: - kargs["allow_partial_loading"] = allow_partial_loading - self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) - - # Transpose w2_weight layout: (E, H, ...) -> (H, E, ...) for dense GEMM. - # NOTE: .contiguous() on the transposed view allocates a full-size temporary, - # temporarily doubling peak memory. An in-place multi-dim transpose is not - # feasible without complex cycle-following, and this runs only once during - # weight loading, so the trade-off is acceptable. - w2_transposed = self.w2_weight.transpose(0, 1).contiguous() - self.w2_weight.reshape([-1]).copy_(w2_transposed.reshape([-1]), non_blocking=True) - del w2_transposed - if self.has_any_quant: - if self.has_nvfp4: - self._transform_w2_weight_scale_for_min_latency() - # Compute fc2_alpha_max for fused fc2_alpha optimization - self.fc2_alpha_max.copy_(torch.max(self.fc2_alpha).reshape(1), non_blocking=True) - else: - raise ValueError( - f"{self.__class__.__name__} only supports nvfp4 quantization, " - f"got {self.quant_config.quant_mode}." - ) - - def post_load_weights(self): - self.quant_method.post_load_weights(self) - - def _transform_w2_weight_scale_for_min_latency(self): - """Transform w2_weight_scale for minimum latency path optimization.""" - # Calculate padded dimensions - nrows = fp4_utils.pad_up(self.hidden_size, 128) - ncols = fp4_utils.pad_up( - self.intermediate_size_per_partition // self.scaling_vector_size, 4 - ) - - # Clone and convert weight scale to uint8 - w2_weight_scale = self.w2_weight_scale.clone().view(torch.uint8) - - # Unswizzle the scale factor - w2_weight_scale = unswizzle_sf( - w2_weight_scale, - self.hidden_size * self.expert_size_per_partition, - self.intermediate_size_per_partition, - ) - - # Reshape and transpose for min latency layout - w2_weight_scale = w2_weight_scale.reshape([self.expert_size_per_partition, nrows, ncols]) - w2_weight_scale = w2_weight_scale.transpose(0, 1).reshape( - nrows, self.expert_size_per_partition * ncols - ) - - # Swizzle back with new layout - w2_weight_scale = swizzle_sf( - w2_weight_scale, - self.hidden_size, - self.expert_size_per_partition * self.intermediate_size_per_partition, - ) - - # Copy back to original tensor - self.w2_weight_scale.copy_( - w2_weight_scale.view(self.w2_weight_scale.dtype).view(self.w2_weight_scale.shape), - non_blocking=True, - ) - - def quantize_input( - self, x: Union[torch.Tensor, Fp4QuantizedTensor], post_quant_comm: bool = True - ): - """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. - - Args: - x: Input tensor to quantize - post_quant_comm: - If True, quantize for post-quant communication path. - If False, quantize for non-communication path - - Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed - - For quantization methods that produce scaling factors: - - x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)] - - The 2D shape is required for proper handling in alltoall/allgather operations - - scaling_vector_size is typically the group size for block-wise quantization - """ - x_sf = None - if self.has_nvfp4: - if isinstance(x, Fp4QuantizedTensor): - assert not x.is_sf_swizzled, ( - "Fp4QuantizedTensor should not be swizzled before communication" - ) - x_row = x.shape[0] - x, x_sf = x.fp4_tensor, x.scaling_factor - else: - x_row = x.shape[0] - x, x_sf = torch.ops.trtllm.fp4_quantize( - x, self.fc31_input_scale, self.scaling_vector_size, False, False - ) - else: - raise ValueError( - f"{self.__class__.__name__} only supports nvfp4 quantization, " - f"got {self.quant_config.quant_mode}." - ) - - if x_sf is not None: - x_sf = x_sf.view(x_row, -1) - - return x, x_sf - - def run_moe_nvfp4( - self, - x: torch.Tensor, - token_selected_experts: torch.Tensor, - token_final_scales: Optional[torch.Tensor], - x_sf: Optional[torch.Tensor] = None, - enable_alltoall: bool = False, - ) -> torch.Tensor: - """Run MoE computation with NVFP4 quantization. - - Args: - x: Input tensor - token_selected_experts: Expert indices for each token - token_final_scales: Final scaling factors for each token - x_sf: Input scale factors - enable_alltoall: Whether alltoall communication is enabled - - Note: - The implementation is controlled by TRTLLM_MOE_FUSED_FC2_ALPHA env var (default: enabled). - When enabled, fc2_alpha is fused into FC1's alpha_post with scalar fc2_alpha_max in FC2. - When disabled, uses the original per-token per-expert fc2_alpha in FC2. - """ - assert self.has_nvfp4 - num_tokens = x.shape[0] - - # Get pre-allocated buffer for fc2_alpha (CUDA graph compatible) - capture_graph = torch.cuda.is_current_stream_capturing() - fc2_alpha_buffer = DenseGEMMFusedMoE.buffers.get_buffer( - (num_tokens, self.expert_size_per_partition), - dtype=torch.float32, - buffer_name="fc2_alpha", - reserve_buffer=capture_graph, - ) - - if self.use_fused_fc2_alpha: - # New implementation: fuse fc2_alpha into FC1's alpha_post - x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) - - # Generate normalized fc2_alpha for FC1 alpha_post fusion - fc2_alpha_normalized = gen_fc2_alpha_fused( - token_selected_experts, - token_final_scales, - self.fc2_alpha, - self.fc2_alpha_max, # Normalize by max for FC1 alpha_post - fc2_alpha_buffer, # Pre-allocated buffer - ) - - # FC1: GEMM + SwiGLU with post-SwiGLU alpha scaling (fused fc2_alpha) - fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( - x, - self.w3_w1_weight.view(torch.uint8), - x_sf, - self.w3_w1_weight_scale, - self.fc31_alpha, - fc2_alpha_normalized, # Pass normalized fc2_alpha as alpha_post - self.fc2_input_scale, - expert_count=self.expert_size_per_partition, - weight_per_expert=2 * self.intermediate_size_per_partition, - output_dtype=torch.float4_e2m1fn_x2, - scaling_vector_size=self.scaling_vector_size, - ) - - # FC2: Standard nvfp4_gemm with scalar alpha = fc2_alpha_max - final_hidden_states = torch.ops.trtllm.nvfp4_gemm( - fc1_output.view(torch.uint8), - self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), - fc1_output_sf.view(torch.uint8).reshape(-1), - self.w2_weight_scale.view(torch.uint8), - self.fc2_alpha_max, - torch.bfloat16, - to_userbuffers=False, - allowed_backends="cutlass,cublaslt,cutedsl,cuda_core", - ) - else: - # Original implementation: per-token per-expert fc2_alpha in FC2 - self.event_dict[EventType.Main].record() - x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) - - # FC1: GEMM + SwiGLU, output is fp4 quantized - fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( - x, - self.w3_w1_weight.view(torch.uint8), - x_sf, - self.w3_w1_weight_scale, - self.fc31_alpha, - None, # alpha_post: no post-SwiGLU scaling - self.fc2_input_scale, - expert_count=self.expert_size_per_partition, - weight_per_expert=2 * self.intermediate_size_per_partition, - output_dtype=torch.float4_e2m1fn_x2, - scaling_vector_size=self.scaling_vector_size, - ) - - with torch.cuda.stream(self.aux_stream_dict[AuxStreamType.MoeFc2Alpha]): - self.event_dict[EventType.Main].wait() - fc2_alpha = gen_fc2_alpha_fused( - token_selected_experts, - token_final_scales, - self.fc2_alpha, - output=fc2_alpha_buffer, # Use pre-allocated buffer - ) - self.event_dict[EventType.MoeFc2Alpha].record() - - self.event_dict[EventType.MoeFc2Alpha].wait() - - # FC2: input k = expert_count * intermediate_size (after SwiGLU) - final_hidden_states = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_fc2_blackwell( - fc1_output, - self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), - fc1_output_sf.reshape(-1), - self.w2_weight_scale, - fc2_alpha, - expert_count=self.expert_size_per_partition, - weight_per_expert=self.intermediate_size_per_partition, - output_dtype=torch.bfloat16, - scaling_vector_size=self.scaling_vector_size, - ) - - return final_hidden_states - - def run_moe( - self, - x: torch.Tensor, - token_selected_experts: torch.Tensor, - token_final_scales: Optional[torch.Tensor], - x_sf: Optional[torch.Tensor] = None, - enable_alltoall: bool = False, - **kwargs, - ) -> torch.Tensor: - """ - Run MoE computation with DenseGEMM backend (NVFP4 only). - - Args: - x: Input hidden states (pre-quantized to NVFP4) - token_selected_experts: Expert IDs [num_tokens, top_k]. If EPLB is enabled, - this represents expert slots [num_tokens, top_k] instead. - token_final_scales: Final scaling factors for each token - x_sf: Input scale factors for NVFP4 - enable_alltoall: Whether alltoall communication is enabled. - **kwargs: Additional arguments for forward compatibility. - - Returns: - final_hidden_states tensor. - """ - assert self.has_nvfp4, ( - f"{self.__class__.__name__} only supports nvfp4 quantization, " - f"got {self.quant_config.quant_mode}." - ) - return self.run_moe_nvfp4( - x=x, - token_selected_experts=token_selected_experts, - token_final_scales=token_final_scales, - x_sf=x_sf, - enable_alltoall=enable_alltoall, - ) - - def forward_chunk( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - repeating_info: tuple = (True, True), - ) -> torch.Tensor: - # Currently, the default path is that ConfigurableMoE calls DenseGEMMFusedMoE.run_moe. - # This forward_chunk method is a reference implementation of the legacy path. - # Apply routing - token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) - assert token_selected_experts.shape[1] == self.routing_method.experts_per_token - assert token_selected_experts.shape == token_final_scales.shape - assert token_selected_experts.shape[0] == router_logits.shape[0] - assert token_final_scales.dtype == torch.float32 - assert token_selected_experts.dtype == torch.int32 - - x, x_sf = self.quantize_input(x) - - if self.use_dp and self.parallel_size > 1: - x, x_sf, token_selected_experts, token_final_scales = allgather( - [x, x_sf, token_selected_experts, token_final_scales], - self.mapping, - dim=0, - sizes=None if use_dp_padding else all_rank_num_tokens, - ) - - x = self.run_moe( - x=x, - token_selected_experts=token_selected_experts, - token_final_scales=token_final_scales, - x_sf=x_sf, - enable_alltoall=False, - ) - return x +import os - def forward_impl( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - *, - do_finalize: bool = True, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - **kwargs, - ) -> torch.Tensor: - assert do_finalize, "DenseGEMMFusedMoE does not support do_finalize=False" +_mode = os.environ.get("TRTLLM_MOE_SM_SPLIT_MODE", "gc").strip().lower() - is_first_call = self.repeat_idx == 0 - is_last_call = self.repeat_idx == self.repeat_count - 1 +if _mode == "smp": + from .fused_moe_densegemm_smp import DenseGEMMFusedMoE +elif _mode == "no_overlap": + from .fused_moe_densegemm_no_overlap import NoOverlapDenseGEMMFusedMoE as DenseGEMMFusedMoE +else: + if _mode != "gc": + import warnings - outputs = self.forward_chunk( - x, - router_logits, - output_dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - repeating_info=(is_first_call, is_last_call), - ) - outputs = self.reducescatter_or_allreduce( - outputs, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, + warnings.warn( + f"Unknown TRTLLM_MOE_SM_SPLIT_MODE='{_mode}', falling back to 'gc'.", + stacklevel=1, ) + from .fused_moe_densegemm_gc import DenseGEMMFusedMoE - if self.use_dp and self.parallel_size > 1: - rank = self.parallel_rank - outputs = outputs[: all_rank_num_tokens[rank]] - self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 - return outputs +__all__ = ["DenseGEMMFusedMoE"] diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_gc.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_gc.py new file mode 100644 index 00000000000..ec273329e21 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_gc.py @@ -0,0 +1,1034 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import os +from typing import Any, Dict, List, Optional, Union + +import torch + +from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils import fp4_utils + +from ...autotuner import ( + AutoTuner, + ConstraintSpec, + DynamicTensorSpec, + OptimizationProfile, + TunableRunner, + TuningConfig, +) +from ...distributed import allgather +from ...memory_buffer_utils import get_memory_buffers +from ...model_config import ModelConfig +from ...utils import ( + AuxStreamType, + EventType, + Fp4QuantizedTensor, + deep_gemm_tuning_buckets, + prev_deep_gemm_bucket, + swizzle_sf, + unswizzle_sf, +) +from .green_context import create_sm_only_gc_streams +from .interface import MoE, MoEWeightLoadingMode +from .quantization import NVFP4CuteDslFusedMoEMethod +from .routing import BaseMoeRoutingMethod + + +@torch.compile(options={"max-autotune": True}) +def gen_fc2_alpha_fused( + token_selected_experts: torch.Tensor, + token_final_scales: torch.Tensor, + alpha: Optional[torch.Tensor], + alpha_max: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Generate fc2 alpha values, optionally normalized for FC1 alpha_post fusion. + + Instead of: + 1. zeros() -> scatter_() -> multiply with alpha (operates on large [N, E] tensor) + + We do: + 1. Gather alpha values for selected experts (small [N, top_k] tensor) + 2. Multiply scales with gathered alpha (small tensor operation) + 3. Optionally normalize by alpha_max for FC1 alpha_post fusion + 4. Scatter to output (single write to large tensor) + + This reduces memory bandwidth by avoiding read-modify-write on the large tensor. + + Args: + token_selected_experts: Expert indices for each token [num_tokens, top_k] + token_final_scales: Final scaling factors [num_tokens, top_k] + alpha: Per-expert alpha values [expert_size] + alpha_max: Max alpha value for normalization (optional) + output: Pre-allocated output buffer [num_tokens, expert_size] (optional). + If None, a new tensor will be allocated (not compatible with CUDA graph). + """ + # Pre-compute scaled values on small tensor [num_tokens, top_k] + if alpha is not None: + # Gather alpha for selected experts: alpha[expert_idx] for each selection + gathered_alpha = alpha[token_selected_experts.long()] # [num_tokens, top_k] + scaled_values = token_final_scales * gathered_alpha + else: + scaled_values = token_final_scales + + # Normalize by alpha_max for FC1 alpha_post fusion + if alpha_max is not None: + scaled_values = scaled_values / alpha_max + + # Use pre-allocated output or create new tensor + if output is not None: + output.zero_() + fc2_alpha = output + else: + assert alpha is not None, ( + "alpha must be provided when output buffer is not pre-allocated, " + "since expert_size cannot be inferred from token_final_scales alone" + ) + num_tokens = token_selected_experts.shape[0] + expert_size = alpha.shape[0] + fc2_alpha = torch.zeros( + [num_tokens, expert_size], + dtype=torch.float32, + device=token_selected_experts.device, + ) + + return fc2_alpha.scatter_(1, token_selected_experts.long(), scaled_values) + + +class DenseGEMMGCSMRunner(TunableRunner): + """TunableRunner that sweeps the FC1 SM count for the Green Context overlap path. + + Each tactic is an integer representing the requested SM count assigned to FC1. + The router partition receives the remaining SMs. GreenContext stream pairs for + all candidate SM splits are pre-created in ``DenseGEMMFusedMoE.__init__`` and + stored in ``moe_layer._gc_stream_pool``. + """ + + tuning_config_cache: dict = {} + + def __init__(self, moe_layer: "DenseGEMMFusedMoE") -> None: + super().__init__() + self.moe = moe_layer + + def unique_id(self): + moe = self.moe + return ( + moe.num_experts, + moe.hidden_size, + moe.intermediate_size_per_partition, + moe.expert_size_per_partition, + ) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + candidates = list(self.moe._gc_stream_pool.keys()) + # print( + # f"[DenseGEMMGCSMRunner] get_valid_tactics: num_tokens={inputs[0].shape[0]}, " + # f"fc1_sms candidates={candidates}" + # ) + return candidates + + def get_tuning_config(self) -> TuningConfig: + key = self.unique_id() + if key not in self.__class__.tuning_config_cache: + self.__class__.tuning_config_cache[key] = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + 0, + 0, + # Use the same fine-grained bucket grid as the inner + # GEMM runners so that the outer sm-split tactic is + # tuned at the same resolution as the inner tactics. + deep_gemm_tuning_buckets, + # Floor lookup: map actual num_tokens to the largest + # cached bucket strictly smaller than x. This ensures + # the cached tactic was profiled with fewer tokens than + # the actual workload (conservative choice). + prev_deep_gemm_bucket, + ), + ), + constraint_specs=( + # x_sf: shape[0] tracks num_tokens (pre-swizzle, 2-D) + ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), + # router_input: shape[0] tracks num_tokens + ConstraintSpec(3, 0, lambda shapes: shapes[0][0]), + ), + use_cold_l2_cache=True, + ) + return self.__class__.tuning_config_cache[key] + + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: Any = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + x, x_sf, router_weight_t, router_input = inputs + num_tokens = x.shape[0] + pool = self.moe._gc_stream_pool + # do_preparation is called by the autotuner with tactic=-1 before the + # actual sweep; use the first pool entry so the kernel compiles. + fc1_sms = tactic if tactic in pool else next(iter(pool)) + if do_preparation: + # Two-phase inner-runner warm-up before the outer CUDA-graph capture. + # + # Phase A — full-SM tuning (once, cheap in the outer dimension): + # Call _warm_inner_runners_full_sm() with sm_budget=0 so the inner + # autotuners (CuteDSLNVFP4DenseGemmSwigluRunner and + # CuteDSLBf16BlackwellGemmRunner) profile all token-buckets once + # with the full hardware SM count and cache the best tactic. + # Because sm_budget is excluded from both runners' unique_id(), this + # single entry is shared across every fc1_sms split. + # + # Phase B — per-candidate kernel-compilation warm-up (S passes, fast): + # Iterate over all fc1_sms candidates as before. Since the inner + # autotuner caches are now warm (Phase A), choose_one() returns the + # cached full-SM tactic without re-profiling. The only work done is + # kernel compilation for the specific (tactic, max_active_clusters) + # combination corresponding to each split, which is required before + # CUDA-graph capture (cute.compile() is not graph-safe). + # + # Net cost: O(inner_buckets × inner_tactics) + O(S) instead of + # O(S × inner_buckets × inner_tactics). + # print( + # f"[DenseGEMMGCSMRunner] [preparation] Phase A: warming inner runners " + # f"with full SM, num_tokens={num_tokens}" + # ) + try: + self.moe._warm_inner_runners_full_sm(x, x_sf, router_weight_t, router_input) + except Exception as e: + logger.warning( + f"[DenseGEMMGCSMRunner] full-SM inner warm-up failed, " + f"num_tokens={num_tokens}: {e}" + ) + + for cand_fc1_sms in pool: + # print( + # f"[DenseGEMMGCSMRunner] [preparation] Phase B: kernel compilation " + # f"fc1_sms={cand_fc1_sms}, num_tokens={num_tokens}" + # ) + try: + self.moe._run_moe_nvfp4_gc_impl( + x, x_sf, router_weight_t, router_input, cand_fc1_sms + ) + except Exception as e: + # Some candidates may be incompatible with the hardware GC + # topology (e.g. unsupported cluster shape for a given SM + # count). Catch and log the failure so the outer autotuner + # can still profile all remaining candidates; this candidate + # will be naturally rejected during the outer profiling sweep. + logger.warning( + f"[DenseGEMMGCSMRunner] preparation failed for " + f"fc1_sms={cand_fc1_sms}, num_tokens={num_tokens}: {e}" + ) + else: + pass + # print( + # f"[DenseGEMMGCSMRunner] fc1_sms={fc1_sms}, " + # f"num_tokens={num_tokens}" + # ) + return self.moe._run_moe_nvfp4_gc_impl(x, x_sf, router_weight_t, router_input, fc1_sms) + + +class DenseGEMMFusedMoE(MoE): + """CuteDSL DenseGEMM flow of fused mixture of experts (MoE) Layer — Green Context variant. + + This backend uses CuTe DSL dense GEMM kernels with fused SwiGLU for MoE + computation. It supports NVFP4 quantization only and is restricted to + SM100/SM103 (Blackwell) architectures. + + Unlike CutlassFusedMoE which uses per-expert scattered GEMM, DenseGEMM + packs all experts into a single dense matrix and uses standard GEMM operations, + which can be more efficient for small token counts (min-latency scenarios). + + This variant (gc) uses GreenContext for hardware-level SM isolation so that + FC1 and the Router GEMM truly run in parallel without SM contention. The SM + split follows the same formula as the smp variant (raw SM count), but enforcement + is via hardware GreenContext partitions rather than soft sm_budget kernel hints. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + """ + + # Memory buffer pool for CUDA graph compatibility + buffers = get_memory_buffers() + + # DenseGEMM only supports SM100 and SM103 (Blackwell CuTe DSL kernels). + _SUPPORTED_SM_VERSIONS = (100, 103) + + @classmethod + def can_implement( + cls, + quant_algo: Optional[QuantAlgo], + dtype_activation: torch.dtype = torch.bfloat16, + swiglu_gptoss_style: bool = False, + ) -> tuple: + """Check if DenseGEMMFusedMoE can implement the given configuration. + + DenseGEMMFusedMoE supports: + - NVFP4 quantization only + - SM100/SM103 (Blackwell) only + - SwiGLU activation only (swiglu_gptoss_style not supported) + """ + from tensorrt_llm._utils import get_sm_version + + from .interface import _warn_and_return + + sm_version = get_sm_version() + if sm_version not in cls._SUPPORTED_SM_VERSIONS: + return _warn_and_return( + f"DenseGEMMFusedMoE requires SM {cls._SUPPORTED_SM_VERSIONS}, got SM{sm_version}" + ) + + if quant_algo != QuantAlgo.NVFP4: + return _warn_and_return( + f"DenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={quant_algo})" + ) + + if swiglu_gptoss_style: + return _warn_and_return("DenseGEMMFusedMoE does not support swiglu_gptoss_style") + + return (True, None) + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + init_load_balancer: bool = True, + without_comm: bool = False, + activation_type=None, + ): + # DenseGEMM CuTe DSL kernels only support SM100 and SM103. + from tensorrt_llm._utils import get_sm_version + + from ...utils import ActivationType + + sm_version = get_sm_version() + assert sm_version in self._SUPPORTED_SM_VERSIONS, ( + f"DenseGEMMFusedMoE only supports SM {self._SUPPORTED_SM_VERSIONS} " + f"(got SM {sm_version}). The CuTe DSL kernels require Blackwell architecture." + ) + + # DenseGEMM kernel hardcodes SwiGLU fusion — reject other activation types + # before calling super().__init__() to fail fast with a clear message. + if activation_type is None: + activation_type = ActivationType.Swiglu + assert activation_type == ActivationType.Swiglu, ( + f"DenseGEMMFusedMoE only supports SwiGLU activation " + f"(got activation_type={activation_type}). " + f"The FC1 kernel fuses SwiGLU into the GEMM epilogue." + ) + + # FC2 DenseGEMM kernel tiles K dimension with MMA tile size 256. + # weight_per_expert (= intermediate_size) must be 256-aligned so that + # expert boundaries align with MMA tile boundaries. + _MMA_TILE_K = 256 + assert intermediate_size % _MMA_TILE_K == 0, ( + f"DenseGEMMFusedMoE requires intermediate_size to be a multiple of " + f"{_MMA_TILE_K} (got intermediate_size={intermediate_size}). " + f"FC2 kernel cannot correctly split alpha_scale at expert boundaries " + f"when weight_per_expert is not MMA tile-K aligned." + ) + + # Call MoE base class directly (not CutlassFusedMoE). + # Note: `without_comm` and `apply_router_weight_on_input` are accepted + # for API compatibility with create_moe_backend() but are not passed to + # MoE.__init__() since DenseGEMM does not use alltoall communication. + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream_dict=aux_stream_dict, + weight_loading_mode=weight_loading_mode, + layer_idx=layer_idx, + init_load_balancer=init_load_balancer, + activation_type=activation_type, + ) + + # Environment variable to control fc2_alpha fusion into FC1's alpha_post. + # Default: disabled (0). Set to "1" to enable fusion (known accuracy issue under TP). + self.use_fused_fc2_alpha = os.environ.get("TRTLLM_MOE_FUSED_FC2_ALPHA", "0") == "1" + + # Pre-register fc2_alpha_max buffer for fused fc2_alpha optimization. + # Populated in load_weights() with max(fc2_alpha). + self.register_buffer("fc2_alpha_max", torch.zeros(1, dtype=torch.float32)) + + device_id = torch.cuda.current_device() + num_sms = torch.cuda.get_device_properties(device_id).multi_processor_count + self.num_sms = num_sms + print(f"DenseGEMMFusedMoE initializing on device {device_id} with {num_sms} SMs") + + # Initialize GreenContext SM partitions and events for FC1/Router overlap. + # Each GreenContext provides hardware-level SM isolation so that FC1 and + # the Router GEMM truly run in parallel without SM contention. + # Uses PyTorch torch.cuda.GreenContext path: SM isolation only, + # workqueue remains shared across GreenContexts. + + self._gc_cleanup = None + if not self.use_fused_fc2_alpha: + # Build a pool of GC stream pairs for all SM-split candidates so the + # autotuner can profile each split and pick the best one. + # Sweep fc1_sms from 104 to num_sms (exclusive) with step 8. + # Step 8 matches smCoscheduledAlignment on Blackwell so every + # candidate maps to a distinct hardware SM partition; non-aligned + # values would be rounded to the same partition as the next + # multiple of 8, making duplicate or incompatible GC configs. + _alignment = 8 # smCoscheduledAlignment on Blackwell (SM100/SM103) + _start = 108 # first sensible fc1_sms (leaves at least 1 SM for Router) + _candidates = sorted(range(_start, num_sms, _alignment)) + + self._gc_stream_pool: Dict[int, tuple] = {} + self._gc_cleanup_fns: List = [] + for cand_fc1_sms in _candidates: + cand_router_sms = max(1, num_sms - cand_fc1_sms) + try: + fc1_s, router_s, cleanup = create_sm_only_gc_streams( + cand_fc1_sms, cand_router_sms, device_id + ) + self._gc_stream_pool[cand_fc1_sms] = (fc1_s, router_s) + self._gc_cleanup_fns.append(cleanup) + except Exception as e: + logger.warning( + f"skipping fc1_sms={cand_fc1_sms} (router_sms={cand_router_sms}): {e}" + ) + + self.fc1_gc = None # managed by Driver API + self.router_gc = None + # Pick the router stream from the first successfully-created candidate. + _first_valid = next(iter(self._gc_stream_pool)) + router_gc_stream = self._gc_stream_pool[_first_valid][1] + else: + self._gc_stream_pool = {} + self._gc_cleanup_fns = [] + self.fc1_gc = None + self.router_gc = None + router_gc_stream = torch.cuda.Stream() + + # Initialize auxiliary stream and events for gen_fc2_alpha_fused overlap with fc1 + if self.aux_stream_dict is None: + self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} + if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict: + self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = router_gc_stream + self.event_dict = {} + for key in [EventType.Main, EventType.MoeFc2Alpha]: + self.event_dict[key] = torch.cuda.Event(enable_timing=False) + self.fc1_done_event = torch.cuda.Event(enable_timing=False) + + # Weight creation + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + def __del__(self) -> None: + """Release raw CUDA Driver handles held by all GreenContext stream pairs.""" + for cleanup in getattr(self, "_gc_cleanup_fns", []): + try: + cleanup() + except Exception: + pass # best-effort; avoid raising in __del__ + # Legacy single-cleanup fallback (kept for safety). + cleanup = getattr(self, "_gc_cleanup", None) + if cleanup is not None: + try: + cleanup() + except Exception: + pass + + def _supports_load_balancer(self) -> bool: + """DenseGEMMFusedMoE supports load balancer.""" + return True + + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True + ): + if self.quant_config.layer_quant_mode.has_nvfp4(): + return NVFP4CuteDslFusedMoEMethod() + raise ValueError( + f"{self.__class__.__name__} only supports NVFP4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + raise ValueError( + f"{self.__class__.__name__} requires quantization (NVFP4), " + f"but no quantization config was provided." + ) + + def create_weights(self): + if self._weights_created: + return + + self.quant_method = self._get_quant_method() + self.quant_method.create_weights(self) + + self._weights_created = True + + def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False): + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec(self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) + + # Transpose w2_weight layout: (E, H, ...) -> (H, E, ...) for dense GEMM. + # NOTE: .contiguous() on the transposed view allocates a full-size temporary, + # temporarily doubling peak memory. An in-place multi-dim transpose is not + # feasible without complex cycle-following, and this runs only once during + # weight loading, so the trade-off is acceptable. + w2_transposed = self.w2_weight.transpose(0, 1).contiguous() + self.w2_weight.reshape([-1]).copy_(w2_transposed.reshape([-1]), non_blocking=True) + del w2_transposed + if self.has_any_quant: + if self.has_nvfp4: + self._transform_w2_weight_scale_for_min_latency() + # Compute fc2_alpha_max for fused fc2_alpha optimization + self.fc2_alpha_max.copy_(torch.max(self.fc2_alpha).reshape(1), non_blocking=True) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + def post_load_weights(self): + self.quant_method.post_load_weights(self) + + def _transform_w2_weight_scale_for_min_latency(self): + """Transform w2_weight_scale for minimum latency path optimization.""" + # Calculate padded dimensions + nrows = fp4_utils.pad_up(self.hidden_size, 128) + ncols = fp4_utils.pad_up( + self.intermediate_size_per_partition // self.scaling_vector_size, 4 + ) + + # Clone and convert weight scale to uint8 + w2_weight_scale = self.w2_weight_scale.clone().view(torch.uint8) + + # Unswizzle the scale factor + w2_weight_scale = unswizzle_sf( + w2_weight_scale, + self.hidden_size * self.expert_size_per_partition, + self.intermediate_size_per_partition, + ) + + # Reshape and transpose for min latency layout + w2_weight_scale = w2_weight_scale.reshape([self.expert_size_per_partition, nrows, ncols]) + w2_weight_scale = w2_weight_scale.transpose(0, 1).reshape( + nrows, self.expert_size_per_partition * ncols + ) + + # Swizzle back with new layout + w2_weight_scale = swizzle_sf( + w2_weight_scale, + self.hidden_size, + self.expert_size_per_partition * self.intermediate_size_per_partition, + ) + + # Copy back to original tensor + self.w2_weight_scale.copy_( + w2_weight_scale.view(self.w2_weight_scale.dtype).view(self.w2_weight_scale.shape), + non_blocking=True, + ) + + def quantize_input( + self, x: Union[torch.Tensor, Fp4QuantizedTensor], post_quant_comm: bool = True + ): + """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. + + Args: + x: Input tensor to quantize + post_quant_comm: + If True, quantize for post-quant communication path. + If False, quantize for non-communication path + + Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed + + For quantization methods that produce scaling factors: + - x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)] + - The 2D shape is required for proper handling in alltoall/allgather operations + - scaling_vector_size is typically the group size for block-wise quantization + """ + x_sf = None + if self.has_nvfp4: + if isinstance(x, Fp4QuantizedTensor): + assert not x.is_sf_swizzled, ( + "Fp4QuantizedTensor should not be swizzled before communication" + ) + x_row = x.shape[0] + x, x_sf = x.fp4_tensor, x.scaling_factor + else: + x_row = x.shape[0] + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, False, False + ) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + if x_sf is not None: + x_sf = x_sf.view(x_row, -1) + + return x, x_sf + + def run_moe_nvfp4( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + router_weight_t: Optional[torch.Tensor] = None, + router_input: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run MoE computation with NVFP4 quantization using Green Context stream overlap. + + Args: + x: Input tensor + token_selected_experts: Must be None (routing done internally) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors + enable_alltoall: Whether alltoall communication is enabled + router_weight_t: Router weight matrix [hidden, num_experts] (transposed) + router_input: Optional separate router input; defaults to x if None + + Note: + The implementation is controlled by TRTLLM_MOE_FUSED_FC2_ALPHA env var (default: disabled). + When enabled, fc2_alpha is fused into FC1's alpha_post with scalar fc2_alpha_max in FC2. + When disabled, uses Green Context for hardware-level SM isolation so that FC1 and + Router GEMM run in parallel on separate SM partitions. + """ + assert self.has_nvfp4 + assert token_selected_experts is None and token_final_scales is None, ( + "DenseGEMMFusedMoE only supports internal routing path. " + "Expected token_selected_experts/token_final_scales to be None." + ) + assert router_weight_t is not None, ( + "DenseGEMMFusedMoE internal routing requires router_weight_t." + ) + if router_input is None: + router_input = x + assert router_input.ndim == 2 and router_weight_t.ndim == 2, ( + "router_input and router_weight_t must both be rank-2 tensors." + ) + assert router_input.shape[1] == router_weight_t.shape[0], ( + "DenseGEMMFusedMoE internal routing shape mismatch: " + f"router_input.shape={tuple(router_input.shape)}, " + f"router_weight_t.shape={tuple(router_weight_t.shape)}" + ) + + num_tokens = x.shape[0] + + if self.use_fused_fc2_alpha: + # Fused path: router GEMM and FC1 run sequentially on main stream (no overlap). + # No SM budget applied — both kernels use all available SMs for maximum throughput. + capture_graph = torch.cuda.is_current_stream_capturing() + fc2_alpha_buffer = DenseGEMMFusedMoE.buffers.get_buffer( + (num_tokens, self.expert_size_per_partition), + dtype=torch.float32, + buffer_name="fc2_alpha", + reserve_buffer=capture_graph, + ) + m, n = router_input.shape[0], router_weight_t.shape[1] + router_logits = torch.empty(m, n, dtype=torch.float32, device=router_input.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + router_input.contiguous(), + router_weight_t.t().contiguous(), + router_logits, + ) + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + token_selected_experts = token_selected_experts.to(torch.int32) + assert token_final_scales is not None + assert token_final_scales.dtype == torch.float32 + + # Append fused shared experts (always active, scale=1.0) if any. + # Shared experts are the experts beyond the routed experts range + # (expert_size_per_partition > num_routing_experts). + num_routing_experts = router_weight_t.shape[1] + if self.expert_size_per_partition > num_routing_experts: + n_shared = self.expert_size_per_partition - num_routing_experts + M_tok = router_input.shape[0] + shared_ids = ( + torch.arange( + num_routing_experts, + self.expert_size_per_partition, + dtype=torch.int32, + device=x.device, + ) + .unsqueeze(0) + .expand(M_tok, n_shared) + ) + shared_scales = torch.ones(M_tok, n_shared, dtype=torch.float32, device=x.device) + token_selected_experts = torch.cat([token_selected_experts, shared_ids], dim=1) + token_final_scales = torch.cat([token_final_scales, shared_scales], dim=1) + + # New implementation: fuse fc2_alpha into FC1's alpha_post + x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + # Generate normalized fc2_alpha for FC1 alpha_post fusion + fc2_alpha_normalized = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + self.fc2_alpha_max, # Normalize by max for FC1 alpha_post + fc2_alpha_buffer, # Pre-allocated buffer + ) + + # FC1: GEMM + SwiGLU with post-SwiGLU alpha scaling (fused fc2_alpha) + fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + fc2_alpha_normalized, # Pass normalized fc2_alpha as alpha_post + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + + # FC2: Standard nvfp4_gemm with scalar alpha = fc2_alpha_max + final_hidden_states = torch.ops.trtllm.nvfp4_gemm( + fc1_output.view(torch.uint8), + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.view(torch.uint8).reshape(-1), + self.w2_weight_scale.view(torch.uint8), + self.fc2_alpha_max, + torch.bfloat16, + to_userbuffers=False, + allowed_backends="cutlass,cublaslt,cutedsl,cuda_core", + ) + else: + # Green Context path: use AutoTuner to find the optimal FC1 SM split. + # Pass unswizzled x_sf so the constraint spec (shape[0] == num_tokens) + # remains valid; swizzle_sf is applied inside _run_moe_nvfp4_gc_impl. + tuner = AutoTuner.get() + runner = DenseGEMMGCSMRunner(self) + inputs = [x, x_sf, router_weight_t, router_input] + _, best_tactic = tuner.choose_one( + "DenseGEMMFusedMoE::run_moe_nvfp4_gc", + [runner], + runner.get_tuning_config(), + inputs, + ) + final_hidden_states = runner(inputs, tactic=best_tactic) + + return final_hidden_states + + def _warm_inner_runners_full_sm( + self, + x: torch.Tensor, + x_sf: torch.Tensor, + router_weight_t: torch.Tensor, + router_input: torch.Tensor, + ) -> None: + """Warm FC1 and Router GEMM autotuner caches with the full SM budget. + + This must be called once before the outer profiling loop in + DenseGEMMGCSMRunner.do_preparation. Because sm_budget is excluded from + both inner runners' unique_id(), a single full-SM tuning pass populates + the shared cache entry used by every GC SM split. Subsequent calls with + any sm_budget find a cache hit and skip profiling entirely, reducing the + total tune cost from O(S × inner) to O(inner + S). + + Args: + x: NVFP4-quantised activations (packed), unswizzled x_sf expected. + x_sf: Unswizzled scale factors for x (2-D: [num_tokens, hidden/16]). + router_weight_t: Router weight matrix [hidden, num_experts]. + router_input: BF16 input to the router GEMM [num_tokens, hidden]. + """ + num_tokens = x.shape[0] + # swizzle_sf is normally applied inside _run_moe_nvfp4_gc_impl; do it + # here because we are calling the FC1 op directly. + x_sf_swizzled = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + # FC1: sm_budget=0 → unconstrained (bypasses GC auto-detect, full SM). + torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf_swizzled, + self.w3_w1_weight_scale, + self.fc31_alpha, + None, # alpha_post: not needed for cache warm-up + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + sm_budget=0, # full SM: sm_budget <= 0 → max_active_clusters_hw + ) + + # Router GEMM: sm_budget=0 → unconstrained (stream has no GC, so + # auto-detect would also return 0; passing 0 explicitly is cleaner). + m, n = router_input.shape[0], router_weight_t.shape[1] + router_logits = torch.empty(m, n, dtype=torch.float32, device=router_input.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + router_input.contiguous(), + router_weight_t.t().contiguous(), + router_logits, + sm_budget=0, # full SM + ) + + def _run_moe_nvfp4_gc_impl( + self, + x: torch.Tensor, + x_sf: torch.Tensor, + router_weight_t: torch.Tensor, + router_input: torch.Tensor, + fc1_sms: int, + ) -> torch.Tensor: + """Execute the GC overlap path with the given FC1 SM partition. + + Args: + x: NVFP4-quantised activations (already packed). + x_sf: Swizzled scaling factors for ``x``. + router_weight_t: Router weight matrix ``[hidden, num_experts]``. + router_input: Input fed to the router GEMM (same shape as ``x`` + before quantisation, i.e. ``[num_tokens, hidden_size]``). + fc1_sms: Requested SM count for the FC1 GreenContext partition. + Must be a key in ``self._gc_stream_pool``. + + Returns: + ``final_hidden_states`` tensor of shape ``[num_tokens, hidden_size]``. + """ + fc1_gc_stream, router_gc_stream = self._gc_stream_pool[fc1_sms] + num_tokens = x.shape[0] + + x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + capture_graph = torch.cuda.is_current_stream_capturing() + fc2_alpha_buffer = DenseGEMMFusedMoE.buffers.get_buffer( + (num_tokens, self.expert_size_per_partition), + dtype=torch.float32, + buffer_name="fc2_alpha", + reserve_buffer=capture_graph, + ) + + # Fork: record on main stream so both GC streams can wait on it. + self.event_dict[EventType.Main].record() + + router_sms = max(1, self.num_sms - fc1_sms) + router_sms = ( + (router_sms + 7) // 8 * 8 + ) # Round up to nearest multiple of 8 for better GC compatibility + fc1_sms = self.num_sms - router_sms # Recalculate fc1_sms to reflect any rounding + + with torch.cuda.stream(fc1_gc_stream): + self.event_dict[EventType.Main].wait() + fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + None, # alpha_post: no post-SwiGLU scaling in this path + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + self.fc1_done_event.record() + + with torch.cuda.stream(router_gc_stream): + self.event_dict[EventType.Main].wait() + m, n = router_input.shape[0], router_weight_t.shape[1] + router_logits = torch.empty(m, n, dtype=torch.float32, device=router_input.device) + _router_input_c = router_input.contiguous() + _router_weight_c = router_weight_t.t().contiguous() + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + _router_input_c, + _router_weight_c, + router_logits, + # sm_budget=router_sms + ) + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + token_selected_experts = token_selected_experts.to(torch.int32) + # Append fused shared experts (always active, scale=1.0) if any. + num_routing_experts = router_weight_t.shape[1] + if self.expert_size_per_partition > num_routing_experts: + n_shared = self.expert_size_per_partition - num_routing_experts + M_tok = router_input.shape[0] + shared_ids = ( + torch.arange( + num_routing_experts, + self.expert_size_per_partition, + dtype=torch.int32, + device=x.device, + ) + .unsqueeze(0) + .expand(M_tok, n_shared) + ) + shared_scales = torch.ones(M_tok, n_shared, dtype=torch.float32, device=x.device) + token_selected_experts = torch.cat([token_selected_experts, shared_ids], dim=1) + token_final_scales = torch.cat([token_final_scales, shared_scales], dim=1) + fc2_alpha = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + output=fc2_alpha_buffer, + ) + self.event_dict[EventType.MoeFc2Alpha].record() + + # Join: main stream waits for both FC1 output and fc2_alpha. + self.fc1_done_event.wait() + self.event_dict[EventType.MoeFc2Alpha].wait() + + # FC2: input k = expert_count * intermediate_size (after SwiGLU halving) + result = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_fc2_blackwell( + fc1_output, + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.reshape(-1), + self.w2_weight_scale, + fc2_alpha, + expert_count=self.expert_size_per_partition, + weight_per_expert=self.intermediate_size_per_partition, + output_dtype=torch.bfloat16, + scaling_vector_size=self.scaling_vector_size, + ) + return result + + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + **kwargs, + ) -> torch.Tensor: + """ + Run MoE computation with DenseGEMM backend (NVFP4 only). + + Args: + x: Input hidden states (pre-quantized to NVFP4) + token_selected_experts: Must be None (routing done internally via router_weight_t) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors for NVFP4 + enable_alltoall: Whether alltoall communication is enabled. + **kwargs: Must contain 'router_weight_t'; optionally 'router_input'. + + Returns: + final_hidden_states tensor. + """ + assert self.has_nvfp4, ( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + return self.run_moe_nvfp4( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=enable_alltoall, + router_weight_t=kwargs.get("router_weight_t"), + router_input=kwargs.get("router_input"), + ) + + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: tuple = (True, True), + ) -> torch.Tensor: + # Currently, the default path is that ConfigurableMoE calls DenseGEMMFusedMoE.run_moe. + # This forward_chunk method is a reference implementation of the legacy path. + # Apply routing + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + assert token_selected_experts.shape[1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + x, x_sf = self.quantize_input(x) + + if self.use_dp and self.parallel_size > 1: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens, + ) + + x = self.run_moe( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=False, + ) + return x + + def forward_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert do_finalize, "DenseGEMMFusedMoE does not support do_finalize=False" + + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call), + ) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + + if self.use_dp and self.parallel_size > 1: + rank = self.parallel_rank + outputs = outputs[: all_rank_num_tokens[rank]] + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_no_overlap.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_no_overlap.py new file mode 100644 index 00000000000..25bceac6322 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_no_overlap.py @@ -0,0 +1,617 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import os +from typing import Dict, List, Optional, Union + +import torch + +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils import fp4_utils + +from ...distributed import allgather +from ...memory_buffer_utils import get_memory_buffers +from ...model_config import ModelConfig +from ...utils import AuxStreamType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf +from .interface import MoE, MoEWeightLoadingMode +from .quantization import NVFP4CuteDslFusedMoEMethod +from .routing import BaseMoeRoutingMethod + + +@torch.compile(options={"max-autotune": True}) +def gen_fc2_alpha_fused( + token_selected_experts: torch.Tensor, + token_final_scales: torch.Tensor, + alpha: Optional[torch.Tensor], + alpha_max: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Generate fc2 alpha values, optionally normalized for FC1 alpha_post fusion. + + Instead of: + 1. zeros() -> scatter_() -> multiply with alpha (operates on large [N, E] tensor) + + We do: + 1. Gather alpha values for selected experts (small [N, top_k] tensor) + 2. Multiply scales with gathered alpha (small tensor operation) + 3. Optionally normalize by alpha_max for FC1 alpha_post fusion + 4. Scatter to output (single write to large tensor) + + This reduces memory bandwidth by avoiding read-modify-write on the large tensor. + + Args: + token_selected_experts: Expert indices for each token [num_tokens, top_k] + token_final_scales: Final scaling factors [num_tokens, top_k] + alpha: Per-expert alpha values [expert_size] + alpha_max: Max alpha value for normalization (optional) + output: Pre-allocated output buffer [num_tokens, expert_size] (optional). + If None, a new tensor will be allocated (not compatible with CUDA graph). + """ + # Pre-compute scaled values on small tensor [num_tokens, top_k] + if alpha is not None: + # Gather alpha for selected experts: alpha[expert_idx] for each selection + gathered_alpha = alpha[token_selected_experts.long()] # [num_tokens, top_k] + scaled_values = token_final_scales * gathered_alpha + else: + scaled_values = token_final_scales + + # Normalize by alpha_max for FC1 alpha_post fusion + if alpha_max is not None: + scaled_values = scaled_values / alpha_max + + # Use pre-allocated output or create new tensor + if output is not None: + output.zero_() + fc2_alpha = output + else: + assert alpha is not None, ( + "alpha must be provided when output buffer is not pre-allocated, " + "since expert_size cannot be inferred from token_final_scales alone" + ) + num_tokens = token_selected_experts.shape[0] + expert_size = alpha.shape[0] + fc2_alpha = torch.zeros( + [num_tokens, expert_size], + dtype=torch.float32, + device=token_selected_experts.device, + ) + + return fc2_alpha.scatter_(1, token_selected_experts.long(), scaled_values) + + +class NoOverlapDenseGEMMFusedMoE(MoE): + """Single-stream, no-SM-partition DenseGEMM MoE for performance baseline. + + This backend uses CuTe DSL dense GEMM kernels with fused SwiGLU for MoE + computation. It supports NVFP4 quantization only and is restricted to + SM100/SM103 (Blackwell) architectures. + + Unlike the gc/smp variants, all ops run sequentially on a single CUDA stream. + No max_active_clusters or sm_budget limit is applied to any kernel (GPU uses + all available SMs). Use this as a baseline to measure the benefit of SM + partitioning + stream overlap in DenseGEMMFusedMoE. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams (unused). + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + """ + + # Memory buffer pool for CUDA graph compatibility + buffers = get_memory_buffers() + + # DenseGEMM only supports SM100 and SM103 (Blackwell CuTe DSL kernels). + _SUPPORTED_SM_VERSIONS = (100, 103) + + @classmethod + def can_implement( + cls, + quant_algo: Optional[QuantAlgo], + dtype_activation: torch.dtype = torch.bfloat16, + swiglu_gptoss_style: bool = False, + ) -> tuple: + """Check if NoOverlapDenseGEMMFusedMoE can implement the given configuration.""" + from tensorrt_llm._utils import get_sm_version + + from .interface import _warn_and_return + + sm_version = get_sm_version() + if sm_version not in cls._SUPPORTED_SM_VERSIONS: + return _warn_and_return( + f"NoOverlapDenseGEMMFusedMoE requires SM {cls._SUPPORTED_SM_VERSIONS}, got SM{sm_version}" + ) + + if quant_algo != QuantAlgo.NVFP4: + return _warn_and_return( + f"NoOverlapDenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={quant_algo})" + ) + + if swiglu_gptoss_style: + return _warn_and_return( + "NoOverlapDenseGEMMFusedMoE does not support swiglu_gptoss_style" + ) + + return (True, None) + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + init_load_balancer: bool = True, + without_comm: bool = False, + activation_type=None, + ): + # DenseGEMM CuTe DSL kernels only support SM100 and SM103. + from tensorrt_llm._utils import get_sm_version + + from ...utils import ActivationType + + sm_version = get_sm_version() + assert sm_version in self._SUPPORTED_SM_VERSIONS, ( + f"NoOverlapDenseGEMMFusedMoE only supports SM {self._SUPPORTED_SM_VERSIONS} " + f"(got SM {sm_version}). The CuTe DSL kernels require Blackwell architecture." + ) + + # DenseGEMM kernel hardcodes SwiGLU fusion — reject other activation types. + if activation_type is None: + activation_type = ActivationType.Swiglu + assert activation_type == ActivationType.Swiglu, ( + f"NoOverlapDenseGEMMFusedMoE only supports SwiGLU activation " + f"(got activation_type={activation_type}). " + f"The FC1 kernel fuses SwiGLU into the GEMM epilogue." + ) + + # FC2 DenseGEMM kernel tiles K dimension with MMA tile size 256. + _MMA_TILE_K = 256 + assert intermediate_size % _MMA_TILE_K == 0, ( + f"NoOverlapDenseGEMMFusedMoE requires intermediate_size to be a multiple of " + f"{_MMA_TILE_K} (got intermediate_size={intermediate_size}). " + f"FC2 kernel cannot correctly split alpha_scale at expert boundaries " + f"when weight_per_expert is not MMA tile-K aligned." + ) + + # Call MoE base class directly (not CutlassFusedMoE). + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream_dict=aux_stream_dict, + weight_loading_mode=weight_loading_mode, + layer_idx=layer_idx, + init_load_balancer=init_load_balancer, + activation_type=activation_type, + ) + + # Environment variable to control fc2_alpha fusion into FC1's alpha_post. + # Default: disabled (0). Set to "1" to enable fusion (known accuracy issue under TP). + self.use_fused_fc2_alpha = os.environ.get("TRTLLM_MOE_FUSED_FC2_ALPHA", "0") == "1" + + # Pre-register fc2_alpha_max buffer for fused fc2_alpha optimization. + # Populated in load_weights() with max(fc2_alpha). + self.register_buffer("fc2_alpha_max", torch.zeros(1, dtype=torch.float32)) + + # No SM budget computation — all SMs used for all kernels. + # No auxiliary streams needed — sequential single-stream execution. + if self.aux_stream_dict is None: + self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} + self.event_dict = {} + + # Weight creation + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + def _supports_load_balancer(self) -> bool: + """NoOverlapDenseGEMMFusedMoE supports load balancer.""" + return True + + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True + ): + if self.quant_config.layer_quant_mode.has_nvfp4(): + return NVFP4CuteDslFusedMoEMethod() + raise ValueError( + f"{self.__class__.__name__} only supports NVFP4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + raise ValueError( + f"{self.__class__.__name__} requires quantization (NVFP4), " + f"but no quantization config was provided." + ) + + def create_weights(self): + if self._weights_created: + return + + self.quant_method = self._get_quant_method() + self.quant_method.create_weights(self) + + self._weights_created = True + + def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False): + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec(self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) + + # Transpose w2_weight layout: (E, H, ...) -> (H, E, ...) for dense GEMM. + w2_transposed = self.w2_weight.transpose(0, 1).contiguous() + self.w2_weight.reshape([-1]).copy_(w2_transposed.reshape([-1]), non_blocking=True) + del w2_transposed + if self.has_any_quant: + if self.has_nvfp4: + self._transform_w2_weight_scale_for_min_latency() + # Compute fc2_alpha_max for fused fc2_alpha optimization + self.fc2_alpha_max.copy_(torch.max(self.fc2_alpha).reshape(1), non_blocking=True) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + def post_load_weights(self): + self.quant_method.post_load_weights(self) + + def _transform_w2_weight_scale_for_min_latency(self): + """Transform w2_weight_scale for minimum latency path optimization.""" + # Calculate padded dimensions + nrows = fp4_utils.pad_up(self.hidden_size, 128) + ncols = fp4_utils.pad_up( + self.intermediate_size_per_partition // self.scaling_vector_size, 4 + ) + + # Clone and convert weight scale to uint8 + w2_weight_scale = self.w2_weight_scale.clone().view(torch.uint8) + + # Unswizzle the scale factor + w2_weight_scale = unswizzle_sf( + w2_weight_scale, + self.hidden_size * self.expert_size_per_partition, + self.intermediate_size_per_partition, + ) + + # Reshape and transpose for min latency layout + w2_weight_scale = w2_weight_scale.reshape([self.expert_size_per_partition, nrows, ncols]) + w2_weight_scale = w2_weight_scale.transpose(0, 1).reshape( + nrows, self.expert_size_per_partition * ncols + ) + + # Swizzle back with new layout + w2_weight_scale = swizzle_sf( + w2_weight_scale, + self.hidden_size, + self.expert_size_per_partition * self.intermediate_size_per_partition, + ) + + # Copy back to original tensor + self.w2_weight_scale.copy_( + w2_weight_scale.view(self.w2_weight_scale.dtype).view(self.w2_weight_scale.shape), + non_blocking=True, + ) + + def quantize_input( + self, x: Union[torch.Tensor, Fp4QuantizedTensor], post_quant_comm: bool = True + ): + """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.""" + x_sf = None + if self.has_nvfp4: + if isinstance(x, Fp4QuantizedTensor): + assert not x.is_sf_swizzled, ( + "Fp4QuantizedTensor should not be swizzled before communication" + ) + x_row = x.shape[0] + x, x_sf = x.fp4_tensor, x.scaling_factor + else: + x_row = x.shape[0] + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, False, False + ) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + if x_sf is not None: + x_sf = x_sf.view(x_row, -1) + + return x, x_sf + + def run_moe_nvfp4( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + router_weight_t: Optional[torch.Tensor] = None, + router_input: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Single-stream MoE forward, no SM partitioning. + + Sequential execution order on main stream: + Router GEMM -> routing -> (gen_fc2_alpha | swizzle_sf) -> FC1 -> FC2 + + All kernels run without SM budget constraints (GPU uses all available SMs). + + Args: + x: Input tensor + token_selected_experts: Must be None (routing done internally) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors + enable_alltoall: Whether alltoall communication is enabled + router_weight_t: Router weight matrix [hidden, num_experts] (transposed) + router_input: Optional separate router input; defaults to x if None + """ + assert self.has_nvfp4 + assert token_selected_experts is None and token_final_scales is None, ( + "NoOverlapDenseGEMMFusedMoE only supports internal routing path. " + "Expected token_selected_experts/token_final_scales to be None." + ) + assert router_weight_t is not None, ( + "NoOverlapDenseGEMMFusedMoE internal routing requires router_weight_t." + ) + if router_input is None: + router_input = x + assert router_input.ndim == 2 and router_weight_t.ndim == 2, ( + "router_input and router_weight_t must both be rank-2 tensors." + ) + assert router_input.shape[1] == router_weight_t.shape[0], ( + "NoOverlapDenseGEMMFusedMoE internal routing shape mismatch: " + f"router_input.shape={tuple(router_input.shape)}, " + f"router_weight_t.shape={tuple(router_weight_t.shape)}" + ) + + num_tokens = x.shape[0] + + # Get pre-allocated buffer for fc2_alpha (CUDA graph compatible) + capture_graph = torch.cuda.is_current_stream_capturing() + fc2_alpha_buffer = NoOverlapDenseGEMMFusedMoE.buffers.get_buffer( + (num_tokens, self.expert_size_per_partition), + dtype=torch.float32, + buffer_name="fc2_alpha_no_overlap", + reserve_buffer=capture_graph, + ) + + # Router GEMM (all SMs, single stream, no SM budget) + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + router_input, + router_weight_t, + None, + torch.float32, + ) + + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + token_selected_experts = token_selected_experts.to(torch.int32) + + # Append fused shared experts (always active, scale=1.0) if any. + num_routing_experts = router_weight_t.shape[1] + if self.expert_size_per_partition > num_routing_experts: + n_shared = self.expert_size_per_partition - num_routing_experts + M_tok = router_input.shape[0] + shared_ids = ( + torch.arange( + num_routing_experts, + self.expert_size_per_partition, + dtype=torch.int32, + device=x.device, + ) + .unsqueeze(0) + .expand(M_tok, n_shared) + ) + shared_scales = torch.ones(M_tok, n_shared, dtype=torch.float32, device=x.device) + token_selected_experts = torch.cat([token_selected_experts, shared_ids], dim=1) + token_final_scales = torch.cat([token_final_scales, shared_scales], dim=1) + + x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + if self.use_fused_fc2_alpha: + # Fused path: fc2_alpha fused into FC1's alpha_post + fc2_alpha_normalized = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + self.fc2_alpha_max, # Normalize by max for FC1 alpha_post + fc2_alpha_buffer, # Pre-allocated buffer + ) + + # FC1: GEMM + SwiGLU with post-SwiGLU alpha scaling (fused fc2_alpha), no SM limit + fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + fc2_alpha_normalized, # Pass normalized fc2_alpha as alpha_post + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + + # FC2: Standard nvfp4_gemm with scalar alpha = fc2_alpha_max + final_hidden_states = torch.ops.trtllm.nvfp4_gemm( + fc1_output.view(torch.uint8), + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.view(torch.uint8).reshape(-1), + self.w2_weight_scale.view(torch.uint8), + self.fc2_alpha_max, + torch.bfloat16, + to_userbuffers=False, + allowed_backends="cutlass,cublaslt,cutedsl,cuda_core", + ) + else: + # Sequential path: gen_fc2_alpha then FC1 then FC2, all without SM budget + fc2_alpha = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + output=fc2_alpha_buffer, # Use pre-allocated buffer + ) + + # FC1: GEMM + SwiGLU, output is fp4 quantized, no SM limit + fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + None, # alpha_post: no post-SwiGLU scaling + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + + # FC2: input k = expert_count * intermediate_size (after SwiGLU) + final_hidden_states = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_fc2_blackwell( + fc1_output, + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.reshape(-1), + self.w2_weight_scale, + fc2_alpha, + expert_count=self.expert_size_per_partition, + weight_per_expert=self.intermediate_size_per_partition, + output_dtype=torch.bfloat16, + scaling_vector_size=self.scaling_vector_size, + ) + + return final_hidden_states + + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + **kwargs, + ) -> torch.Tensor: + """ + Run MoE computation with DenseGEMM backend (NVFP4 only), single-stream no-overlap variant. + + Args: + x: Input hidden states (pre-quantized to NVFP4) + token_selected_experts: Must be None (routing done internally via router_weight_t) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors for NVFP4 + enable_alltoall: Whether alltoall communication is enabled. + **kwargs: Must contain 'router_weight_t'; optionally 'router_input'. + + Returns: + final_hidden_states tensor. + """ + assert self.has_nvfp4, ( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + return self.run_moe_nvfp4( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=enable_alltoall, + router_weight_t=kwargs.get("router_weight_t"), + router_input=kwargs.get("router_input"), + ) + + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: tuple = (True, True), + ) -> torch.Tensor: + # Currently, the default path is that ConfigurableMoE calls run_moe. + # This forward_chunk method is a reference implementation of the legacy path. + # Apply routing + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + assert token_selected_experts.shape[1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + x, x_sf = self.quantize_input(x) + + if self.use_dp and self.parallel_size > 1: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens, + ) + + x = self.run_moe( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=False, + ) + return x + + def forward_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert do_finalize, "NoOverlapDenseGEMMFusedMoE does not support do_finalize=False" + + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call), + ) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + + if self.use_dp and self.parallel_size > 1: + rank = self.parallel_rank + outputs = outputs[: all_rank_num_tokens[rank]] + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_smp.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_smp.py new file mode 100644 index 00000000000..aa7e0a0b3d3 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm_smp.py @@ -0,0 +1,735 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import os +from typing import Dict, List, Optional, Union + +import torch + +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils import fp4_utils + +from ...distributed import allgather +from ...memory_buffer_utils import get_memory_buffers +from ...model_config import ModelConfig +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf +from .interface import MoE, MoEWeightLoadingMode +from .quantization import NVFP4CuteDslFusedMoEMethod +from .routing import BaseMoeRoutingMethod + + +@torch.compile(options={"max-autotune": True}) +def gen_fc2_alpha_fused( + token_selected_experts: torch.Tensor, + token_final_scales: torch.Tensor, + alpha: Optional[torch.Tensor], + alpha_max: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Generate fc2 alpha values, optionally normalized for FC1 alpha_post fusion. + + Instead of: + 1. zeros() -> scatter_() -> multiply with alpha (operates on large [N, E] tensor) + + We do: + 1. Gather alpha values for selected experts (small [N, top_k] tensor) + 2. Multiply scales with gathered alpha (small tensor operation) + 3. Optionally normalize by alpha_max for FC1 alpha_post fusion + 4. Scatter to output (single write to large tensor) + + This reduces memory bandwidth by avoiding read-modify-write on the large tensor. + + Args: + token_selected_experts: Expert indices for each token [num_tokens, top_k] + token_final_scales: Final scaling factors [num_tokens, top_k] + alpha: Per-expert alpha values [expert_size] + alpha_max: Max alpha value for normalization (optional) + output: Pre-allocated output buffer [num_tokens, expert_size] (optional). + If None, a new tensor will be allocated (not compatible with CUDA graph). + """ + # Pre-compute scaled values on small tensor [num_tokens, top_k] + if alpha is not None: + # Gather alpha for selected experts: alpha[expert_idx] for each selection + gathered_alpha = alpha[token_selected_experts.long()] # [num_tokens, top_k] + scaled_values = token_final_scales * gathered_alpha + else: + scaled_values = token_final_scales + + # Normalize by alpha_max for FC1 alpha_post fusion + if alpha_max is not None: + scaled_values = scaled_values / alpha_max + + # Use pre-allocated output or create new tensor + if output is not None: + output.zero_() + fc2_alpha = output + else: + assert alpha is not None, ( + "alpha must be provided when output buffer is not pre-allocated, " + "since expert_size cannot be inferred from token_final_scales alone" + ) + num_tokens = token_selected_experts.shape[0] + expert_size = alpha.shape[0] + fc2_alpha = torch.zeros( + [num_tokens, expert_size], + dtype=torch.float32, + device=token_selected_experts.device, + ) + + return fc2_alpha.scatter_(1, token_selected_experts.long(), scaled_values) + + +class DenseGEMMFusedMoE(MoE): + """CuteDSL DenseGEMM flow of fused mixture of experts (MoE) Layer — SM Partition variant. + + This backend uses CuTe DSL dense GEMM kernels with fused SwiGLU for MoE + computation. It supports NVFP4 quantization only and is restricted to + SM100/SM103 (Blackwell) architectures. + + Unlike CutlassFusedMoE which uses per-expert scattered GEMM, DenseGEMM + packs all experts into a single dense matrix and uses standard GEMM operations, + which can be more efficient for small token counts (min-latency scenarios). + + This variant (smp) uses soft SM limits via the sm_budget kernel parameter + (no GreenContext) to overlap FC1 and Router GEMM on separate auxiliary streams. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + """ + + # Memory buffer pool for CUDA graph compatibility + buffers = get_memory_buffers() + + # DenseGEMM only supports SM100 and SM103 (Blackwell CuTe DSL kernels). + _SUPPORTED_SM_VERSIONS = (100, 103) + + @classmethod + def can_implement( + cls, + quant_algo: Optional[QuantAlgo], + dtype_activation: torch.dtype = torch.bfloat16, + swiglu_gptoss_style: bool = False, + ) -> tuple: + """Check if DenseGEMMFusedMoE can implement the given configuration. + + DenseGEMMFusedMoE supports: + - NVFP4 quantization only + - SM100/SM103 (Blackwell) only + - SwiGLU activation only (swiglu_gptoss_style not supported) + """ + from tensorrt_llm._utils import get_sm_version + + from .interface import _warn_and_return + + sm_version = get_sm_version() + if sm_version not in cls._SUPPORTED_SM_VERSIONS: + return _warn_and_return( + f"DenseGEMMFusedMoE requires SM {cls._SUPPORTED_SM_VERSIONS}, got SM{sm_version}" + ) + + if quant_algo != QuantAlgo.NVFP4: + return _warn_and_return( + f"DenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={quant_algo})" + ) + + if swiglu_gptoss_style: + return _warn_and_return("DenseGEMMFusedMoE does not support swiglu_gptoss_style") + + return (True, None) + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + init_load_balancer: bool = True, + without_comm: bool = False, + activation_type=None, + ): + # DenseGEMM CuTe DSL kernels only support SM100 and SM103. + from tensorrt_llm._utils import get_sm_version + + from ...utils import ActivationType + + sm_version = get_sm_version() + assert sm_version in self._SUPPORTED_SM_VERSIONS, ( + f"DenseGEMMFusedMoE only supports SM {self._SUPPORTED_SM_VERSIONS} " + f"(got SM {sm_version}). The CuTe DSL kernels require Blackwell architecture." + ) + + # DenseGEMM kernel hardcodes SwiGLU fusion — reject other activation types + # before calling super().__init__() to fail fast with a clear message. + if activation_type is None: + activation_type = ActivationType.Swiglu + assert activation_type == ActivationType.Swiglu, ( + f"DenseGEMMFusedMoE only supports SwiGLU activation " + f"(got activation_type={activation_type}). " + f"The FC1 kernel fuses SwiGLU into the GEMM epilogue." + ) + + # FC2 DenseGEMM kernel tiles K dimension with MMA tile size 256. + # weight_per_expert (= intermediate_size) must be 256-aligned so that + # expert boundaries align with MMA tile boundaries. + _MMA_TILE_K = 256 + assert intermediate_size % _MMA_TILE_K == 0, ( + f"DenseGEMMFusedMoE requires intermediate_size to be a multiple of " + f"{_MMA_TILE_K} (got intermediate_size={intermediate_size}). " + f"FC2 kernel cannot correctly split alpha_scale at expert boundaries " + f"when weight_per_expert is not MMA tile-K aligned." + ) + + # Call MoE base class directly (not CutlassFusedMoE). + # Note: `without_comm` and `apply_router_weight_on_input` are accepted + # for API compatibility with create_moe_backend() but are not passed to + # MoE.__init__() since DenseGEMM does not use alltoall communication. + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream_dict=aux_stream_dict, + weight_loading_mode=weight_loading_mode, + layer_idx=layer_idx, + init_load_balancer=init_load_balancer, + activation_type=activation_type, + ) + + # Environment variable to control fc2_alpha fusion into FC1's alpha_post. + # Default: disabled (0). Set to "1" to enable fusion (known accuracy issue under TP). + self.use_fused_fc2_alpha = os.environ.get("TRTLLM_MOE_FUSED_FC2_ALPHA", "0") == "1" + + # Pre-register fc2_alpha_max buffer for fused fc2_alpha optimization. + # Populated in load_weights() with max(fc2_alpha). + self.register_buffer("fc2_alpha_max", torch.zeros(1, dtype=torch.float32)) + + # Whether to use the CLC dynamic tile scheduler for FC1. + # TRTLLM_MOE_FC1_DYNAMIC_SCHED=1 uses Sm100BlockScaledDynamicDenseGemmKernel + # (fc1_dynamic_sched.py) which launches a full-problem grid and lets the CLC + # hardware dispatch tiles to newly freed SMs automatically. + # Default: disabled (0). + self.use_dynamic_fc1 = os.environ.get("TRTLLM_MOE_FC1_DYNAMIC_SCHED", "0") == "1" + + # SM budget for FC1 and router kernels (SM-based, not cluster-based). + # TRTLLM_MOE_FC1_SM_NUNBER: controls SM count for FC1 GEMM. + # If > 1, treated as an absolute SM count; if <= 1, treated as a fraction of total SMs. + # Default: 0.5 (50% of total SMs). + device_id = torch.cuda.current_device() + num_sms = torch.cuda.get_device_properties(device_id).multi_processor_count + _fc1_sm_config = float(os.environ.get("TRTLLM_MOE_FC1_SM_NUNBER", 0.5)) + if _fc1_sm_config > 1: + fc1_sms_raw = int(_fc1_sm_config) + else: + fc1_sms_raw = int(num_sms * _fc1_sm_config) + self.fc1_sms = max(1, fc1_sms_raw) + self.router_sms = max(1, num_sms - self.fc1_sms) + + # Initialize auxiliary stream and events for gen_fc2_alpha_fused overlap with fc1. + # Use regular stream with priority=-1 (no GreenContext). + if self.aux_stream_dict is None: + self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} + if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict: + self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = torch.cuda.Stream(priority=-1) + self.event_dict = {} + for key in [EventType.Main, EventType.MoeFc2Alpha]: + self.event_dict[key] = torch.cuda.Event() + + # Weight creation + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + def _supports_load_balancer(self) -> bool: + """DenseGEMMFusedMoE supports load balancer.""" + return True + + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True + ): + if self.quant_config.layer_quant_mode.has_nvfp4(): + return NVFP4CuteDslFusedMoEMethod() + raise ValueError( + f"{self.__class__.__name__} only supports NVFP4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + raise ValueError( + f"{self.__class__.__name__} requires quantization (NVFP4), " + f"but no quantization config was provided." + ) + + def create_weights(self): + if self._weights_created: + return + + self.quant_method = self._get_quant_method() + self.quant_method.create_weights(self) + + self._weights_created = True + + def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False): + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec(self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) + + # Transpose w2_weight layout: (E, H, ...) -> (H, E, ...) for dense GEMM. + w2_transposed = self.w2_weight.transpose(0, 1).contiguous() + self.w2_weight.reshape([-1]).copy_(w2_transposed.reshape([-1]), non_blocking=True) + del w2_transposed + if self.has_any_quant: + if self.has_nvfp4: + self._transform_w2_weight_scale_for_min_latency() + # Compute fc2_alpha_max for fused fc2_alpha optimization + self.fc2_alpha_max.copy_(torch.max(self.fc2_alpha).reshape(1), non_blocking=True) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + def post_load_weights(self): + self.quant_method.post_load_weights(self) + + def _transform_w2_weight_scale_for_min_latency(self): + """Transform w2_weight_scale for minimum latency path optimization.""" + # Calculate padded dimensions + nrows = fp4_utils.pad_up(self.hidden_size, 128) + ncols = fp4_utils.pad_up( + self.intermediate_size_per_partition // self.scaling_vector_size, 4 + ) + + # Clone and convert weight scale to uint8 + w2_weight_scale = self.w2_weight_scale.clone().view(torch.uint8) + + # Unswizzle the scale factor + w2_weight_scale = unswizzle_sf( + w2_weight_scale, + self.hidden_size * self.expert_size_per_partition, + self.intermediate_size_per_partition, + ) + + # Reshape and transpose for min latency layout + w2_weight_scale = w2_weight_scale.reshape([self.expert_size_per_partition, nrows, ncols]) + w2_weight_scale = w2_weight_scale.transpose(0, 1).reshape( + nrows, self.expert_size_per_partition * ncols + ) + + # Swizzle back with new layout + w2_weight_scale = swizzle_sf( + w2_weight_scale, + self.hidden_size, + self.expert_size_per_partition * self.intermediate_size_per_partition, + ) + + # Copy back to original tensor + self.w2_weight_scale.copy_( + w2_weight_scale.view(self.w2_weight_scale.dtype).view(self.w2_weight_scale.shape), + non_blocking=True, + ) + + def quantize_input( + self, x: Union[torch.Tensor, Fp4QuantizedTensor], post_quant_comm: bool = True + ): + """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. + + Args: + x: Input tensor to quantize + post_quant_comm: + If True, quantize for post-quant communication path. + If False, quantize for non-communication path + + Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed + """ + x_sf = None + if self.has_nvfp4: + if isinstance(x, Fp4QuantizedTensor): + assert not x.is_sf_swizzled, ( + "Fp4QuantizedTensor should not be swizzled before communication" + ) + x_row = x.shape[0] + x, x_sf = x.fp4_tensor, x.scaling_factor + else: + x_row = x.shape[0] + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, False, False + ) + else: + raise ValueError( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + + if x_sf is not None: + x_sf = x_sf.view(x_row, -1) + + return x, x_sf + + def run_moe_nvfp4( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + router_weight_t: Optional[torch.Tensor] = None, + router_input: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run MoE computation with NVFP4 quantization using SM-partition stream overlap. + + Args: + x: Input tensor + token_selected_experts: Must be None (routing done internally) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors + enable_alltoall: Whether alltoall communication is enabled + router_weight_t: Router weight matrix [hidden, num_experts] (transposed) + router_input: Optional separate router input; defaults to x if None + + Note: + The implementation is controlled by TRTLLM_MOE_FUSED_FC2_ALPHA env var (default: disabled). + When enabled, fc2_alpha is fused into FC1's alpha_post with scalar fc2_alpha_max in FC2. + When disabled, uses SM budget (sm_budget) to overlap FC1 and Router GEMM on + separate auxiliary streams without hardware-level SM isolation. + """ + assert self.has_nvfp4 + assert token_selected_experts is None and token_final_scales is None, ( + "DenseGEMMFusedMoE only supports internal routing path. " + "Expected token_selected_experts/token_final_scales to be None." + ) + assert router_weight_t is not None, ( + "DenseGEMMFusedMoE internal routing requires router_weight_t." + ) + if router_input is None: + router_input = x + assert router_input.ndim == 2 and router_weight_t.ndim == 2, ( + "router_input and router_weight_t must both be rank-2 tensors." + ) + assert router_input.shape[1] == router_weight_t.shape[0], ( + "DenseGEMMFusedMoE internal routing shape mismatch: " + f"router_input.shape={tuple(router_input.shape)}, " + f"router_weight_t.shape={tuple(router_weight_t.shape)}" + ) + + num_tokens = x.shape[0] + + # Get pre-allocated buffer for fc2_alpha (CUDA graph compatible) + capture_graph = torch.cuda.is_current_stream_capturing() + fc2_alpha_buffer = DenseGEMMFusedMoE.buffers.get_buffer( + (num_tokens, self.expert_size_per_partition), + dtype=torch.float32, + buffer_name="fc2_alpha", + reserve_buffer=capture_graph, + ) + + if self.use_fused_fc2_alpha: + # Fused path: router GEMM runs on main stream before FC1. + m, n = router_input.shape[0], router_weight_t.shape[1] + router_logits = torch.empty(m, n, dtype=torch.float32, device=router_input.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + router_input.contiguous(), + router_weight_t.t().contiguous(), + router_logits, + sm_budget=self.router_sms, + ) + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + token_selected_experts = token_selected_experts.to(torch.int32) + assert token_final_scales is not None + assert token_final_scales.dtype == torch.float32 + + # Append fused shared experts (always active, scale=1.0) if any. + num_routing_experts = router_weight_t.shape[1] + if self.expert_size_per_partition > num_routing_experts: + n_shared = self.expert_size_per_partition - num_routing_experts + M_tok = router_input.shape[0] + shared_ids = ( + torch.arange( + num_routing_experts, + self.expert_size_per_partition, + dtype=torch.int32, + device=x.device, + ) + .unsqueeze(0) + .expand(M_tok, n_shared) + ) + shared_scales = torch.ones(M_tok, n_shared, dtype=torch.float32, device=x.device) + token_selected_experts = torch.cat([token_selected_experts, shared_ids], dim=1) + token_final_scales = torch.cat([token_final_scales, shared_scales], dim=1) + + # New implementation: fuse fc2_alpha into FC1's alpha_post + x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + # Generate normalized fc2_alpha for FC1 alpha_post fusion + fc2_alpha_normalized = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + self.fc2_alpha_max, # Normalize by max for FC1 alpha_post + fc2_alpha_buffer, # Pre-allocated buffer + ) + + # FC1: GEMM + SwiGLU with post-SwiGLU alpha scaling (fused fc2_alpha) + fc1_output, fc1_output_sf = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + fc2_alpha_normalized, # Pass normalized fc2_alpha as alpha_post + self.fc2_input_scale, + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + + # FC2: Standard nvfp4_gemm with scalar alpha = fc2_alpha_max + final_hidden_states = torch.ops.trtllm.nvfp4_gemm( + fc1_output.view(torch.uint8), + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.view(torch.uint8).reshape(-1), + self.w2_weight_scale.view(torch.uint8), + self.fc2_alpha_max, + torch.bfloat16, + to_userbuffers=False, + allowed_backends="cutlass,cublaslt,cutedsl,cuda_core", + ) + else: + # SM Partition implementation: per-token per-expert fc2_alpha in FC2. + # + # Launch order: + # 1. Router GEMM fires first on the aux stream so that routing + # work is in-flight before FC1 starts. + # 2. FC1 launches on the main stream immediately after. + # + # Static path: router_sms + fc1_sms together cover the full GPU, + # so both kernels overlap spatially via soft SM budgets. + # Dynamic path (use_dynamic_fc1=True): FC1 uses the CLC scheduler + # and is launched without an sm_budget cap so it can saturate all + # available SMs. Router GEMM still runs with router_sms on the + # aux stream; the CLC will reclaim SMs as the router frees them. + x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) + + # Record dependency so the aux stream sees consistent inputs. + self.event_dict[EventType.Main].record() + + # Step 1: launch Router GEMM on the aux stream. + with torch.cuda.stream(self.aux_stream_dict[AuxStreamType.MoeFc2Alpha]): + self.event_dict[EventType.Main].wait() + m, n = router_input.shape[0], router_weight_t.shape[1] + router_logits = torch.empty(m, n, dtype=torch.float32, device=router_input.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + router_input.contiguous(), + router_weight_t.t().contiguous(), + router_logits, + sm_budget=self.router_sms, + ) + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits + ) + token_selected_experts = token_selected_experts.to(torch.int32) + # Append fused shared experts (always active, scale=1.0) if any. + num_routing_experts = router_weight_t.shape[1] + if self.expert_size_per_partition > num_routing_experts: + n_shared = self.expert_size_per_partition - num_routing_experts + M_tok = router_input.shape[0] + shared_ids = ( + torch.arange( + num_routing_experts, + self.expert_size_per_partition, + dtype=torch.int32, + device=x.device, + ) + .unsqueeze(0) + .expand(M_tok, n_shared) + ) + shared_scales = torch.ones( + M_tok, n_shared, dtype=torch.float32, device=x.device + ) + token_selected_experts = torch.cat([token_selected_experts, shared_ids], dim=1) + token_final_scales = torch.cat([token_final_scales, shared_scales], dim=1) + fc2_alpha = gen_fc2_alpha_fused( + token_selected_experts, + token_final_scales, + self.fc2_alpha, + output=fc2_alpha_buffer, + ) + self.event_dict[EventType.MoeFc2Alpha].record() + + # Step 2: launch FC1 on the main stream. + # Dynamic: no sm_budget — CLC scheduler fills all available SMs. + # Static: sm_budget=fc1_sms — soft cap leaving room for router GEMM. + _fc1_op = ( + torch.ops.trtllm.cute_dsl_nvfp4_dynamic_dense_gemm_swiglu_blackwell + if self.use_dynamic_fc1 + else torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_swiglu_blackwell + ) + fc1_kwargs = dict( + expert_count=self.expert_size_per_partition, + weight_per_expert=2 * self.intermediate_size_per_partition, + output_dtype=torch.float4_e2m1fn_x2, + scaling_vector_size=self.scaling_vector_size, + ) + if not self.use_dynamic_fc1: + fc1_kwargs["sm_budget"] = self.fc1_sms + fc1_output, fc1_output_sf = _fc1_op( + x, + self.w3_w1_weight.view(torch.uint8), + x_sf, + self.w3_w1_weight_scale, + self.fc31_alpha, + None, # alpha_post: no post-SwiGLU scaling + self.fc2_input_scale, + **fc1_kwargs, + ) + + self.event_dict[EventType.MoeFc2Alpha].wait() + + # FC2: input k = expert_count * intermediate_size (after SwiGLU) + final_hidden_states = torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_fc2_blackwell( + fc1_output, + self.w2_weight.view(torch.uint8).reshape(self.hidden_size, -1), + fc1_output_sf.reshape(-1), + self.w2_weight_scale, + fc2_alpha, + expert_count=self.expert_size_per_partition, + weight_per_expert=self.intermediate_size_per_partition, + output_dtype=torch.bfloat16, + scaling_vector_size=self.scaling_vector_size, + ) + + return final_hidden_states + + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + **kwargs, + ) -> torch.Tensor: + """ + Run MoE computation with DenseGEMM backend (NVFP4 only). + + Args: + x: Input hidden states (pre-quantized to NVFP4) + token_selected_experts: Must be None (routing done internally via router_weight_t) + token_final_scales: Must be None (routing done internally) + x_sf: Input scale factors for NVFP4 + enable_alltoall: Whether alltoall communication is enabled. + **kwargs: Must contain 'router_weight_t'; optionally 'router_input'. + + Returns: + final_hidden_states tensor. + """ + assert self.has_nvfp4, ( + f"{self.__class__.__name__} only supports nvfp4 quantization, " + f"got {self.quant_config.quant_mode}." + ) + return self.run_moe_nvfp4( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=enable_alltoall, + router_weight_t=kwargs.get("router_weight_t"), + router_input=kwargs.get("router_input"), + ) + + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: tuple = (True, True), + ) -> torch.Tensor: + # Currently, the default path is that ConfigurableMoE calls DenseGEMMFusedMoE.run_moe. + # This forward_chunk method is a reference implementation of the legacy path. + # Apply routing + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + assert token_selected_experts.shape[1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + x, x_sf = self.quantize_input(x) + + if self.use_dp and self.parallel_size > 1: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens, + ) + + x = self.run_moe( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=False, + ) + return x + + def forward_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert do_finalize, "DenseGEMMFusedMoE does not support do_finalize=False" + + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call), + ) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + + if self.use_dp and self.parallel_size > 1: + rank = self.parallel_rank + outputs = outputs[: all_rank_num_tokens[rank]] + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/green_context.py b/tensorrt_llm/_torch/modules/fused_moe/green_context.py new file mode 100644 index 00000000000..387730b653c --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/green_context.py @@ -0,0 +1,596 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""CUDA Driver API helpers for Workqueue-isolated GreenContext stream creation.""" + +import ctypes +from typing import Optional + +import torch + +_libcuda: Optional[ctypes.CDLL] = None + + +def _get_libcuda() -> ctypes.CDLL: + """Lazily load libcuda.so.1 (CUDA Driver API).""" + global _libcuda + if _libcuda is None: + for name in ("libcuda.so.1", "libcuda.so"): + try: + _libcuda = ctypes.CDLL(name) + break + except OSError: + continue + if _libcuda is None: + raise RuntimeError("Cannot load libcuda (CUDA Driver API) shared library") + return _libcuda + + +def get_current_stream_gc_sm_count() -> int: + """Return the SM count of the GreenContext bound to the current CUDA stream. + + When a kernel is dispatched inside ``torch.cuda.stream(gc_stream)`` where + *gc_stream* was created by ``cuGreenCtxStreamCreate``, the stream carries a + ``CUgreenCtx`` handle that encodes the SM partition assigned to that + GreenContext. This function queries that partition size so callers can + derive ``sm_budget`` without hard-coding it at the call site. + + Workflow: + 1. ``cuStreamGetGreenCtx`` — retrieve the ``CUgreenCtx`` bound to the + current PyTorch stream (fails with ``CUDA_ERROR_INVALID_HANDLE`` for + plain streams, returning ``-1``). + 2. ``cuCtxFromGreenCtx`` — convert ``CUgreenCtx`` to a ``CUcontext`` + handle (does not push/pop; the handle is valid for the lifetime of the + GreenContext). + 3. ``cuCtxGetDevResource(CU_DEV_RESOURCE_TYPE_SM)`` — fill a raw + ``CUdevResource`` buffer and read ``smCount`` at the version-specific + byte offset. + + Returns: + ``smCount`` from the GreenContext's SM resource partition, or ``-1`` if: + - the current stream is not a GC-bound stream, + - ``cuStreamGetGreenCtx`` is unavailable (CUDA < 12.4), or + - any Driver API call fails. + """ + from ctypes import byref, c_int, c_void_p, create_string_buffer + + libcuda = _get_libcuda() + + CU_DEV_RESOURCE_TYPE_SM = 1 + CUDA_SUCCESS = 0 + + if not hasattr(libcuda, "cuStreamGetGreenCtx"): + return -1 # CUDA < 12.4 + + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + # Step 1: get GreenContext handle from the stream. + # Signature: cuStreamGetGreenCtx(CUstream hStream, CUgreenCtx *phCtx) + gc_handle = c_void_p() + if libcuda.cuStreamGetGreenCtx(stream_ptr, byref(gc_handle)) != CUDA_SUCCESS: + return -1 # not a cuGreenCtxStreamCreate-created stream + + # Step 2: derive a CUcontext from the CUgreenCtx. + ctx_handle = c_void_p() + if libcuda.cuCtxFromGreenCtx(byref(ctx_handle), gc_handle) != CUDA_SUCCESS: + return -1 + + # Step 3: determine CUDA version to pick the correct smCount byte offset. + # CUdevResource raw layout: + # CUDA 12.x: { type:u32(4) } + { union(92) } — smCount at offset 4 + # CUDA 13.x: { type:u32(4) } + { padding(92) } + { union(40) } + { ptr(8) } + # — smCount is the first field of the union → offset 4+92 = 96 + cuda_ver = c_int(0) + libcuda.cuDriverGetVersion(byref(cuda_ver)) + if cuda_ver.value >= 13000: + buf_size, sm_count_offset = 144, 96 + else: + buf_size, sm_count_offset = 96, 4 + + res_buf = create_string_buffer(buf_size) + if ( + libcuda.cuCtxGetDevResource(ctx_handle, res_buf, c_int(CU_DEV_RESOURCE_TYPE_SM)) + != CUDA_SUCCESS + ): + return -1 + + # smCount is a little-endian uint32 at the computed offset. + return int.from_bytes(res_buf[sm_count_offset : sm_count_offset + 4], byteorder="little") + + +def create_wq_isolated_gc_streams( + fc1_sms: int, + router_sms: int, + device_id: int, +) -> tuple: + """Create two GreenContext streams with SM *and* Workqueue isolation. + + PyTorch's ``torch.cuda.GreenContext.create()`` only partitions SM resources, + leaving the hardware workqueue shared across GreenContexts. That shared WQ + introduces a ~10 µs dispatch serialisation overhead even when the two kernels + execute on disjoint SM partitions. + + This function bypasses PyTorch and calls the CUDA Driver API directly to + build resource descriptors that include *both* an SM partition and a + ``CU_WORKQUEUE_SCOPE_GREEN_CTX_BALANCED`` workqueue config. The result is + two streams that are truly independent at both the SM and WQ scheduling + levels. + + SM counts are rounded up to the hardware ``smCoscheduledAlignment`` boundary + (queried at runtime). If the requested counts cannot both be satisfied after + alignment, the FC1 partition is shrunk to leave the Router at least one + scheduling unit. + + Args: + fc1_sms: Requested SM count for the FC1 partition. + router_sms: Requested SM count for the Router partition. + device_id: CUDA device index. + + Returns: + ``(fc1_stream, router_stream, cleanup_fn)`` where the two streams are + :class:`torch.cuda.ExternalStream` objects backed by GC-bound + ``CUstream`` handles, and ``cleanup_fn()`` destroys those handles + (call it when the layer is no longer needed). + + Raises: + RuntimeError: if any CUDA Driver API call fails. + """ + from ctypes import ( + Structure, + Union, + addressof, + byref, + c_int, + c_ubyte, + c_uint, + c_void_p, + memmove, + sizeof, + ) + + libcuda = _get_libcuda() + + # ------------------------------------------------------------------ + # CUDA Driver API constants + # ------------------------------------------------------------------ + CU_DEV_RESOURCE_TYPE_SM = 1 + CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG = 1000 + CU_GREEN_CTX_DEFAULT_STREAM = 0x1 + CU_STREAM_NON_BLOCKING = 0x1 + CU_WORKQUEUE_SCOPE_GREEN_CTX_BALANCED = 1 + CUDA_SUCCESS = 0 + + # ------------------------------------------------------------------ + # CUdevResource ctypes layout — version-aware + # + # CUDA 12.x: + # struct { CUdevResourceType type; union { sm; wqConfig; wqResource; char raw[92]; }; } + # CUdevSmResource: smCount, smCoscheduledAlignment, minSmPartitionSize, reserved[13] + # CUdevWorkqueueConfigResource: wqConcurrencyLimit, sharingScope, reserved[6] + # + # CUDA 13.0+: + # struct { CUdevResourceType type; unsigned char _internal_padding[92]; + # union { sm; wqConfig; wq; char raw[40]; }; struct CUdevResource_st* nextResource; } + # CUdevSmResource: smCount, minSmPartitionSize, smCoscheduledAlignment, flags + # CUdevWorkqueueConfigResource: device, wqConcurrencyLimit, sharingScope + # ------------------------------------------------------------------ + cuda_ver = c_int(0) + libcuda.cuDriverGetVersion(byref(cuda_ver)) + _cuda_version = cuda_ver.value # e.g. 13010 for CUDA 13.1, 12060 for CUDA 12.6 + + if _cuda_version >= 13000: + + class _SmData(Structure): + _fields_ = [ + ("smCount", c_uint), + ("minSmPartitionSize", c_uint), + ("smCoscheduledAlignment", c_uint), + ("flags", c_uint), + ] + + class _WqConfigData(Structure): + _fields_ = [ + ("device", c_int), + ("wqConcurrencyLimit", c_uint), + ("sharingScope", c_uint), + ] + + class _ResData(Union): + _fields_ = [ + ("sm", _SmData), + ("wqConfig", _WqConfigData), + # RESOURCE_ABI_BYTES = 40 in CUDA 13.x + ("raw", c_ubyte * 40), + ] + + class CUdevResource(Structure): + _fields_ = [ + ("type", c_uint), + ("_internal_padding", c_ubyte * 92), + ("data", _ResData), + ("nextResource", c_void_p), + ] + else: + + class _SmData(Structure): + _fields_ = [ + ("smCount", c_uint), + ("smCoscheduledAlignment", c_uint), + ("minSmPartitionSize", c_uint), + ("reserved", c_uint * 13), + ] + + class _WqConfigData(Structure): + _fields_ = [ + ("wqConcurrencyLimit", c_uint), + ("sharingScope", c_uint), + ("reserved", c_uint * 6), + ] + + class _ResData(Union): + _fields_ = [ + ("sm", _SmData), + ("wqConfig", _WqConfigData), + # 92 bytes: CUDA 12.x union size + ("raw", c_ubyte * 92), + ] + + class CUdevResource(Structure): + _fields_ = [("type", c_uint), ("data", _ResData)] + + def _check(ret: int, fn_name: str) -> None: + if ret != CUDA_SUCCESS: + raise RuntimeError(f"{fn_name} failed with CUDA error code {ret}") + + # 1. Query device SM resource to get total SM count ------------------- + sm_res = CUdevResource() + _check( + libcuda.cuDeviceGetDevResource( + c_int(device_id), byref(sm_res), c_int(CU_DEV_RESOURCE_TYPE_SM) + ), + "cuDeviceGetDevResource(SM)", + ) + total_sms = sm_res.data.sm.smCount + + # 2. Split SM resources via cuDevSmResourceSplitByCount --------------- + # Carve the Router partition first; FC1 receives the unstructured + # remainder. Reversing the order ensures the router gets an exactly + # sized, aligned partition, while FC1's remainder is still fully usable + # for a GreenContext. + router_sm_result = (CUdevResource * 1)() + fc1_sm_res = CUdevResource() + nbGroups = c_uint(1) + _check( + libcuda.cuDevSmResourceSplitByCount( + router_sm_result, + byref(nbGroups), + byref(sm_res), + byref(fc1_sm_res), + c_uint(0), + c_uint(router_sms), + ), + "cuDevSmResourceSplitByCount(router)", + ) + if nbGroups.value == 0: + raise RuntimeError( + f"cuDevSmResourceSplitByCount returned 0 groups " + f"(router_sms={router_sms}, total_sms={total_sms})" + ) + + # 3. Build WQ config resources with GREEN_CTX_BALANCED scope ---------- + # WQ_CONFIG must be queried with cuDeviceGetDevResource (device-level), not + # cuCtxGetDevResource (context-level). When mixing SM + WQ_CONFIG resources + # in cuDevResourceGenerateDesc, the driver requires both to originate from the + # same device. SM resources come from cuDeviceGetDevResource/cuDevSmResourceSplitByCount; + # using cuCtxGetDevResource for WQ gives them different provenance, causing + # CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION (914). + # CUDA docs: "In case of workqueues, an existing one queried via the + # cuDeviceGetDevResource API." + wq_fc1 = CUdevResource() + _check( + libcuda.cuDeviceGetDevResource( + c_int(device_id), + byref(wq_fc1), + c_int(CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG), + ), + "cuDeviceGetDevResource(WQ_CONFIG)", + ) + wq_router = CUdevResource() + memmove(addressof(wq_router), addressof(wq_fc1), sizeof(CUdevResource)) + # Override only the sharing scope; all other fields (wqConcurrencyLimit, etc.) + # retain the driver-provided values. + wq_fc1.data.wqConfig.sharingScope = CU_WORKQUEUE_SCOPE_GREEN_CTX_BALANCED + wq_router.data.wqConfig.sharingScope = CU_WORKQUEUE_SCOPE_GREEN_CTX_BALANCED + + # 4. Build resource descriptors (SM partition + WQ config, per GC) --- + fc1_res_arr = (CUdevResource * 2)() + router_res_arr = (CUdevResource * 2)() + res_sz = sizeof(CUdevResource) + # fc1: fc1_sm_res (remainder) + wq_fc1 + memmove(addressof(fc1_res_arr), addressof(fc1_sm_res), res_sz) + memmove(addressof(fc1_res_arr) + res_sz, addressof(wq_fc1), res_sz) + # router: router_sm_result[0] + wq_router + memmove(addressof(router_res_arr), addressof(router_sm_result), res_sz) + memmove(addressof(router_res_arr) + res_sz, addressof(wq_router), res_sz) + + fc1_desc = c_void_p() + router_desc = c_void_p() + _check( + libcuda.cuDevResourceGenerateDesc(byref(fc1_desc), fc1_res_arr, c_uint(2)), + "cuDevResourceGenerateDesc(fc1)", + ) + _check( + libcuda.cuDevResourceGenerateDesc(byref(router_desc), router_res_arr, c_uint(2)), + "cuDevResourceGenerateDesc(router)", + ) + + # 5. Create GreenContexts --------------------------------------------- + fc1_gc_h = c_void_p() + router_gc_h = c_void_p() + _check( + libcuda.cuGreenCtxCreate( + byref(fc1_gc_h), + fc1_desc, + c_int(device_id), + c_uint(CU_GREEN_CTX_DEFAULT_STREAM), + ), + "cuGreenCtxCreate(fc1)", + ) + _check( + libcuda.cuGreenCtxCreate( + byref(router_gc_h), + router_desc, + c_int(device_id), + c_uint(CU_GREEN_CTX_DEFAULT_STREAM), + ), + "cuGreenCtxCreate(router)", + ) + + # 6. Create NON_BLOCKING streams bound to each GreenContext ----------- + fc1_stream_ptr = c_void_p() + router_stream_ptr = c_void_p() + _check( + libcuda.cuGreenCtxStreamCreate( + byref(fc1_stream_ptr), + fc1_gc_h, + c_uint(CU_STREAM_NON_BLOCKING), + c_int(0), + ), + "cuGreenCtxStreamCreate(fc1)", + ) + _check( + libcuda.cuGreenCtxStreamCreate( + byref(router_stream_ptr), + router_gc_h, + c_uint(CU_STREAM_NON_BLOCKING), + c_int(0), + ), + "cuGreenCtxStreamCreate(router)", + ) + + # 7. Wrap raw CUstream handles in torch ExternalStream ---------------- + fc1_stream = torch.cuda.ExternalStream(fc1_stream_ptr.value, device=f"cuda:{device_id}") + router_stream = torch.cuda.ExternalStream(router_stream_ptr.value, device=f"cuda:{device_id}") + + # 8. Cleanup closure -- call when the layer is destroyed -------------- + _fc1_gc_val = fc1_gc_h.value + _router_gc_val = router_gc_h.value + _fc1_s_val = fc1_stream_ptr.value + _router_s_val = router_stream_ptr.value + + def _cleanup() -> None: + _lib = _get_libcuda() + _lib.cuStreamDestroy(c_void_p(_fc1_s_val)) + _lib.cuStreamDestroy(c_void_p(_router_s_val)) + _lib.cuGreenCtxDestroy(c_void_p(_fc1_gc_val)) + _lib.cuGreenCtxDestroy(c_void_p(_router_gc_val)) + + return fc1_stream, router_stream, _cleanup + + +def create_sm_only_gc_streams( + fc1_sms: int, + router_sms: int, + device_id: int, +) -> tuple: + """Create GreenContext streams with SM isolation only (no WQ isolation). + + This is a CUDA-Graph-compatible alternative to the PyTorch + ``torch.cuda.GreenContext`` + ``torch.cuda.Stream()`` path. + + **Why PyTorch streams break CUDA Graph:** + ``torch.cuda.Stream()`` created inside ``GreenContext.set_context() / pop_context()`` + lives in a ``cuCtxFromGreenCtx``-derived regular context. CUDA Graph + capture/replay runs on the primary context's default stream. When the graph is + replayed, streams belonging to the GC-derived context silently **lose their SM + partition** — FC1 and Router CTAs compete for all SMs, adding ~10 µs of latency. + + **Why this helper works:** + Streams created by ``cuGreenCtxStreamCreate`` are bound directly to the + ``CUgreenCtx`` handle itself (not to any derived context). The SM partition + encoded in the ``CUgreenCtx`` is preserved across CUDA Graph capture and replay + because the stream identity remains stable. + + The difference from :func:`create_wq_isolated_gc_streams` is that only the SM + resource is included in ``cuDevResourceGenerateDesc`` (count=1). The hardware + workqueue remains shared across the two GreenContexts. + + SM counts are rounded up to the hardware ``smCoscheduledAlignment`` boundary via + ``cuDevSmResourceSplitByCount``. + + Args: + fc1_sms: Requested SM count for the FC1 partition. + router_sms: Requested SM count for the Router partition (informational only; + router gets the remainder after FC1 is carved out). + device_id: CUDA device index. + + Returns: + ``(fc1_stream, router_stream, cleanup_fn)`` where the two streams are + :class:`torch.cuda.ExternalStream` objects backed by GC-bound ``CUstream`` + handles, and ``cleanup_fn()`` destroys those handles and the GreenContexts. + + Raises: + RuntimeError: if any CUDA Driver API call fails. + """ + from ctypes import Structure, Union, byref, c_int, c_ubyte, c_uint, c_void_p + + libcuda = _get_libcuda() + + CU_DEV_RESOURCE_TYPE_SM = 1 + CU_GREEN_CTX_DEFAULT_STREAM = 0x1 + CU_STREAM_NON_BLOCKING = 0x1 + CUDA_SUCCESS = 0 + + cuda_ver = c_int(0) + libcuda.cuDriverGetVersion(byref(cuda_ver)) + _cuda_version = cuda_ver.value + + # Reuse the same version-aware CUdevResource layout as create_wq_isolated_gc_streams. + if _cuda_version >= 13000: + + class _SmData(Structure): + _fields_ = [ + ("smCount", c_uint), + ("minSmPartitionSize", c_uint), + ("smCoscheduledAlignment", c_uint), + ("flags", c_uint), + ] + + class _ResData(Union): + _fields_ = [ + ("sm", _SmData), + ("raw", c_ubyte * 40), + ] + + class CUdevResource(Structure): + _fields_ = [ + ("type", c_uint), + ("_internal_padding", c_ubyte * 92), + ("data", _ResData), + ("nextResource", c_void_p), + ] + else: + + class _SmData(Structure): + _fields_ = [ + ("smCount", c_uint), + ("smCoscheduledAlignment", c_uint), + ("minSmPartitionSize", c_uint), + ("reserved", c_uint * 13), + ] + + class _ResData(Union): + _fields_ = [ + ("sm", _SmData), + ("raw", c_ubyte * 92), + ] + + class CUdevResource(Structure): + _fields_ = [("type", c_uint), ("data", _ResData)] + + def _check(ret: int, fn_name: str) -> None: + if ret != CUDA_SUCCESS: + raise RuntimeError(f"{fn_name} failed with CUDA error code {ret}") + + # 1. Query total SM resource. + sm_res = CUdevResource() + _check( + libcuda.cuDeviceGetDevResource( + c_int(device_id), byref(sm_res), c_int(CU_DEV_RESOURCE_TYPE_SM) + ), + "cuDeviceGetDevResource(SM)", + ) + total_sms = sm_res.data.sm.smCount + + # 2. Split SMs: carve Router partition first; FC1 gets the remainder. + router_sm_result = (CUdevResource * 1)() + fc1_sm_res = CUdevResource() + nbGroups = c_uint(1) + _check( + libcuda.cuDevSmResourceSplitByCount( + router_sm_result, + byref(nbGroups), + byref(sm_res), + byref(fc1_sm_res), + c_uint(0), + c_uint(router_sms), + ), + "cuDevSmResourceSplitByCount(router)", + ) + if nbGroups.value == 0: + raise RuntimeError( + f"cuDevSmResourceSplitByCount returned 0 groups " + f"(router_sms={router_sms}, total_sms={total_sms})" + ) + + # 3. Build resource descriptors with SM only (count=1, no WQ config). + fc1_desc = c_void_p() + router_desc = c_void_p() + _check( + libcuda.cuDevResourceGenerateDesc(byref(fc1_desc), byref(fc1_sm_res), c_uint(1)), + "cuDevResourceGenerateDesc(fc1)", + ) + _check( + libcuda.cuDevResourceGenerateDesc(byref(router_desc), router_sm_result, c_uint(1)), + "cuDevResourceGenerateDesc(router)", + ) + + # 4. Create GreenContexts. + fc1_gc_h = c_void_p() + router_gc_h = c_void_p() + _check( + libcuda.cuGreenCtxCreate( + byref(fc1_gc_h), + fc1_desc, + c_int(device_id), + c_uint(CU_GREEN_CTX_DEFAULT_STREAM), + ), + "cuGreenCtxCreate(fc1)", + ) + _check( + libcuda.cuGreenCtxCreate( + byref(router_gc_h), + router_desc, + c_int(device_id), + c_uint(CU_GREEN_CTX_DEFAULT_STREAM), + ), + "cuGreenCtxCreate(router)", + ) + + # 5. Create NON_BLOCKING streams bound directly to each CUgreenCtx. + fc1_stream_ptr = c_void_p() + router_stream_ptr = c_void_p() + _check( + libcuda.cuGreenCtxStreamCreate( + byref(fc1_stream_ptr), + fc1_gc_h, + c_uint(CU_STREAM_NON_BLOCKING), + c_int(0), + ), + "cuGreenCtxStreamCreate(fc1)", + ) + _check( + libcuda.cuGreenCtxStreamCreate( + byref(router_stream_ptr), + router_gc_h, + c_uint(CU_STREAM_NON_BLOCKING), + c_int(0), + ), + "cuGreenCtxStreamCreate(router)", + ) + + fc1_stream = torch.cuda.ExternalStream(fc1_stream_ptr.value, device=f"cuda:{device_id}") + router_stream = torch.cuda.ExternalStream(router_stream_ptr.value, device=f"cuda:{device_id}") + + _fc1_gc_val = fc1_gc_h.value + _router_gc_val = router_gc_h.value + _fc1_s_val = fc1_stream_ptr.value + _router_s_val = router_stream_ptr.value + + def _cleanup() -> None: + _lib = _get_libcuda() + _lib.cuStreamDestroy(c_void_p(_fc1_s_val)) + _lib.cuStreamDestroy(c_void_p(_router_s_val)) + _lib.cuGreenCtxDestroy(c_void_p(_fc1_gc_val)) + _lib.cuGreenCtxDestroy(c_void_p(_router_gc_val)) + + return fc1_stream, router_stream, _cleanup diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1efb539fa01..48238ce86fe 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -470,7 +470,27 @@ def create_weights(self, module: Linear, in_features: int, def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): - if module.use_custom_cublas_mm: + # CuTe DSL BF16 GEMM path for Blackwell + if (module.use_cute_dsl_bf16_gemm and is_sm_100f() + and module.weight.dtype == torch.bfloat16): + # input: [*, K], weight: [N, K], output: [*, N] + input_2d = input.view(-1, input.shape[-1]) # [M, K] + m, k = input_2d.shape + n = module.weight.shape[0] + output = torch.empty(m, + n, + dtype=torch.bfloat16, + device=input.device) + torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell( + input_2d.contiguous(), + module.weight, + output, + ) + # Reshape output back to match input batch dims + output = output.view(*input.shape[:-1], n) + if bias is not None: + output = output + bias + elif module.use_custom_cublas_mm: output = torch.ops.trtllm.cublas_mm(input, module.weight.t(), bias, @@ -2467,6 +2487,7 @@ def __init__( reduce_output: bool = True, # ROW parallel only skip_create_weights_in_init: bool = False, use_custom_cublas_mm: bool = False, + use_cute_dsl_bf16_gemm: bool = False, lora: Optional[LoraLayer] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, @@ -2542,6 +2563,7 @@ def __init__( self._weights_created = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm + self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm self.lora = lora mpi_enabled = not mpi_disabled() diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index b3c57a29f75..336ca06c4c4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -486,6 +486,8 @@ def _load_and_validate_config( use_cute_dsl_blockscaling_bmm=self.llm_args. use_cute_dsl_blockscaling_bmm, video_pruning_rate=self.llm_args.video_pruning_rate, + use_cute_dsl_bf16_bmm=self.llm_args.use_cute_dsl_bf16_bmm, + use_cute_dsl_bf16_gemm=self.llm_args.use_cute_dsl_bf16_gemm, ) # Only pass model_kwargs if it's explicitly set (not None) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index e9df1bf2f25..aa013db9bb4 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -298,7 +298,10 @@ def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: def deep_gemm_gen_tuning_buckets(x: int): - buckets = tuple(range(8, 128, 8)) + # Include 1 as the first bucket so that small token counts (1, 2, 4, ...) + # all resolve to the same [min=1, opt=1, max=8] profile and avoid separate + # per-num_tokens inner autotuner sweeps for each outer token bucket. + buckets = (1, 2, 4) + tuple(range(8, 128, 8)) # Clamp x to be between 4096 and 8192. if x >= 128: x = min(x, 8192) @@ -307,6 +310,51 @@ def deep_gemm_gen_tuning_buckets(x: int): return buckets +def deep_gemm_tuning_buckets(x: int): + # Include 1 as the first bucket so that small token counts (1, 2, 4, ...) + # all resolve to the same [min=1, opt=1, max=8] profile and avoid separate + # per-num_tokens inner autotuner sweeps for each outer token bucket. + # Sub-128 buckets: 1, 2, 4, 8, 16, 32, 64, 96 (step-32 from 32 onward) + buckets = (1, 2, 4, 8, 16) + tuple(range(32, 128, 32)) + # Clamp x to be between 4096 and 8192. + if x >= 128: + x = min(x, 8192) + x = max(x, 4096) + buckets += tuple(range(128, x, 128)) + return buckets + + +def prev_deep_gemm_bucket(x: int) -> int: + """Return the largest deep_gemm_tuning_buckets value <= x (floor lookup). + + Mirrors the bucket layout of deep_gemm_tuning_buckets: + [1, 2, 4, 8, 16, 32, 64, 96] step-32 range + [128, 256, 384, ...] step-128 range + + Used as ``map_to_tuning_buckets`` so that at runtime an actual token count + is mapped to the largest cached bucket that is <= x. + """ + if x < 2: + return 1 + if x < 4: + return 2 + if x < 8: + return 4 + if x < 16: + return 8 + if x < 32: + return 16 + if x < 64: + return 32 + if x < 96: + return 64 + if x < 128: + # x in [96, 128): nearest bucket below is 96 + return 96 + # x >= 128: floor to nearest multiple of 128 + return ((x + 127) // 128) * 128 + + def fp4_scale_infer_shape(input_shapes: List[List[int]]): """Calculate the dimensions of the fp4 scale tensor. """ diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f001b4e5ae..ef2f436fef1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -90,7 +90,6 @@ def Field(default: Any = ..., Returns: A Pydantic FieldInfo object with the status added to json_schema_extra if provided """ - if status is not None: json_schema_extra = kwargs.get('json_schema_extra', {}) if isinstance(json_schema_extra, dict): @@ -104,8 +103,7 @@ def Field(default: Any = ..., class CudaGraphConfig(StrictBaseModel): - """ - Configuration for CUDA graphs. + """Configuration for CUDA graphs. """ # List of batch sizes to create CUDA graphs for. batch_sizes: Optional[List[int]] = Field( @@ -247,8 +245,7 @@ class GuidedDecodingBackend(Enum): class BaseSparseAttentionConfig(StrictBaseModel): - """ - Configuration for sparse attention. + """Configuration for sparse attention. """ algorithm: str @@ -259,8 +256,7 @@ class BaseSparseAttentionConfig(StrictBaseModel): ) def supports_backend(self, backend: str) -> bool: - """ - Override if the sparse attention algorithm does not support + """Override if the sparse attention algorithm does not support a subset of the possible backends. """ return True @@ -269,8 +265,7 @@ def get_indices_block_size(self) -> int: return 1 def needs_separate_short_long_cuda_graphs(self) -> bool: - """ - Determines whether to capture a dedicated CUDA graph for batches consisting entirely of short sequences. + """Determines whether to capture a dedicated CUDA graph for batches consisting entirely of short sequences. If True, capture distinct graphs for short-only batches and general cases (e.g., long or mixed batches). If False, capture a single unified CUDA graph for all sequences regardless of length. The seq_len_threshold parameter defines the cutoff boundary between short and long sequences. @@ -279,8 +274,7 @@ def needs_separate_short_long_cuda_graphs(self) -> bool: class RocketSparseAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for RocketKV sparse attention. + """Configuration for RocketKV sparse attention. """ algorithm: Literal["rocket"] = "rocket" window_size: Optional[int] = Field( @@ -306,8 +300,7 @@ def get_indices_block_size(self) -> int: class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for DeepSeek Sparse Attention. + """Configuration for DeepSeek Sparse Attention. """ algorithm: Literal["dsa"] = "dsa" index_n_heads: Optional[int] = Field( @@ -347,8 +340,7 @@ def supports_backend(self, backend: str) -> bool: return backend == "pytorch" def needs_separate_short_long_cuda_graphs(self) -> bool: - """ - Whether to capture separate CUDA graphs for short and long sequences. + """Whether to capture separate CUDA graphs for short and long sequences. Use seq_len_threshold to determine the threshold for separating short and long sequences. """ self.seq_len_threshold = self.index_topk @@ -356,8 +348,7 @@ def needs_separate_short_long_cuda_graphs(self) -> bool: class SkipSoftmaxAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for skip softmax attention. + """Configuration for skip softmax attention. """ algorithm: Literal["skip_softmax"] = "skip_softmax" threshold_scale_factor: Optional[Union[float, Dict[str, float]]] = Field( @@ -436,8 +427,7 @@ def _compute(phase: str, sparsity: Optional[float]) -> Optional[float]: class MoeLoadBalancerConfig(StrictBaseModel): - """ - Pydantic configuration model for the Mixture of Experts (MoE) load balancer. + """Pydantic configuration model for the Mixture of Experts (MoE) load balancer. This model holds configuration data (`num_slots`, etc.) as well as runtime state (`_ep_rank`, `_ep_size`) which must be set via the @@ -456,8 +446,7 @@ class MoeLoadBalancerConfig(StrictBaseModel): # --- Methods --- def setup(self, ep_rank: int, ep_size: int) -> None: - """ - Initializes the runtime state of the configuration. + """Initializes the runtime state of the configuration. This must be called before accessing properties like `num_local_slots`. """ self._ep_rank = ep_rank @@ -513,8 +502,7 @@ def slot_end(self) -> int: def get_layer_initial_global_assignments( self, layer_idx: int) -> Optional[List[int]]: - """ - Retrieves the initial global assignments for a specific layer. + """Retrieves the initial global assignments for a specific layer. """ if self.initial_global_assignments is None: return None @@ -539,8 +527,7 @@ def get_layer_initial_global_assignments( class MoeConfig(StrictBaseModel): - """ - Configuration for MoE. + """Configuration for MoE. """ backend: Literal[ "AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", @@ -578,8 +565,7 @@ class MoeConfig(StrictBaseModel): class Nvfp4GemmConfig(StrictBaseModel): - """ - Configuration for NVFP4 GEMM backend selection. + """Configuration for NVFP4 GEMM backend selection. """ allowed_backends: List[Nvfp4Backend] = Field( default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'], @@ -591,8 +577,7 @@ class Nvfp4GemmConfig(StrictBaseModel): class AttentionDpConfig(StrictBaseModel): - """ - Configuration for attention DP. + """Configuration for attention DP. """ enable_balance: bool = Field(default=False, description="Whether to enable balance.") @@ -629,8 +614,7 @@ def validate_attention_dp_config(self) -> 'AttentionDpConfig': class CpConfig(StrictBaseModel): - """ - Configuration for context parallelism. + """Configuration for context parallelism. """ # TODO: given that multiple fields here are only used with specific cp_types, consider # making this a Pydantic discriminated union. @@ -737,8 +721,7 @@ def to_mapping(self) -> Mapping: class CalibConfig(StrictBaseModel): - """ - Calibration configuration. + """Calibration configuration. """ device: Literal['cuda', 'cpu'] = Field(default='cuda', @@ -916,8 +899,7 @@ def validate_max_concurrency_and_draft_len_schedule_mutually_exclusive( return self def supports_backend(self, backend: str) -> bool: - """ - Override if the speculation algorithm does not support + """Override if the speculation algorithm does not support a subset of the possible backends. """ return True @@ -946,8 +928,7 @@ def num_capture_layers(self) -> int: class KvCacheConnectorConfig(StrictBaseModel): - """ - Configuration for the KV Cache Connector. + """Configuration for the KV Cache Connector. """ connector_module: str = Field( ..., @@ -961,8 +942,7 @@ class KvCacheConnectorConfig(StrictBaseModel): class LayerwiseBenchmarksConfig(StrictBaseModel): - """ - Configuration for layer-wise benchmarks calibration. + """Configuration for layer-wise benchmarks calibration. """ calibration_mode: Literal["NONE", "MARK", "COLLECT"] = Field( default="NONE", @@ -1170,8 +1150,7 @@ def spec_dec_mode(self): @functools.cached_property def num_capture_layers(self) -> int: - """ - Returns the number of layers to capture of the target model. + """Returns the number of layers to capture of the target model. If eagle3_layers_to_capture is not None, return the length of the set. Otherwise, assume Eagle3 base set and return 3. """ @@ -1276,8 +1255,7 @@ def spec_dec_mode(self): @functools.cached_property def num_capture_layers(self): - """ - Returns the number of layers to save. + """Returns the number of layers to save. The following hidden states are saved: - If eagle3_layers_to_capture is None, save the eagle3 base set plus the post norm last hidden state. @@ -1318,8 +1296,7 @@ def set_max_total_draft_tokens(self): class NGramDecodingConfig(DecodingBaseConfig): - """ - Configuration for NGram drafter speculative decoding. + """Configuration for NGram drafter speculative decoding. """ decoding_type: Literal["NGram"] = "NGram" max_matching_ngram_size: PositiveInt = Field( @@ -1353,7 +1330,7 @@ def supports_backend(self, backend: str) -> bool: class SADecodingConfig(DecodingBaseConfig): - """Configuration for standalone Suffix Automaton (SA) speculative decoding. + """Configuration for standalone Suffix Automaton (SA) speculative decoding (one-model design). Uses a GPU-native suffix automaton for pattern matching. Drafting runs inside the target model forward; supports CUDA graph and overlap scheduler. @@ -1582,8 +1559,7 @@ def spec_dec_mode(self): class AutoDecodingConfig(DecodingBaseConfig): - """ - Configuration for auto speculative decoding. + """Configuration for auto speculative decoding. This config will automatically select a good, draft-model free speculation algorithm with some heuristic. @@ -1603,8 +1579,7 @@ def supports_backend(self, backend: str) -> bool: class RayPlacementConfig(StrictBaseModel): - """ - Configuration for Ray GPU workers placement. + """Configuration for Ray GPU workers placement. Currently, this config is only used with AsyncLLM for RL scenarios. """ defer_workers_init: bool = Field( @@ -1670,8 +1645,8 @@ def validate_ray_placement(self) -> 'RayPlacementConfig': class ExecutorMemoryType(StrEnum): """Types of GPU memory used by executor. - These are used by the sleep/wakeup feature to target specific type of memory. - """ + These are used by the sleep/wakeup feature to target specific type of memory. + """ SAMPLER = "sampler" DRAFTER = "drafter" GUIDED_DECODER = "guided_decoder" @@ -1781,9 +1756,9 @@ def _validate_restore_modes(cls, v): class PybindMirror(ABC): - ''' A class containing the utilities for mirroring Python classes to + """A class containing the utilities for mirroring Python classes to pybind classes. - ''' + """ @abstractmethod def _to_pybind(self): @@ -1799,8 +1774,7 @@ def maybe_to_pybind(ins): @staticmethod def mirror_pybind_fields(pybind_class): - """ - Class decorator that ensures Python class fields mirror those of a C++ class. + """Class decorator that ensures Python class fields mirror those of a C++ class. Args: pybind_class: The C++ class whose fields should be mirrored @@ -1829,7 +1803,7 @@ def decorator(cls): @staticmethod def get_pybind_enum_fields(pybind_class): - ''' Get all the enum fields from the pybind class. ''' + """Get all the enum fields from the pybind class.""" return [ f for f in pybind_class.__members__.keys() if not f.startswith('_') and not callable(getattr(pybind_class, f)) @@ -1837,7 +1811,7 @@ def get_pybind_enum_fields(pybind_class): @staticmethod def mirror_pybind_enum(pybind_class): - ''' Mirror the enum fields from the pybind class to the Python class. ''' + """Mirror the enum fields from the pybind class to the Python class.""" def decorator(cls): assert issubclass(cls, Enum) @@ -1855,7 +1829,7 @@ def decorator(cls): @staticmethod def get_pybind_variable_fields(config_cls): - ''' Get all the variable fields from the pybind class. ''' + """Get all the variable fields from the pybind class.""" return [ f for f in dir(config_cls) if not f.startswith('_') and not callable(getattr(config_cls, f)) @@ -1863,7 +1837,7 @@ def get_pybind_variable_fields(config_cls): @staticmethod def pybind_equals(obj0, obj1): - ''' Check if two pybind objects are equal. ''' + """Check if two pybind objects are equal.""" assert type(obj0) is type(obj1) for field in PybindMirror.get_pybind_variable_fields(type(obj0)): if getattr(obj0, field) != getattr(obj1, field): @@ -1932,8 +1906,7 @@ class PybindMirrorMeta(type(PybindMirror)): class PybindMirrorEnumMeta(EnumMeta, PybindMirrorMeta): - """ - Combined metaclass for Enum and PybindMirror. This is crucial. + """Combined metaclass for Enum and PybindMirror. This is crucial. """ @@ -1958,7 +1931,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_enum(_ContextChunkingPolicy) class ContextChunkingPolicy(StrEnum, metaclass=PybindMirrorEnumMeta): - ''' Context chunking policy. ''' + """Context chunking policy.""" FIRST_COME_FIRST_SERVED = "FIRST_COME_FIRST_SERVED" EQUAL_PROGRESS = "EQUAL_PROGRESS" FORCE_CHUNK = "FORCE_CHUNK" @@ -2038,8 +2011,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_PeftCacheConfig) class PeftCacheConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the PEFT cache. + """Configuration for the PEFT cache. """ num_host_module_layer: int = Field( default=0, @@ -2108,8 +2080,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig) class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror): - """ - Configuration for lookahead speculative decoding. + """Configuration for lookahead speculative decoding. """ decoding_type: Literal["Lookahead"] = "Lookahead" @@ -2174,8 +2145,7 @@ def supports_backend(self, backend: str) -> bool: @PybindMirror.mirror_pybind_fields(_KvCacheConfig) class KvCacheConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the KV cache. + """Configuration for the KV cache. """ enable_block_reuse: bool = Field( default=True, @@ -2376,8 +2346,7 @@ def validate_max_util_for_resume(cls, v: float): @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror): - """ - Configuration for extended runtime performance knobs. + """Configuration for extended runtime performance knobs. """ multi_block_mode: bool = Field( @@ -2407,8 +2376,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig) class CacheTransceiverConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the cache transceiver. + """Configuration for the cache transceiver. """ backend: Optional[Literal[ @@ -2520,8 +2488,7 @@ class DwdpConfig(StrictBaseModel): class BaseLlmArgs(StrictBaseModel): - """ - Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. + """Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") @@ -3037,8 +3004,7 @@ class TrtLlmArgs(BaseLlmArgs): @model_validator(mode="after") def init_build_config(self): - """ - Creating a default BuildConfig if none is provided + """Creating a default BuildConfig if none is provided """ build_config = getattr(self, "build_config", None) if build_config is None: @@ -3204,12 +3170,11 @@ def _load_config_from_ckpt(self, ckpt_dir: Path): @model_validator(mode="after") def validate_model_format_misc(self): - ''' - Load the model format, and do the following: + """Load the model format, and do the following: 1. Load the build_config if got an engine. 2. Load the parallel_config if got a checkpoint. - ''' + """ model_obj = _ModelWrapper(self.model) if model_obj.is_local_model and self.backend not in [ @@ -3316,8 +3281,7 @@ class SamplerType(StrEnum): class TorchCompileConfig(StrictBaseModel): - """ - Configuration for torch.compile. + """Configuration for torch.compile. """ enable_fullgraph: bool = Field( default=True, @@ -3587,6 +3551,18 @@ class TorchLlmArgs(BaseLlmArgs): description="If true, use CuTe DSL fp8 blockscaling bmm implementation.", status="prototype", ) + use_cute_dsl_bf16_bmm: bool = Field( + default=False, + description= + "If true, use CuTe DSL bf16 persistent GEMM for BMM on Blackwell.", + status="prototype", + ) + use_cute_dsl_bf16_gemm: bool = Field( + default=False, + description= + "If true, use CuTe DSL bf16 persistent GEMM for Linear layers on Blackwell.", + status="prototype", + ) # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) @@ -3836,7 +3812,7 @@ def warn_on_unstable_feature_usage(self) -> 'TorchLlmArgs': def validate_ray_worker_extension_cls(self) -> 'TorchLlmArgs': if self.ray_worker_extension_cls is not None and self.orchestrator_type != "ray": raise ValueError( - f"ray_worker_extension_cls is only supported with orchestrator_type='ray'" + "ray_worker_extension_cls is only supported with orchestrator_type='ray'" ) return self @@ -3929,7 +3905,7 @@ def update_llm_args_with_extra_options(llm_args: Dict, def get_model_format(model_dir: str, trust_remote_code: bool = False) -> _ModelFormatKind: - ''' Get the format of the model. ''' + """Get the format of the model.""" if not (Path(model_dir) / 'config.json').exists(): raise ValueError( f"Failed to infer model format because no config.json exists in {model_dir}" diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2c88b1e11f8..03ace30a00e 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -25,8 +25,7 @@ def patch_mpi_pool_session_for_env(mocker, env_vars: dict): - """ - Patch MpiPoolSession._start_mpi_pool to propagate environment variables to MPI child processes. + """Patch MpiPoolSession._start_mpi_pool to propagate environment variables to MPI child processes. Uses MPIPoolExecutor's built-in `env` parameter instead of `initializer` to avoid segfault issues during process cleanup (UCX memory cache conflicts with PyTorch @@ -152,7 +151,7 @@ def test_nvfp4_with_norm_quant(self, monkeypatch): sm_version = get_sm_version() if sm_version not in (100, 103): pytest.skip( - f"test_nvfp4_with_norm_quant supports SM 100 and 103 only") + "test_nvfp4_with_norm_quant supports SM 100 and 103 only") monkeypatch.setenv("TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "0") assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 task = CnnDailymail(self.MODEL_NAME) @@ -1962,6 +1961,42 @@ def test_cute_dsl_fp8_block_scales( task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @parametrize_with_ids("cuda_graph", [False, True]) + def test_cute_dsl_bf16_bmm(self, cuda_graph): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + use_cute_dsl_bf16_bmm=True, + ) + + with LLM( + self.MODEL_PATH, + kv_cache_config=kv_cache_config, + **pytorch_config, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + @parametrize_with_ids("cuda_graph", [False, True]) + def test_cute_dsl_bf16_gemm(self, cuda_graph): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + use_cute_dsl_bf16_gemm=True, + ) + + with LLM( + self.MODEL_PATH, + kv_cache_config=kv_cache_config, + **pytorch_config, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_hopper @parametrize_with_ids("mtp_nextn", [0, 2]) def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn): @@ -2128,6 +2163,56 @@ def test_cute_dsl_fp8_block_scales_4gpus( task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + @parametrize_with_ids("cuda_graph", [False, True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(4, 1, 1), (4, 1, 4)], + ids=["tp4", "ep4"]) + def test_cute_dsl_bf16_bmm_4gpus(self, tp_size, pp_size, ep_size, + cuda_graph): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + use_cute_dsl_bf16_bmm=True, + ) + + with LLM( + self.MODEL_PATH, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + @parametrize_with_ids("cuda_graph", [False, True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(4, 1, 1), (4, 1, 4)], + ids=["tp4", "ep4"]) + def test_cute_dsl_bf16_gemm_4gpus(self, tp_size, pp_size, ep_size, + cuda_graph): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + use_cute_dsl_bf16_gemm=True, + ) + + with LLM( + self.MODEL_PATH, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @skip_pre_hopper def test_fp8_block_scales_4gpus_static_eplb(self): @@ -2920,8 +3005,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size, @skip_pre_blackwell @pytest.mark.skip_less_device(8) def test_nvfp4_multi_gpus_corner_case(self): - """ - This test is used to test the corner case of the NVFP4 model. + """This test is used to test the corner case of the NVFP4 model. When using the same value for max_seq_len and max_num_tokens, there will be no enough kv block for the dummy requests in CUDA graph warmup when creating the py_executor before estimating kv cache. Then CUDA graph capture will be @@ -3770,8 +3854,7 @@ def test_nvfp4(self, tp_size): "ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning" ) def test_nvfp4_longseq_trtllm_moe_stress(self, mocker): - """ - Long-sequence MoE stress test with PDL enabled. + """Long-sequence MoE stress test with PDL enabled. RCCA: https://nvbugspro.nvidia.com/bug/5661741 """ patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"}) @@ -3852,8 +3935,7 @@ def test_nvfp4_longseq_trtllm_moe_stress(self, mocker): "ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning" ) def test_nvfp4_longseq_trtllm_moe_async_cancel(self, mocker): - """ - Long-sequence MoE async streaming test with cancellation. + """Long-sequence MoE async streaming test with cancellation. RCCA: https://nvbugspro.nvidia.com/bug/5661741 """ patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"}) @@ -4305,7 +4387,7 @@ class TestQwen3_4B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-4B" def test_eagle3(self): - "RCCA: https://nvbugspro.nvidia.com/bug/5698434" + """RCCA: https://nvbugspro.nvidia.com/bug/5698434""" pytorch_config = dict( disable_overlap_scheduler=True, cuda_graph_config=CudaGraphConfig(), @@ -4522,7 +4604,7 @@ def test_dummy_load_format(self): ids=["latency"]) def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, torch_compile): - "RCCA: https://nvbugspro.nvidia.com/bug/5284463" + """RCCA: https://nvbugspro.nvidia.com/bug/5284463""" "Need to check Ada support" torch_compile_config = _get_default_torch_compile_config(torch_compile) @@ -5001,7 +5083,7 @@ class TestKanana_Instruct(LlmapiAccuracyTestHarness): @pytest.mark.skip_device_not_contain(["H20", "H100"]) def test_auto_dtype(self): - "RCCA: https://nvbugspro.nvidia.com/bug/5310520" + """RCCA: https://nvbugspro.nvidia.com/bug/5310520""" pytorch_config = dict(cuda_graph_config=CudaGraphConfig( enable_padding=True, max_batch_size=384)) with LLM(self.MODEL_PATH, **pytorch_config, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 229a7f1e180..be4f3812dd6 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -557,6 +557,20 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tra disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] +# CuTe DSL BF16 BMM/GEMM tests +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm[cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm[cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm_4gpus[tp4-cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm_4gpus[tp4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm_4gpus[ep4-cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_bmm_4gpus[ep4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm[cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm[cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm_4gpus[tp4-cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm_4gpus[tp4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm_4gpus[ep4-cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_cute_dsl_bf16_gemm_4gpus[ep4-cuda_graph=True] + # llm-api promote pytorch to default llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_tensorrt diff --git a/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py b/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py index 81d52d24b22..7d3fcbf1eff 100644 --- a/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py +++ b/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py @@ -102,6 +102,7 @@ def run( skip_ref_check: bool = False, use_cold_l2: bool = False, use_cupti: bool = True, + split_k: int = 1, **kwargs, ): """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture. @@ -164,6 +165,7 @@ def run( print(f"Skip reference checking: {skip_ref_check}") print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") print(f"Use CUPTI: {'True' if use_cupti else 'False'}") + print(f"Split-K: {split_k}") # Skip unsupported testcase if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( @@ -350,8 +352,13 @@ def create_alpha_scale_tensor(l, m, n, expert_count, dtype): # noqa: E741 weight_per_expert, use_prefetch, prefetch_dist, + split_k, ) + # For split-K > 1: zero-initialize C (kernel uses atomic add for reduction) + if split_k > 1: + c_torch.zero_() + # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( @@ -438,9 +445,12 @@ def generate_tensors(): b_tensor, _ = cutlass_torch.cute_tensor_like( b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 ) - c_tensor, _ = cutlass_torch.cute_tensor_like( + c_tensor, c_torch_gen = cutlass_torch.cute_tensor_like( c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 ) + # Zero-init C for split-K (atomic adds accumulate onto initial values) + if split_k > 1: + c_torch_gen.zero_() # Mark tensor to be byte aligned a_tensor.mark_compact_shape_dynamic( @@ -575,6 +585,12 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: default=True, help="Use CUPTI for profiling (default: True)", ) + parser.add_argument( + "--split_k", + type=int, + default=1, + help="Split-K factor (default: 1)", + ) args = parser.parse_args() if len(args.mnkl) != 4: @@ -606,6 +622,7 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: args.skip_ref_check, args.use_cold_l2, args.use_cupti, + args.split_k, ) print(f"Execution time: {exec_time:.2f} us") print("PASS") diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index f436ead7c82..7708f4099f2 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -255,6 +255,14 @@ methods: annotation: bool default: False status: prototype + use_cute_dsl_bf16_bmm: + annotation: bool + default: False + status: prototype + use_cute_dsl_bf16_gemm: + annotation: bool + default: False + status: prototype return_annotation: None generate: parameters: