diff --git a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py index 2f12c35..b1ad087 100644 --- a/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py +++ b/examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py @@ -9,14 +9,19 @@ from __future__ import annotations """ -DeepSeek V3.2-EXP single-layer decode FRONT part — Scope 1 only (batch=16). +DeepSeek V3.2-EXP single-layer decode front path (batch=16). -Scope 1: input RMSNorm + Q/KV projection. +Projection stage: - Compute RMSNorm of hidden_states - Project to Q latent (qr) via wq_a -- Apply q_norm to Q latent, then project to Q heads (q_proj) via wq_b +- Apply q_norm to Q latent, then project to Q heads via wq_b - Project to KV latent (kv_a) via wkv_a +Decode cache preparation: +- Split Q heads into q_nope and q_pe outputs +- Apply RoPE to q_pe and k_pe +- Apply kv_norm to KV latent, then update KV/PE cache for the current token + Aligned to official v3.2-exp MLA shapes: - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 @@ -28,6 +33,7 @@ BATCH = 16 +MAX_SEQ = 4096 HIDDEN = 7168 NUM_HEADS = 128 Q_LORA_RANK = 1536 @@ -37,11 +43,12 @@ QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM V_HEAD_DIM = 128 KV_A_OUT = KV_LORA_RANK + QK_ROPE_HEAD_DIM +CACHE_ROWS = BATCH * MAX_SEQ EPS = 1e-6 HIDDEN_INV = 1.0 / HIDDEN -# Tile sizes tuned for standalone scope-1 incore boundaries: +# Tile sizes tuned for the standalone front projection/cache-prep example: # - PROJ_K = K-dimension chunk for projection matmuls (kept at 512). # - LORA_CHUNK, KV_OUT_CHUNK = 64 so AIC Right buffer ≤ 65536 # (512 * 64 * 2 = 65536). @@ -52,7 +59,7 @@ # static_assert in pto_tile.hpp. The original 3-scope pipeline used # BATCH_TILE=4 because the combined scopes allowed a different split. # - LOCAL_PAD_WIDTH removed; the pad tensor was a tuning hint for the -# combined scope1+2+3 pipeline and is not needed for scope1 alone. +# combined front pipeline and is not needed here. RMSNORM_K = 512 PROJ_K = 512 Q_OUT_CHUNK = 64 @@ -63,6 +70,7 @@ def build_deepseek_v3_2_decode_front_scope1_program( batch: int = BATCH, + max_seq_len: int = MAX_SEQ, hidden_size: int = HIDDEN, num_heads: int = NUM_HEADS, q_lora_rank: int = Q_LORA_RANK, @@ -72,6 +80,7 @@ def build_deepseek_v3_2_decode_front_scope1_program( v_head_dim: int = V_HEAD_DIM, ): BATCH_CFG = batch + MAX_SEQ_CFG = max_seq_len HIDDEN_CFG = hidden_size NUM_HEADS_CFG = num_heads Q_LORA_RANK_CFG = q_lora_rank @@ -81,6 +90,7 @@ def build_deepseek_v3_2_decode_front_scope1_program( QK_HEAD_DIM_CFG = qk_nope_head_dim + qk_rope_head_dim V_HEAD_DIM_CFG = v_head_dim KV_A_OUT_CFG = kv_lora_rank + qk_rope_head_dim + CACHE_ROWS_CFG = batch * max_seq_len Q_LORA_INV_CFG = 1.0 / q_lora_rank RMSNORM_BLOCKS = (HIDDEN_CFG + RMSNORM_K - 1) // RMSNORM_K @@ -100,12 +110,21 @@ def deepseek_v3_2_decode_front_scope1( q_norm_weight: pl.Tensor[[1, Q_LORA_RANK_CFG], pl.FP32], wq_b: pl.Tensor[[Q_LORA_RANK_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16], wkv_a: pl.Tensor[[HIDDEN_CFG, KV_A_OUT_CFG], pl.BF16], + seq_lens: pl.Tensor[[BATCH_CFG], pl.INT32], + rope_cos: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], + rope_sin: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], + kv_norm_weight: pl.Tensor[[1, KV_LORA_RANK_CFG], pl.FP32], # Output buffers qr_out: pl.Tensor[[BATCH_CFG, Q_LORA_RANK_CFG], pl.BF16], - q_proj_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16], + q_nope_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_NOPE_HEAD_DIM_CFG], pl.BF16], + q_pe_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16], kv_a_out: pl.Tensor[[BATCH_CFG, KV_A_OUT_CFG], pl.BF16], - ) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16]: - # Scope 1: input RMSNorm + Q/KV projection. + kv_cache: pl.Tensor[[CACHE_ROWS_CFG, KV_LORA_RANK_CFG], pl.BF16], + pe_cache: pl.Tensor[[CACHE_ROWS_CFG, QK_ROPE_HEAD_DIM_CFG], pl.BF16], + ) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16]: + q_proj_out = pl.create_tensor([BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16) + + # Front projection: input RMSNorm + Q/KV projections. for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) qr_fp32_tile = pl.create_tensor([BATCH_TILE, Q_LORA_RANK_CFG], dtype=pl.FP32) @@ -228,7 +247,121 @@ def deepseek_v3_2_decode_front_scope1( ) kv_a_out = pl.assemble(kv_a_out, kv_chunk, [b0, kv0]) - return q_proj_out + kv_normed_out = pl.create_tensor([BATCH_CFG, KV_LORA_RANK_CFG], dtype=pl.BF16) + + with pl.at(level=pl.Level.CORE_GROUP): + kv_rows = pl.cast(pl.slice(kv_a_out, [BATCH_CFG, KV_LORA_RANK_CFG], [0, 0]), target_type=pl.FP32) + kv_partial_sq = pl.reshape(pl.row_sum(pl.mul(kv_rows, kv_rows)), [1, BATCH_CFG]) + kv_variance = pl.reshape( + pl.add(pl.mul(kv_partial_sq, 1.0 / KV_LORA_RANK_CFG), EPS), + [BATCH_CFG, 1], + ) + kv_inv_rms = pl.recip(pl.sqrt(kv_variance)) + kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0]) + kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_rows, kv_inv_rms), kv_gamma) + kv_normed_out = pl.assemble(kv_normed_out, pl.cast(kv_normed, target_type=pl.BF16), [0, 0]) + + # Q split + q_rope: produce q_nope/q_pe while keeping the internal + # full Q layout available for the in-place RoPE writes. + # Cache preparation: write RMS-normalized KV latent and rotated + # k_pe for the current decode token. + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + cache_row = b * MAX_SEQ_CFG + pos + + cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) + cos_hi = pl.slice( + rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + ) + sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) + sin_hi = pl.slice( + rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + ) + + for h in pl.range(NUM_HEADS_CFG): + q_col = h * QK_HEAD_DIM_CFG + q_nope_col = h * QK_NOPE_HEAD_DIM_CFG + q_pe_col = h * QK_ROPE_HEAD_DIM_CFG + q_nope = pl.slice(q_proj_out, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col]) + q_nope_out = pl.assemble(q_nope_out, q_nope, [b, q_nope_col]) + q_lo = pl.cast( + pl.slice( + q_proj_out, + [1, QK_ROPE_HEAD_DIM_CFG // 2], + [b, q_col + QK_NOPE_HEAD_DIM_CFG], + ), + target_type=pl.FP32, + ) + q_hi = pl.cast( + pl.slice( + q_proj_out, + [1, QK_ROPE_HEAD_DIM_CFG // 2], + [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2], + ), + target_type=pl.FP32, + ) + q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)) + q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)) + q_proj_out = pl.assemble( + q_proj_out, + pl.cast(q_rot_lo, target_type=pl.BF16), + [b, q_col + QK_NOPE_HEAD_DIM_CFG], + ) + q_proj_out = pl.assemble( + q_proj_out, + pl.cast(q_rot_hi, target_type=pl.BF16), + [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2], + ) + q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col]) + q_pe_out = pl.assemble( + q_pe_out, + pl.cast(q_rot_hi, target_type=pl.BF16), + [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2], + ) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + cache_row = b * MAX_SEQ_CFG + pos + + cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) + cos_hi = pl.slice( + rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + ) + sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) + sin_hi = pl.slice( + rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] + ) + kv_normed_row = pl.slice(kv_normed_out, [1, KV_LORA_RANK_CFG], [b, 0]) + + pe_lo = pl.cast( + pl.slice(kv_a_out, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, KV_LORA_RANK_CFG]), + target_type=pl.FP32, + ) + pe_hi = pl.cast( + pl.slice( + kv_a_out, + [1, QK_ROPE_HEAD_DIM_CFG // 2], + [b, KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG // 2], + ), + target_type=pl.FP32, + ) + + pe_rot_lo = pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo)) + pe_rot_hi = pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi)) + + kv_cache = pl.assemble(kv_cache, kv_normed_row, [cache_row, 0]) + pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot_lo, target_type=pl.BF16), [cache_row, 0]) + pe_cache = pl.assemble( + pe_cache, + pl.cast(pe_rot_hi, target_type=pl.BF16), + [cache_row, QK_ROPE_HEAD_DIM_CFG // 2], + ) + + return q_pe_out return DeepSeekV32DecodeFrontScope1 @@ -242,6 +375,12 @@ def golden_decode_front_scope1(tensors): q_norm_weight = tensors["q_norm_weight"].float() wq_b = tensors["wq_b"].float() wkv_a = tensors["wkv_a"].float() + seq_lens = tensors["seq_lens"] + rope_cos = tensors["rope_cos"].float() + rope_sin = tensors["rope_sin"].float() + kv_norm_weight = tensors["kv_norm_weight"].float() + kv_cache = tensors["kv_cache"] + pe_cache = tensors["pe_cache"] # RMSNorm sq_sum = torch.sum(hidden_states * hidden_states, dim=1, keepdim=True) @@ -262,12 +401,42 @@ def golden_decode_front_scope1(tensors): # Write into output tensor slots tensors["qr_out"].copy_(qr) - tensors["q_proj_out"].copy_(q_proj) tensors["kv_a_out"].copy_(kv_a) + half = QK_ROPE_HEAD_DIM // 2 + q_proj_view = q_proj.float().view(q_proj.shape[0], NUM_HEADS, QK_HEAD_DIM) + for b in range(kv_a.shape[0]): + ctx_len = int(seq_lens[b].item()) + pos = ctx_len - 1 + cache_row = b * MAX_SEQ + pos + + kv_row = kv_a[b : b + 1, :KV_LORA_RANK].float() + kv_var = torch.mean(kv_row * kv_row, dim=-1, keepdim=True) + kv_normed = kv_row * torch.rsqrt(kv_var + EPS) * kv_norm_weight + kv_cache[cache_row : cache_row + 1].copy_(kv_normed.to(torch.bfloat16)) + + pe_lo = kv_a[b : b + 1, KV_LORA_RANK : KV_LORA_RANK + half].float() + pe_hi = kv_a[b : b + 1, KV_LORA_RANK + half : KV_LORA_RANK + 2 * half].float() + cos_lo = rope_cos[pos : pos + 1, :half] + cos_hi = rope_cos[pos : pos + 1, half:] + sin_lo = rope_sin[pos : pos + 1, :half] + sin_hi = rope_sin[pos : pos + 1, half:] + + q_pe = q_proj_view[b, :, QK_NOPE_HEAD_DIM:] + q_lo = q_pe[:, :half].clone() + q_hi = q_pe[:, half:].clone() + q_proj_view[b, :, QK_NOPE_HEAD_DIM : QK_NOPE_HEAD_DIM + half] = q_lo * cos_lo - q_hi * sin_lo + q_proj_view[b, :, QK_NOPE_HEAD_DIM + half :] = q_hi * cos_hi + q_lo * sin_hi + + pe_cache[cache_row : cache_row + 1, :half].copy_((pe_lo * cos_lo - pe_hi * sin_lo).to(torch.bfloat16)) + pe_cache[cache_row : cache_row + 1, half:].copy_((pe_hi * cos_hi + pe_lo * sin_hi).to(torch.bfloat16)) + tensors["q_nope_out"].copy_(q_proj_view[:, :, :QK_NOPE_HEAD_DIM].reshape(q_proj.shape[0], -1).to(torch.bfloat16)) + tensors["q_pe_out"].copy_(q_proj_view[:, :, QK_NOPE_HEAD_DIM:].reshape(q_proj.shape[0], -1).to(torch.bfloat16)) + def build_tensor_specs( batch: int = BATCH, + max_seq_len: int = MAX_SEQ, hidden_size: int = HIDDEN, num_heads: int = NUM_HEADS, q_lora_rank: int = Q_LORA_RANK, @@ -281,6 +450,8 @@ def build_tensor_specs( qk_head_dim = qk_nope_head_dim + qk_rope_head_dim kv_a_out = kv_lora_rank + qk_rope_head_dim + cache_rows = batch * max_seq_len + seq_lens_data = torch.randint(1, max_seq_len + 1, (batch,), dtype=torch.int32) def init_hidden_states(): return torch.rand(batch, hidden_size) - 0.5 @@ -300,6 +471,18 @@ def init_wq_b(): def init_wkv_a(): return (torch.rand(hidden_size, kv_a_out) - 0.5) / hidden_size ** 0.5 + def init_kv_norm_weight(): + return torch.rand(1, kv_lora_rank) - 0.5 + + def init_rope(): + return torch.rand(max_seq_len, qk_rope_head_dim) - 0.5 + + def init_cache_kv(): + return torch.zeros(cache_rows, kv_lora_rank) + + def init_cache_pe(): + return torch.zeros(cache_rows, qk_rope_head_dim) + return [ TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, init_value=init_hidden_states), TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, init_value=init_rms_weight), @@ -307,9 +490,28 @@ def init_wkv_a(): TensorSpec("q_norm_weight", [1, q_lora_rank], torch.float32, init_value=init_q_norm_weight), TensorSpec("wq_b", [q_lora_rank, num_heads * qk_head_dim], torch.bfloat16, init_value=init_wq_b), TensorSpec("wkv_a", [hidden_size, kv_a_out], torch.bfloat16, init_value=init_wkv_a), + TensorSpec("seq_lens", [batch], torch.int32, init_value=seq_lens_data), + TensorSpec("rope_cos", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope), + TensorSpec("rope_sin", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope), + TensorSpec("kv_norm_weight", [1, kv_lora_rank], torch.float32, init_value=init_kv_norm_weight), TensorSpec("qr_out", [batch, q_lora_rank], torch.bfloat16, is_output=True), - TensorSpec("q_proj_out", [batch, num_heads * qk_head_dim], torch.bfloat16, is_output=True), + TensorSpec("q_nope_out", [batch, num_heads * qk_nope_head_dim], torch.bfloat16, is_output=True), + TensorSpec("q_pe_out", [batch, num_heads * qk_rope_head_dim], torch.bfloat16, is_output=True), TensorSpec("kv_a_out", [batch, kv_a_out], torch.bfloat16, is_output=True), + TensorSpec( + "kv_cache", + [cache_rows, kv_lora_rank], + torch.bfloat16, + init_value=init_cache_kv, + is_output=True, + ), + TensorSpec( + "pe_cache", + [cache_rows, qk_rope_head_dim], + torch.bfloat16, + init_value=init_cache_pe, + is_output=True, + ), ]