Skip to content

Add fused QK RoPE Qwen3 decode example#112

Open
high-cloud wants to merge 1 commit intohw-native-sys:mainfrom
high-cloud:feat/qwen3-fused-qkrope-decode
Open

Add fused QK RoPE Qwen3 decode example#112
high-cloud wants to merge 1 commit intohw-native-sys:mainfrom
high-cloud:feat/qwen3-fused-qkrope-decode

Conversation

@high-cloud
Copy link
Copy Markdown
Contributor

Summary

  • Add qwen3_32b_decode_fused_qkrope.py as a Qwen3-32B decode variant with Q projection fused with Q RoPE and Q padding.
  • Compared with qwen3_32b_decode.py, move Q RoPE, K RoPE/cache writes, and V cache writes into Scope 1 projection tiling so Scope 2 starts directly from grouped-query attention.
  • Compared with qwen3_32b_decode.py, avoid program-scope q_proj, k_proj, and v_proj tensors before RoPE, and tune Scope 1 chunks from 512/64 to 256/128 for hidden/KV projection tiling.
  • Validation note: python examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py --platform a5sim compiles and runs, but golden comparison currently fails for out with 83472/131072 mismatched elements at rtol=0.003, atol=0.003.

Related Issues

None

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 15, 2026

📝 Walkthrough

Walkthrough

A single file adding a complete Qwen3-32B single-layer decode implementation using PyPTO, featuring RMSNorm, fused Q/RoPE projection, grouped-query attention with online softmax stabilization, KV caching, and SwiGLU MLP components with PyTorch reference and execution utilities.

Changes

Cohort / File(s) Summary
Qwen3-32B Decode Example
examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py
Comprehensive single-layer decode implementation with three phases: (1) RMSNorm + fused Q projection with RoPE and K/V cache updates, (2) grouped-query attention with FP32 tile QK scoring, max-stabilized masked softmax, SV matmul, and online softmax accumulation, (3) output projection with residual, post-RMSNorm, and SwiGLU MLP. Includes tensor spec builder, PyTorch golden reference, and compile/run utilities with platform-aware backend selection.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input Tensors
    participant Phase1 as Phase 1: Projection
    participant Phase2 as Phase 2: Attention
    participant Phase3 as Phase 3: Output
    participant Output as Output

    Input->>Phase1: hidden_states, params
    Phase1->>Phase1: RMSNorm(hidden_states)
    Phase1->>Phase1: Q projection + RoPE + padding
    Phase1->>Phase1: K/V projection + RoPE
    Phase1->>Phase1: Write K/V to cache
    Phase1->>Phase2: Q, K, V, cached_K, cached_V

    Phase2->>Phase2: Compute QK scores (FP32 tiles)
    Phase2->>Phase2: Apply scaled masked softmax<br/>(max-subtraction, BF16 exp)
    Phase2->>Phase2: SV matmul with online<br/>softmax accumulation
    Phase2->>Phase3: attn_out

    Phase3->>Phase3: Output projection
    Phase3->>Phase3: Residual addition
    Phase3->>Phase3: Post-RMSNorm
    Phase3->>Phase3: SwiGLU MLP<br/>(gate, up, SiLU, down)
    Phase3->>Phase3: Final residual writeout
    Phase3->>Output: out
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 A decode hops forth with RoPE so grand,
Projections fused across the softmax land,
With Q and K in FP32's embrace,
KV cache gleams—swift attention's grace!
Through MLP's gate, the token finds its way,
PyPTO's magic brightens decoding day! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add fused QK RoPE Qwen3 decode example' accurately describes the main change: adding a new Qwen3 decode example with fused Q projection and RoPE optimization.
Description check ✅ Passed The description is comprehensive and directly related to the changeset, detailing the new example file, the fused Q RoPE optimization, comparisons with existing code, and validation results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Qwen3-32B single-layer decode forward using the PyPTO language, structured into three scopes for projections, attention, and MLP. The review identifies a critical correctness issue where local tensors are not persistent across separate pl.at blocks, requiring block fusion. Further feedback suggests performance optimizations, including using pl.parallel to minimize kernel launches, batching DMA transfers for pl.assemble operations, and reducing redundant global memory reads of sequence lengths.

Comment on lines +155 to +187
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [0, q0])
q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)

