Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 34 additions & 41 deletions examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


BATCH = 16
MAX_SEQ = 8192
MAX_SEQ = 4096
INDEX_HEAD_DIM = 128
INDEX_TOPK = 2048
CACHE_ROWS_IDX = BATCH * MAX_SEQ
Expand All @@ -45,7 +45,10 @@
Q_VALID = 1
Q_PAD = 16

# sort32 + 4 mrgsort iterations (block_len 64,256,1024,4096) sort MAX_SEQ=8192.
# sort32 + 4 mrgsort iterations (block_len 64,256,1024,4096) sort SORT_LEN=8192.
# SORT_LEN > MAX_SEQ so the full sort buffer is pre-filled with -inf (Stage 0)
# and only [0, ctx_len) contains real scores; tail stays -inf.
SORT_LEN = 8192
MRGSORT_ITERS = 4

# -inf sentinel for score tail. FP32 lowest, since ptoas rejects literal -inf.
Expand All @@ -61,7 +64,7 @@ def deepseek_v3_2_decode_front_scope3(
q_idx: pl.Tensor[[BATCH, INDEX_HEAD_DIM], pl.BF16],
k_cache_idx: pl.Tensor[[CACHE_ROWS_IDX, INDEX_HEAD_DIM], pl.BF16],
seq_lens: pl.Tensor[[BATCH], pl.INT32],
idx_init: pl.Tensor[[1, MAX_SEQ], pl.UINT32],
idx_init: pl.Tensor[[1, SORT_LEN], pl.UINT32],
topk_vals_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.FP32],
topk_idx_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.INT32],
) -> tuple[
Expand Down Expand Up @@ -89,19 +92,20 @@ def deepseek_v3_2_decode_front_scope3(
)

# Transient GM buffers.
scores = pl.create_tensor([BATCH, MAX_SEQ], dtype=pl.FP32)
sorted_gm = pl.create_tensor([BATCH, 2 * MAX_SEQ], dtype=pl.FP32)
# scores is [BATCH, SORT_LEN]: Stage 0 fills the full row with -inf
# so the [MAX_SEQ, SORT_LEN) tail is always -inf for the sort.
scores = pl.create_tensor([BATCH, SORT_LEN], dtype=pl.FP32)
sorted_gm = pl.create_tensor([BATCH, 2 * SORT_LEN], dtype=pl.FP32)
raw_idx_gm = pl.create_tensor([BATCH, INDEX_TOPK], dtype=pl.INT32)

for b in pl.range(0, BATCH, 1):
ctx_len = pl.tensor.read(seq_lens, [b])
ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE

# Stage 0: pre-fill scores[b] with -inf. Stage 2's parallel
# loop only covers [0, ctx_blocks), so untouched tail slots
# keep the create_tensor default of 0.0 without this.
# Stage 0: pre-fill scores[b, 0:SORT_LEN] with -inf so both the
# untouched ctx tail and the [MAX_SEQ, SORT_LEN) pad are -inf.
with pl.at(level=pl.Level.CORE_GROUP):
neg_inf_row = pl.full([1, MAX_SEQ], dtype=pl.FP32, value=FP32_NEG_INF)
neg_inf_row = pl.full([1, SORT_LEN], dtype=pl.FP32, value=FP32_NEG_INF)
scores = pl.assemble(scores, neg_inf_row, [b, 0])

# Stage 1: tiled QK matmul into all_scores[sb*Q_PAD, 0] (row 0 valid).
Expand Down Expand Up @@ -133,43 +137,32 @@ def deepseek_v3_2_decode_front_scope3(
tile_padded = pl.fillpad(tile_valid, pad_value=pl.PadValue.min)
scores = pl.assemble(scores, tile_padded, [b, s0])

# Stage 3: sort32 + 4 mrgsort iterations. Result is [1, 2*MAX_SEQ]
# interleaved (val, idx). Stored to GM because vec→vec tile slicing
# is not supported on a2a3, so downstream gather re-loads from GM.
# Distinct SSA names for each mrgsort so the final tile is the one
# that gets stored.
# Stage 3: sort32 + 4 mrgsort iterations (tensor-level). Operates
# directly on GM slices; result is [1, 2*SORT_LEN] interleaved
# (val, idx). Stored to sorted_gm for gather in Stage 4.
with pl.at(level=pl.Level.CORE_GROUP):
score_row = pl.load(scores, offsets=[b, 0], shapes=[1, MAX_SEQ])
idx_row = pl.load(idx_init, offsets=[0, 0], shapes=[1, MAX_SEQ])
sorted0 = pl.tile.sort32(score_row, idx_row)
sorted1 = pl.tile.mrgsort(sorted0, block_len=64)
sorted2 = pl.tile.mrgsort(sorted1, block_len=256)
sorted3 = pl.tile.mrgsort(sorted2, block_len=1024)
sorted4 = pl.tile.mrgsort(sorted3, block_len=4096)
sorted_gm = pl.store(
sorted4, offsets=[b, 0], output_tensor=sorted_gm
)

# Stage 4: GM-load the first INDEX_TOPK pairs (2*INDEX_TOPK cols),
# gather P0101/P1010 to split vals / idx bits, store to outputs.
score_row = pl.slice(scores, [1, SORT_LEN], [b, 0])
sorted_t = pl.tensor.sort32(score_row, idx_init)
sorted_t = pl.tensor.mrgsort(sorted_t, block_len=64)
sorted_t = pl.tensor.mrgsort(sorted_t, block_len=256)
sorted_t = pl.tensor.mrgsort(sorted_t, block_len=1024)
sorted_t = pl.tensor.mrgsort(sorted_t, block_len=4096)
sorted_gm = pl.assemble(sorted_gm, sorted_t, [b, 0])

# Stage 4: gather P0101/P1010 to split vals / idx bits from the
# first INDEX_TOPK pairs (2*INDEX_TOPK cols) in sorted_gm.
with pl.at(level=pl.Level.CORE_GROUP):
topk_pairs = pl.load(
sorted_gm, offsets=[b, 0], shapes=[1, 2 * INDEX_TOPK]
)
topk_v = pl.tile.gather(
topk_pairs = pl.slice(sorted_gm, [1, 2 * INDEX_TOPK], [b, 0])
topk_v = pl.tensor.gather(
topk_pairs, mask_pattern=pl.tile.MaskPattern.P0101
)
topk_i_raw = pl.tile.gather(
topk_i_raw = pl.tensor.gather(
topk_pairs,
mask_pattern=pl.tile.MaskPattern.P1010,
output_dtype=pl.INT32,
)
topk_vals_out = pl.store(
topk_v, offsets=[b, 0], output_tensor=topk_vals_out
)
raw_idx_gm = pl.store(
topk_i_raw, offsets=[b, 0], output_tensor=raw_idx_gm
)
topk_vals_out = pl.assemble(topk_vals_out, topk_v, [b, 0])
raw_idx_gm = pl.assemble(raw_idx_gm, topk_i_raw, [b, 0])

# Stage 5: GM reload + valid_shape fillpad to mark idx slots past
# ctx_len with PadValue.min (= INT32_MIN < 0).
Expand Down Expand Up @@ -198,7 +191,7 @@ def golden_decode_front_scope3(tensors):
topk_vals_out = tensors["topk_vals_out"]
topk_idx_out = tensors["topk_idx_out"]

scores = torch.full((BATCH, MAX_SEQ), FP32_NEG_INF, dtype=torch.float32)
scores = torch.full((BATCH, SORT_LEN), FP32_NEG_INF, dtype=torch.float32)
for b in range(BATCH):
ctx_len = int(seq_lens[b].item())
q_b = q_idx[b : b + 1]
Expand Down Expand Up @@ -233,7 +226,7 @@ def init_k_cache_idx():
return torch.rand(CACHE_ROWS_IDX, INDEX_HEAD_DIM) - 0.5

def init_idx_init():
return torch.arange(MAX_SEQ, dtype=torch.int32).unsqueeze(0)
return torch.arange(SORT_LEN, dtype=torch.int32).unsqueeze(0)

def init_topk_vals_out():
return torch.zeros((BATCH, INDEX_TOPK), dtype=torch.float32)
Expand All @@ -250,7 +243,7 @@ def init_topk_idx_out():
init_value=init_k_cache_idx,
),
TensorSpec("seq_lens", [BATCH], torch.int32, init_value=seq_lens_data),
TensorSpec("idx_init", [1, MAX_SEQ], torch.int32, init_value=init_idx_init),
TensorSpec("idx_init", [1, SORT_LEN], torch.int32, init_value=init_idx_init),
TensorSpec(
"topk_vals_out",
[BATCH, INDEX_TOPK],
Expand Down
Loading