Skip to content

Fix DeepSeek scope1 cache prep outputs#134

Merged
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
high-cloud:fix/deepseek-scope1-cache-prep
Apr 21, 2026
Merged

Fix DeepSeek scope1 cache prep outputs#134
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
high-cloud:fix/deepseek-scope1-cache-prep

Conversation

@high-cloud
Copy link
Copy Markdown
Contributor

Summary

  • Add DeepSeek decode front cache preparation to split q_nope/q_pe outputs and rotate q_pe/k_pe.
  • Align kv_norm with ds32 RMSNorm semantics before writing kv_cache.
  • Update comments to describe projection and cache preparation stages accurately.

Related Issues

None

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 20, 2026

📝 Walkthrough

Walkthrough

The DeepSeek V3.2 decode front-scope program is updated with new inputs (sequence lengths, RoPE parameters, KV normalization weight) and restructured outputs. The control flow now performs RMS normalization on KV latents, splits Q heads into no-PE and RoPE-applied components, and writes normalized KV and rotated PE to token-indexed caches.

Changes

Cohort / File(s) Summary
DeepSeek V3.2 Decode Front Scope
examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py
Added inputs for sequence lengths, RoPE parameters, and KV normalization weight; restructured outputs by splitting Q projection into q_nope_out and q_pe_out, and adding kv_cache and pe_cache outputs; modified core logic to perform KV RMS normalization, apply RoPE to Q's PE component, and write cached tensors indexed by token position; updated golden reference implementation and tensor specifications; added max_seq_len parameter to build function.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Poem

🐰 Q's now split in halves so neat,
RoPE and norm in sync compete,
Cache hops forward, seq\_lens align,
This decode scope is quite divine!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: fixing cache prep outputs by splitting q_nope/q_pe and rotating components for the DeepSeek scope1 module.
Description check ✅ Passed The description clearly outlines the three main objectives of the PR: adding cache preparation, aligning normalization semantics, and updating comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request expands the DeepSeek V3.2-EXP decode front path by implementing Q/K RoPE application, KV normalization, and cache updates. Review feedback identifies opportunities for performance optimization, including the removal of redundant writes to a temporary tensor and the fusion of two parallel loops to reduce overhead from repeated tensor reads and slicing.

Comment on lines +307 to +316
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pl.assemble calls updating q_proj_out with rotated values are redundant. q_proj_out is a local temporary tensor created at line 125 and is not used after this loop (the function returns q_pe_out). Removing these unnecessary Global Memory writes will improve performance.

                        q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
                        q_pe_out = pl.assemble(
                            q_pe_out,
                            pl.cast(q_rot_hi, target_type=pl.BF16),
                            [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
                        )

Comment on lines +268 to +363
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for b in pl.parallel(0, BATCH_CFG, 1, chunk=4):
ctx_len = pl.tensor.read(seq_lens, [b])
pos = ctx_len - 1
cache_row = b * MAX_SEQ_CFG + pos

cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
cos_hi = pl.slice(
rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
sin_hi = pl.slice(
rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)

for h in pl.range(NUM_HEADS_CFG):
q_col = h * QK_HEAD_DIM_CFG
q_nope_col = h * QK_NOPE_HEAD_DIM_CFG
q_pe_col = h * QK_ROPE_HEAD_DIM_CFG
q_nope = pl.slice(q_proj_out, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col])
q_nope_out = pl.assemble(q_nope_out, q_nope, [b, q_nope_col])
q_lo = pl.cast(
pl.slice(
q_proj_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
),
target_type=pl.FP32,
)
q_hi = pl.cast(
pl.slice(
q_proj_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
),
target_type=pl.FP32,
)
q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
q_pe_out = pl.assemble(
q_pe_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
)

with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for b in pl.parallel(0, BATCH_CFG, 1, chunk=4):
ctx_len = pl.tensor.read(seq_lens, [b])
pos = ctx_len - 1
cache_row = b * MAX_SEQ_CFG + pos

cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
cos_hi = pl.slice(
rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
sin_hi = pl.slice(
rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
kv_normed_row = pl.slice(kv_normed_out, [1, KV_LORA_RANK_CFG], [b, 0])

pe_lo = pl.cast(
pl.slice(kv_a_out, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, KV_LORA_RANK_CFG]),
target_type=pl.FP32,
)
pe_hi = pl.cast(
pl.slice(
kv_a_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
),
target_type=pl.FP32,
)

pe_rot_lo = pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo))
pe_rot_hi = pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi))