with pl.at(level=pl.Level.CORE_GROUP):
gi = qh // Q_HEAD_BATCH
qi = qh - gi * Q_HEAD_BATCH
for bi in pl.range(BATCH_TILE):
b = b0 + bi
ctx_len_q = pl.tensor.read(seq_lens, [b])
pos_q = ctx_len_q - 1
cos_lo_q = pl.slice(rope_cos, [1, half_dim], [pos_q, 0])
cos_hi_q = pl.slice(rope_cos, [1, half_dim], [pos_q, half_dim])
sin_lo_q = pl.slice(rope_sin, [1, half_dim], [pos_q, 0])
sin_hi_q = pl.slice(rope_sin, [1, half_dim], [pos_q, half_dim])
q_lo = pl.slice(q_acc, [1, half_dim], [bi, 0])
q_hi = pl.slice(q_acc, [1, half_dim], [bi, half_dim])
rot_lo_bf16 = pl.cast(
pl.sub(pl.col_expand_mul(q_lo, cos_lo_q), pl.col_expand_mul(q_hi, sin_lo_q)),
target_type=pl.BF16,
)
rot_hi_bf16 = pl.cast(
pl.add(pl.col_expand_mul(q_hi, cos_hi_q), pl.col_expand_mul(q_lo, sin_hi_q)),
target_type=pl.BF16,
)
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The local tensor q_acc is defined in one pl.at block and accessed in another. In PyPTO, local tensors are not persistent across separate pl.at blocks (which map to separate kernel launches). This results in the second block reading garbage data from local memory, causing the reported validation failure. These blocks must be fused into a single pl.at block to ensure data persistence. Similar fusion is required for the K and V projection blocks.

                    with pl.at(level=pl.Level.CORE_GROUP):
                        tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
                        tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [0, q0])
                        q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
                        for kb in pl.range(1, scope1_hidden_blocks):
                            k0 = kb * SCOPE1_K_CHUNK
                            tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
                            tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [k0, q0])
                            q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)

                        gi = qh // Q_HEAD_BATCH
                        qi = qh - gi * Q_HEAD_BATCH
                        for bi in pl.range(BATCH_TILE):
                            b = b0 + bi
                            ctx_len_q = pl.tensor.read(seq_lens, [b])
                            pos_q = ctx_len_q - 1
                            cos_lo_q = pl.slice(rope_cos, [1, half_dim], [pos_q, 0])
                            cos_hi_q = pl.slice(rope_cos, [1, half_dim], [pos_q, half_dim])
                            sin_lo_q = pl.slice(rope_sin, [1, half_dim], [pos_q, 0])
                            sin_hi_q = pl.slice(rope_sin, [1, half_dim], [pos_q, half_dim])
                            q_lo = pl.slice(q_acc, [1, half_dim], [bi, 0])
                            q_hi = pl.slice(q_acc, [1, half_dim], [bi, half_dim])
                            rot_lo_bf16 = pl.cast(
                                pl.sub(pl.col_expand_mul(q_lo, cos_lo_q), pl.col_expand_mul(q_hi, sin_lo_q)),
                                target_type=pl.BF16,
                            )
                            rot_hi_bf16 = pl.cast(
                                pl.add(pl.col_expand_mul(q_hi, cos_hi_q), pl.col_expand_mul(q_lo, sin_hi_q)),
                                target_type=pl.BF16,
                            )
                            all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0])
                            all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim])

normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma)
normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0])

for qh in pl.range(num_heads):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Iterating over heads with pl.range(num_heads) at the program level causes 64 separate kernel launches per batch tile. Use pl.parallel over the head dimension within a single pl.at block to allow the compiler to generate a single optimized kernel. This applies to K and V projection loops as well (lines 189 and 225).

Comment on lines +186 to +187
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling pl.assemble for every head and batch item triggers a high volume of small DMA transfers. Accumulate results into a local tile and perform larger, batched assembles to global memory to improve performance. This also applies to cache updates at lines 222 and 244.

qi = qh - gi * Q_HEAD_BATCH
for bi in pl.range(BATCH_TILE):
b = b0 + bi
ctx_len_q = pl.tensor.read(seq_lens, [b])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Redundant global memory reads of seq_lens inside nested loops over heads and batch items will degrade performance. Read the sequence lengths for the entire batch tile into a local tensor once before entering the head loop.

Comment on lines +225 to +235
for ob in pl.range(kv_out_blocks):
kv0 = ob * KV_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_wv = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0])
v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32)
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_wv_i = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0])
v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The K and V projection loops are separate but iterate over the same blocks and re-slice the same input data. Fusing these loops would improve data reuse and reduce overhead.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py (1)

