From da642d0d2633f9092af19d7d2e5dba438e96d788 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 21 Apr 2026 11:21:13 +0800 Subject: [PATCH] Update: migrate scope3 sort/gather to tensor-level API and reduce MAX_SEQ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace pl.tile.sort32/mrgsort/gather + explicit pl.load/pl.store with pl.tensor.sort32/mrgsort/gather + pl.slice/pl.assemble, adapting to the new tensor-level ops merged in pypto (#1097) - Reduce MAX_SEQ from 8192 to 4096; introduce SORT_LEN=8192 to keep the sort buffer at full width — scores tensor is [BATCH, SORT_LEN] and Stage 0 fills the entire row with -inf so the [MAX_SEQ, SORT_LEN) tail is always -inf without an extra fillpad in the sort kernel - idx_init signature changed to pl.UINT32 (required by tensor.sort32); TensorSpec keeps torch.int32 (same bit layout, matches simpler runtime) --- .../deepseek_v3_2_decode_front_scope3.py | 75 +++++++++---------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py index edb80d1..dae83df 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py @@ -33,7 +33,7 @@ BATCH = 16 -MAX_SEQ = 8192 +MAX_SEQ = 4096 INDEX_HEAD_DIM = 128 INDEX_TOPK = 2048 CACHE_ROWS_IDX = BATCH * MAX_SEQ @@ -45,7 +45,10 @@ Q_VALID = 1 Q_PAD = 16 -# sort32 + 4 mrgsort iterations (block_len 64,256,1024,4096) sort MAX_SEQ=8192. +# sort32 + 4 mrgsort iterations (block_len 64,256,1024,4096) sort SORT_LEN=8192. +# SORT_LEN > MAX_SEQ so the full sort buffer is pre-filled with -inf (Stage 0) +# and only [0, ctx_len) contains real scores; tail stays -inf. +SORT_LEN = 8192 MRGSORT_ITERS = 4 # -inf sentinel for score tail. FP32 lowest, since ptoas rejects literal -inf. @@ -61,7 +64,7 @@ def deepseek_v3_2_decode_front_scope3( q_idx: pl.Tensor[[BATCH, INDEX_HEAD_DIM], pl.BF16], k_cache_idx: pl.Tensor[[CACHE_ROWS_IDX, INDEX_HEAD_DIM], pl.BF16], seq_lens: pl.Tensor[[BATCH], pl.INT32], - idx_init: pl.Tensor[[1, MAX_SEQ], pl.UINT32], + idx_init: pl.Tensor[[1, SORT_LEN], pl.UINT32], topk_vals_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.FP32], topk_idx_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.INT32], ) -> tuple[ @@ -89,19 +92,20 @@ def deepseek_v3_2_decode_front_scope3( ) # Transient GM buffers. - scores = pl.create_tensor([BATCH, MAX_SEQ], dtype=pl.FP32) - sorted_gm = pl.create_tensor([BATCH, 2 * MAX_SEQ], dtype=pl.FP32) + # scores is [BATCH, SORT_LEN]: Stage 0 fills the full row with -inf + # so the [MAX_SEQ, SORT_LEN) tail is always -inf for the sort. + scores = pl.create_tensor([BATCH, SORT_LEN], dtype=pl.FP32) + sorted_gm = pl.create_tensor([BATCH, 2 * SORT_LEN], dtype=pl.FP32) raw_idx_gm = pl.create_tensor([BATCH, INDEX_TOPK], dtype=pl.INT32) for b in pl.range(0, BATCH, 1): ctx_len = pl.tensor.read(seq_lens, [b]) ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE - # Stage 0: pre-fill scores[b] with -inf. Stage 2's parallel - # loop only covers [0, ctx_blocks), so untouched tail slots - # keep the create_tensor default of 0.0 without this. + # Stage 0: pre-fill scores[b, 0:SORT_LEN] with -inf so both the + # untouched ctx tail and the [MAX_SEQ, SORT_LEN) pad are -inf. with pl.at(level=pl.Level.CORE_GROUP): - neg_inf_row = pl.full([1, MAX_SEQ], dtype=pl.FP32, value=FP32_NEG_INF) + neg_inf_row = pl.full([1, SORT_LEN], dtype=pl.FP32, value=FP32_NEG_INF) scores = pl.assemble(scores, neg_inf_row, [b, 0]) # Stage 1: tiled QK matmul into all_scores[sb*Q_PAD, 0] (row 0 valid). @@ -133,43 +137,32 @@ def deepseek_v3_2_decode_front_scope3( tile_padded = pl.fillpad(tile_valid, pad_value=pl.PadValue.min) scores = pl.assemble(scores, tile_padded, [b, s0]) - # Stage 3: sort32 + 4 mrgsort iterations. Result is [1, 2*MAX_SEQ] - # interleaved (val, idx). Stored to GM because vec→vec tile slicing - # is not supported on a2a3, so downstream gather re-loads from GM. - # Distinct SSA names for each mrgsort so the final tile is the one - # that gets stored. + # Stage 3: sort32 + 4 mrgsort iterations (tensor-level). Operates + # directly on GM slices; result is [1, 2*SORT_LEN] interleaved + # (val, idx). Stored to sorted_gm for gather in Stage 4. with pl.at(level=pl.Level.CORE_GROUP): - score_row = pl.load(scores, offsets=[b, 0], shapes=[1, MAX_SEQ]) - idx_row = pl.load(idx_init, offsets=[0, 0], shapes=[1, MAX_SEQ]) - sorted0 = pl.tile.sort32(score_row, idx_row) - sorted1 = pl.tile.mrgsort(sorted0, block_len=64) - sorted2 = pl.tile.mrgsort(sorted1, block_len=256) - sorted3 = pl.tile.mrgsort(sorted2, block_len=1024) - sorted4 = pl.tile.mrgsort(sorted3, block_len=4096) - sorted_gm = pl.store( - sorted4, offsets=[b, 0], output_tensor=sorted_gm - ) - - # Stage 4: GM-load the first INDEX_TOPK pairs (2*INDEX_TOPK cols), - # gather P0101/P1010 to split vals / idx bits, store to outputs. + score_row = pl.slice(scores, [1, SORT_LEN], [b, 0]) + sorted_t = pl.tensor.sort32(score_row, idx_init) + sorted_t = pl.tensor.mrgsort(sorted_t, block_len=64) + sorted_t = pl.tensor.mrgsort(sorted_t, block_len=256) + sorted_t = pl.tensor.mrgsort(sorted_t, block_len=1024) + sorted_t = pl.tensor.mrgsort(sorted_t, block_len=4096) + sorted_gm = pl.assemble(sorted_gm, sorted_t, [b, 0]) + + # Stage 4: gather P0101/P1010 to split vals / idx bits from the + # first INDEX_TOPK pairs (2*INDEX_TOPK cols) in sorted_gm. with pl.at(level=pl.Level.CORE_GROUP): - topk_pairs = pl.load( - sorted_gm, offsets=[b, 0], shapes=[1, 2 * INDEX_TOPK] - ) - topk_v = pl.tile.gather( + topk_pairs = pl.slice(sorted_gm, [1, 2 * INDEX_TOPK], [b, 0]) + topk_v = pl.tensor.gather( topk_pairs, mask_pattern=pl.tile.MaskPattern.P0101 ) - topk_i_raw = pl.tile.gather( + topk_i_raw = pl.tensor.gather( topk_pairs, mask_pattern=pl.tile.MaskPattern.P1010, output_dtype=pl.INT32, ) - topk_vals_out = pl.store( - topk_v, offsets=[b, 0], output_tensor=topk_vals_out - ) - raw_idx_gm = pl.store( - topk_i_raw, offsets=[b, 0], output_tensor=raw_idx_gm - ) + topk_vals_out = pl.assemble(topk_vals_out, topk_v, [b, 0]) + raw_idx_gm = pl.assemble(raw_idx_gm, topk_i_raw, [b, 0]) # Stage 5: GM reload + valid_shape fillpad to mark idx slots past # ctx_len with PadValue.min (= INT32_MIN < 0). @@ -198,7 +191,7 @@ def golden_decode_front_scope3(tensors): topk_vals_out = tensors["topk_vals_out"] topk_idx_out = tensors["topk_idx_out"] - scores = torch.full((BATCH, MAX_SEQ), FP32_NEG_INF, dtype=torch.float32) + scores = torch.full((BATCH, SORT_LEN), FP32_NEG_INF, dtype=torch.float32) for b in range(BATCH): ctx_len = int(seq_lens[b].item()) q_b = q_idx[b : b + 1] @@ -233,7 +226,7 @@ def init_k_cache_idx(): return torch.rand(CACHE_ROWS_IDX, INDEX_HEAD_DIM) - 0.5 def init_idx_init(): - return torch.arange(MAX_SEQ, dtype=torch.int32).unsqueeze(0) + return torch.arange(SORT_LEN, dtype=torch.int32).unsqueeze(0) def init_topk_vals_out(): return torch.zeros((BATCH, INDEX_TOPK), dtype=torch.float32) @@ -250,7 +243,7 @@ def init_topk_idx_out(): init_value=init_k_cache_idx, ), TensorSpec("seq_lens", [BATCH], torch.int32, init_value=seq_lens_data), - TensorSpec("idx_init", [1, MAX_SEQ], torch.int32, init_value=init_idx_init), + TensorSpec("idx_init", [1, SORT_LEN], torch.int32, init_value=init_idx_init), TensorSpec( "topk_vals_out", [BATCH, INDEX_TOPK],