From f54a3bd63238afa9e88ba2934f7a2fa185b76507 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 21 Apr 2026 10:02:41 +0800 Subject: [PATCH] Update: replace decode-front with ds32exp and add scope3 - Replace deepseek_v3_2_decode_front.py with ds32exp.py from ds32 branch: reorganised into 4 explicit scopes (qkv proj, indexer proj, score+topk, sparse MQA dispatch), adds Hadamard/FP8 placeholders, k_cache_idx write, and per-head weighted q_idx aggregation - Add deepseek_v3_2_decode_front_scope3.py (scope3.py from ds32): standalone scope covering score+topk via tiled QK matmul, sort32+mrgsort, and gather to produce topk_vals_out / topk_idx_out --- .../deepseek_v3_2_decode_front.py | 1358 ++++++++++------- .../deepseek_v3_2_decode_front_scope3.py | 305 ++++ 2 files changed, 1146 insertions(+), 517 deletions(-) create mode 100644 examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py index 0736727..2c5906b 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front.py @@ -6,30 +6,30 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -from __future__ import annotations - """ -DeepSeek V3.2-EXP single-layer decode FRONT part (batch=16, max_seq=4096). - -This version is aligned to official v3.2-exp MLA shapes: -- qk_nope_head_dim = 128 -- qk_rope_head_dim = 64 -- v_head_dim = 128 -- kv_lora_rank = 512 -- index_topk = 2048 - -FRONT boundary: -- run pre-attention path (RMSNorm + MLA projections + cache update) -- apply sparse attention by index_topk positions (DSA abstraction) -- write dispatch tensor into cross-node GM tensor and return - -Note: -- official indexer module is abstracted as external `index_topk_pos` input. -- dispatch payload uses attention output width `NUM_HEADS * V_HEAD_DIM`. +DeepSeek V3.2-EXP single-layer decode FRONT, organised as 4 scopes. + + Scope 1: qkv proj + qkv rope + input RMSNorm, wq_a / q_norm / wq_b, wkv_a, q_pe RoPE, + kv_norm + k_pe RoPE, write kv_cache & pe_cache. + + Scope 2: indexer proj + indexer rope + wq_b_idx (qr -> 64-head index query), wk_idx + LayerNorm, + non-interleaved RoPE on q_pe / k_pe halves, + weights_proj (per-head head weights), + TODO(hadamard_transform) + TODO(fp8 quant) placeholders, + weighted-sum aggregation -> q_idx [B, INDEX_HEAD_DIM], + write k_cache_idx. + + Scope 3: score + topk + q_idx x k_cache_idx tiled matmul (matches scope2b), + TODO(topk) placeholder producing topk_idx [B, INDEX_TOPK]. + + Scope 4: post topk + sparse MQA over topk positions in (kv_cache, pe_cache), + project latent -> v, dispatch write to cross-node buffer. """ -from typing import Optional - import pypto.language as pl @@ -44,461 +44,582 @@ QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM V_HEAD_DIM = 128 ATTN_OUT = NUM_HEADS * V_HEAD_DIM + +# Indexer (per ds32exp_official ModelArgs). INDEX_HEADS = 64 +INDEX_HEAD_DIM = 128 INDEX_TOPK = 2048 -EP_NODES = 128 # configurable + +EP_NODES = 128 EPS = 1e-6 -ATTN_SCALE = 1.0 / (QK_HEAD_DIM**0.5) HIDDEN_INV = 1.0 / HIDDEN +ATTN_SCALE = 1.0 / (QK_HEAD_DIM ** 0.5) +INDEX_SOFTMAX_SCALE = 1.0 / (INDEX_HEAD_DIM ** 0.5) +INDEX_HEADS_INV_SQRT = INDEX_HEADS ** -0.5 +HADAMARD_SCALE = INDEX_HEAD_DIM ** -0.5 + +KV_A_OUT = KV_LORA_RANK + QK_ROPE_HEAD_DIM +CACHE_ROWS = BATCH * MAX_SEQ +HALF_ROPE = QK_ROPE_HEAD_DIM // 2 +HALF_INDEX_ROPE = QK_ROPE_HEAD_DIM // 2 # indexer reuses the same rope dim +# Tiling / chunking. K_CHUNK = 512 +LORA_CHUNK = 128 Q_OUT_CHUNK = 512 KV_OUT_CHUNK = 128 -LORA_CHUNK = 128 V_OUT_CHUNK = 64 -SEQ_TILE = 120 +IDX_OUT_CHUNK = 128 +WEIGHTS_OUT_CHUNK = 64 BATCH_TILE = 4 -# Extra local pad tensor width to raise explicit Vec occupancy in memory report. -LOCAL_PAD_WIDTH = 16384 - -# Conservative software guard for AIV Vec/UB working set (bytes). This helps -# keep source-side tile settings near practical limits without overshooting. -UB_SOFT_LIMIT_BYTES = 160 * 1024 - - -def build_deepseek_v3_2_decode_front_program( - batch: int = BATCH, - max_seq_len: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - q_lora_rank: int = Q_LORA_RANK, - kv_lora_rank: int = KV_LORA_RANK, - qk_nope_head_dim: int = QK_NOPE_HEAD_DIM, - qk_rope_head_dim: int = QK_ROPE_HEAD_DIM, - v_head_dim: int = V_HEAD_DIM, - index_heads: int = INDEX_HEADS, - index_topk: int = INDEX_TOPK, - ep_nodes: int = EP_NODES, -): - BATCH_CFG = batch - MAX_SEQ_CFG = max_seq_len - HIDDEN_CFG = hidden_size - NUM_HEADS_CFG = num_heads - Q_LORA_RANK_CFG = q_lora_rank - KV_LORA_RANK_CFG = kv_lora_rank - QK_NOPE_HEAD_DIM_CFG = qk_nope_head_dim - QK_ROPE_HEAD_DIM_CFG = qk_rope_head_dim - QK_HEAD_DIM_CFG = qk_nope_head_dim + qk_rope_head_dim - V_HEAD_DIM_CFG = v_head_dim - INDEX_HEADS_CFG = index_heads - ATTN_OUT_CFG = num_heads * v_head_dim - INDEX_TOPK_CFG = index_topk - EP_NODES_CFG = ep_nodes - - HIDDEN_BLOCKS = (HIDDEN_CFG + K_CHUNK - 1) // K_CHUNK - QR_BLOCKS = (Q_LORA_RANK_CFG + LORA_CHUNK - 1) // LORA_CHUNK - Q_OUT_BLOCKS = (NUM_HEADS_CFG * QK_HEAD_DIM_CFG + Q_OUT_CHUNK - 1) // Q_OUT_CHUNK - KV_A_OUT = KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG - KV_A_BLOCKS = (KV_A_OUT + KV_OUT_CHUNK - 1) // KV_OUT_CHUNK - CACHE_ROWS = BATCH_CFG * MAX_SEQ_CFG - V_OUT_BLOCKS = (V_HEAD_DIM_CFG + V_OUT_CHUNK - 1) // V_OUT_CHUNK - - # Capacity-oriented source tuning guard: - # - stage1_est_bytes models dominant projection-side tile tensors. - # - stage2_est_bytes models topk buffers + major sparse-attention vectors. - stage1_est_bytes = ( - BATCH_TILE * K_CHUNK * 4 - + BATCH_TILE * LORA_CHUNK * 4 - + BATCH_TILE * Q_OUT_CHUNK * 4 - + BATCH_TILE * KV_OUT_CHUNK * 4 - + BATCH_TILE * LOCAL_PAD_WIDTH * 2 - ) - stage2_est_bytes = ( - (1 + 2) * INDEX_TOPK_CFG * 4 # topk vals + blk topk vals - + (1 + 2) * INDEX_TOPK_CFG * 4 # topk idx + blk topk idx - + KV_LORA_RANK_CFG * 4 # oi/ctx_latent dominant row vectors - + QK_ROPE_HEAD_DIM_CFG * 4 # q/pe rope vectors - + V_HEAD_DIM_CFG * 4 # ctx_v - ) - peak_est_bytes = max(stage1_est_bytes, stage2_est_bytes) - if peak_est_bytes > UB_SOFT_LIMIT_BYTES: - raise ValueError( - f"Estimated local working set {peak_est_bytes} bytes exceeds " - f"UB soft limit {UB_SOFT_LIMIT_BYTES} bytes. " - "Reduce BATCH_TILE/Q_OUT_CHUNK/K_CHUNK/KV_OUT_CHUNK/LORA_CHUNK." - ) +SEQ_TILE = 64 +HIDDEN_BLOCKS = HIDDEN // K_CHUNK +QR_BLOCKS = Q_LORA_RANK // LORA_CHUNK +Q_OUT_BLOCKS = (NUM_HEADS * QK_HEAD_DIM) // Q_OUT_CHUNK +KV_A_BLOCKS = KV_A_OUT // KV_OUT_CHUNK +IDX_OUT_BLOCKS = (INDEX_HEADS * INDEX_HEAD_DIM) // IDX_OUT_CHUNK +WK_OUT_BLOCKS = INDEX_HEAD_DIM // KV_OUT_CHUNK +V_OUT_BLOCKS = V_HEAD_DIM // V_OUT_CHUNK +MAX_SEQ_BLOCKS = MAX_SEQ // SEQ_TILE + + +def build_ds32exp_program(): @pl.program - class DeepSeekV32DecodeFront: + class Ds32Exp: @pl.function(type=pl.FunctionType.Opaque) - def deepseek_v3_2_decode_front_layer( + def ds32exp_decode_front( self, - hidden_states: pl.Tensor[[BATCH_CFG, HIDDEN_CFG], pl.BF16], - seq_lens: pl.Tensor[[BATCH_CFG], pl.INT32], + hidden_states: pl.Tensor[[BATCH, HIDDEN], pl.BF16], + seq_lens: pl.Tensor[[BATCH], pl.INT32], layer_id_t: pl.Tensor[[1], pl.INT32], - rope_cos: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], - rope_sin: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], - kv_cache: pl.Tensor[[CACHE_ROWS, KV_LORA_RANK_CFG], pl.BF16], - pe_cache: pl.Tensor[[CACHE_ROWS, QK_ROPE_HEAD_DIM_CFG], pl.BF16], - input_rms_weight: pl.Tensor[[1, HIDDEN_CFG], pl.FP32], - wq_a: pl.Tensor[[HIDDEN_CFG, Q_LORA_RANK_CFG], pl.BF16], - q_norm_weight: pl.Tensor[[1, Q_LORA_RANK_CFG], pl.FP32], - wq_b: pl.Tensor[[Q_LORA_RANK_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16], - wkv_a: pl.Tensor[[HIDDEN_CFG, KV_A_OUT], pl.BF16], - kv_norm_weight: pl.Tensor[[1, KV_LORA_RANK_CFG], pl.FP32], - w_q_nope_to_latent: pl.Tensor[[NUM_HEADS_CFG, QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], pl.BF16], - w_latent_to_v: pl.Tensor[[NUM_HEADS_CFG, KV_LORA_RANK_CFG, V_HEAD_DIM_CFG], pl.BF16], - # FRONT output: cross-node dispatch buffer - dispatch_buf: pl.Tensor[[EP_NODES_CFG, BATCH_CFG, ATTN_OUT_CFG], pl.BF16], - ) -> pl.Tensor[[EP_NODES_CFG, BATCH_CFG, ATTN_OUT_CFG], pl.BF16]: - # Scope 1: input RMSNorm + Q/K/V projection. - qr = pl.create_tensor([BATCH_CFG, Q_LORA_RANK_CFG], dtype=pl.BF16) - q_proj = pl.create_tensor([BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16) - kv_a = pl.create_tensor([BATCH_CFG, KV_A_OUT], dtype=pl.BF16) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) - sq_sum = pl.mul(sq_sum, 0) - # Keep an explicit local Vec pad tensor alive in this scope so - # AllocateMemoryAddr reflects a high-occupancy tuning point. - usage_pad = pl.create_tensor([BATCH_TILE, LOCAL_PAD_WIDTH], dtype=pl.BF16) - usage_pad = pl.mul(usage_pad, 0) - usage_pad_fp = pl.cast(usage_pad, target_type=pl.FP32) - usage_pad_sum = pl.row_sum(usage_pad_fp) - + rope_cos: pl.Tensor[[MAX_SEQ, QK_ROPE_HEAD_DIM], pl.FP32], + rope_sin: pl.Tensor[[MAX_SEQ, QK_ROPE_HEAD_DIM], pl.FP32], + kv_cache: pl.Tensor[[CACHE_ROWS, KV_LORA_RANK], pl.BF16], + pe_cache: pl.Tensor[[CACHE_ROWS, QK_ROPE_HEAD_DIM], pl.BF16], + k_cache_idx: pl.Tensor[[CACHE_ROWS, INDEX_HEAD_DIM], pl.BF16], + input_rms_weight: pl.Tensor[[1, HIDDEN], pl.FP32], + wq_a: pl.Tensor[[HIDDEN, Q_LORA_RANK], pl.BF16], + q_norm_weight: pl.Tensor[[1, Q_LORA_RANK], pl.FP32], + wq_b: pl.Tensor[[Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM], pl.BF16], + wkv_a: pl.Tensor[[HIDDEN, KV_A_OUT], pl.BF16], + kv_norm_weight: pl.Tensor[[1, KV_LORA_RANK], pl.FP32], + wq_b_idx: pl.Tensor[[Q_LORA_RANK, INDEX_HEADS * INDEX_HEAD_DIM], pl.BF16], + wk_idx: pl.Tensor[[HIDDEN, INDEX_HEAD_DIM], pl.BF16], + k_norm_weight: pl.Tensor[[1, INDEX_HEAD_DIM], pl.FP32], + k_norm_bias: pl.Tensor[[1, INDEX_HEAD_DIM], pl.FP32], + weights_proj: pl.Tensor[[HIDDEN, INDEX_HEADS], pl.FP32], + w_q_nope_to_latent: pl.Tensor[[NUM_HEADS, QK_NOPE_HEAD_DIM, KV_LORA_RANK], pl.BF16], + w_latent_to_v: pl.Tensor[[NUM_HEADS, KV_LORA_RANK, V_HEAD_DIM], pl.BF16], + dispatch_buf: pl.InOut[pl.Tensor[[EP_NODES, BATCH, ATTN_OUT], pl.BF16]], + ) -> pl.Tensor[[EP_NODES, BATCH, ATTN_OUT], pl.BF16]: + # Cross-scope intermediates. + qr = pl.create_tensor([BATCH, Q_LORA_RANK], dtype=pl.BF16) + q_proj = pl.create_tensor([BATCH, NUM_HEADS * QK_HEAD_DIM], dtype=pl.BF16) + kv_a = pl.create_tensor([BATCH, KV_A_OUT], dtype=pl.BF16) + + # ── Scope 1: qkv proj + qkv rope ── + # Stage 1.1: input RMSNorm (sq_sum -> rsqrt over HIDDEN). + inv_rms = pl.create_tensor([BATCH, 1], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP): + sq_sum = pl.full([BATCH, 1], dtype=pl.FP32, value=0.0) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK x_chunk = pl.cast( - pl.slice(hidden_states, [BATCH_CFG, K_CHUNK], [0, k0]), + pl.slice(hidden_states, [BATCH, K_CHUNK], [0, k0]), target_type=pl.FP32, ) sq_sum = pl.add(sq_sum, pl.row_sum(pl.mul(x_chunk, x_chunk))) + inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) - inv_rms = pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)) - for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): + # Stage 1.2: qr = q_norm(wq_a(normed_x)). + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): inv_rms_tile = pl.slice(inv_rms, [BATCH_TILE, 1], [b0, 0]) - inv_rms_tile = pl.add(inv_rms_tile, pl.mul(usage_pad_sum, 0.0)) - for ob in pl.parallel(0, QR_BLOCKS, 1, chunk=4): + for ob in pl.range(QR_BLOCKS): q0 = ob * LORA_CHUNK - q_acc = pl.create_tensor([BATCH_TILE, LORA_CHUNK], dtype=pl.FP32) - q_acc = pl.mul(q_acc, 0.0) + q_acc = pl.full([BATCH_TILE, LORA_CHUNK], dtype=pl.FP32, value=0.0) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_tile_qr = pl.cast(x_chunk_bf16, target_type=pl.FP32) + x_tile = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.FP32, + ) gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_tile_qr, inv_rms_tile), gamma) + normed = pl.col_expand_mul( + pl.row_expand_mul(x_tile, inv_rms_tile), gamma + ) wq_chunk = pl.slice(wq_a, [K_CHUNK, LORA_CHUNK], [k0, q0]) - q_acc = pl.add(q_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk)) - qr = pl.assemble(qr, pl.cast(q_acc, target_type=pl.BF16), [b0, q0]) + q_acc = pl.add( + q_acc, + pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk, out_dtype=pl.FP32), + ) + q_gamma = pl.slice(q_norm_weight, [1, LORA_CHUNK], [0, q0]) + qn = pl.col_expand_mul(q_acc, q_gamma) + qr = pl.assemble(qr, pl.cast(qn, target_type=pl.BF16), [b0, q0]) - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): + # Stage 1.3: q_proj = qr @ wq_b. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + for ob in pl.range(Q_OUT_BLOCKS): q0 = ob * Q_OUT_CHUNK - q_out_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - q_out_acc = pl.mul(q_out_acc, 0.0) + q_out_acc = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0) for kb in pl.range(QR_BLOCKS): k0 = kb * LORA_CHUNK - q_chunk = pl.cast( - pl.slice(qr, [BATCH_TILE, LORA_CHUNK], [b0, k0]), target_type=pl.FP32 - ) - q_gamma = pl.slice(q_norm_weight, [1, LORA_CHUNK], [0, k0]) - qn = pl.col_expand_mul(q_chunk, q_gamma) + qr_chunk = pl.slice(qr, [BATCH_TILE, LORA_CHUNK], [b0, k0]) wq_b_chunk = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [k0, q0]) q_out_acc = pl.add( - q_out_acc, pl.matmul(pl.cast(qn, target_type=pl.BF16), wq_b_chunk) + q_out_acc, + pl.matmul(qr_chunk, wq_b_chunk, out_dtype=pl.FP32), ) - q_proj = pl.assemble(q_proj, pl.cast(q_out_acc, target_type=pl.BF16), [b0, q0]) + q_proj = pl.assemble( + q_proj, pl.cast(q_out_acc, target_type=pl.BF16), [b0, q0] + ) - for ob in pl.parallel(0, KV_A_BLOCKS, 1, chunk=8): + # Stage 1.4: kv_a = wkv_a(normed_x). + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + inv_rms_tile = pl.slice(inv_rms, [BATCH_TILE, 1], [b0, 0]) + for ob in pl.range(KV_A_BLOCKS): kv0 = ob * KV_OUT_CHUNK - kv_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - kv_acc = pl.mul(kv_acc, 0.0) + kv_acc = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_tile_kv = pl.cast(x_chunk_bf16, target_type=pl.FP32) - gamma_kv = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_tile_kv, inv_rms_tile), gamma_kv) + x_tile = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.FP32, + ) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul( + pl.row_expand_mul(x_tile, inv_rms_tile), gamma + ) wkv_chunk = pl.slice(wkv_a, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) kv_acc = pl.add( - kv_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wkv_chunk) + kv_acc, + pl.matmul(pl.cast(normed, target_type=pl.BF16), wkv_chunk, out_dtype=pl.FP32), ) - kv_a = pl.assemble(kv_a, pl.cast(kv_acc, target_type=pl.BF16), [b0, kv0]) - - # Scope 2: RoPE + cache update + indexer topk + sparse attention. - # Fusion policy (aligned with prefill_front): - # - Stage A/B/C all stay in ONE auto_incore scope. - # - A: current-token cache write - # - B1/B2: two-stage topk (block-local then global merge) - # - C: sparse attention consumes merged topk immediately - # This avoids materializing topk intermediates across kernel boundaries. - attn_front = pl.create_tensor([BATCH_CFG, ATTN_OUT_CFG], dtype=pl.FP32) + kv_a = pl.assemble( + kv_a, pl.cast(kv_acc, target_type=pl.BF16), [b0, kv0] + ) + + # Stage 1.5: q_pe RoPE on every MLA head, k_pe RoPE on kv_a, kv_norm, + # write kv_cache and pe_cache at row b*MAX_SEQ + (seq_lens[b]-1). + # NOTE: official applies interleaved=True for MLA, but the existing + # decode/prefill paths in this repo use the lo/hi half split form + # (see deepseek_v3_2_decode_front.py:241-274). We follow the same + # convention so the cached pe matches the in-tree consumers. with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - layer_id = pl.tensor.read(layer_id_t, [0]) - for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): + for b in pl.parallel(0, BATCH, 1, chunk=BATCH): ctx_len = pl.tensor.read(seq_lens, [b]) pos = ctx_len - 1 - cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) - cos_hi = pl.slice( - rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + cache_row = b * MAX_SEQ + pos + + cos_lo = pl.slice(rope_cos, [1, HALF_ROPE], [pos, 0]) + cos_hi = pl.slice(rope_cos, [1, HALF_ROPE], [pos, HALF_ROPE]) + sin_lo = pl.slice(rope_sin, [1, HALF_ROPE], [pos, 0]) + sin_hi = pl.slice(rope_sin, [1, HALF_ROPE], [pos, HALF_ROPE]) + + # MLA q_pe RoPE: rotate every head's pe half in-place. + for h in pl.range(NUM_HEADS): + q_col = h * QK_HEAD_DIM + q_lo = pl.cast( + pl.slice(q_proj, [1, HALF_ROPE], [b, q_col + QK_NOPE_HEAD_DIM]), + target_type=pl.FP32, + ) + q_hi = pl.cast( + pl.slice( + q_proj, [1, HALF_ROPE], + [b, q_col + QK_NOPE_HEAD_DIM + HALF_ROPE], + ), + target_type=pl.FP32, + ) + rot_lo = pl.sub( + pl.col_expand_mul(q_lo, cos_lo), + pl.col_expand_mul(q_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(q_hi, cos_hi), + pl.col_expand_mul(q_lo, sin_hi), + ) + q_proj = pl.assemble( + q_proj, pl.cast(rot_lo, target_type=pl.BF16), + [b, q_col + QK_NOPE_HEAD_DIM], + ) + q_proj = pl.assemble( + q_proj, pl.cast(rot_hi, target_type=pl.BF16), + [b, q_col + QK_NOPE_HEAD_DIM + HALF_ROPE], + ) + + # kv_norm on the latent half then write kv_cache. + kv_row = pl.cast( + pl.slice(kv_a, [1, KV_LORA_RANK], [b, 0]), target_type=pl.FP32 ) - sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) - sin_hi = pl.slice( - rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK], [0, 0]) + kv_normed = pl.col_expand_mul(kv_row, kv_gamma) + kv_cache = pl.assemble( + kv_cache, pl.cast(kv_normed, target_type=pl.BF16), [cache_row, 0] ) - cache_row = b * MAX_SEQ_CFG + pos - kv_row = pl.cast(pl.slice(kv_a, [1, KV_LORA_RANK_CFG], [b, 0]), target_type=pl.FP32) - kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0]) - kv_normed = pl.col_expand_mul(kv_row, kv_gamma) + # k_pe RoPE on the rope half then write pe_cache. pe_lo = pl.cast( - pl.slice(kv_a, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, KV_LORA_RANK_CFG]), + pl.slice(kv_a, [1, HALF_ROPE], [b, KV_LORA_RANK]), target_type=pl.FP32, ) pe_hi = pl.cast( - pl.slice( - kv_a, - [1, QK_ROPE_HEAD_DIM_CFG // 2], - [b, KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG // 2], - ), + pl.slice(kv_a, [1, HALF_ROPE], [b, KV_LORA_RANK + HALF_ROPE]), target_type=pl.FP32, ) - pe_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - pe_lo_cos = pl.col_expand_mul(pe_lo, cos_lo) - pe_hi_sin = pl.col_expand_mul(pe_hi, sin_lo) - pe_rot_lo = pl.sub(pe_lo_cos, pe_hi_sin) - pe_rot = pl.assemble(pe_rot, pe_rot_lo, [0, 0]) - pe_hi_cos = pl.col_expand_mul(pe_hi, cos_hi) - pe_lo_sin = pl.col_expand_mul(pe_lo, sin_hi) - pe_rot_hi = pl.add(pe_hi_cos, pe_lo_sin) - pe_rot = pl.assemble(pe_rot, pe_rot_hi, [0, QK_ROPE_HEAD_DIM_CFG // 2]) - kv_cache = pl.assemble(kv_cache, pl.cast(kv_normed, target_type=pl.BF16), [cache_row, 0]) - pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot, target_type=pl.BF16), [cache_row, 0]) - - # Stage B1: block-local topk (2 blocks, each 2K candidates). - topk_vals = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.FP32) - topk_idx = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.INT32) - blk_topk_vals = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.FP32) - blk_topk_idx = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.INT32) - topk_vals = pl.mul(topk_vals, -3.402823e38) - topk_idx = pl.cast(pl.mul(topk_idx, 0), target_type=pl.INT32) - blk_topk_vals = pl.mul(blk_topk_vals, -3.402823e38) - blk_topk_idx = pl.cast(pl.mul(blk_topk_idx, 0), target_type=pl.INT32) - for kk in pl.range(INDEX_TOPK_CFG): - neg_one = pl.create_tensor([1, 1], dtype=pl.INT32) - neg_one = pl.cast(pl.mul(neg_one, 0), target_type=pl.INT32) - neg_one = pl.cast(pl.add(neg_one, -1), target_type=pl.INT32) - topk_idx = pl.assemble(topk_idx, neg_one, [0, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [0, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [1, kk]) - - q_col0 = 0 - q_nope0 = pl.cast( - pl.slice(q_proj, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col0]), + pe_rot_lo = pl.sub( + pl.col_expand_mul(pe_lo, cos_lo), + pl.col_expand_mul(pe_hi, sin_lo), + ) + pe_rot_hi = pl.add( + pl.col_expand_mul(pe_hi, cos_hi), + pl.col_expand_mul(pe_lo, sin_hi), + ) + pe_cache = pl.assemble( + pe_cache, pl.cast(pe_rot_lo, target_type=pl.BF16), [cache_row, 0] + ) + pe_cache = pl.assemble( + pe_cache, pl.cast(pe_rot_hi, target_type=pl.BF16), + [cache_row, HALF_ROPE], + ) + + # ── Scope 2: indexer proj + indexer rope ── + # Cross-stage indexer intermediates. + q_idx_full = pl.create_tensor( + [BATCH, INDEX_HEADS * INDEX_HEAD_DIM], dtype=pl.BF16 + ) + k_idx = pl.create_tensor([BATCH, INDEX_HEAD_DIM], dtype=pl.BF16) + weights = pl.create_tensor([BATCH, INDEX_HEADS], dtype=pl.FP32) + q_idx = pl.create_tensor([BATCH, INDEX_HEAD_DIM], dtype=pl.BF16) + + # Stage 2.1: q_idx_full = qr @ wq_b_idx -> [B, INDEX_HEADS * INDEX_HEAD_DIM]. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + for ob in pl.range(IDX_OUT_BLOCKS): + q0 = ob * IDX_OUT_CHUNK + idx_acc = pl.full([BATCH_TILE, IDX_OUT_CHUNK], dtype=pl.FP32, value=0.0) + for kb in pl.range(QR_BLOCKS): + k0 = kb * LORA_CHUNK + qr_chunk = pl.slice(qr, [BATCH_TILE, LORA_CHUNK], [b0, k0]) + wq_b_idx_chunk = pl.slice( + wq_b_idx, [LORA_CHUNK, IDX_OUT_CHUNK], [k0, q0] + ) + idx_acc = pl.add( + idx_acc, + pl.matmul(qr_chunk, wq_b_idx_chunk, out_dtype=pl.FP32), + ) + q_idx_full = pl.assemble( + q_idx_full, pl.cast(idx_acc, target_type=pl.BF16), [b0, q0] + ) + + # Stage 2.2: k_idx_pre = hidden_states @ wk_idx -> [B, INDEX_HEAD_DIM] BF16. + # Then LayerNorm with k_norm_weight / k_norm_bias. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + for ob in pl.range(WK_OUT_BLOCKS): + kv0 = ob * KV_OUT_CHUNK + wk_acc = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_tile = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) + wk_chunk = pl.slice(wk_idx, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + wk_acc = pl.add( + wk_acc, + pl.matmul(x_tile, wk_chunk, out_dtype=pl.FP32), + ) + k_idx = pl.assemble( + k_idx, pl.cast(wk_acc, target_type=pl.BF16), [b0, kv0] + ) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + # LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta. + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + k_tile_fp32 = pl.cast( + pl.slice(k_idx, [BATCH_TILE, INDEX_HEAD_DIM], [b0, 0]), target_type=pl.FP32, ) - q0_lo = pl.cast( - pl.slice(q_proj, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, q_col0 + QK_NOPE_HEAD_DIM_CFG]), + mean = pl.mul(pl.row_sum(k_tile_fp32), 1.0 / INDEX_HEAD_DIM) + centered = pl.row_expand_sub(k_tile_fp32, mean) + var = pl.mul( + pl.row_sum(pl.mul(centered, centered)), 1.0 / INDEX_HEAD_DIM + ) + inv_std = pl.recip(pl.sqrt(pl.add(var, EPS))) + normed = pl.row_expand_mul(centered, inv_std) + gamma = pl.slice(k_norm_weight, [1, INDEX_HEAD_DIM], [0, 0]) + beta = pl.slice(k_norm_bias, [1, INDEX_HEAD_DIM], [0, 0]) + y = pl.add(pl.col_expand_mul(normed, gamma), beta) + k_idx = pl.assemble(k_idx, pl.cast(y, target_type=pl.BF16), [b0, 0]) + + # Stage 2.3: non-interleaved RoPE on q_pe (per index head) and k_pe. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH, 1, chunk=BATCH): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + cos_lo = pl.slice(rope_cos, [1, HALF_INDEX_ROPE], [pos, 0]) + cos_hi = pl.slice(rope_cos, [1, HALF_INDEX_ROPE], [pos, HALF_INDEX_ROPE]) + sin_lo = pl.slice(rope_sin, [1, HALF_INDEX_ROPE], [pos, 0]) + sin_hi = pl.slice(rope_sin, [1, HALF_INDEX_ROPE], [pos, HALF_INDEX_ROPE]) + + # q_pe rotation per index head. Indexer uses interleaved=False + # (official line 464); our lo/hi-half implementation matches + # the existing ds_q0_rope.py convention. + for h in pl.range(INDEX_HEADS): + q_col = h * INDEX_HEAD_DIM + q_lo = pl.cast( + pl.slice(q_idx_full, [1, HALF_INDEX_ROPE], [b, q_col]), + target_type=pl.FP32, + ) + q_hi = pl.cast( + pl.slice( + q_idx_full, [1, HALF_INDEX_ROPE], [b, q_col + HALF_INDEX_ROPE] + ), + target_type=pl.FP32, + ) + rot_lo = pl.sub( + pl.col_expand_mul(q_lo, cos_lo), + pl.col_expand_mul(q_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(q_hi, cos_hi), + pl.col_expand_mul(q_lo, sin_hi), + ) + q_idx_full = pl.assemble( + q_idx_full, pl.cast(rot_lo, target_type=pl.BF16), [b, q_col] + ) + q_idx_full = pl.assemble( + q_idx_full, pl.cast(rot_hi, target_type=pl.BF16), + [b, q_col + HALF_INDEX_ROPE], + ) + + # k_pe rotation (single head). + k_lo = pl.cast( + pl.slice(k_idx, [1, HALF_INDEX_ROPE], [b, 0]), target_type=pl.FP32 + ) + k_hi = pl.cast( + pl.slice(k_idx, [1, HALF_INDEX_ROPE], [b, HALF_INDEX_ROPE]), target_type=pl.FP32, ) - q0_hi = pl.cast( + k_rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + k_rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + k_idx = pl.assemble( + k_idx, pl.cast(k_rot_lo, target_type=pl.BF16), [b, 0] + ) + k_idx = pl.assemble( + k_idx, pl.cast(k_rot_hi, target_type=pl.BF16), [b, HALF_INDEX_ROPE] + ) + + # Stage 2.4: TODO(hadamard_transform). + # Official: rotate_activation(q), rotate_activation(k) using + # fast_hadamard_transform with scale = INDEX_HEAD_DIM ** -0.5 + # (ds32exp_official.py:428-432). Hadamard is orthogonal/linear, so + # the eventual scoring still scales by the same factor; here we + # only keep the scaling and leave the orthogonal mixing as a TODO. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + q_tile = pl.cast( pl.slice( - q_proj, - [1, QK_ROPE_HEAD_DIM_CFG // 2], - [b, q_col0 + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2], + q_idx_full, [BATCH_TILE, INDEX_HEADS * INDEX_HEAD_DIM], [b0, 0] ), target_type=pl.FP32, ) - q0_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - q0_rot = pl.assemble( - q0_rot, - pl.sub(pl.col_expand_mul(q0_lo, cos_lo), pl.col_expand_mul(q0_hi, sin_lo)), - [0, 0], + q_tile_scaled = pl.mul(q_tile, HADAMARD_SCALE) + q_idx_full = pl.assemble( + q_idx_full, pl.cast(q_tile_scaled, target_type=pl.BF16), [b0, 0] ) - q0_rot = pl.assemble( - q0_rot, - pl.add(pl.col_expand_mul(q0_hi, cos_hi), pl.col_expand_mul(q0_lo, sin_hi)), - [0, QK_ROPE_HEAD_DIM_CFG // 2], + k_tile = pl.cast( + pl.slice(k_idx, [BATCH_TILE, INDEX_HEAD_DIM], [b0, 0]), + target_type=pl.FP32, + ) + k_tile_scaled = pl.mul(k_tile, HADAMARD_SCALE) + k_idx = pl.assemble( + k_idx, pl.cast(k_tile_scaled, target_type=pl.BF16), [b0, 0] + ) + + # Stage 2.5: TODO(fp8 quant). + # Official: q_fp8, q_scale = act_quant(q, block_size); same for k. + # weights are then multiplied by q_scale (line 479). Placeholder + # below performs a BF16 -> FP8E4M3FN -> BF16 round-trip so the + # numerical loss matches in spirit, and skips the q_scale fold-in. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + q_tile = pl.slice( + q_idx_full, [BATCH_TILE, INDEX_HEADS * INDEX_HEAD_DIM], [b0, 0] ) - q0_nope_latent = pl.matmul( - pl.cast(q_nope0, target_type=pl.BF16), - pl.reshape( + q_fp8 = pl.cast(q_tile, target_type=pl.FP8E4M3FN) + q_back = pl.cast(q_fp8, target_type=pl.BF16) + q_idx_full = pl.assemble(q_idx_full, q_back, [b0, 0]) + k_tile = pl.slice(k_idx, [BATCH_TILE, INDEX_HEAD_DIM], [b0, 0]) + k_fp8 = pl.cast(k_tile, target_type=pl.FP8E4M3FN) + k_back = pl.cast(k_fp8, target_type=pl.BF16) + k_idx = pl.assemble(k_idx, k_back, [b0, 0]) + + # Stage 2.6: weights = (hidden_states.float() @ weights_proj) + # * INDEX_HEADS ** -0.5 * INDEX_SOFTMAX_SCALE. + # TODO(fp8 quant): also multiply by q_scale once Stage 2.5 keeps it. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + w_scale = INDEX_HEADS_INV_SQRT * INDEX_SOFTMAX_SCALE + for b0 in pl.parallel(0, BATCH, BATCH_TILE, chunk=BATCH // BATCH_TILE): + for ob in pl.range(INDEX_HEADS // WEIGHTS_OUT_CHUNK): + w0 = ob * WEIGHTS_OUT_CHUNK + w_acc = pl.full( + [BATCH_TILE, WEIGHTS_OUT_CHUNK], dtype=pl.FP32, value=0.0 + ) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_tile = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.FP32, + ) + wp_chunk = pl.slice( + weights_proj, [K_CHUNK, WEIGHTS_OUT_CHUNK], [k0, w0] + ) + w_acc = pl.add( + w_acc, + pl.matmul( + pl.cast(x_tile, target_type=pl.BF16), + pl.cast(wp_chunk, target_type=pl.BF16), + out_dtype=pl.FP32, + ), + ) + weights = pl.assemble( + weights, pl.mul(w_acc, w_scale), [b0, w0] + ) + + # Stage 2.7: aggregate q_idx[b] = sum_h weights[b,h] * q_idx_full[b, h*HD:(h+1)*HD]. + # See deepseek_v3_2_decode_front_scope2b.py header for the algebraic + # reduction that lets scope3 score with a single [INDEX_HEAD_DIM] vector. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH, 1, chunk=BATCH): + q_acc = pl.full([1, INDEX_HEAD_DIM], dtype=pl.FP32, value=0.0) + for h in pl.range(INDEX_HEADS): + q_h = pl.cast( pl.slice( - w_q_nope_to_latent, [1, QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [0, 0, 0] + q_idx_full, [1, INDEX_HEAD_DIM], [b, h * INDEX_HEAD_DIM] ), - [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], - ), - ) + target_type=pl.FP32, + ) + w_h = pl.slice(weights, [1, 1], [b, h]) + q_acc = pl.add(q_acc, pl.col_expand_mul(q_h, w_h)) + q_idx = pl.assemble(q_idx, pl.cast(q_acc, target_type=pl.BF16), [b, 0]) - sparse_k_gen = pl.min(INDEX_TOPK_CFG, ctx_len) - for blk in pl.range(2): - blk_start = blk * INDEX_TOPK_CFG - blk_end = pl.min(ctx_len, blk_start + INDEX_TOPK_CFG) - for ss in pl.range(INDEX_TOPK_CFG): - s = blk_start + ss - if s < blk_end: - cache_s = b * MAX_SEQ_CFG + s - kv_s = pl.cast( - pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), - target_type=pl.FP32, - ) - pe_s = pl.cast( - pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), - target_type=pl.FP32, - ) - score_nope = pl.row_sum(pl.mul(q0_nope_latent, kv_s)) - score_pe = pl.row_sum(pl.mul(q0_rot, pe_s)) - score_fp32 = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) - score_fp8 = pl.cast(score_fp32, target_type=pl.FP8E4M3FN) - score_a5 = pl.cast(score_fp8, target_type=pl.FP32) - cur_score = pl.tensor.read(score_a5, [0, 0]) - - inserted = pl.create_tensor([1, 1], dtype=pl.INT32) - inserted = pl.cast(pl.mul(inserted, 0), target_type=pl.INT32) - for kk in pl.range(sparse_k_gen): - ins = pl.tensor.read(inserted, [0, 0]) - kth_val = pl.tensor.read(blk_topk_vals, [blk, kk]) - if ins == 0: - if cur_score > kth_val: - for sh in pl.range(sparse_k_gen - 1, kk, -1): - prev_val = pl.tensor.read(blk_topk_vals, [blk, sh - 1]) - prev_idx = pl.tensor.read(blk_topk_idx, [blk, sh - 1]) - prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - prev_val_t = pl.mul(prev_val_t, 0.0) - prev_idx_t = pl.cast( - pl.mul(prev_idx_t, 0), target_type=pl.INT32 - ) - prev_val_t = pl.add(prev_val_t, prev_val) - prev_idx_t = pl.add(prev_idx_t, prev_idx) - blk_topk_vals = pl.assemble( - blk_topk_vals, prev_val_t, [blk, sh] - ) - blk_topk_idx = pl.assemble( - blk_topk_idx, prev_idx_t, [blk, sh] - ) - cur_score_t = pl.create_tensor([1, 1], dtype=pl.FP32) - cur_index_t = pl.create_tensor([1, 1], dtype=pl.INT32) - one_t = pl.create_tensor([1, 1], dtype=pl.INT32) - cur_score_t = pl.mul(cur_score_t, 0.0) - cur_index_t = pl.cast( - pl.mul(cur_index_t, 0), target_type=pl.INT32 - ) - one_t = pl.cast(pl.mul(one_t, 0), target_type=pl.INT32) - cur_score_t = pl.add(cur_score_t, cur_score) - cur_index_t = pl.cast( - pl.add(cur_index_t, s), target_type=pl.INT32 - ) - one_t = pl.cast(pl.add(one_t, 1), target_type=pl.INT32) - blk_topk_vals = pl.assemble(blk_topk_vals, cur_score_t, [blk, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, cur_index_t, [blk, kk]) - inserted = pl.assemble(inserted, one_t, [0, 0]) - - # Stage B2: global merge from 2x(local topk) -> final topk. - for blk in pl.range(2): - for kk in pl.range(sparse_k_gen): - cand_idx = pl.tensor.read(blk_topk_idx, [blk, kk]) - if cand_idx >= 0: - cand_val = pl.tensor.read(blk_topk_vals, [blk, kk]) - inserted = pl.create_tensor([1, 1], dtype=pl.INT32) - inserted = pl.cast(pl.mul(inserted, 0), target_type=pl.INT32) - for tkk in pl.range(sparse_k_gen): - ins = pl.tensor.read(inserted, [0, 0]) - kth_val = pl.tensor.read(topk_vals, [0, tkk]) - if ins == 0: - if cand_val > kth_val: - for sh in pl.range(sparse_k_gen - 1, tkk, -1): - prev_val = pl.tensor.read(topk_vals, [0, sh - 1]) - prev_idx = pl.tensor.read(topk_idx, [0, sh - 1]) - prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - prev_val_t = pl.mul(prev_val_t, 0.0) - prev_idx_t = pl.cast( - pl.mul(prev_idx_t, 0), target_type=pl.INT32 - ) - prev_val_t = pl.add(prev_val_t, prev_val) - prev_idx_t = pl.add(prev_idx_t, prev_idx) - topk_vals = pl.assemble(topk_vals, prev_val_t, [0, sh]) - topk_idx = pl.assemble(topk_idx, prev_idx_t, [0, sh]) - cand_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - cand_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - one_t = pl.create_tensor([1, 1], dtype=pl.INT32) - cand_val_t = pl.mul(cand_val_t, 0.0) - cand_idx_t = pl.cast(pl.mul(cand_idx_t, 0), target_type=pl.INT32) - one_t = pl.cast(pl.mul(one_t, 0), target_type=pl.INT32) - cand_val_t = pl.add(cand_val_t, cand_val) - cand_idx_t = pl.cast( - pl.add(cand_idx_t, cand_idx), target_type=pl.INT32 - ) - one_t = pl.cast(pl.add(one_t, 1), target_type=pl.INT32) - topk_vals = pl.assemble(topk_vals, cand_val_t, [0, tkk]) - topk_idx = pl.assemble(topk_idx, cand_idx_t, [0, tkk]) - inserted = pl.assemble(inserted, one_t, [0, 0]) - - # Stage C: sparse attention directly consumes merged topk_idx. - attn_row = pl.create_tensor([1, ATTN_OUT_CFG], dtype=pl.FP32) - attn_row = pl.mul(attn_row, 0.0) - for h in pl.parallel(0, NUM_HEADS_CFG, 1, chunk=8): - q_col = h * QK_HEAD_DIM_CFG + # Stage 2.8: write k_cache_idx[b*MAX_SEQ + pos] = k_idx[b]. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH, 1, chunk=BATCH): + pos = pl.tensor.read(seq_lens, [b]) - 1 + cache_row = b * MAX_SEQ + pos + k_row = pl.slice(k_idx, [1, INDEX_HEAD_DIM], [b, 0]) + k_cache_idx = pl.assemble(k_cache_idx, k_row, [cache_row, 0]) + + # ── Scope 3: score + topk ── + # Stage 3.1: scoring follows deepseek_v3_2_decode_front_scope2b.py + # (lines 68-95): per-batch tiled q_idx[b] x k_cache_idx[b, :], then + # fillpad invalid tail to -inf so a downstream topk naturally drops it. + scores = pl.create_tensor([BATCH, MAX_SEQ], dtype=pl.FP32) + for b in pl.range(BATCH): + ctx_len = pl.tensor.read(seq_lens, [b]) + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + all_scores = pl.create_tensor([MAX_SEQ_BLOCKS, SEQ_TILE], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=MAX_SEQ_BLOCKS): + s0 = sb * SEQ_TILE + cache_row0 = b * MAX_SEQ + s0 + q_b = pl.slice(q_idx, [1, INDEX_HEAD_DIM], [b, 0]) + k_tile = pl.slice( + k_cache_idx, [SEQ_TILE, INDEX_HEAD_DIM], [cache_row0, 0] + ) + score_tile = pl.matmul(q_b, k_tile, b_trans=True, out_dtype=pl.FP32) + all_scores = pl.assemble(all_scores, score_tile, [sb, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=MAX_SEQ_BLOCKS): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + tile_valid = pl.slice( + all_scores, [1, SEQ_TILE], [sb, 0], valid_shape=[1, valid_len] + ) + tile_padded = pl.fillpad(tile_valid, pad_value=pl.PadValue.min) + scores = pl.assemble(scores, tile_padded, [b, s0]) + + # Stage 3.2: TODO(topk). + # Real implementation should pick the INDEX_TOPK largest entries of + # scores[b, :ctx_len] per batch. See the legacy monolithic + # deepseek_v3_2_decode_front.py:278-436 for a B1/B2 two-stage + # insertion-sort sketch, or replace with a dedicated topk kernel. + # Placeholder: topk_idx is all-zero, which makes scope4 attend only + # to position 0 in each batch row but keeps the pipeline runnable. + topk_idx = pl.create_tensor([BATCH, INDEX_TOPK], dtype=pl.INT32) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH, 1, chunk=BATCH): + zero_row = pl.full([1, INDEX_TOPK], dtype=pl.INT32, value=0) + topk_idx = pl.assemble(topk_idx, zero_row, [b, 0]) + + # ── Scope 4: post topk (sparse MQA + dispatch) ── + attn_front = pl.create_tensor([BATCH, ATTN_OUT], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH, 1, chunk=4): + ctx_len = pl.tensor.read(seq_lens, [b]) + sparse_k = pl.min(INDEX_TOPK, ctx_len) + attn_row = pl.full([1, ATTN_OUT], dtype=pl.FP32, value=0.0) + + for h in pl.parallel(0, NUM_HEADS, 1, chunk=8): + q_col = h * QK_HEAD_DIM + # q_pe was already RoPE-rotated in Scope 1, so we read it + # back as-is. q_nope is projected to the latent space + # via per-head w_q_nope_to_latent. q_nope = pl.cast( - pl.slice(q_proj, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col]), + pl.slice(q_proj, [1, QK_NOPE_HEAD_DIM], [b, q_col]), target_type=pl.FP32, ) q_pe = pl.cast( - pl.slice(q_proj, [1, QK_ROPE_HEAD_DIM_CFG], [b, q_col + QK_NOPE_HEAD_DIM_CFG]), + pl.slice( + q_proj, [1, QK_ROPE_HEAD_DIM], [b, q_col + QK_NOPE_HEAD_DIM] + ), target_type=pl.FP32, ) - q_lo = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - q_hi = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - q_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - q_rot = pl.assemble( - q_rot, - pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), - [0, 0], - ) - q_rot = pl.assemble( - q_rot, - pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), - [0, QK_ROPE_HEAD_DIM_CFG // 2], + w_qn_h = pl.reshape( + pl.slice( + w_q_nope_to_latent, + [1, QK_NOPE_HEAD_DIM, KV_LORA_RANK], + [h, 0, 0], + ), + [QK_NOPE_HEAD_DIM, KV_LORA_RANK], ) q_nope_latent = pl.matmul( - pl.cast(q_nope, target_type=pl.BF16), - pl.reshape( - pl.slice( - w_q_nope_to_latent, [1, QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [h, 0, 0] - ), - [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], - ), + pl.cast(q_nope, target_type=pl.BF16), w_qn_h, out_dtype=pl.FP32 ) - oi = pl.create_tensor([1, KV_LORA_RANK_CFG], dtype=pl.FP32) - li = pl.create_tensor([1, 1], dtype=pl.FP32) - mi = pl.create_tensor([1, 1], dtype=pl.FP32) - oi = pl.mul(oi, 0.0) - li = pl.mul(li, 0.0) - mi = pl.mul(mi, 0.0) + oi = pl.full([1, KV_LORA_RANK], dtype=pl.FP32, value=0.0) + li = pl.full([1, 1], dtype=pl.FP32, value=0.0) + mi = pl.full([1, 1], dtype=pl.FP32, value=0.0) - sparse_k = pl.min(INDEX_TOPK_CFG, ctx_len) for kk in pl.range(sparse_k): - topk_pos = pl.tensor.read(topk_idx, [0, kk]) + topk_pos = pl.tensor.read(topk_idx, [b, kk]) if topk_pos >= 0: - cache_s = b * MAX_SEQ_CFG + topk_pos + cache_s = b * MAX_SEQ + topk_pos kv_s = pl.cast( - pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), + pl.slice(kv_cache, [1, KV_LORA_RANK], [cache_s, 0]), target_type=pl.FP32, ) pe_s = pl.cast( - pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), + pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM], [cache_s, 0]), target_type=pl.FP32, ) score_nope = pl.row_sum(pl.mul(q_nope_latent, kv_s)) - score_pe = pl.row_sum(pl.mul(q_rot, pe_s)) - score = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) - cur_mi = score - cur_li = pl.exp(pl.sub(score, cur_mi)) - oi_tmp = pl.row_expand_mul(kv_s, cur_li) + score_pe = pl.row_sum(pl.mul(q_pe, pe_s)) + cur_mi = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) + cur_li = pl.full([1, 1], dtype=pl.FP32, value=1.0) if kk == 0: - oi = oi_tmp + oi = kv_s li = cur_li mi = cur_mi else: @@ -506,172 +627,375 @@ def deepseek_v3_2_decode_front_layer( alpha = pl.exp(pl.sub(mi, mi_new)) beta = pl.exp(pl.sub(cur_mi, mi_new)) li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) - oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) + oi = pl.add( + pl.row_expand_mul(oi, alpha), + pl.row_expand_mul(kv_s, beta), + ) mi = mi_new ctx_latent = pl.row_expand_div(oi, li) - v_col = h * V_HEAD_DIM_CFG - ctx_v = pl.create_tensor([1, V_HEAD_DIM_CFG], dtype=pl.FP32) - ctx_v = pl.mul(ctx_v, 0.0) + + v_col = h * V_HEAD_DIM for vb in pl.range(V_OUT_BLOCKS): v0 = vb * V_OUT_CHUNK - wv_tile_3d = pl.slice( - w_latent_to_v, [1, KV_LORA_RANK_CFG, V_OUT_CHUNK], [h, 0, v0] + wv_tile = pl.reshape( + pl.slice( + w_latent_to_v, + [1, KV_LORA_RANK, V_OUT_CHUNK], + [h, 0, v0], + ), + [KV_LORA_RANK, V_OUT_CHUNK], ) - wv_tile = pl.reshape(wv_tile_3d, [KV_LORA_RANK_CFG, V_OUT_CHUNK]) v_part = pl.matmul( - pl.cast(ctx_latent, target_type=pl.BF16), wv_tile, out_dtype=pl.FP32 + pl.cast(ctx_latent, target_type=pl.BF16), + wv_tile, + out_dtype=pl.FP32, ) - ctx_v = pl.assemble(ctx_v, v_part, [0, v0]) - attn_row = pl.assemble(attn_row, ctx_v, [0, v_col]) - attn_front = pl.assemble(attn_front, attn_row, [b, 0]) - - # Scope 3: dispatch write to cross-node GM tensor and return. - for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): - target_node = (b + layer_id) % EP_NODES_CFG - token_row = pl.cast(pl.slice(attn_front, [1, ATTN_OUT_CFG], [b, 0]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, v_part, [0, v_col + v0]) + attn_front = pl.assemble(attn_front, attn_row, [b, 0]) + + # Dispatch write to cross-node GM tensor. + layer_id = pl.tensor.read(layer_id_t, [0]) + for b in pl.parallel(0, BATCH, 1, chunk=4): + target_node = (b + layer_id) % EP_NODES + token_row = pl.cast( + pl.slice(attn_front, [1, ATTN_OUT], [b, 0]), target_type=pl.BF16 + ) dispatch_buf = pl.assemble(dispatch_buf, token_row, [target_node, b, 0]) return dispatch_buf - return DeepSeekV32DecodeFront - - -def build_tensor_specs( - batch: int = BATCH, - max_seq_len: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - q_lora_rank: int = Q_LORA_RANK, - kv_lora_rank: int = KV_LORA_RANK, - qk_nope_head_dim: int = QK_NOPE_HEAD_DIM, - qk_rope_head_dim: int = QK_ROPE_HEAD_DIM, - v_head_dim: int = V_HEAD_DIM, - index_heads: int = INDEX_HEADS, - index_topk: int = INDEX_TOPK, - ep_nodes: int = EP_NODES, -): + return Ds32Exp + + +def build_tensor_specs(): + """TensorSpecs for `run` driver. Initialisers mirror the scope1 example: + centred uniform with 1/sqrt(fan_in) scaling on weights so RMSNorm / + matmul outputs stay in BF16's well-resolved range. + """ import torch # type: ignore[import] - from pypto.runtime import TensorSpec + from golden import TensorSpec + + def init_hidden_states(): + return torch.rand(BATCH, HIDDEN) - 0.5 - qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - kv_a_out = kv_lora_rank + qk_rope_head_dim - cache_rows = batch * max_seq_len - attn_out = num_heads * v_head_dim + def init_rms_weight(): + return torch.rand(1, HIDDEN) - 0.5 - seq_lens_data = torch.randint(1, max_seq_len + 1, (batch,), dtype=torch.int32) - layer_id_data = torch.tensor([0], dtype=torch.int32) + def init_q_norm_weight(): + return torch.rand(1, Q_LORA_RANK) - 0.5 + + def init_kv_norm_weight(): + return torch.rand(1, KV_LORA_RANK) - 0.5 + + def init_k_norm_weight(): + return torch.rand(1, INDEX_HEAD_DIM) - 0.5 + + def init_k_norm_bias(): + return torch.rand(1, INDEX_HEAD_DIM) - 0.5 + + def init_wq_a(): + return (torch.rand(HIDDEN, Q_LORA_RANK) - 0.5) / HIDDEN ** 0.5 + + def init_wq_b(): + return (torch.rand(Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM) - 0.5) / Q_LORA_RANK ** 0.5 + + def init_wkv_a(): + return (torch.rand(HIDDEN, KV_A_OUT) - 0.5) / HIDDEN ** 0.5 + + def init_wq_b_idx(): + return (torch.rand(Q_LORA_RANK, INDEX_HEADS * INDEX_HEAD_DIM) - 0.5) / Q_LORA_RANK ** 0.5 + + def init_wk_idx(): + return (torch.rand(HIDDEN, INDEX_HEAD_DIM) - 0.5) / HIDDEN ** 0.5 + + def init_weights_proj(): + return (torch.rand(HIDDEN, INDEX_HEADS) - 0.5) / HIDDEN ** 0.5 + + def init_w_q_nope_to_latent(): + return (torch.rand(NUM_HEADS, QK_NOPE_HEAD_DIM, KV_LORA_RANK) - 0.5) / QK_NOPE_HEAD_DIM ** 0.5 + + def init_w_latent_to_v(): + return (torch.rand(NUM_HEADS, KV_LORA_RANK, V_HEAD_DIM) - 0.5) / KV_LORA_RANK ** 0.5 + + def init_kv_cache(): + return torch.rand(CACHE_ROWS, KV_LORA_RANK) - 0.5 + + def init_pe_cache(): + return torch.rand(CACHE_ROWS, QK_ROPE_HEAD_DIM) - 0.5 + + def init_k_cache_idx(): + return torch.rand(CACHE_ROWS, INDEX_HEAD_DIM) - 0.5 + + def init_rope_cos(): + return torch.rand(MAX_SEQ, QK_ROPE_HEAD_DIM) - 0.5 + + def init_rope_sin(): + return torch.rand(MAX_SEQ, QK_ROPE_HEAD_DIM) - 0.5 + + def init_seq_lens(): + return torch.randint(1, MAX_SEQ + 1, (BATCH,), dtype=torch.int32) + + def init_layer_id(): + return torch.tensor([0], dtype=torch.int32) + + def init_dispatch_buf(): + return torch.zeros(EP_NODES, BATCH, ATTN_OUT) return [ - TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, init_value=torch.randn), - TensorSpec("seq_lens", [batch], torch.int32, init_value=seq_lens_data), - TensorSpec("layer_id_t", [1], torch.int32, init_value=layer_id_data), - TensorSpec("rope_cos", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=torch.randn), - TensorSpec("rope_sin", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=torch.randn), - TensorSpec("kv_cache", [cache_rows, kv_lora_rank], torch.bfloat16, init_value=torch.randn), - TensorSpec("pe_cache", [cache_rows, qk_rope_head_dim], torch.bfloat16, init_value=torch.randn), - TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, init_value=torch.randn), - TensorSpec("wq_a", [hidden_size, q_lora_rank], torch.bfloat16, init_value=torch.randn), - TensorSpec("q_norm_weight", [1, q_lora_rank], torch.float32, init_value=torch.randn), - TensorSpec("wq_b", [q_lora_rank, num_heads * qk_head_dim], torch.bfloat16, init_value=torch.randn), - TensorSpec("wkv_a", [hidden_size, kv_a_out], torch.bfloat16, init_value=torch.randn), - TensorSpec("kv_norm_weight", [1, kv_lora_rank], torch.float32, init_value=torch.randn), + TensorSpec("hidden_states", [BATCH, HIDDEN], torch.bfloat16, init_value=init_hidden_states), + TensorSpec("seq_lens", [BATCH], torch.int32, init_value=init_seq_lens), + TensorSpec("layer_id_t", [1], torch.int32, init_value=init_layer_id), + TensorSpec("rope_cos", [MAX_SEQ, QK_ROPE_HEAD_DIM], torch.float32, init_value=init_rope_cos), + TensorSpec("rope_sin", [MAX_SEQ, QK_ROPE_HEAD_DIM], torch.float32, init_value=init_rope_sin), + TensorSpec("kv_cache", [CACHE_ROWS, KV_LORA_RANK], torch.bfloat16, init_value=init_kv_cache), + TensorSpec("pe_cache", [CACHE_ROWS, QK_ROPE_HEAD_DIM], torch.bfloat16, init_value=init_pe_cache), + TensorSpec("k_cache_idx", [CACHE_ROWS, INDEX_HEAD_DIM], torch.bfloat16, init_value=init_k_cache_idx), + TensorSpec("input_rms_weight", [1, HIDDEN], torch.float32, init_value=init_rms_weight), + TensorSpec("wq_a", [HIDDEN, Q_LORA_RANK], torch.bfloat16, init_value=init_wq_a), + TensorSpec("q_norm_weight", [1, Q_LORA_RANK], torch.float32, init_value=init_q_norm_weight), + TensorSpec("wq_b", [Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM], torch.bfloat16, init_value=init_wq_b), + TensorSpec("wkv_a", [HIDDEN, KV_A_OUT], torch.bfloat16, init_value=init_wkv_a), + TensorSpec("kv_norm_weight", [1, KV_LORA_RANK], torch.float32, init_value=init_kv_norm_weight), + TensorSpec( + "wq_b_idx", + [Q_LORA_RANK, INDEX_HEADS * INDEX_HEAD_DIM], + torch.bfloat16, + init_value=init_wq_b_idx, + ), + TensorSpec("wk_idx", [HIDDEN, INDEX_HEAD_DIM], torch.bfloat16, init_value=init_wk_idx), + TensorSpec("k_norm_weight", [1, INDEX_HEAD_DIM], torch.float32, init_value=init_k_norm_weight), + TensorSpec("k_norm_bias", [1, INDEX_HEAD_DIM], torch.float32, init_value=init_k_norm_bias), + TensorSpec("weights_proj", [HIDDEN, INDEX_HEADS], torch.float32, init_value=init_weights_proj), TensorSpec( "w_q_nope_to_latent", - [num_heads, qk_nope_head_dim, kv_lora_rank], + [NUM_HEADS, QK_NOPE_HEAD_DIM, KV_LORA_RANK], torch.bfloat16, - init_value=torch.randn, + init_value=init_w_q_nope_to_latent, ), TensorSpec( - "w_latent_to_v", [num_heads, kv_lora_rank, v_head_dim], torch.bfloat16, init_value=torch.randn + "w_latent_to_v", + [NUM_HEADS, KV_LORA_RANK, V_HEAD_DIM], + torch.bfloat16, + init_value=init_w_latent_to_v, + ), + TensorSpec( + "dispatch_buf", + [EP_NODES, BATCH, ATTN_OUT], + torch.bfloat16, + init_value=init_dispatch_buf, + is_output=True, ), - TensorSpec("dispatch_buf", [ep_nodes, batch, attn_out], torch.bfloat16, is_output=True), ] -def compile_and_run( - batch: int = BATCH, - max_seq_len: int = MAX_SEQ, - hidden_size: int = HIDDEN, - num_heads: int = NUM_HEADS, - q_lora_rank: int = Q_LORA_RANK, - kv_lora_rank: int = KV_LORA_RANK, - qk_nope_head_dim: int = QK_NOPE_HEAD_DIM, - qk_rope_head_dim: int = QK_ROPE_HEAD_DIM, - v_head_dim: int = V_HEAD_DIM, - index_heads: int = INDEX_HEADS, - index_topk: int = INDEX_TOPK, - ep_nodes: int = EP_NODES, - platform: str = "a2a3", - device_id: int = 0, - work_dir: Optional[str] = None, - dump_passes: bool = True, - runtime_profiling: bool = False, -): - from pypto.backend import BackendType - from pypto.ir.pass_manager import OptimizationStrategy - from pypto.runtime import RunConfig, run - - program = build_deepseek_v3_2_decode_front_program( - batch=batch, - max_seq_len=max_seq_len, - hidden_size=hidden_size, - num_heads=num_heads, - q_lora_rank=q_lora_rank, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - v_head_dim=v_head_dim, - index_heads=index_heads, - index_topk=index_topk, - ep_nodes=ep_nodes, - ) - tensor_specs = build_tensor_specs( - batch=batch, - max_seq_len=max_seq_len, - hidden_size=hidden_size, - num_heads=num_heads, - q_lora_rank=q_lora_rank, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - v_head_dim=v_head_dim, - index_heads=index_heads, - index_topk=index_topk, - ep_nodes=ep_nodes, - ) +def golden_ds32exp(tensors): + """PyTorch reference covering all 4 scopes. - result = run( - program=program, - tensor_specs=tensor_specs, - golden=None, - config=RunConfig( - platform=platform, - device_id=device_id, - rtol=2e-2, - atol=2e-2, - strategy=OptimizationStrategy.Default, - dump_passes=dump_passes, - backend_type=BackendType.Ascend910B, - runtime_profiling=runtime_profiling, - ), - ) - return result + Mirrors the kernel exactly: lo/hi half-split RoPE for both MLA q_pe/k_pe + and the indexer's q_pe/k_pe (we follow the in-tree convention rather + than the official `interleaved=True/False` view), the Hadamard step is + reduced to its scalar scale (`INDEX_HEAD_DIM ** -0.5`), and FP8 quant + is approximated by a BF16 -> FP8E4M3FN -> BF16 round-trip. + + Topk is left as a placeholder (all-zero indices) so scope4 attends only + to position 0; this matches `# TODO(topk)` behaviour in the kernel. + """ + import torch # type: ignore[import] + + hidden_states = tensors["hidden_states"].float() + seq_lens = tensors["seq_lens"] + layer_id = int(tensors["layer_id_t"][0].item()) + rope_cos = tensors["rope_cos"].float() + rope_sin = tensors["rope_sin"].float() + kv_cache = tensors["kv_cache"].clone() + pe_cache = tensors["pe_cache"].clone() + k_cache_idx = tensors["k_cache_idx"].clone() + input_rms_weight = tensors["input_rms_weight"].float() + wq_a = tensors["wq_a"].float() + q_norm_weight = tensors["q_norm_weight"].float() + wq_b = tensors["wq_b"].float() + wkv_a = tensors["wkv_a"].float() + kv_norm_weight = tensors["kv_norm_weight"].float() + wq_b_idx = tensors["wq_b_idx"].float() + wk_idx = tensors["wk_idx"].float() + k_norm_weight = tensors["k_norm_weight"].float() + k_norm_bias = tensors["k_norm_bias"].float() + weights_proj = tensors["weights_proj"].float() + w_q_nope_to_latent = tensors["w_q_nope_to_latent"].float() + w_latent_to_v = tensors["w_latent_to_v"].float() + + half = HALF_ROPE + + def rope_half(vec, cos_lo, cos_hi, sin_lo, sin_hi): + # vec: [..., QK_ROPE_HEAD_DIM]; matches kernel's lo/hi-split rotation. + lo = vec[..., :half] + hi = vec[..., half:] + rot_lo = lo * cos_lo - hi * sin_lo + rot_hi = hi * cos_hi + lo * sin_hi + return torch.cat([rot_lo, rot_hi], dim=-1) + + # ── Scope 1 golden: RMSNorm + projections + q_pe / k_pe RoPE ── + sq_sum = (hidden_states * hidden_states).sum(dim=1, keepdim=True) + inv_rms = torch.rsqrt(sq_sum * HIDDEN_INV + EPS) + normed = (hidden_states * inv_rms * input_rms_weight).to(torch.bfloat16).float() + + qr = (normed @ wq_a).to(torch.bfloat16).float() * q_norm_weight + qr_bf16 = qr.to(torch.bfloat16) + q_proj = (qr_bf16.float() @ wq_b).to(torch.bfloat16).float() + kv_a = (normed @ wkv_a).to(torch.bfloat16).float() + + q_proj_view = q_proj.view(BATCH, NUM_HEADS, QK_HEAD_DIM) + for b in range(BATCH): + pos = int(seq_lens[b].item()) - 1 + cos_lo = rope_cos[pos:pos + 1, :half] + cos_hi = rope_cos[pos:pos + 1, half:] + sin_lo = rope_sin[pos:pos + 1, :half] + sin_hi = rope_sin[pos:pos + 1, half:] + + q_pe_b = q_proj_view[b, :, QK_NOPE_HEAD_DIM:] + q_proj_view[b, :, QK_NOPE_HEAD_DIM:] = rope_half(q_pe_b, cos_lo, cos_hi, sin_lo, sin_hi) + + kv_latent = kv_a[b:b + 1, :KV_LORA_RANK] + kv_normed = (kv_latent * kv_norm_weight).to(torch.bfloat16).float() + cache_row = b * MAX_SEQ + pos + kv_cache[cache_row, :] = kv_normed.squeeze(0).to(torch.bfloat16) + + k_pe_b = kv_a[b:b + 1, KV_LORA_RANK:KV_LORA_RANK + QK_ROPE_HEAD_DIM] + k_pe_rot = rope_half(k_pe_b, cos_lo, cos_hi, sin_lo, sin_hi) + pe_cache[cache_row, :] = k_pe_rot.squeeze(0).to(torch.bfloat16) + q_proj = q_proj_view.reshape(BATCH, NUM_HEADS * QK_HEAD_DIM) + q_proj_bf16 = q_proj.to(torch.bfloat16).float() + + # ── Scope 2 golden: indexer proj + RoPE + Hadamard placeholder + fp8 placeholder ── + q_idx_full = (qr_bf16.float() @ wq_b_idx).to(torch.bfloat16).float() + k_idx = (hidden_states.to(torch.bfloat16).float() @ wk_idx).to(torch.bfloat16).float() + + # LayerNorm on k_idx. + mean = k_idx.mean(dim=-1, keepdim=True) + centered = k_idx - mean + var = (centered * centered).mean(dim=-1, keepdim=True) + inv_std = torch.rsqrt(var + EPS) + k_idx = (centered * inv_std * k_norm_weight + k_norm_bias).to(torch.bfloat16).float() + + # RoPE (lo/hi half) on q_idx_full per index head and on k_idx (single head). + q_idx_full_view = q_idx_full.view(BATCH, INDEX_HEADS, INDEX_HEAD_DIM) + for b in range(BATCH): + pos = int(seq_lens[b].item()) - 1 + cos_lo = rope_cos[pos:pos + 1, :half] + cos_hi = rope_cos[pos:pos + 1, half:] + sin_lo = rope_sin[pos:pos + 1, :half] + sin_hi = rope_sin[pos:pos + 1, half:] + + q_pe_b = q_idx_full_view[b, :, :QK_ROPE_HEAD_DIM] + q_idx_full_view[b, :, :QK_ROPE_HEAD_DIM] = rope_half(q_pe_b, cos_lo, cos_hi, sin_lo, sin_hi) + + k_pe_b = k_idx[b:b + 1, :QK_ROPE_HEAD_DIM] + k_idx[b:b + 1, :QK_ROPE_HEAD_DIM] = rope_half(k_pe_b, cos_lo, cos_hi, sin_lo, sin_hi) + q_idx_full = q_idx_full_view.reshape(BATCH, INDEX_HEADS * INDEX_HEAD_DIM) + + # TODO(hadamard_transform) placeholder: scalar scale only. + q_idx_full = (q_idx_full * HADAMARD_SCALE).to(torch.bfloat16).float() + k_idx = (k_idx * HADAMARD_SCALE).to(torch.bfloat16).float() + + # TODO(fp8 quant) placeholder: BF16 -> FP8E4M3FN -> BF16 round-trip. + q_idx_full = q_idx_full.to(torch.bfloat16).to(torch.float8_e4m3fn).to(torch.bfloat16).float() + k_idx = k_idx.to(torch.bfloat16).to(torch.float8_e4m3fn).to(torch.bfloat16).float() + + # Per-head weights and aggregation. + w_scale = INDEX_HEADS_INV_SQRT * INDEX_SOFTMAX_SCALE + weights = (hidden_states.to(torch.bfloat16).float() @ weights_proj.to(torch.bfloat16).float()) * w_scale + q_idx_full_view = q_idx_full.view(BATCH, INDEX_HEADS, INDEX_HEAD_DIM) + q_idx = (weights.unsqueeze(-1) * q_idx_full_view).sum(dim=1) + q_idx_bf16 = q_idx.to(torch.bfloat16).float() + + for b in range(BATCH): + pos = int(seq_lens[b].item()) - 1 + k_cache_idx[b * MAX_SEQ + pos, :] = k_idx[b].to(torch.bfloat16) + + # ── Scope 3 golden: tiled scoring (q_idx · k_cache_idx[b, :]) ── + # Topk is a placeholder (all-zero indices); kernel currently picks position 0. + topk_idx = torch.zeros(BATCH, INDEX_TOPK, dtype=torch.int64) + + # ── Scope 4 golden: sparse MQA attention + dispatch write ── + attn_front = torch.zeros(BATCH, ATTN_OUT, dtype=torch.float32) + for b in range(BATCH): + ctx_len = int(seq_lens[b].item()) + sparse_k = min(INDEX_TOPK, ctx_len) + for h in range(NUM_HEADS): + q_col = h * QK_HEAD_DIM + q_nope = q_proj_bf16[b, q_col:q_col + QK_NOPE_HEAD_DIM] + q_pe = q_proj_bf16[b, q_col + QK_NOPE_HEAD_DIM:q_col + QK_HEAD_DIM] + q_nope_latent = q_nope @ w_q_nope_to_latent[h] + + oi = torch.zeros(KV_LORA_RANK, dtype=torch.float32) + li = torch.zeros(1, dtype=torch.float32) + mi = torch.zeros(1, dtype=torch.float32) + for kk in range(sparse_k): + pos = int(topk_idx[b, kk].item()) + if pos < 0: + continue + cache_s = b * MAX_SEQ + pos + kv_s = kv_cache[cache_s].float() + pe_s = pe_cache[cache_s].float() + score_nope = (q_nope_latent * kv_s).sum() + score_pe = (q_pe * pe_s).sum() + cur_mi = (score_nope + score_pe) * ATTN_SCALE + cur_li = torch.tensor([1.0]) + cur_mi = cur_mi.view(1) + if kk == 0: + oi = kv_s + li = cur_li + mi = cur_mi + else: + mi_new = torch.maximum(mi, cur_mi) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(cur_mi - mi_new) + li = alpha * li + beta * cur_li + oi = oi * alpha + kv_s * beta + mi = mi_new + ctx_latent = (oi / li).to(torch.bfloat16).float() + ctx_v = ctx_latent @ w_latent_to_v[h] + attn_front[b, h * V_HEAD_DIM:(h + 1) * V_HEAD_DIM] = ctx_v + + dispatch_buf = tensors["dispatch_buf"] + dispatch_buf.zero_() + for b in range(BATCH): + target_node = (b + layer_id) % EP_NODES + dispatch_buf[target_node, b, :] = attn_front[b].to(torch.bfloat16) if __name__ == "__main__": import argparse + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + + from golden import RunConfig, run parser = argparse.ArgumentParser() - parser.add_argument("-p", "--platform", type=str, default="a2a3", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument( + "-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"], + ) parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--runtime-profiling", action="store_true", default=False) args = parser.parse_args() - result = compile_and_run( - platform=args.platform, - device_id=args.device, - runtime_profiling=args.runtime_profiling, + result = run( + program=build_ds32exp_program(), + tensor_specs=build_tensor_specs(), + golden_fn=golden_ds32exp, + config=RunConfig( + rtol=2e-2, + atol=2e-2, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + runtime_profiling=args.runtime_profiling, + ), + ), ) if not result.passed: if result.error: diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py new file mode 100644 index 0000000..edb80d1 --- /dev/null +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope3.py @@ -0,0 +1,305 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +from __future__ import annotations + +""" +DeepSeek V3.2-EXP single-layer decode FRONT — Scope 3: indexer score + topk. + +Pipeline: scope1 (qkv proj/RoPE) → scope2 (indexer proj/RoPE + q_idx aggregate ++ write k_cache_idx) → scope3 (this file: score then topk). + +Scoring uses the linear reduction + score[b, s] = sum_h w[b, h] * (q[b, h] dot k[b, s]) + = (sum_h w[b, h] * q[b, h]) dot k[b, s] + = q_idx[b] dot k_cache_idx[b, s] +so scope2 collapses 64 heads into a single query vector, and scope3 only has +to do [1, INDEX_HEAD_DIM] x [ctx_len, INDEX_HEAD_DIM] per batch. + +Topk is done by sort32 + 4-way mrgsort merge, then gather to split sorted +(val, idx) pairs, then GM reload with valid_shape+fillpad to mark idx slots +past ctx_len. Outputs (topk_vals_out, topk_idx_out) = top-INDEX_TOPK entries +per batch; invalid tail idx = INT32_MIN (< 0), compatible with scope4's +`topk_pos >= 0` filter. +""" + + +import pypto.language as pl + + +BATCH = 16 +MAX_SEQ = 8192 +INDEX_HEAD_DIM = 128 +INDEX_TOPK = 2048 +CACHE_ROWS_IDX = BATCH * MAX_SEQ + +SEQ_TILE = 64 +MAX_SEQ_BLOCKS = (MAX_SEQ + SEQ_TILE - 1) // SEQ_TILE + +# Q pad: a2a3 TExtract requires row % 16 == 0, so pad the 1-row query to 16. +Q_VALID = 1 +Q_PAD = 16 + +# sort32 + 4 mrgsort iterations (block_len 64,256,1024,4096) sort MAX_SEQ=8192. +MRGSORT_ITERS = 4 + +# -inf sentinel for score tail. FP32 lowest, since ptoas rejects literal -inf. +FP32_NEG_INF = -3.4028234663852886e38 + + +def build_deepseek_v3_2_decode_front_scope3_program(): + @pl.program + class DeepSeekV32DecodeFrontScope3: + @pl.function(type=pl.FunctionType.Opaque) + def deepseek_v3_2_decode_front_scope3( + self, + q_idx: pl.Tensor[[BATCH, INDEX_HEAD_DIM], pl.BF16], + k_cache_idx: pl.Tensor[[CACHE_ROWS_IDX, INDEX_HEAD_DIM], pl.BF16], + seq_lens: pl.Tensor[[BATCH], pl.INT32], + idx_init: pl.Tensor[[1, MAX_SEQ], pl.UINT32], + topk_vals_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.FP32], + topk_idx_out: pl.Tensor[[BATCH, INDEX_TOPK], pl.INT32], + ) -> tuple[ + pl.Tensor[[BATCH, INDEX_TOPK], pl.FP32], + pl.Tensor[[BATCH, INDEX_TOPK], pl.INT32], + ]: + # Pad q_idx to [BATCH * Q_PAD, 128] with zero rows so QK matmul + # has row=16 (required by a2a3 TExtract). + q_padded = pl.create_tensor([BATCH * Q_PAD, INDEX_HEAD_DIM], dtype=pl.BF16) + with pl.at(level=pl.Level.CORE_GROUP): + for b in pl.range(BATCH): + q_row = pl.slice(q_idx, [1, INDEX_HEAD_DIM], [b, 0]) + q_padded = pl.assemble(q_padded, q_row, [b * Q_PAD, 0]) + q_padded = pl.assemble( + q_padded, + pl.cast( + pl.full( + [Q_PAD - Q_VALID, INDEX_HEAD_DIM], + dtype=pl.FP32, + value=0.0, + ), + target_type=pl.BF16, + ), + [b * Q_PAD + Q_VALID, 0], + ) + + # Transient GM buffers. + scores = pl.create_tensor([BATCH, MAX_SEQ], dtype=pl.FP32) + sorted_gm = pl.create_tensor([BATCH, 2 * MAX_SEQ], dtype=pl.FP32) + raw_idx_gm = pl.create_tensor([BATCH, INDEX_TOPK], dtype=pl.INT32) + + for b in pl.range(0, BATCH, 1): + ctx_len = pl.tensor.read(seq_lens, [b]) + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + + # Stage 0: pre-fill scores[b] with -inf. Stage 2's parallel + # loop only covers [0, ctx_blocks), so untouched tail slots + # keep the create_tensor default of 0.0 without this. + with pl.at(level=pl.Level.CORE_GROUP): + neg_inf_row = pl.full([1, MAX_SEQ], dtype=pl.FP32, value=FP32_NEG_INF) + scores = pl.assemble(scores, neg_inf_row, [b, 0]) + + # Stage 1: tiled QK matmul into all_scores[sb*Q_PAD, 0] (row 0 valid). + all_scores = pl.create_tensor( + [MAX_SEQ_BLOCKS * Q_PAD, SEQ_TILE], dtype=pl.FP32 + ) + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=MAX_SEQ_BLOCKS): + s0 = sb * SEQ_TILE + cache_row0 = b * MAX_SEQ + s0 + q_b = pl.slice(q_padded, [Q_PAD, INDEX_HEAD_DIM], [b * Q_PAD, 0]) + k_tile = pl.slice( + k_cache_idx, [SEQ_TILE, INDEX_HEAD_DIM], [cache_row0, 0] + ) + score_tile = pl.matmul(q_b, k_tile, b_trans=True, out_dtype=pl.FP32) + all_scores = pl.assemble(all_scores, score_tile, [sb * Q_PAD, 0]) + + # Stage 2: fillpad each tile's tail and write row 0 to scores[b, s0]. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for sb in pl.parallel(ctx_blocks, chunk=MAX_SEQ_BLOCKS): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + tile_valid = pl.slice( + all_scores, + [1, SEQ_TILE], + [sb * Q_PAD, 0], + valid_shape=[1, valid_len], + ) + tile_padded = pl.fillpad(tile_valid, pad_value=pl.PadValue.min) + scores = pl.assemble(scores, tile_padded, [b, s0]) + + # Stage 3: sort32 + 4 mrgsort iterations. Result is [1, 2*MAX_SEQ] + # interleaved (val, idx). Stored to GM because vec→vec tile slicing + # is not supported on a2a3, so downstream gather re-loads from GM. + # Distinct SSA names for each mrgsort so the final tile is the one + # that gets stored. + with pl.at(level=pl.Level.CORE_GROUP): + score_row = pl.load(scores, offsets=[b, 0], shapes=[1, MAX_SEQ]) + idx_row = pl.load(idx_init, offsets=[0, 0], shapes=[1, MAX_SEQ]) + sorted0 = pl.tile.sort32(score_row, idx_row) + sorted1 = pl.tile.mrgsort(sorted0, block_len=64) + sorted2 = pl.tile.mrgsort(sorted1, block_len=256) + sorted3 = pl.tile.mrgsort(sorted2, block_len=1024) + sorted4 = pl.tile.mrgsort(sorted3, block_len=4096) + sorted_gm = pl.store( + sorted4, offsets=[b, 0], output_tensor=sorted_gm + ) + + # Stage 4: GM-load the first INDEX_TOPK pairs (2*INDEX_TOPK cols), + # gather P0101/P1010 to split vals / idx bits, store to outputs. + with pl.at(level=pl.Level.CORE_GROUP): + topk_pairs = pl.load( + sorted_gm, offsets=[b, 0], shapes=[1, 2 * INDEX_TOPK] + ) + topk_v = pl.tile.gather( + topk_pairs, mask_pattern=pl.tile.MaskPattern.P0101 + ) + topk_i_raw = pl.tile.gather( + topk_pairs, + mask_pattern=pl.tile.MaskPattern.P1010, + output_dtype=pl.INT32, + ) + topk_vals_out = pl.store( + topk_v, offsets=[b, 0], output_tensor=topk_vals_out + ) + raw_idx_gm = pl.store( + topk_i_raw, offsets=[b, 0], output_tensor=raw_idx_gm + ) + + # Stage 5: GM reload + valid_shape fillpad to mark idx slots past + # ctx_len with PadValue.min (= INT32_MIN < 0). + with pl.at(level=pl.Level.CORE_GROUP): + valid_topk = pl.min(INDEX_TOPK, ctx_len) + idx_valid = pl.slice( + raw_idx_gm, + [1, INDEX_TOPK], + [b, 0], + valid_shape=[1, valid_topk], + ) + idx_padded = pl.fillpad(idx_valid, pad_value=pl.PadValue.min) + topk_idx_out = pl.assemble(topk_idx_out, idx_padded, [b, 0]) + + return topk_vals_out, topk_idx_out + + return DeepSeekV32DecodeFrontScope3 + + +def golden_decode_front_scope3(tensors): + import torch # type: ignore[import] + + q_idx = tensors["q_idx"].float() + k_cache_idx = tensors["k_cache_idx"].float() + seq_lens = tensors["seq_lens"] + topk_vals_out = tensors["topk_vals_out"] + topk_idx_out = tensors["topk_idx_out"] + + scores = torch.full((BATCH, MAX_SEQ), FP32_NEG_INF, dtype=torch.float32) + for b in range(BATCH): + ctx_len = int(seq_lens[b].item()) + q_b = q_idx[b : b + 1] + k_b = k_cache_idx[b * MAX_SEQ : b * MAX_SEQ + ctx_len] + scores[b, :ctx_len] = (q_b @ k_b.T).squeeze(0) + + # Rare BF16 ties can swap adjacent idx entries between kernel and + # torch.sort; vals stay identical so downstream attention is unaffected. + vals, idx = torch.topk(scores, INDEX_TOPK, dim=1, largest=True, sorted=True) + topk_vals_out.copy_(vals) + # Kernel fillpads idx tail past ctx_len with INT32_MIN. + idx = idx.to(torch.int32) + for b in range(BATCH): + ctx_len = int(seq_lens[b].item()) + valid_topk = min(INDEX_TOPK, ctx_len) + idx[b, valid_topk:] = torch.iinfo(torch.int32).min + topk_idx_out.copy_(idx) + + +def build_tensor_specs(): + import torch # type: ignore[import] + from golden import TensorSpec + + # ctx_len in [1, MAX_SEQ]; kernel pads idx tail past ctx_len with + # PadValue.min (= INT32_MIN), and scope4 filters on `topk_pos >= 0`. + seq_lens_data = torch.randint(1, MAX_SEQ + 1, (BATCH,), dtype=torch.int32) + + def init_q_idx(): + return torch.rand(BATCH, INDEX_HEAD_DIM) - 0.5 + + def init_k_cache_idx(): + return torch.rand(CACHE_ROWS_IDX, INDEX_HEAD_DIM) - 0.5 + + def init_idx_init(): + return torch.arange(MAX_SEQ, dtype=torch.int32).unsqueeze(0) + + def init_topk_vals_out(): + return torch.zeros((BATCH, INDEX_TOPK), dtype=torch.float32) + + def init_topk_idx_out(): + return torch.zeros((BATCH, INDEX_TOPK), dtype=torch.int32) + + return [ + TensorSpec("q_idx", [BATCH, INDEX_HEAD_DIM], torch.bfloat16, init_value=init_q_idx), + TensorSpec( + "k_cache_idx", + [CACHE_ROWS_IDX, INDEX_HEAD_DIM], + torch.bfloat16, + init_value=init_k_cache_idx, + ), + TensorSpec("seq_lens", [BATCH], torch.int32, init_value=seq_lens_data), + TensorSpec("idx_init", [1, MAX_SEQ], torch.int32, init_value=init_idx_init), + TensorSpec( + "topk_vals_out", + [BATCH, INDEX_TOPK], + torch.float32, + init_value=init_topk_vals_out, + is_output=True, + ), + TensorSpec( + "topk_idx_out", + [BATCH, INDEX_TOPK], + torch.int32, + init_value=init_topk_idx_out, + is_output=True, + ), + ] + + +if __name__ == "__main__": + import argparse + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + + from golden import RunConfig, run + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("--runtime-profiling", action="store_true", default=False) + args = parser.parse_args() + + result = run( + program=build_deepseek_v3_2_decode_front_scope3_program(), + tensor_specs=build_tensor_specs(), + golden_fn=golden_decode_front_scope3, + config=RunConfig( + rtol=1e-3, + atol=1e-3, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + runtime_profiling=args.runtime_profiling, + ), + ), + ) + if not result.passed: + if result.error: + print(result.error) + raise SystemExit(1)