220-223: Hidden assumption: KV_OUT_CHUNK == head_dim.

The computation kvh = kv0 // head_dim where kv0 = ob * KV_OUT_CHUNK assumes KV_OUT_CHUNK == head_dim (both are 128). If these constants diverge in future refactoring, the head index calculation and cache writes will break silently.

Consider adding an assertion or comment to document this constraint.

 KV_OUT_CHUNK = 128
+
+# Note: KV_OUT_CHUNK must equal HEAD_DIM for correct cache indexing in K/V projection loops.
+assert KV_OUT_CHUNK == HEAD_DIM, "KV_OUT_CHUNK must equal HEAD_DIM"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` around lines 220 -
223, The code computing kvh using kv0 = ob * KV_OUT_CHUNK and kvh = kv0 //
head_dim implicitly assumes KV_OUT_CHUNK == head_dim; add a clear guard and/or
compute kvh without that assumption: either assert KV_OUT_CHUNK == head_dim at
the top of the block (or module) or derive head index from ob and head_dim
directly (e.g., compute kvh = ob * (KV_OUT_CHUNK // head_dim) or compute ob and
head index from the original index), and document the constraint; update the
logic around kv0, kvh and the k_cache pl.assemble calls to use the corrected kvh
so cache_row_k remains correct for k_cache writes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py`:
- Line 24: Replace the Unicode multiplication sign in the comment/expression
"attn_out × wo" with the ASCII letter "x": update the occurrence of "attn_out ×
wo" to "attn_out x wo" (or "attn_out x wo" in any comment/string) so the
variables attn_out and wo use a plain ASCII 'x' for multiplication to avoid
confusion and search issues.
- Line 350: Replace the Unicode multiplication sign in the comment "Stage 1:
Output projection: attn_out × wo, tiled by Q_OUT_CHUNK." with an ASCII "x" so it
reads "attn_out x wo"; also mirror the same change in the matching docstring
occurrence to keep formatting consistent (search for the exact phrase "attn_out
× wo" or "×" near Q_OUT_CHUNK and update to "x").

---

Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py`:
- Around line 220-223: The code computing kvh using kv0 = ob * KV_OUT_CHUNK and
kvh = kv0 // head_dim implicitly assumes KV_OUT_CHUNK == head_dim; add a clear
guard and/or compute kvh without that assumption: either assert KV_OUT_CHUNK ==
head_dim at the top of the block (or module) or derive head index from ob and
head_dim directly (e.g., compute kvh = ob * (KV_OUT_CHUNK // head_dim) or
compute ob and head index from the original index), and document the constraint;
update the logic around kv0, kvh and the k_cache pl.assemble calls to use the
corrected kvh so cache_row_k remains correct for k_cache writes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: a51dfdbb-3575-4b5e-b521-5985f42b95df

📥 Commits

Reviewing files that changed from the base of the PR and between 773aab2 and b17373e.

📒 Files selected for processing (1)
  • examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py

4. Online-softmax accumulation + final normalisation

Scope 3:
1. Output projection: attn_out × wo
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace ambiguous × with ASCII x.

The multiplication sign × (U+00D7) could cause confusion or search issues. Use the ASCII letter x instead for consistency.

-  1. Output projection: attn_out × wo
+  1. Output projection: attn_out x wo
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
1. Output projection: attn_out × wo
1. Output projection: attn_out x wo
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 24-24: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` at line 24, Replace
the Unicode multiplication sign in the comment/expression "attn_out × wo" with
the ASCII letter "x": update the occurrence of "attn_out × wo" to "attn_out x
wo" (or "attn_out x wo" in any comment/string) so the variables attn_out and wo
use a plain ASCII 'x' for multiplication to avoid confusion and search issues.

for b0 in pl.range(0, batch, BATCH_TILE):
resid1_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.FP32)

# Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace ambiguous × with ASCII x.

Same issue as in the docstring.

-                # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
+                # Stage 1: Output projection: attn_out x wo, tiled by Q_OUT_CHUNK.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
# Stage 1: Output projection: attn_out x wo, tiled by Q_OUT_CHUNK.
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 350-350: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` at line 350, Replace
the Unicode multiplication sign in the comment "Stage 1: Output projection:
attn_out × wo, tiled by Q_OUT_CHUNK." with an ASCII "x" so it reads "attn_out x
wo"; also mirror the same change in the matching docstring occurrence to keep
formatting consistent (search for the exact phrase "attn_out × wo" or "×" near
Q_OUT_CHUNK and update to "x").

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant