Skip to content

[Pallas] Fix pre-broadcasting transformation bug when non-broadcast dims exceed PRE_BROADCAST_SIZE#2223

Merged
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/41
May 6, 2026
Merged

[Pallas] Fix pre-broadcasting transformation bug when non-broadcast dims exceed PRE_BROADCAST_SIZE#2223
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/41

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 4, 2026

Stacked PRs:


[Pallas] Fix pre-broadcasting transformation bug when non-broadcast dims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
PRE_BROADCAST_SIZE=128 dimension to loop-carried scratch buffers like
m_i and l_i to avoid costly implicit broadcasts on TPU. When
head_dim equals 128, this works correctly because the pre-broadcast
dimension matches head_dim. However, when head_dim > 128 (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from dunfanlu_notes/tpu/attn/attn_helion.py
with D=256 and pallas_pre_broadcast=True.

Error: helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256] and [u0, u1, 128] at acc = acc / l_i[:, :, None] in the outer graph.

Root cause: _rewrite_outer_subscripts_for_pre_broadcast correctly rewrites
subscript[:, :, None] to identity slicing [:, :] for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted _pre_broadcast_tile nodes to
expand the 128-wide values to match wider-dim consumers (e.g. acc with
head_dim=256). The inner graph handled this via Step 3 of
_annotate_pre_broadcast, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to _rewrite_outer_subscripts_for_pre_broadcast
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts _pre_broadcast_tile
to expand them (e.g. jnp.tile(subscript_2, 2) for 128→256).

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from 22ec35f to 5685e61 Compare May 4, 2026 16:45
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 4, 2026
@AmesingFlank AmesingFlank requested review from jansel, norx1991 and oulgen May 4, 2026 16:48
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 4, 2026 17:55
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from 5685e61 to 835be34 Compare May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 4, 2026 17:55
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 17:55
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 18:54
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 4, 2026 18:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from 835be34 to efaa2ba Compare May 4, 2026 18:54
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 4, 2026 18:54
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 18:55
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 4, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41

# Insert _pre_broadcast_tile where pre-broadcast outer nodes feed wider-dim ops
all_pre_broadcast_outer: set[str] = set(pre_broadcast_outer_nodes)
all_pre_broadcast_outer.update(node.name for node in reshaped)
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we cover indirect consumer? I found this test case with help of claude:

@helion.kernel(backend="pallas", static_shapes=True)
  def outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
      batch, m, k = a.size()
      head_dim = hl.specialize(b.size(-1))
      out = torch.empty([batch, m, head_dim], device=a.device, dtype=a.dtype)
      for tile_b, tile_m in hl.tile([batch, m]):
          running = hl.zeros([tile_b, tile_m], dtype=torch.float32)
          acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
          for tile_k in hl.tile(k):
              chunk = a[tile_b, tile_m, tile_k]
              running = running + torch.sum(chunk, -1)
              acc = acc + running[:, :, None]
          scale = torch.rsqrt(running[:, :, None] + 1.0)
          out[tile_b, tile_m, :] = (acc * scale).to(out.dtype)
      return out

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! Update the PR to fix this and added a test for this case

AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank marked this pull request as draft May 5, 2026 02:47
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 5, 2026 02:47
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from 34c03fd to 81f56aa Compare May 5, 2026 02:47
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 5, 2026 02:47
@AmesingFlank AmesingFlank marked this pull request as ready for review May 5, 2026 02:47
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank marked this pull request as draft May 5, 2026 03:43
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 5, 2026 03:43
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from 81f56aa to fd4cbbd Compare May 5, 2026 03:43
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 5, 2026 03:43
@AmesingFlank AmesingFlank marked this pull request as ready for review May 5, 2026 03:43
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
@AmesingFlank AmesingFlank marked this pull request as draft May 5, 2026 03:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 5, 2026 03:46
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/41 branch from fd4cbbd to 96e6165 Compare May 5, 2026 03:46
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 5, 2026 03:46
@AmesingFlank AmesingFlank marked this pull request as ready for review May 5, 2026 03:46
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
AmesingFlank added a commit that referenced this pull request May 5, 2026
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
…ims exceed PRE_BROADCAST_SIZE

The pre-broadcast optimization pass (added in #2103) appends a trailing
`PRE_BROADCAST_SIZE=128` dimension to loop-carried scratch buffers like
`m_i` and `l_i` to avoid costly implicit broadcasts on TPU. When
`head_dim` equals 128, this works correctly because the pre-broadcast
dimension matches `head_dim`. However, when `head_dim > 128` (e.g. 256),
the outer graph produced after the inner loop has a shape mismatch.

Repro: Run the attention kernel from `dunfanlu_notes/tpu/attn/attn_helion.py`
with `D=256` and `pallas_pre_broadcast=True`.

Error: `helion.exc.ShapeMismatch: Shape mismatch between [u0, u1, 256]
and [u0, u1, 128]` at `acc = acc / l_i[:, :, None]` in the outer graph.

Root cause: `_rewrite_outer_subscripts_for_pre_broadcast` correctly rewrites
`subscript[:, :, None]` to identity slicing `[:, :]` for pre-broadcast
results exiting the inner loop, and updates their meta shapes to include
the trailing 128 dim. But it never inserted `_pre_broadcast_tile` nodes to
expand the 128-wide values to match wider-dim consumers (e.g. `acc` with
`head_dim=256`). The inner graph handled this via Step 3 of
`_annotate_pre_broadcast`, but the outer graph lacked the equivalent logic.

Fix: Add a tile-insertion pass to `_rewrite_outer_subscripts_for_pre_broadcast`
that mirrors the inner graph's Step 3. It scans non-pre-broadcast nodes in
the outer graph whose last dim exceeds 128, checks if any of their args are
pre-broadcast nodes with trailing dim 128, and inserts `_pre_broadcast_tile`
to expand them (e.g. `jnp.tile(subscript_2, 2)` for 128→256).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2223, branch: AmesingFlank/stack/41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants