scope1中qr_proj kv_proj 精度正确, q_proj精度错误#111
scope1中qr_proj kv_proj 精度正确, q_proj精度错误#111Inspiron-st wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds a new example script demonstrating a PyPTO program for computing DeepSeek V3.2 "decode Scope 1" projections, including RMSNorm normalization and Q/KV LoRA compression. The implementation performs sequential stages: variance computation over input chunks, Q projection via matmul accumulation, and KV projection compression, returning assembled FP32 outputs. Changes
Sequence Diagram(s)sequenceDiagram
participant Input as Input<br/>(hidden_states)
participant RMSNorm as Stage 1<br/>(RMSNorm)
participant QComp as Stage 2<br/>(Q Compression)
participant KVComp as Stage 3<br/>(KV Compression)
participant Output as Output<br/>(Projections)
Input->>RMSNorm: Pass hidden_states tiles
activate RMSNorm
RMSNorm->>RMSNorm: Accumulate squared chunks
RMSNorm->>RMSNorm: Compute variance + apply RMSNorm weight
RMSNorm-->>QComp: Normalized BF16 tile
deactivate RMSNorm
activate QComp
QComp->>QComp: Chunked matmul with wq_a<br/>(FP32 accumulation)
QComp->>QComp: Assemble qr_proj
QComp-->>KVComp: qr_proj
deactivate QComp
activate KVComp
KVComp->>KVComp: Chunked matmul with wkv_a<br/>(FP32 accumulation)
KVComp->>KVComp: Assemble kv_proj
KVComp-->>Output: kv_proj
deactivate KVComp
QComp-->>Output: qr_proj + wq_b → q_proj
Output->>Output: Wire output specs
Output-->>Input: Return (qr_proj, q_proj, kv_proj)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py`:
- Around line 89-102: The RMSNorm implementation is using multiplication by
variance instead of dividing by sqrt(variance); update the computation so
x_chunk is scaled by rsqrt(variance + EPS) then multiplied by the weight
(gamma). Concretely, replace the expression building normed in the loop
(currently using pl.row_expand_mul/pl.col_expand_mul with variance) to compute
inv_root = rsqrt(variance) or pl.rsqrt(pl.add(variance, EPS)) and then do normed
= pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_root), gamma); apply the same
change to the golden function that mirrors this logic so both implementations
use x_chunk * rsqrt(variance + EPS) * input_rms_weight.
- Around line 201-203: The three tensors qr_proj, q_proj, and kv_proj are
created with module-level constants (Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM,
KV_A_OUT) which ignores the passed-in params; update the allocation to derive
their shapes from the function parameters/tensors (e.g., use values from params
or the input tensor shapes available in decode_front_scope1 / compile_and_run)
so they reflect the current run configuration: replace hardcoded sizes with
params.q_lora_rank (or equivalent), params.num_heads * params.qk_head_dim (or
computing from qkv input shape), and params.kv_a_out (or infer from kv tensor
shape) respectively so the function works with non-default values.
- Around line 104-138: The code never computes q_proj and leaves wq_b unused;
add a "Stage 4: Q expansion" after the KV compression loop that computes q_proj
= qr_proj @ wq_b using the same tiled matmul pattern as prior stages: iterate
over output blocks (matching q_proj layout), within pl.incore() slice qr_proj
and wq_b tiles, use pl.matmul to start and pl.matmul_acc to accumulate across
hidden_blocks, then pl.assemble into q_proj; ensure you reference and update the
existing symbols qr_proj, wq_b, q_proj and use
pl.matmul/pl.matmul_acc/pl.assemble exactly as in the QR/KV stages.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 071004e3-a5e9-4b16-85e3-2375ef7f1793
📒 Files selected for processing (1)
examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py
| variance = pl.reshape( | ||
| pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), | ||
| [BATCH_TILE, 1], | ||
| ) | ||
|
|
||
| for kb in pl.range(hidden_blocks): | ||
| k0 = kb * K_CHUNK | ||
| x_chunk = pl.cast( | ||
| pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) | ||
| normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, variance), gamma) | ||
| normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) |
There was a problem hiding this comment.
RMSNorm formula is incorrect — multiplies by variance instead of dividing by sqrt(variance).
Standard RMSNorm computes: output = x * rsqrt(variance + eps) * weight
Current implementation (line 101):
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, variance), gamma)This multiplies x_chunk * variance * gamma, but should be x_chunk / sqrt(variance) * gamma.
Note: The golden function (line 214) has the same incorrect formula, so tests pass but both implementations are mathematically wrong.
🐛 Proposed fix for RMSNorm formula
variance = pl.reshape(
- pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS),
+ pl.rsqrt(pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS)),
[BATCH_TILE, 1],
)And correspondingly in the golden function (line 213-214):
- variance = sq_sum / hidden_size + EPS
- normed = (x_tile * variance * input_rms_weight.float()).bfloat16()
+ rsqrt_var = torch.rsqrt(sq_sum / hidden_size + EPS)
+ normed = (x_tile * rsqrt_var * input_rms_weight.float()).bfloat16()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 89 - 102, The RMSNorm implementation is using multiplication by variance
instead of dividing by sqrt(variance); update the computation so x_chunk is
scaled by rsqrt(variance + EPS) then multiplied by the weight (gamma).
Concretely, replace the expression building normed in the loop (currently using
pl.row_expand_mul/pl.col_expand_mul with variance) to compute inv_root =
rsqrt(variance) or pl.rsqrt(pl.add(variance, EPS)) and then do normed =
pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_root), gamma); apply the same
change to the golden function that mirrors this logic so both implementations
use x_chunk * rsqrt(variance + EPS) * input_rms_weight.
| # Stage 2: Q compression (matmul + matmul_acc in single incore). | ||
| for ob in pl.range(lora_blocks): | ||
| q0 = ob * LORA_CHUNK | ||
|
|
||
| with pl.incore(): | ||
| tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||
| tile_b = pl.slice(wq_a, [K_CHUNK, LORA_CHUNK], [0, q0]) | ||
| qr_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | ||
|
|
||
| for kb in pl.range(1, hidden_blocks): | ||
| k0 = kb * K_CHUNK | ||
| tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) | ||
| tile_b_i = pl.slice(wq_a, [K_CHUNK, LORA_CHUNK], [k0, q0]) | ||
| qr_acc = pl.matmul_acc(qr_acc, tile_a_i, tile_b_i) | ||
|
|
||
| qr_proj = pl.assemble(qr_proj, qr_acc, [b0, q0]) | ||
|
|
||
| # Stage 3: KV compression (matmul + matmul_acc in single incore). | ||
| for ob in pl.range(kv_out_blocks): | ||
| kv0 = ob * KV_OUT_CHUNK | ||
|
|
||
| with pl.incore(): | ||
| tile_a = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||
| tile_b = pl.slice(wkv_a, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) | ||
| kv_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | ||
|
|
||
| for kb in pl.range(1, hidden_blocks): | ||
| k0 = kb * K_CHUNK | ||
| tile_a_i = pl.slice(normed_tile, [BATCH_TILE, K_CHUNK], [0, k0]) | ||
| tile_b_i = pl.slice(wkv_a, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) | ||
| kv_acc = pl.matmul_acc(kv_acc, tile_a_i, tile_b_i) | ||
|
|
||
| kv_proj = pl.assemble(kv_proj, kv_acc, [b0, kv0]) | ||
|
|
||
| return qr_proj, q_proj, kv_proj |
There was a problem hiding this comment.
Critical: q_proj is never computed — wq_b is unused and Q expansion stage is missing.
This is the root cause of the precision error mentioned in the PR title. The function:
- Declares
wq_bas input (line 63) but never uses it - Returns
q_proj(line 138) without ever computing or assembling it
The golden function correctly computes q_proj = qr_proj @ wq_b (line 220), but the PyPTO program is missing this "Stage 4: Q expansion" entirely.
🐛 Proposed fix — add Stage 4: Q expansion after KV compression
kv_proj = pl.assemble(kv_proj, kv_acc, [b0, kv0])
+ # Stage 4: Q expansion (qr_proj @ wq_b -> q_proj).
+ for ob in pl.range(q_out_blocks):
+ q0 = ob * Q_OUT_CHUNK
+
+ with pl.incore():
+ tile_a = pl.slice(qr_proj, [BATCH_TILE, LORA_CHUNK], [b0, 0])
+ tile_b = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [0, q0])
+ q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
+
+ for lb in pl.range(1, lora_blocks):
+ l0 = lb * LORA_CHUNK
+ tile_a_i = pl.slice(qr_proj, [BATCH_TILE, LORA_CHUNK], [b0, l0])
+ tile_b_i = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [l0, q0])
+ q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)
+
+ q_proj = pl.assemble(q_proj, q_acc, [b0, q0])
+
return qr_proj, q_proj, kv_proj🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 104 - 138, The code never computes q_proj and leaves wq_b unused; add a
"Stage 4: Q expansion" after the KV compression loop that computes q_proj =
qr_proj @ wq_b using the same tiled matmul pattern as prior stages: iterate over
output blocks (matching q_proj layout), within pl.incore() slice qr_proj and
wq_b tiles, use pl.matmul to start and pl.matmul_acc to accumulate across
hidden_blocks, then pl.assemble into q_proj; ensure you reference and update the
existing symbols qr_proj, wq_b, q_proj and use
pl.matmul/pl.matmul_acc/pl.assemble exactly as in the QR/KV stages.
| qr_proj = torch.zeros(batch, Q_LORA_RANK, dtype=torch.float32) | ||
| q_proj = torch.zeros(batch, NUM_HEADS * QK_HEAD_DIM, dtype=torch.float32) | ||
| kv_proj = torch.zeros(batch, KV_A_OUT, dtype=torch.float32) |
There was a problem hiding this comment.
Golden function uses module-level constants instead of function parameters.
Lines 201-203 use hardcoded Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM, and KV_A_OUT instead of deriving from params or the tensor shapes. This will produce incorrect results if compile_and_run is called with non-default values.
🔧 Proposed fix — derive dimensions from params or tensors
- qr_proj = torch.zeros(batch, Q_LORA_RANK, dtype=torch.float32)
- q_proj = torch.zeros(batch, NUM_HEADS * QK_HEAD_DIM, dtype=torch.float32)
- kv_proj = torch.zeros(batch, KV_A_OUT, dtype=torch.float32)
+ q_lora_rank = wq_a.shape[1]
+ q_out_dim = wq_b.shape[1]
+ kv_a_out_dim = wkv_a.shape[1]
+
+ qr_proj = torch.zeros(batch, q_lora_rank, dtype=torch.float32)
+ q_proj = torch.zeros(batch, q_out_dim, dtype=torch.float32)
+ kv_proj = torch.zeros(batch, kv_a_out_dim, dtype=torch.float32)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| qr_proj = torch.zeros(batch, Q_LORA_RANK, dtype=torch.float32) | |
| q_proj = torch.zeros(batch, NUM_HEADS * QK_HEAD_DIM, dtype=torch.float32) | |
| kv_proj = torch.zeros(batch, KV_A_OUT, dtype=torch.float32) | |
| q_lora_rank = wq_a.shape[1] | |
| q_out_dim = wq_b.shape[1] | |
| kv_a_out_dim = wkv_a.shape[1] | |
| qr_proj = torch.zeros(batch, q_lora_rank, dtype=torch.float32) | |
| q_proj = torch.zeros(batch, q_out_dim, dtype=torch.float32) | |
| kv_proj = torch.zeros(batch, kv_a_out_dim, dtype=torch.float32) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 201 - 203, The three tensors qr_proj, q_proj, and kv_proj are created with
module-level constants (Q_LORA_RANK, NUM_HEADS * QK_HEAD_DIM, KV_A_OUT) which
ignores the passed-in params; update the allocation to derive their shapes from
the function parameters/tensors (e.g., use values from params or the input
tensor shapes available in decode_front_scope1 / compile_and_run) so they
reflect the current run configuration: replace hardcoded sizes with
params.q_lora_rank (or equivalent), params.num_heads * params.qk_head_dim (or
computing from qkv input shape), and params.kv_a_out (or infer from kv tensor
shape) respectively so the function works with non-default values.
There was a problem hiding this comment.
Code Review
This pull request introduces the implementation of DeepSeek V3.2 decode scope 1, covering input RMSNorm and Q/K/V projections. The review feedback identifies critical logic omissions: the Q expansion stage is entirely missing, meaning the q_proj output is never calculated, and the RMSNorm implementation in both the main program and the golden reference function lacks the necessary reciprocal square root (rsqrt) operation. Suggestions have been provided to add the missing computation stage and correct the normalization formula to ensure numerical accuracy.
|
|
||
| kv_proj = pl.assemble(kv_proj, kv_acc, [b0, kv0]) | ||
|
|
||
| return qr_proj, q_proj, kv_proj |
There was a problem hiding this comment.
在 deepseek_v3_2_decode_scope1 函数中,q_proj 虽然作为输出参数传入并返回,但代码中完全没有对其进行计算。这解释了为什么 PR 标题中提到 q_proj 精度错误(实际上是未更新,保持为初始值)。你需要添加 Q 扩展(Q expansion)阶段,使用 qr_proj 和 wq_b 计算 q_proj。建议在 Stage 3 之后添加 Stage 4。同时请注意,根据项目约定,'pypto.runtime' (pl) 应在函数内部延迟导入以减少开销。
| return qr_proj, q_proj, kv_proj | |
| # Stage 4: Q expansion (matmul + matmul_acc in single incore). | |
| for ob in pl.range(q_out_blocks): | |
| q0 = ob * Q_OUT_CHUNK | |
| with pl.incore(): | |
| tile_a = pl.cast(pl.slice(qr_proj, [BATCH_TILE, LORA_CHUNK], [b0, 0]), pl.BF16) | |
| tile_b = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [0, q0]) | |
| q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | |
| for lb in pl.range(1, lora_blocks): | |
| l0 = lb * LORA_CHUNK | |
| tile_a_i = pl.cast(pl.slice(qr_proj, [BATCH_TILE, LORA_CHUNK], [b0, l0]), pl.BF16) | |
| tile_b_i = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [l0, q0]) | |
| q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) | |
| q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) | |
| return qr_proj, q_proj, kv_proj |
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
| variance = pl.reshape( | ||
| pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), | ||
| [BATCH_TILE, 1], | ||
| ) |
There was a problem hiding this comment.
RMSNorm 的实现逻辑缺少了平方根倒数(rsqrt)步骤。标准的 RMSNorm 公式是 variance 时使用 pl.rsqrt。同时请注意,根据项目约定,'pypto.runtime' (pl) 应在函数内部延迟导入以减少开销。
| variance = pl.reshape( | |
| pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), | |
| [BATCH_TILE, 1], | |
| ) | |
| variance = pl.reshape( | |
| pl.rsqrt(pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS)), | |
| [BATCH_TILE, 1], | |
| ) |
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
| for k0 in range(0, hidden_size, K_CHUNK): | ||
| x_chunk = x_tile[:, k0:k0 + K_CHUNK] | ||
| sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) | ||
| variance = sq_sum / hidden_size + EPS |
There was a problem hiding this comment.
Golden 参考函数中的 RMSNorm 逻辑也需要同步修正,添加 torch.rsqrt 以符合标准定义。同时请注意,根据项目约定,'torch' 应在函数内部延迟导入以减少开销。
| variance = sq_sum / hidden_size + EPS | |
| variance = torch.rsqrt(sq_sum / hidden_size + EPS) |
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
No description provided.