Add fused QK RoPE Qwen3 decode example#112
Add fused QK RoPE Qwen3 decode example#112high-cloud wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
📝 WalkthroughWalkthroughA single file adding a complete Qwen3-32B single-layer decode implementation using PyPTO, featuring RMSNorm, fused Q/RoPE projection, grouped-query attention with online softmax stabilization, KV caching, and SwiGLU MLP components with PyTorch reference and execution utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant Input as Input Tensors
participant Phase1 as Phase 1: Projection
participant Phase2 as Phase 2: Attention
participant Phase3 as Phase 3: Output
participant Output as Output
Input->>Phase1: hidden_states, params
Phase1->>Phase1: RMSNorm(hidden_states)
Phase1->>Phase1: Q projection + RoPE + padding
Phase1->>Phase1: K/V projection + RoPE
Phase1->>Phase1: Write K/V to cache
Phase1->>Phase2: Q, K, V, cached_K, cached_V
Phase2->>Phase2: Compute QK scores (FP32 tiles)
Phase2->>Phase2: Apply scaled masked softmax<br/>(max-subtraction, BF16 exp)
Phase2->>Phase2: SV matmul with online<br/>softmax accumulation
Phase2->>Phase3: attn_out
Phase3->>Phase3: Output projection
Phase3->>Phase3: Residual addition
Phase3->>Phase3: Post-RMSNorm
Phase3->>Phase3: SwiGLU MLP<br/>(gate, up, SiLU, down)
Phase3->>Phase3: Final residual writeout
Phase3->>Output: out
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements the Qwen3-32B single-layer decode forward using the PyPTO language, structured into three scopes for projections, attention, and MLP. The review identifies a critical correctness issue where local tensors are not persistent across separate pl.at blocks, requiring block fusion. Further feedback suggests performance optimizations, including using pl.parallel to minimize kernel launches, batching DMA transfers for pl.assemble operations, and reducing redundant global memory reads of sequence lengths.
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) | ||
| tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [0, q0]) | ||
| q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | ||
| for kb in pl.range(1, scope1_hidden_blocks): | ||
| k0 = kb * SCOPE1_K_CHUNK | ||
| tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) | ||
| tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [k0, q0]) | ||
| q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| gi = qh // Q_HEAD_BATCH | ||
| qi = qh - gi * Q_HEAD_BATCH | ||
| for bi in pl.range(BATCH_TILE): | ||
| b = b0 + bi | ||
| ctx_len_q = pl.tensor.read(seq_lens, [b]) | ||
| pos_q = ctx_len_q - 1 | ||
| cos_lo_q = pl.slice(rope_cos, [1, half_dim], [pos_q, 0]) | ||
| cos_hi_q = pl.slice(rope_cos, [1, half_dim], [pos_q, half_dim]) | ||
| sin_lo_q = pl.slice(rope_sin, [1, half_dim], [pos_q, 0]) | ||
| sin_hi_q = pl.slice(rope_sin, [1, half_dim], [pos_q, half_dim]) | ||
| q_lo = pl.slice(q_acc, [1, half_dim], [bi, 0]) | ||
| q_hi = pl.slice(q_acc, [1, half_dim], [bi, half_dim]) | ||
| rot_lo_bf16 = pl.cast( | ||
| pl.sub(pl.col_expand_mul(q_lo, cos_lo_q), pl.col_expand_mul(q_hi, sin_lo_q)), | ||
| target_type=pl.BF16, | ||
| ) | ||
| rot_hi_bf16 = pl.cast( | ||
| pl.add(pl.col_expand_mul(q_hi, cos_hi_q), pl.col_expand_mul(q_lo, sin_hi_q)), | ||
| target_type=pl.BF16, | ||
| ) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0]) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim]) |
There was a problem hiding this comment.
The local tensor q_acc is defined in one pl.at block and accessed in another. In PyPTO, local tensors are not persistent across separate pl.at blocks (which map to separate kernel launches). This results in the second block reading garbage data from local memory, causing the reported validation failure. These blocks must be fused into a single pl.at block to ensure data persistence. Similar fusion is required for the K and V projection blocks.
with pl.at(level=pl.Level.CORE_GROUP):
tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0])
tile_b = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [0, q0])
q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32)
for kb in pl.range(1, scope1_hidden_blocks):
k0 = kb * SCOPE1_K_CHUNK
tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0])
tile_b_i = pl.slice(wq, [SCOPE1_K_CHUNK, head_dim], [k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i)
gi = qh // Q_HEAD_BATCH
qi = qh - gi * Q_HEAD_BATCH
for bi in pl.range(BATCH_TILE):
b = b0 + bi
ctx_len_q = pl.tensor.read(seq_lens, [b])
pos_q = ctx_len_q - 1
cos_lo_q = pl.slice(rope_cos, [1, half_dim], [pos_q, 0])
cos_hi_q = pl.slice(rope_cos, [1, half_dim], [pos_q, half_dim])
sin_lo_q = pl.slice(rope_sin, [1, half_dim], [pos_q, 0])
sin_hi_q = pl.slice(rope_sin, [1, half_dim], [pos_q, half_dim])
q_lo = pl.slice(q_acc, [1, half_dim], [bi, 0])
q_hi = pl.slice(q_acc, [1, half_dim], [bi, half_dim])
rot_lo_bf16 = pl.cast(
pl.sub(pl.col_expand_mul(q_lo, cos_lo_q), pl.col_expand_mul(q_hi, sin_lo_q)),
target_type=pl.BF16,
)
rot_hi_bf16 = pl.cast(
pl.add(pl.col_expand_mul(q_hi, cos_hi_q), pl.col_expand_mul(q_lo, sin_hi_q)),
target_type=pl.BF16,
)
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim])| normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) | ||
| normed_tile = pl.assemble(normed_tile, pl.cast(normed, target_type=pl.BF16), [0, k0]) | ||
|
|
||
| for qh in pl.range(num_heads): |
There was a problem hiding this comment.
Iterating over heads with pl.range(num_heads) at the program level causes 64 separate kernel launches per batch tile. Use pl.parallel over the head dimension within a single pl.at block to allow the compiler to generate a single optimized kernel. This applies to K and V projection loops as well (lines 189 and 225).
| all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, 0]) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD + qi, half_dim]) |
There was a problem hiding this comment.
| qi = qh - gi * Q_HEAD_BATCH | ||
| for bi in pl.range(BATCH_TILE): | ||
| b = b0 + bi | ||
| ctx_len_q = pl.tensor.read(seq_lens, [b]) |
| for ob in pl.range(kv_out_blocks): | ||
| kv0 = ob * KV_OUT_CHUNK | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) | ||
| tile_wv = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [0, kv0]) | ||
| v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) | ||
| for kb in pl.range(1, scope1_hidden_blocks): | ||
| k0 = kb * SCOPE1_K_CHUNK | ||
| tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) | ||
| tile_wv_i = pl.slice(wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) | ||
| v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py (1)
220-223: Hidden assumption:KV_OUT_CHUNK == head_dim.The computation
kvh = kv0 // head_dimwherekv0 = ob * KV_OUT_CHUNKassumesKV_OUT_CHUNK == head_dim(both are 128). If these constants diverge in future refactoring, the head index calculation and cache writes will break silently.Consider adding an assertion or comment to document this constraint.
KV_OUT_CHUNK = 128 + +# Note: KV_OUT_CHUNK must equal HEAD_DIM for correct cache indexing in K/V projection loops. +assert KV_OUT_CHUNK == HEAD_DIM, "KV_OUT_CHUNK must equal HEAD_DIM"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` around lines 220 - 223, The code computing kvh using kv0 = ob * KV_OUT_CHUNK and kvh = kv0 // head_dim implicitly assumes KV_OUT_CHUNK == head_dim; add a clear guard and/or compute kvh without that assumption: either assert KV_OUT_CHUNK == head_dim at the top of the block (or module) or derive head index from ob and head_dim directly (e.g., compute kvh = ob * (KV_OUT_CHUNK // head_dim) or compute ob and head index from the original index), and document the constraint; update the logic around kv0, kvh and the k_cache pl.assemble calls to use the corrected kvh so cache_row_k remains correct for k_cache writes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py`:
- Line 24: Replace the Unicode multiplication sign in the comment/expression
"attn_out × wo" with the ASCII letter "x": update the occurrence of "attn_out ×
wo" to "attn_out x wo" (or "attn_out x wo" in any comment/string) so the
variables attn_out and wo use a plain ASCII 'x' for multiplication to avoid
confusion and search issues.
- Line 350: Replace the Unicode multiplication sign in the comment "Stage 1:
Output projection: attn_out × wo, tiled by Q_OUT_CHUNK." with an ASCII "x" so it
reads "attn_out x wo"; also mirror the same change in the matching docstring
occurrence to keep formatting consistent (search for the exact phrase "attn_out
× wo" or "×" near Q_OUT_CHUNK and update to "x").
---
Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py`:
- Around line 220-223: The code computing kvh using kv0 = ob * KV_OUT_CHUNK and
kvh = kv0 // head_dim implicitly assumes KV_OUT_CHUNK == head_dim; add a clear
guard and/or compute kvh without that assumption: either assert KV_OUT_CHUNK ==
head_dim at the top of the block (or module) or derive head index from ob and
head_dim directly (e.g., compute kvh = ob * (KV_OUT_CHUNK // head_dim) or
compute ob and head index from the original index), and document the constraint;
update the logic around kv0, kvh and the k_cache pl.assemble calls to use the
corrected kvh so cache_row_k remains correct for k_cache writes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: a51dfdbb-3575-4b5e-b521-5985f42b95df
📒 Files selected for processing (1)
examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py
| 4. Online-softmax accumulation + final normalisation | ||
|
|
||
| Scope 3: | ||
| 1. Output projection: attn_out × wo |
There was a problem hiding this comment.
Replace ambiguous × with ASCII x.
The multiplication sign × (U+00D7) could cause confusion or search issues. Use the ASCII letter x instead for consistency.
- 1. Output projection: attn_out × wo
+ 1. Output projection: attn_out x wo📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| 1. Output projection: attn_out × wo | |
| 1. Output projection: attn_out x wo |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 24-24: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` at line 24, Replace
the Unicode multiplication sign in the comment/expression "attn_out × wo" with
the ASCII letter "x": update the occurrence of "attn_out × wo" to "attn_out x
wo" (or "attn_out x wo" in any comment/string) so the variables attn_out and wo
use a plain ASCII 'x' for multiplication to avoid confusion and search issues.
| for b0 in pl.range(0, batch, BATCH_TILE): | ||
| resid1_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.FP32) | ||
|
|
||
| # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. |
There was a problem hiding this comment.
Replace ambiguous × with ASCII x.
Same issue as in the docstring.
- # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK.
+ # Stage 1: Output projection: attn_out x wo, tiled by Q_OUT_CHUNK.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Stage 1: Output projection: attn_out × wo, tiled by Q_OUT_CHUNK. | |
| # Stage 1: Output projection: attn_out x wo, tiled by Q_OUT_CHUNK. |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 350-350: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py` at line 350, Replace
the Unicode multiplication sign in the comment "Stage 1: Output projection:
attn_out × wo, tiled by Q_OUT_CHUNK." with an ASCII "x" so it reads "attn_out x
wo"; also mirror the same change in the matching docstring occurrence to keep
formatting consistent (search for the exact phrase "attn_out × wo" or "×" near
Q_OUT_CHUNK and update to "x").
Summary
qwen3_32b_decode_fused_qkrope.pyas a Qwen3-32B decode variant with Q projection fused with Q RoPE and Q padding.qwen3_32b_decode.py, move Q RoPE, K RoPE/cache writes, and V cache writes into Scope 1 projection tiling so Scope 2 starts directly from grouped-query attention.qwen3_32b_decode.py, avoid program-scopeq_proj,k_proj, andv_projtensors before RoPE, and tune Scope 1 chunks from 512/64 to 256/128 for hidden/KV projection tiling.python examples/models/qwen3/qwen3_32b_decode_fused_qkrope.py --platform a5simcompiles and runs, but golden comparison currently fails foroutwith 83472/131072 mismatched elements atrtol=0.003,atol=0.003.Related Issues
None