kv_cache = pl.assemble(kv_cache, kv_normed_row, [cache_row, 0])
pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot_lo, target_type=pl.BF16), [cache_row, 0])
pe_cache = pl.assemble(
pe_cache,
pl.cast(pe_rot_hi, target_type=pl.BF16),
[cache_row, QK_ROPE_HEAD_DIM_CFG // 2],
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The two parallel loops over BATCH_CFG (lines 269 and 325) can be fused into a single loop. This optimization avoids redundant reads of seq_lens and redundant slicing of rope_cos/rope_sin tensors, reducing overhead and improving cache locality.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py (1)

264-362: Consider merging the two per-batch parallel loops.

Both with pl.at(..., optimization=pl.chunked_loop_optimizer) blocks iterate pl.parallel(0, BATCH_CFG, 1, chunk=4) and independently re-read seq_lens[b], recompute pos/cache_row, and re-slice the same four rope_cos/rope_sin halves. Fusing them into a single loop would remove the duplicated reads and let the PE rotation share cos_lo/cos_hi/sin_lo/sin_hi with the Q-PE rotation. It would also let you drop the misplaced "Cache preparation: write RMS-normalized KV latent and rotated k_pe" portion of the comment at 264-267 (which currently describes the second loop but sits above the first).

Leaving this as optional in case the split is intentional for scheduling/tiling reasons — please verify whether the chunked-loop optimizer or downstream compile passes prefer them separate before applying.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 264 - 362, The two chunked-loop optimizer blocks both iterate pl.parallel
over BATCH_CFG and duplicate reading seq_lens/pos/cache_row and slicing
rope_cos/rope_sin; merge them into a single with
pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer) loop
that contains both the Q split/RoPE operations (q_nope/q_pe, q_rot_lo/q_rot_hi,
q_proj_out/q_pe_out assembly) and the KV cache preparation (kv_normed_row,
pe_lo/pe_hi, pe_rot_lo/pe_rot_hi, kv_cache/pe_cache assembly), reuse the
computed cos_lo/cos_hi/sin_lo/sin_hi and pos/cache_row to avoid redundant work,
and remove or relocate the misleading cache-preparation comment currently
sitting above the first block; keep the same variable names (q_proj_out,
q_pe_out, q_nope_out, kv_normed_out, kv_a_out, kv_cache, pe_cache, rope_cos,
rope_sin, seq_lens) so schedule semantics remain clear and run tests to ensure
chunked_loop_optimizer behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py`:
- Around line 307-316: The writes that assemble q_rot_lo/q_rot_hi back into the
local q_proj_out buffer (via q_proj_out = pl.assemble(...)) are dead because
q_proj_out is never read later; only q_pe_out is returned and used — remove the
two in-place RoPE assemble calls that write to q_proj_out (the ones using
q_rot_lo and q_rot_hi with QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG
offsets) and clean up the surrounding comment near the earlier "keeping the
internal full Q layout" notes to reflect that we now directly assemble q_pe_out
from q_rot_lo/q_rot_hi instead of maintaining a full Q layout in q_proj_out.
Ensure references to q_rot_lo, q_rot_hi, q_proj_out, q_pe_out,
QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG are updated or removed in comments
as appropriate.

---

Nitpick comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py`:
- Around line 264-362: The two chunked-loop optimizer blocks both iterate
pl.parallel over BATCH_CFG and duplicate reading seq_lens/pos/cache_row and
slicing rope_cos/rope_sin; merge them into a single with
pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer) loop
that contains both the Q split/RoPE operations (q_nope/q_pe, q_rot_lo/q_rot_hi,
q_proj_out/q_pe_out assembly) and the KV cache preparation (kv_normed_row,
pe_lo/pe_hi, pe_rot_lo/pe_rot_hi, kv_cache/pe_cache assembly), reuse the
computed cos_lo/cos_hi/sin_lo/sin_hi and pos/cache_row to avoid redundant work,
and remove or relocate the misleading cache-preparation comment currently
sitting above the first block; keep the same variable names (q_proj_out,
q_pe_out, q_nope_out, kv_normed_out, kv_a_out, kv_cache, pe_cache, rope_cos,
rope_sin, seq_lens) so schedule semantics remain clear and run tests to ensure
chunked_loop_optimizer behavior is unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 39726079-cc1a-47c0-bc9d-3068464bbbfc

📥 Commits

Reviewing files that changed from the base of the PR and between 29950cd and c9642e4.

📒 Files selected for processing (1)
  • examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py

Comment on lines +307 to +316
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Dead writes: rotated q_pe values stored back into local q_proj_out are never read.

q_proj_out is a local tensor (line 125) and is not consumed after this loop — return is q_pe_out, and the second parallel loop (lines 324-362) doesn't read from q_proj_out. The in-place RoPE write-back into the PE slice of q_proj_out is therefore wasted work inside a tight per-head loop. The comment at 264-265 about "keeping the internal full Q layout available for the in-place RoPE writes" no longer reflects the actual data flow — q_pe_out is assembled directly from q_rot_lo/q_rot_hi two lines below.

🧹 Proposed cleanup
                         q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
                         q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
-                        q_proj_out = pl.assemble(
-                            q_proj_out,
-                            pl.cast(q_rot_lo, target_type=pl.BF16),
-                            [b, q_col + QK_NOPE_HEAD_DIM_CFG],
-                        )
-                        q_proj_out = pl.assemble(
-                            q_proj_out,
-                            pl.cast(q_rot_hi, target_type=pl.BF16),
-                            [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
-                        )
                         q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
                         q_pe_out = pl.assemble(
                             q_pe_out,
                             pl.cast(q_rot_hi, target_type=pl.BF16),
                             [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
                         )

And refresh the comment block around 264-265 accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 307 - 316, The writes that assemble q_rot_lo/q_rot_hi back into the local
q_proj_out buffer (via q_proj_out = pl.assemble(...)) are dead because
q_proj_out is never read later; only q_pe_out is returned and used — remove the
two in-place RoPE assemble calls that write to q_proj_out (the ones using
q_rot_lo and q_rot_hi with QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG
offsets) and clean up the surrounding comment near the earlier "keeping the
internal full Q layout" notes to reflect that we now directly assemble q_pe_out
from q_rot_lo/q_rot_hi instead of maintaining a full Q layout in q_proj_out.
Ensure references to q_rot_lo, q_rot_hi, q_proj_out, q_pe_out,
QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG are updated or removed in comments
as appropriate.

@zhangqi-chen zhangqi-chen merged commit 5d60e4e into hw-native-sys:main Apr 21, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants