[Pallas] Fix pre-broadcasting transformation bug when non-broadcast dims exceed PRE_BROADCAST_SIZE#2223
Merged
Merged
Conversation
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
22ec35f to
5685e61
Compare
This was referenced May 4, 2026
Merged
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
5685e61 to
835be34
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
835be34 to
efaa2ba
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
norx1991
reviewed
May 5, 2026
|
|
||
| # Insert _pre_broadcast_tile where pre-broadcast outer nodes feed wider-dim ops | ||
| all_pre_broadcast_outer: set[str] = set(pre_broadcast_outer_nodes) | ||
| all_pre_broadcast_outer.update(node.name for node in reshaped) |
Contributor
There was a problem hiding this comment.
Do we cover indirect consumer? I found this test case with help of claude:
@helion.kernel(backend="pallas", static_shapes=True)
def outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
batch, m, k = a.size()
head_dim = hl.specialize(b.size(-1))
out = torch.empty([batch, m, head_dim], device=a.device, dtype=a.dtype)
for tile_b, tile_m in hl.tile([batch, m]):
running = hl.zeros([tile_b, tile_m], dtype=torch.float32)
acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
for tile_k in hl.tile(k):
chunk = a[tile_b, tile_m, tile_k]
running = running + torch.sum(chunk, -1)
acc = acc + running[:, :, None]
scale = torch.rsqrt(running[:, :, None] + 1.0)
out[tile_b, tile_m, :] = (acc * scale).to(out.dtype)
return out
Contributor
Author
There was a problem hiding this comment.
good catch! Update the PR to fix this and added a test for this case
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
34c03fd to
81f56aa
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
81f56aa to
fd4cbbd
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
fd4cbbd to
96e6165
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank
added a commit
that referenced
this pull request
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
norx1991
approved these changes
May 5, 2026
jansel
approved these changes
May 5, 2026
…ims exceed PRE_BROADCAST_SIZE The pre-broadcast optimization pass (added in #2103) appends a trailing `PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like `m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When `head_dim` equals 128, this works correctly because the pre-broadcast dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256), the outer graph produced after the inner loop has a shape mismatch. Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py` with `D=256` and `pallas_pre_broadcast=True`. Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph. Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites `subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast results exiting the inner loop, and updates their meta shapes to include the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to expand the 128-wide values to match wider-dim consumers (e.g. `acc` with `head_dim=256`). The inner graph handled this via Step 3 of `_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic. Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast` that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in the outer graph whose last dim exceeds 128, checks if any of their args are pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile` to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2223, branch: AmesingFlank/stack/41
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[Pallas] Fix pre-broadcasting transformation bug when non-broadcast dims exceed PRE_BROADCAST_SIZE
The pre-broadcast optimization pass (added in #2103) appends a trailing
PRE_BROADCAST_SIZE=128dimension to loop-carried scratch buffers likem_iandl_ito avoid costly implicit broadcasts on TPU. Whenhead_dimequals 128, this works correctly because the pre-broadcastdimension matches
head_dim. However, whenhead_dim > 128(e.g. 256),the outer graph produced after the inner loop has a shape mismatch.
Repro: Run the attention kernel from
dunfanlu_notes/tpu/attn/attn_helion.pywith
D=256andpallas_pre_broadcast=True.Error:
helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128]atacc = acc / l_i[:, :, None]in the outer graph.Root cause:
_rewrite_outer_subscripts_for_pre_broadcastcorrectly rewritessubscript[:, :, None]to identity slicing[:, :]for pre-broadcastresults exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted
_pre_broadcast_tilenodes toexpand the 128-wide values to match wider-dim consumers (e.g.
accwithhead_dim=256). The inner graph handled this via Step 3 of_annotate_pre_broadcast, but the outer graph lacked the equivalent logic.Fix: Add a tile-insertion pass to
_rewrite_outer_subscripts_for_pre_broadcastthat mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts
_pre_broadcast_tileto expand them (e.g.
jnp.tile(subscript_2, 2)for 128→256).Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com