Add deepseek_v3_2_prefill_back_scope1 & 2 & 12#113
Add deepseek_v3_2_prefill_back_scope1 & 2 & 12#113xzhxzhxzh123 wants to merge 3 commits intohw-native-sys:mainfrom
Conversation
Add deepseek_v3_2_prefill_back_scope1&2
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds three new DeepSeek V3.2 PyPTO example programs implementing prefill-back scopes: scope1 (combine + matmul + residual), scope12 (scope1 projection + post-RMSNorm), and scope2 (post-RMSNorm), each with tensor specs, PyTorch golden references, and compile/run CLIs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces implementation scripts for DeepSeek V3.2 prefill back-end scopes, specifically handling projection, residual, and RMSNorm operations. The review feedback identifies critical race conditions due to shared tensors in parallel loops and performance inefficiencies from redundant memory operations and global memory round-trips. Additionally, the reviewer noted that partial token blocks are not correctly handled, requiring the addition of shape validation and padding logic.
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) | ||
|
|
||
| for b in pl.parallel(0, BATCH_CFG, 1): |
There was a problem hiding this comment.
The tensor combine_local is created outside the pl.parallel loop. In pypto, tensors created outside a parallel loop are shared across all parallel iterations. Since each iteration (representing a batch element) writes to and reads from this tensor independently, this will cause a race condition. Move the tensor creation inside the pl.parallel loop to ensure it is thread-local.
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) | |
| for b in pl.parallel(0, BATCH_CFG, 1): | |
| for b in pl.parallel(0, BATCH_CFG, 1): | |
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) |
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) | ||
|
|
||
| for b in pl.parallel(0, BATCH_CFG, 1): |
There was a problem hiding this comment.
The tensor combine_local is created outside the pl.parallel loop, which will lead to race conditions across parallel iterations. Move the tensor creation inside the pl.parallel loop.
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) | |
| for b in pl.parallel(0, BATCH_CFG, 1): | |
| for b in pl.parallel(0, BATCH_CFG, 1): | |
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) |
| with pl.incore(): | ||
| for cb in pl.range(COMBINE_BLOCKS): | ||
| c0 = cb * COMBINE_CHUNK | ||
| chunk = pl.reshape( | ||
| pl.slice(combine_buf, [1, 1, TOK_TILE, COMBINE_CHUNK], [node_id, b, p0, c0]), | ||
| [TOK_TILE, COMBINE_CHUNK], | ||
| ) | ||
| combine_local = pl.assemble(combine_local, chunk, [0, c0]) |
There was a problem hiding this comment.
The population of combine_local is currently inside the ob loop (over Q_OUT_BLOCKS). However, the content of combine_local only depends on the batch index b and the token block p0. Moving this logic outside the ob loop will avoid redundant memory reads and significantly improve performance. Additionally, valid_shape and fillpad should be used to correctly handle partial token blocks at the end of sequences and prevent data leakage from previous iterations.
| with pl.incore(): | ||
| for cb in pl.range(COMBINE_BLOCKS): | ||
| c0 = cb * COMBINE_CHUNK | ||
| chunk = pl.reshape( | ||
| pl.slice(combine_buf, [1, 1, TOK_TILE, COMBINE_CHUNK], [node_id, b, p0, c0]), | ||
| [TOK_TILE, COMBINE_CHUNK], | ||
| ) | ||
| combine_local = pl.assemble(combine_local, chunk, [0, c0]) |
| x_chunk = pl.reshape( | ||
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0]), | ||
| [TOK_TILE, K_CHUNK] | ||
| ) |
There was a problem hiding this comment.
The pl.slice operation here is missing the valid_shape parameter, and the resulting chunk is not padded. Since valid_tok is calculated at line 110, it should be used here to correctly handle partial token blocks, similar to the implementation in deepseek_v3_2_prefill_back_scope2.py.
| x_chunk = pl.reshape( | |
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0]), | |
| [TOK_TILE, K_CHUNK] | |
| ) | |
| x_chunk_raw = pl.reshape( | |
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0], | |
| valid_shape=[1, valid_tok, K_CHUNK]), | |
| [TOK_TILE, K_CHUNK] | |
| ) | |
| x_chunk = pl.fillpad(x_chunk_raw, pad_value=pl.PadValue.zero) |
| x_chunk = pl.reshape( | ||
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0]), | ||
| [TOK_TILE, K_CHUNK] | ||
| ) |
There was a problem hiding this comment.
Missing valid_shape and fillpad for partial tile handling. Use valid_tok to mask the slice and pad the resulting chunk.
| x_chunk = pl.reshape( | |
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0]), | |
| [TOK_TILE, K_CHUNK] | |
| ) | |
| x_chunk_raw = pl.reshape( | |
| pl.slice(resid1, [1, TOK_TILE, K_CHUNK], [b, p0, k0], | |
| valid_shape=[1, valid_tok, K_CHUNK]), | |
| [TOK_TILE, K_CHUNK] | |
| ) | |
| x_chunk = pl.fillpad(x_chunk_raw, pad_value=pl.PadValue.zero) |
| resid1 = pl.assemble(resid1, o_acc, [b, p0, o0]) | ||
| with pl.incore(): | ||
| proj = pl.slice(resid1, [1, TOK_TILE, Q_OUT_CHUNK], [b, p0, o0]) | ||
| resid = pl.cast( | ||
| pl.slice(hidden_states, [1, TOK_TILE, Q_OUT_CHUNK], [b, p0, o0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| resid1 = pl.assemble(resid1, pl.add(proj, resid), [b, p0, o0]) |
There was a problem hiding this comment.
This implementation performs an inefficient round-trip to global memory. It assembles the intermediate matmul result o_acc into the output tensor resid1, then immediately slices it back to perform the residual addition. It is much more efficient to perform the addition in-core before assembling the final result to the output tensor.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py (1)
199-201: Inconsistent CLI defaults compared to scope1/scope12.This file uses full production-size defaults (
BATCH=16,MAX_SEQ=4096,HIDDEN=7168), while scope1 and scope12 use smaller defaults (batch=2,max_seq=64,hidden=512) for quick testing. Consider aligning for consistency.♻️ Proposed fix
- parser.add_argument("--batch", type=int, default=BATCH) - parser.add_argument("--max-seq", type=int, default=MAX_SEQ) - parser.add_argument("--hidden", type=int, default=HIDDEN) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--max-seq", type=int, default=64) + parser.add_argument("--hidden", type=int, default=512)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py` around lines 199 - 201, The CLI defaults are set to production-size constants (BATCH, MAX_SEQ, HIDDEN) in parser.add_argument calls; change these to the smaller quick-test defaults used in scope1/scope12 (batch=2, max_seq=64, hidden=512) so parser.add_argument("--batch", "--max-seq", and "--hidden" use those values instead of BATCH/ MAX_SEQ/ HIDDEN; update any related help text or docstring references if present to reflect the new test-friendly defaults.examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.py (3)
110-110: Unused variablevalid_tok.
valid_tokis computed but never used in the post-normalization phase. This appears to be dead code—either remove it or use it for padding as done inscope2.♻️ Proposed fix
- valid_tok = pl.min(TOK_TILE, seq_len_b - p0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.py` at line 110, Remove or use the dead variable `valid_tok` computed as `valid_tok = pl.min(TOK_TILE, seq_len_b - p0)` in deepseek_v3_2_prefill_back_scope12.py: either delete this unused assignment or apply it to the post-normalization padding step (similar to the logic in `scope2`) so padding respects the remaining valid token count; locate the computation of `valid_tok` and mirror the `scope2` padding behavior by using `valid_tok` when constructing or trimming the padded tensor instead of leaving it unused.
107-123: Inconsistent padding behavior compared to scope2.The post-normalization phase in
scope12computes RMS statistics without zero-padding invalid tokens beyondseq_len_b, whilescope2explicitly usesfillpadwithvalid_shapefor partial tiles. This could cause numerical differences when the sequence length is not a multiple ofTOK_TILE.Both the program (here) and
golden_scope12omit padding, so they're internally consistent, but this differs fromscope2's approach. If this is intentional (e.g., scope12 assumes full tiles), consider adding a comment. Otherwise, align the padding behavior for consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.py` around lines 107 - 123, The RMS computation in scope12 omits zero-padding for partial TOK_TILE tiles, causing inconsistency with scope2; modify the block around tok_blocks/tok loop so x_chunk is created using pl.fillpad (or equivalent) with valid_shape=[valid_tok, K_CHUNK] before reshaping/row-squaring, ensuring any tokens beyond seq_len_b are zeroed before accumulating into sq_sum (symbols to change: the x_chunk creation that slices resid1, the pl.incore loop that builds sq_sum across HIDDEN_BLOCKS, and the subsequent inv_rms computation); alternatively, if the no-padding behavior is intentional, add a concise comment in that region explaining the assumption that tiles are full to make the divergence explicit.
224-225: Global state for_NODE_IDis acceptable for examples but fragile.Using a module-level global for
node_idworks but isn't thread-safe. For example code this is fine; for production, consider passing it via theparamsargument to the golden function.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.py` around lines 224 - 225, The module-level _NODE_ID assignment is fragile; remove the global reliance and instead accept node_id via the golden function's params argument: update the golden function to read node_id = params.get('node_id') (or a named parameter) and fall back to _NODE_ID only if params lacks it for compatibility, then use that local variable everywhere instead of the global _NODE_ID; this ensures thread-safety while keeping backward compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py`:
- Around line 137-138: The golden function is mutating the input view
resid1_tile_full (a view of tensors["resid1"]) when it zeros rows, corrupting
resid1; fix by making an explicit copy of the slice before modification (e.g.,
clone/contiguous+clone the result of resid1[b, p0:p0 + TOK_TILE, :]) and perform
the zeroing on that cloned tensor so tensors["resid1"] is not altered.
---
Nitpick comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.py`:
- Line 110: Remove or use the dead variable `valid_tok` computed as `valid_tok =
pl.min(TOK_TILE, seq_len_b - p0)` in deepseek_v3_2_prefill_back_scope12.py:
either delete this unused assignment or apply it to the post-normalization
padding step (similar to the logic in `scope2`) so padding respects the
remaining valid token count; locate the computation of `valid_tok` and mirror
the `scope2` padding behavior by using `valid_tok` when constructing or trimming
the padded tensor instead of leaving it unused.
- Around line 107-123: The RMS computation in scope12 omits zero-padding for
partial TOK_TILE tiles, causing inconsistency with scope2; modify the block
around tok_blocks/tok loop so x_chunk is created using pl.fillpad (or
equivalent) with valid_shape=[valid_tok, K_CHUNK] before reshaping/row-squaring,
ensuring any tokens beyond seq_len_b are zeroed before accumulating into sq_sum
(symbols to change: the x_chunk creation that slices resid1, the pl.incore loop
that builds sq_sum across HIDDEN_BLOCKS, and the subsequent inv_rms
computation); alternatively, if the no-padding behavior is intentional, add a
concise comment in that region explaining the assumption that tiles are full to
make the divergence explicit.
- Around line 224-225: The module-level _NODE_ID assignment is fragile; remove
the global reliance and instead accept node_id via the golden function's params
argument: update the golden function to read node_id = params.get('node_id') (or
a named parameter) and fall back to _NODE_ID only if params lacks it for
compatibility, then use that local variable everywhere instead of the global
_NODE_ID; this ensures thread-safety while keeping backward compatibility.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py`:
- Around line 199-201: The CLI defaults are set to production-size constants
(BATCH, MAX_SEQ, HIDDEN) in parser.add_argument calls; change these to the
smaller quick-test defaults used in scope1/scope12 (batch=2, max_seq=64,
hidden=512) so parser.add_argument("--batch", "--max-seq", and "--hidden" use
those values instead of BATCH/ MAX_SEQ/ HIDDEN; update any related help text or
docstring references if present to reflect the new test-friendly defaults.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 30faa6ab-8141-4121-9402-c0b86fdb350a
📒 Files selected for processing (3)
examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.pyexamples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope12.pyexamples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py
| resid1_tile_full = resid1[b, p0:p0 + TOK_TILE, :] | ||
| resid1_tile_full[valid_tok:, :] = 0.0 |
There was a problem hiding this comment.
Bug: Golden function mutates input tensor resid1.
resid1_tile_full is a view into tensors["resid1"]. Line 138 zeros rows in-place, corrupting the input tensor. This could cause issues if the tensor is inspected after the golden function runs (e.g., for debugging).
🐛 Proposed fix: Clone the tile before modifying
- resid1_tile_full = resid1[b, p0:p0 + TOK_TILE, :]
- resid1_tile_full[valid_tok:, :] = 0.0
+ resid1_tile_full = resid1[b, p0:p0 + TOK_TILE, :].clone()
+ resid1_tile_full[valid_tok:, :] = 0.0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| resid1_tile_full = resid1[b, p0:p0 + TOK_TILE, :] | |
| resid1_tile_full[valid_tok:, :] = 0.0 | |
| resid1_tile_full = resid1[b, p0:p0 + TOK_TILE, :].clone() | |
| resid1_tile_full[valid_tok:, :] = 0.0 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope2.py` around
lines 137 - 138, The golden function is mutating the input view resid1_tile_full
(a view of tensors["resid1"]) when it zeros rows, corrupting resid1; fix by
making an explicit copy of the slice before modification (e.g.,
clone/contiguous+clone the result of resid1[b, p0:p0 + TOK_TILE, :]) and perform
the zeroing on that cloned tensor so tensors["resid1"] is not altered.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py (1)
68-78: Avoid rebuildingcombine_localfor every output block.
combine_localdoes not depend onob, but Lines 71-77 re-read and reassemble the samecombine_buftile inside every output-block iteration. With the default shapes that is a large amount of redundant memory traffic. Build it once perp0tile, then reuse it across allobiterations.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py` around lines 68 - 78, The loop currently rebuilds combine_local inside the Q_OUT_BLOCKS (ob) loop even though it only depends on p0 and combine_buf; move the entire with pl.incore() / for cb in pl.range(COMBINE_BLOCKS) assembly of combine_local out of the for ob in pl.range(Q_OUT_BLOCKS) loop so combine_local is constructed once per p0 tile and then reused for each ob iteration; update any variable scope so combine_local is defined before the ob loop and remove the duplicated assembly code within the ob loop (references: combine_local, combine_buf, p0, COMBINE_BLOCKS, COMBINE_CHUNK, TOK_TILE, Q_OUT_BLOCKS, ob).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py`:
- Around line 30-31: The golden path is using module-global _NODE_ID, making
golden_scope1() and related validation non-reentrant; change golden_scope1 to
accept node_id via its params (or return a closure that captures node_id) and
update compile_and_run() to pass the correct node_id into golden_scope1 (instead
of mutating _NODE_ID), and similarly update other usages mentioned (lines
referenced) so all validation uses the passed-in node_id rather than the
module-global.
- Around line 63-95: The code assumes full tile/chunk sizes (TOK_TILE,
COMBINE_CHUNK, Q_OUT_CHUNK) for every block; compute the tail sizes for the last
token/combine/output blocks (e.g., remaining_tok = min(TOK_TILE, seq_len_b -
p0), remaining_combine = min(COMBINE_CHUNK, combine_buf_cols - c0), remaining_q
= min(Q_OUT_CHUNK, attn_out_size - o0)) and use those when
creating/slicing/assembling buffers and when matmul inputs are shaped; update
uses of combine_local, pl.reshape(pl.slice(combine_buf,...)), pl.slice(wo,...),
pl.assemble(resid1,...), and pl.slice(hidden_states,...) to pass the computed
dynamic shapes/offsets (or assemble only the valid subregion) so the final
iteration does not read/write past valid ranges while preserving full-tile
buffers for intermediate accumulation when needed.
---
Nitpick comments:
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py`:
- Around line 68-78: The loop currently rebuilds combine_local inside the
Q_OUT_BLOCKS (ob) loop even though it only depends on p0 and combine_buf; move
the entire with pl.incore() / for cb in pl.range(COMBINE_BLOCKS) assembly of
combine_local out of the for ob in pl.range(Q_OUT_BLOCKS) loop so combine_local
is constructed once per p0 tile and then reused for each ob iteration; update
any variable scope so combine_local is defined before the ob loop and remove the
duplicated assembly code within the ob loop (references: combine_local,
combine_buf, p0, COMBINE_BLOCKS, COMBINE_CHUNK, TOK_TILE, Q_OUT_BLOCKS, ob).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 08feab95-66b9-4ed0-9c6f-d35f568f34ad
📒 Files selected for processing (1)
examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py
| _NODE_ID = 0 # Module-level variable for golden function | ||
|
|
There was a problem hiding this comment.
Make the golden path stateless.
golden_scope1() reads node_id from the module-global _NODE_ID, and compile_and_run() mutates that global right before run(). That makes validation non-reentrant: concurrent runs, or two invocations with different node_ids in the same process, can validate against the wrong shard. Pass node_id through params or capture it in a closure instead of relying on shared module state.
Also applies to: 122-128, 166-167
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py` around
lines 30 - 31, The golden path is using module-global _NODE_ID, making
golden_scope1() and related validation non-reentrant; change golden_scope1 to
accept node_id via its params (or return a closure that captures node_id) and
update compile_and_run() to pass the correct node_id into golden_scope1 (instead
of mutating _NODE_ID), and similarly update other usages mentioned (lines
referenced) so all validation uses the passed-in node_id rather than the
module-global.
| tok_blocks = (seq_len_b + TOK_TILE - 1) // TOK_TILE | ||
| for p0_idx in pl.range(tok_blocks): | ||
| p0 = p0_idx * TOK_TILE | ||
| combine_local = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.BF16) | ||
|
|
||
| for ob in pl.range(Q_OUT_BLOCKS): | ||
| o0 = ob * Q_OUT_CHUNK | ||
| with pl.incore(): | ||
| for cb in pl.range(COMBINE_BLOCKS): | ||
| c0 = cb * COMBINE_CHUNK | ||
| chunk = pl.reshape( | ||
| pl.slice(combine_buf, [1, 1, TOK_TILE, COMBINE_CHUNK], [node_id, b, p0, c0]), | ||
| [TOK_TILE, COMBINE_CHUNK], | ||
| ) | ||
| combine_local = pl.assemble(combine_local, chunk, [0, c0]) | ||
| with pl.incore(): | ||
| c_tile_0 = pl.slice(combine_local, [TOK_TILE, COMBINE_CHUNK], [0, 0]) | ||
| w_tile_0 = pl.slice(wo, [COMBINE_CHUNK, Q_OUT_CHUNK], [0, o0]) | ||
| o_acc = pl.matmul(c_tile_0, w_tile_0, out_dtype=pl.FP32) | ||
| for cb in pl.range(1, COMBINE_BLOCKS): | ||
| c0 = cb * COMBINE_CHUNK | ||
| c_tile_i = pl.slice(combine_local, [TOK_TILE, COMBINE_CHUNK], [0, c0]) | ||
| w_tile_i = pl.slice(wo, [COMBINE_CHUNK, Q_OUT_CHUNK], [c0, o0]) | ||
| o_acc = pl.matmul_acc(o_acc, c_tile_i, w_tile_i) | ||
| resid1 = pl.assemble(resid1, o_acc, [b, p0, o0]) | ||
| with pl.incore(): | ||
| proj = pl.slice(resid1, [1, TOK_TILE, Q_OUT_CHUNK], [b, p0, o0]) | ||
| resid = pl.cast( | ||
| pl.slice(hidden_states, [1, TOK_TILE, Q_OUT_CHUNK], [b, p0, o0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| resid1 = pl.assemble(resid1, pl.add(proj, resid), [b, p0, o0]) | ||
|
|
There was a problem hiding this comment.
Handle tail tiles explicitly.
This loop counts token/output/combine blocks with ceil division, but every slice/assemble still uses full TOK_TILE, COMBINE_CHUNK, and Q_OUT_CHUNK. For non-aligned seq_len_b, max_seq_len, attn_out_size, or hidden_size, the last iteration will process past the valid region and can go out of bounds; even when it stays in bounds, it still writes padded tokens past seq_len_b. golden_scope1() has the same assumption, so the reference masks the bug instead of catching it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/deepseek_v3_2/deepseek_v3_2_prefill_back_scope1.py` around
lines 63 - 95, The code assumes full tile/chunk sizes (TOK_TILE, COMBINE_CHUNK,
Q_OUT_CHUNK) for every block; compute the tail sizes for the last
token/combine/output blocks (e.g., remaining_tok = min(TOK_TILE, seq_len_b -
p0), remaining_combine = min(COMBINE_CHUNK, combine_buf_cols - c0), remaining_q
= min(Q_OUT_CHUNK, attn_out_size - o0)) and use those when
creating/slicing/assembling buffers and when matmul inputs are shaped; update
uses of combine_local, pl.reshape(pl.slice(combine_buf,...)), pl.slice(wo,...),
pl.assemble(resid1,...), and pl.slice(hidden_states,...) to pass the computed
dynamic shapes/offsets (or assemble only the valid subregion) so the final
iteration does not read/write past valid ranges while preserving full-tile
buffers for intermediate accumulation when needed.
Add deepseek_v3_2_prefill_back_scope1 & 2 & 12
精度全通