-
Notifications
You must be signed in to change notification settings - Fork 22
Fix DeepSeek scope1 cache prep outputs #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,14 +9,19 @@ | |||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| DeepSeek V3.2-EXP single-layer decode FRONT part — Scope 1 only (batch=16). | ||||||||||||||||||||||
| DeepSeek V3.2-EXP single-layer decode front path (batch=16). | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Scope 1: input RMSNorm + Q/KV projection. | ||||||||||||||||||||||
| Projection stage: | ||||||||||||||||||||||
| - Compute RMSNorm of hidden_states | ||||||||||||||||||||||
| - Project to Q latent (qr) via wq_a | ||||||||||||||||||||||
| - Apply q_norm to Q latent, then project to Q heads (q_proj) via wq_b | ||||||||||||||||||||||
| - Apply q_norm to Q latent, then project to Q heads via wq_b | ||||||||||||||||||||||
| - Project to KV latent (kv_a) via wkv_a | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Decode cache preparation: | ||||||||||||||||||||||
| - Split Q heads into q_nope and q_pe outputs | ||||||||||||||||||||||
| - Apply RoPE to q_pe and k_pe | ||||||||||||||||||||||
| - Apply kv_norm to KV latent, then update KV/PE cache for the current token | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Aligned to official v3.2-exp MLA shapes: | ||||||||||||||||||||||
| - qk_nope_head_dim = 128 | ||||||||||||||||||||||
| - qk_rope_head_dim = 64 | ||||||||||||||||||||||
|
|
@@ -28,6 +33,7 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| BATCH = 16 | ||||||||||||||||||||||
| MAX_SEQ = 4096 | ||||||||||||||||||||||
| HIDDEN = 7168 | ||||||||||||||||||||||
| NUM_HEADS = 128 | ||||||||||||||||||||||
| Q_LORA_RANK = 1536 | ||||||||||||||||||||||
|
|
@@ -37,11 +43,12 @@ | |||||||||||||||||||||
| QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM | ||||||||||||||||||||||
| V_HEAD_DIM = 128 | ||||||||||||||||||||||
| KV_A_OUT = KV_LORA_RANK + QK_ROPE_HEAD_DIM | ||||||||||||||||||||||
| CACHE_ROWS = BATCH * MAX_SEQ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| EPS = 1e-6 | ||||||||||||||||||||||
| HIDDEN_INV = 1.0 / HIDDEN | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Tile sizes tuned for standalone scope-1 incore boundaries: | ||||||||||||||||||||||
| # Tile sizes tuned for the standalone front projection/cache-prep example: | ||||||||||||||||||||||
| # - PROJ_K = K-dimension chunk for projection matmuls (kept at 512). | ||||||||||||||||||||||
| # - LORA_CHUNK, KV_OUT_CHUNK = 64 so AIC Right buffer ≤ 65536 | ||||||||||||||||||||||
| # (512 * 64 * 2 = 65536). | ||||||||||||||||||||||
|
|
@@ -52,7 +59,7 @@ | |||||||||||||||||||||
| # static_assert in pto_tile.hpp. The original 3-scope pipeline used | ||||||||||||||||||||||
| # BATCH_TILE=4 because the combined scopes allowed a different split. | ||||||||||||||||||||||
| # - LOCAL_PAD_WIDTH removed; the pad tensor was a tuning hint for the | ||||||||||||||||||||||
| # combined scope1+2+3 pipeline and is not needed for scope1 alone. | ||||||||||||||||||||||
| # combined front pipeline and is not needed here. | ||||||||||||||||||||||
| RMSNORM_K = 512 | ||||||||||||||||||||||
| PROJ_K = 512 | ||||||||||||||||||||||
| Q_OUT_CHUNK = 64 | ||||||||||||||||||||||
|
|
@@ -63,6 +70,7 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| def build_deepseek_v3_2_decode_front_scope1_program( | ||||||||||||||||||||||
| batch: int = BATCH, | ||||||||||||||||||||||
| max_seq_len: int = MAX_SEQ, | ||||||||||||||||||||||
| hidden_size: int = HIDDEN, | ||||||||||||||||||||||
| num_heads: int = NUM_HEADS, | ||||||||||||||||||||||
| q_lora_rank: int = Q_LORA_RANK, | ||||||||||||||||||||||
|
|
@@ -72,6 +80,7 @@ def build_deepseek_v3_2_decode_front_scope1_program( | |||||||||||||||||||||
| v_head_dim: int = V_HEAD_DIM, | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| BATCH_CFG = batch | ||||||||||||||||||||||
| MAX_SEQ_CFG = max_seq_len | ||||||||||||||||||||||
| HIDDEN_CFG = hidden_size | ||||||||||||||||||||||
| NUM_HEADS_CFG = num_heads | ||||||||||||||||||||||
| Q_LORA_RANK_CFG = q_lora_rank | ||||||||||||||||||||||
|
|
@@ -81,6 +90,7 @@ def build_deepseek_v3_2_decode_front_scope1_program( | |||||||||||||||||||||
| QK_HEAD_DIM_CFG = qk_nope_head_dim + qk_rope_head_dim | ||||||||||||||||||||||
| V_HEAD_DIM_CFG = v_head_dim | ||||||||||||||||||||||
| KV_A_OUT_CFG = kv_lora_rank + qk_rope_head_dim | ||||||||||||||||||||||
| CACHE_ROWS_CFG = batch * max_seq_len | ||||||||||||||||||||||
| Q_LORA_INV_CFG = 1.0 / q_lora_rank | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| RMSNORM_BLOCKS = (HIDDEN_CFG + RMSNORM_K - 1) // RMSNORM_K | ||||||||||||||||||||||
|
|
@@ -100,12 +110,21 @@ def deepseek_v3_2_decode_front_scope1( | |||||||||||||||||||||
| q_norm_weight: pl.Tensor[[1, Q_LORA_RANK_CFG], pl.FP32], | ||||||||||||||||||||||
| wq_b: pl.Tensor[[Q_LORA_RANK_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16], | ||||||||||||||||||||||
| wkv_a: pl.Tensor[[HIDDEN_CFG, KV_A_OUT_CFG], pl.BF16], | ||||||||||||||||||||||
| seq_lens: pl.Tensor[[BATCH_CFG], pl.INT32], | ||||||||||||||||||||||
| rope_cos: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], | ||||||||||||||||||||||
| rope_sin: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32], | ||||||||||||||||||||||
| kv_norm_weight: pl.Tensor[[1, KV_LORA_RANK_CFG], pl.FP32], | ||||||||||||||||||||||
| # Output buffers | ||||||||||||||||||||||
| qr_out: pl.Tensor[[BATCH_CFG, Q_LORA_RANK_CFG], pl.BF16], | ||||||||||||||||||||||
| q_proj_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16], | ||||||||||||||||||||||
| q_nope_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_NOPE_HEAD_DIM_CFG], pl.BF16], | ||||||||||||||||||||||
| q_pe_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16], | ||||||||||||||||||||||
| kv_a_out: pl.Tensor[[BATCH_CFG, KV_A_OUT_CFG], pl.BF16], | ||||||||||||||||||||||
| ) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16]: | ||||||||||||||||||||||
| # Scope 1: input RMSNorm + Q/KV projection. | ||||||||||||||||||||||
| kv_cache: pl.Tensor[[CACHE_ROWS_CFG, KV_LORA_RANK_CFG], pl.BF16], | ||||||||||||||||||||||
| pe_cache: pl.Tensor[[CACHE_ROWS_CFG, QK_ROPE_HEAD_DIM_CFG], pl.BF16], | ||||||||||||||||||||||
| ) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16]: | ||||||||||||||||||||||
| q_proj_out = pl.create_tensor([BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Front projection: input RMSNorm + Q/KV projections. | ||||||||||||||||||||||
| for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): | ||||||||||||||||||||||
| normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16) | ||||||||||||||||||||||
| qr_fp32_tile = pl.create_tensor([BATCH_TILE, Q_LORA_RANK_CFG], dtype=pl.FP32) | ||||||||||||||||||||||
|
|
@@ -228,7 +247,121 @@ def deepseek_v3_2_decode_front_scope1( | |||||||||||||||||||||
| ) | ||||||||||||||||||||||
| kv_a_out = pl.assemble(kv_a_out, kv_chunk, [b0, kv0]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return q_proj_out | ||||||||||||||||||||||
| kv_normed_out = pl.create_tensor([BATCH_CFG, KV_LORA_RANK_CFG], dtype=pl.BF16) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| with pl.at(level=pl.Level.CORE_GROUP): | ||||||||||||||||||||||
| kv_rows = pl.cast(pl.slice(kv_a_out, [BATCH_CFG, KV_LORA_RANK_CFG], [0, 0]), target_type=pl.FP32) | ||||||||||||||||||||||
| kv_partial_sq = pl.reshape(pl.row_sum(pl.mul(kv_rows, kv_rows)), [1, BATCH_CFG]) | ||||||||||||||||||||||
| kv_variance = pl.reshape( | ||||||||||||||||||||||
| pl.add(pl.mul(kv_partial_sq, 1.0 / KV_LORA_RANK_CFG), EPS), | ||||||||||||||||||||||
| [BATCH_CFG, 1], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| kv_inv_rms = pl.recip(pl.sqrt(kv_variance)) | ||||||||||||||||||||||
| kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0]) | ||||||||||||||||||||||
| kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_rows, kv_inv_rms), kv_gamma) | ||||||||||||||||||||||
| kv_normed_out = pl.assemble(kv_normed_out, pl.cast(kv_normed, target_type=pl.BF16), [0, 0]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Q split + q_rope: produce q_nope/q_pe while keeping the internal | ||||||||||||||||||||||
| # full Q layout available for the in-place RoPE writes. | ||||||||||||||||||||||
| # Cache preparation: write RMS-normalized KV latent and rotated | ||||||||||||||||||||||
| # k_pe for the current decode token. | ||||||||||||||||||||||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): | ||||||||||||||||||||||
| for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): | ||||||||||||||||||||||
| ctx_len = pl.tensor.read(seq_lens, [b]) | ||||||||||||||||||||||
| pos = ctx_len - 1 | ||||||||||||||||||||||
| cache_row = b * MAX_SEQ_CFG + pos | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) | ||||||||||||||||||||||
| cos_hi = pl.slice( | ||||||||||||||||||||||
| rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) | ||||||||||||||||||||||
| sin_hi = pl.slice( | ||||||||||||||||||||||
| rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for h in pl.range(NUM_HEADS_CFG): | ||||||||||||||||||||||
| q_col = h * QK_HEAD_DIM_CFG | ||||||||||||||||||||||
| q_nope_col = h * QK_NOPE_HEAD_DIM_CFG | ||||||||||||||||||||||
| q_pe_col = h * QK_ROPE_HEAD_DIM_CFG | ||||||||||||||||||||||
| q_nope = pl.slice(q_proj_out, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col]) | ||||||||||||||||||||||
| q_nope_out = pl.assemble(q_nope_out, q_nope, [b, q_nope_col]) | ||||||||||||||||||||||
| q_lo = pl.cast( | ||||||||||||||||||||||
| pl.slice( | ||||||||||||||||||||||
| q_proj_out, | ||||||||||||||||||||||
| [1, QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| [b, q_col + QK_NOPE_HEAD_DIM_CFG], | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| target_type=pl.FP32, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| q_hi = pl.cast( | ||||||||||||||||||||||
| pl.slice( | ||||||||||||||||||||||
| q_proj_out, | ||||||||||||||||||||||
| [1, QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| target_type=pl.FP32, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)) | ||||||||||||||||||||||
| q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)) | ||||||||||||||||||||||
| q_proj_out = pl.assemble( | ||||||||||||||||||||||
| q_proj_out, | ||||||||||||||||||||||
| pl.cast(q_rot_lo, target_type=pl.BF16), | ||||||||||||||||||||||
| [b, q_col + QK_NOPE_HEAD_DIM_CFG], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| q_proj_out = pl.assemble( | ||||||||||||||||||||||
| q_proj_out, | ||||||||||||||||||||||
| pl.cast(q_rot_hi, target_type=pl.BF16), | ||||||||||||||||||||||
| [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+307
to
+316
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dead writes: rotated q_pe values stored back into local
🧹 Proposed cleanup q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
- q_proj_out = pl.assemble(
- q_proj_out,
- pl.cast(q_rot_lo, target_type=pl.BF16),
- [b, q_col + QK_NOPE_HEAD_DIM_CFG],
- )
- q_proj_out = pl.assemble(
- q_proj_out,
- pl.cast(q_rot_hi, target_type=pl.BF16),
- [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
- )
q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
q_pe_out = pl.assemble(
q_pe_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
)And refresh the comment block around 264-265 accordingly. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
| q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col]) | ||||||||||||||||||||||
| q_pe_out = pl.assemble( | ||||||||||||||||||||||
| q_pe_out, | ||||||||||||||||||||||
| pl.cast(q_rot_hi, target_type=pl.BF16), | ||||||||||||||||||||||
| [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): | ||||||||||||||||||||||
| for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): | ||||||||||||||||||||||
| ctx_len = pl.tensor.read(seq_lens, [b]) | ||||||||||||||||||||||
| pos = ctx_len - 1 | ||||||||||||||||||||||
| cache_row = b * MAX_SEQ_CFG + pos | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) | ||||||||||||||||||||||
| cos_hi = pl.slice( | ||||||||||||||||||||||
| rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0]) | ||||||||||||||||||||||
| sin_hi = pl.slice( | ||||||||||||||||||||||
| rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| kv_normed_row = pl.slice(kv_normed_out, [1, KV_LORA_RANK_CFG], [b, 0]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pe_lo = pl.cast( | ||||||||||||||||||||||
| pl.slice(kv_a_out, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, KV_LORA_RANK_CFG]), | ||||||||||||||||||||||
| target_type=pl.FP32, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| pe_hi = pl.cast( | ||||||||||||||||||||||
| pl.slice( | ||||||||||||||||||||||
| kv_a_out, | ||||||||||||||||||||||
| [1, QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| [b, KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| target_type=pl.FP32, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pe_rot_lo = pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo)) | ||||||||||||||||||||||
| pe_rot_hi = pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| kv_cache = pl.assemble(kv_cache, kv_normed_row, [cache_row, 0]) | ||||||||||||||||||||||
| pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot_lo, target_type=pl.BF16), [cache_row, 0]) | ||||||||||||||||||||||
| pe_cache = pl.assemble( | ||||||||||||||||||||||
| pe_cache, | ||||||||||||||||||||||
| pl.cast(pe_rot_hi, target_type=pl.BF16), | ||||||||||||||||||||||
| [cache_row, QK_ROPE_HEAD_DIM_CFG // 2], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Comment on lines
+268
to
+363
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
| return q_pe_out | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return DeepSeekV32DecodeFrontScope1 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -242,6 +375,12 @@ def golden_decode_front_scope1(tensors): | |||||||||||||||||||||
| q_norm_weight = tensors["q_norm_weight"].float() | ||||||||||||||||||||||
| wq_b = tensors["wq_b"].float() | ||||||||||||||||||||||
| wkv_a = tensors["wkv_a"].float() | ||||||||||||||||||||||
| seq_lens = tensors["seq_lens"] | ||||||||||||||||||||||
| rope_cos = tensors["rope_cos"].float() | ||||||||||||||||||||||
| rope_sin = tensors["rope_sin"].float() | ||||||||||||||||||||||
| kv_norm_weight = tensors["kv_norm_weight"].float() | ||||||||||||||||||||||
| kv_cache = tensors["kv_cache"] | ||||||||||||||||||||||
| pe_cache = tensors["pe_cache"] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # RMSNorm | ||||||||||||||||||||||
| sq_sum = torch.sum(hidden_states * hidden_states, dim=1, keepdim=True) | ||||||||||||||||||||||
|
|
@@ -262,12 +401,42 @@ def golden_decode_front_scope1(tensors): | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Write into output tensor slots | ||||||||||||||||||||||
| tensors["qr_out"].copy_(qr) | ||||||||||||||||||||||
| tensors["q_proj_out"].copy_(q_proj) | ||||||||||||||||||||||
| tensors["kv_a_out"].copy_(kv_a) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| half = QK_ROPE_HEAD_DIM // 2 | ||||||||||||||||||||||
| q_proj_view = q_proj.float().view(q_proj.shape[0], NUM_HEADS, QK_HEAD_DIM) | ||||||||||||||||||||||
| for b in range(kv_a.shape[0]): | ||||||||||||||||||||||
| ctx_len = int(seq_lens[b].item()) | ||||||||||||||||||||||
| pos = ctx_len - 1 | ||||||||||||||||||||||
| cache_row = b * MAX_SEQ + pos | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| kv_row = kv_a[b : b + 1, :KV_LORA_RANK].float() | ||||||||||||||||||||||
| kv_var = torch.mean(kv_row * kv_row, dim=-1, keepdim=True) | ||||||||||||||||||||||
| kv_normed = kv_row * torch.rsqrt(kv_var + EPS) * kv_norm_weight | ||||||||||||||||||||||
| kv_cache[cache_row : cache_row + 1].copy_(kv_normed.to(torch.bfloat16)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pe_lo = kv_a[b : b + 1, KV_LORA_RANK : KV_LORA_RANK + half].float() | ||||||||||||||||||||||
| pe_hi = kv_a[b : b + 1, KV_LORA_RANK + half : KV_LORA_RANK + 2 * half].float() | ||||||||||||||||||||||
| cos_lo = rope_cos[pos : pos + 1, :half] | ||||||||||||||||||||||
| cos_hi = rope_cos[pos : pos + 1, half:] | ||||||||||||||||||||||
| sin_lo = rope_sin[pos : pos + 1, :half] | ||||||||||||||||||||||
| sin_hi = rope_sin[pos : pos + 1, half:] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| q_pe = q_proj_view[b, :, QK_NOPE_HEAD_DIM:] | ||||||||||||||||||||||
| q_lo = q_pe[:, :half].clone() | ||||||||||||||||||||||
| q_hi = q_pe[:, half:].clone() | ||||||||||||||||||||||
| q_proj_view[b, :, QK_NOPE_HEAD_DIM : QK_NOPE_HEAD_DIM + half] = q_lo * cos_lo - q_hi * sin_lo | ||||||||||||||||||||||
| q_proj_view[b, :, QK_NOPE_HEAD_DIM + half :] = q_hi * cos_hi + q_lo * sin_hi | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pe_cache[cache_row : cache_row + 1, :half].copy_((pe_lo * cos_lo - pe_hi * sin_lo).to(torch.bfloat16)) | ||||||||||||||||||||||
| pe_cache[cache_row : cache_row + 1, half:].copy_((pe_hi * cos_hi + pe_lo * sin_hi).to(torch.bfloat16)) | ||||||||||||||||||||||
| tensors["q_nope_out"].copy_(q_proj_view[:, :, :QK_NOPE_HEAD_DIM].reshape(q_proj.shape[0], -1).to(torch.bfloat16)) | ||||||||||||||||||||||
| tensors["q_pe_out"].copy_(q_proj_view[:, :, QK_NOPE_HEAD_DIM:].reshape(q_proj.shape[0], -1).to(torch.bfloat16)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def build_tensor_specs( | ||||||||||||||||||||||
| batch: int = BATCH, | ||||||||||||||||||||||
| max_seq_len: int = MAX_SEQ, | ||||||||||||||||||||||
| hidden_size: int = HIDDEN, | ||||||||||||||||||||||
| num_heads: int = NUM_HEADS, | ||||||||||||||||||||||
| q_lora_rank: int = Q_LORA_RANK, | ||||||||||||||||||||||
|
|
@@ -281,6 +450,8 @@ def build_tensor_specs( | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| qk_head_dim = qk_nope_head_dim + qk_rope_head_dim | ||||||||||||||||||||||
| kv_a_out = kv_lora_rank + qk_rope_head_dim | ||||||||||||||||||||||
| cache_rows = batch * max_seq_len | ||||||||||||||||||||||
| seq_lens_data = torch.randint(1, max_seq_len + 1, (batch,), dtype=torch.int32) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_hidden_states(): | ||||||||||||||||||||||
| return torch.rand(batch, hidden_size) - 0.5 | ||||||||||||||||||||||
|
|
@@ -300,16 +471,47 @@ def init_wq_b(): | |||||||||||||||||||||
| def init_wkv_a(): | ||||||||||||||||||||||
| return (torch.rand(hidden_size, kv_a_out) - 0.5) / hidden_size ** 0.5 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_kv_norm_weight(): | ||||||||||||||||||||||
| return torch.rand(1, kv_lora_rank) - 0.5 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_rope(): | ||||||||||||||||||||||
| return torch.rand(max_seq_len, qk_rope_head_dim) - 0.5 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_cache_kv(): | ||||||||||||||||||||||
| return torch.zeros(cache_rows, kv_lora_rank) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_cache_pe(): | ||||||||||||||||||||||
| return torch.zeros(cache_rows, qk_rope_head_dim) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return [ | ||||||||||||||||||||||
| TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, init_value=init_hidden_states), | ||||||||||||||||||||||
| TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, init_value=init_rms_weight), | ||||||||||||||||||||||
| TensorSpec("wq_a", [hidden_size, q_lora_rank], torch.bfloat16, init_value=init_wq_a), | ||||||||||||||||||||||
| TensorSpec("q_norm_weight", [1, q_lora_rank], torch.float32, init_value=init_q_norm_weight), | ||||||||||||||||||||||
| TensorSpec("wq_b", [q_lora_rank, num_heads * qk_head_dim], torch.bfloat16, init_value=init_wq_b), | ||||||||||||||||||||||
| TensorSpec("wkv_a", [hidden_size, kv_a_out], torch.bfloat16, init_value=init_wkv_a), | ||||||||||||||||||||||
| TensorSpec("seq_lens", [batch], torch.int32, init_value=seq_lens_data), | ||||||||||||||||||||||
| TensorSpec("rope_cos", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope), | ||||||||||||||||||||||
| TensorSpec("rope_sin", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope), | ||||||||||||||||||||||
| TensorSpec("kv_norm_weight", [1, kv_lora_rank], torch.float32, init_value=init_kv_norm_weight), | ||||||||||||||||||||||
| TensorSpec("qr_out", [batch, q_lora_rank], torch.bfloat16, is_output=True), | ||||||||||||||||||||||
| TensorSpec("q_proj_out", [batch, num_heads * qk_head_dim], torch.bfloat16, is_output=True), | ||||||||||||||||||||||
| TensorSpec("q_nope_out", [batch, num_heads * qk_nope_head_dim], torch.bfloat16, is_output=True), | ||||||||||||||||||||||
| TensorSpec("q_pe_out", [batch, num_heads * qk_rope_head_dim], torch.bfloat16, is_output=True), | ||||||||||||||||||||||
| TensorSpec("kv_a_out", [batch, kv_a_out], torch.bfloat16, is_output=True), | ||||||||||||||||||||||
| TensorSpec( | ||||||||||||||||||||||
| "kv_cache", | ||||||||||||||||||||||
| [cache_rows, kv_lora_rank], | ||||||||||||||||||||||
| torch.bfloat16, | ||||||||||||||||||||||
| init_value=init_cache_kv, | ||||||||||||||||||||||
| is_output=True, | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| TensorSpec( | ||||||||||||||||||||||
| "pe_cache", | ||||||||||||||||||||||
| [cache_rows, qk_rope_head_dim], | ||||||||||||||||||||||
| torch.bfloat16, | ||||||||||||||||||||||
| init_value=init_cache_pe, | ||||||||||||||||||||||
| is_output=True, | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
pl.assemblecalls updatingq_proj_outwith rotated values are redundant.q_proj_outis a local temporary tensor created at line 125 and is not used after this loop (the function returnsq_pe_out). Removing these unnecessary Global Memory writes will improve performance.