Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,12 +1093,89 @@ def _rewrite_outer_subscripts_for_pre_broadcast(
)
reshaped.append(node)

# Insert _pre_broadcast_tile where pre-broadcast outer nodes feed wider-dim ops.
# First, propagate pre-broadcast status transitively through indirect consumers.
# After rewriting subscript[:, :, None] → subscript[:, :], downstream nodes
# (e.g. add, rsqrt) may still have stale meta shapes (u0, u1, 1) from trace
# time. We identify them by checking if any arg is pre-broadcast — if so,
# the node is also pre-broadcast (its real last dim is PRE_BROADCAST_SIZE).
all_pre_broadcast_outer: set[str] = set(pre_broadcast_outer_nodes)
all_pre_broadcast_outer.update(node.name for node in reshaped)
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we cover indirect consumer? I found this test case with help of claude:

@helion.kernel(backend="pallas", static_shapes=True)
  def outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
      batch, m, k = a.size()
      head_dim = hl.specialize(b.size(-1))
      out = torch.empty([batch, m, head_dim], device=a.device, dtype=a.dtype)
      for tile_b, tile_m in hl.tile([batch, m]):
          running = hl.zeros([tile_b, tile_m], dtype=torch.float32)
          acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
          for tile_k in hl.tile(k):
              chunk = a[tile_b, tile_m, tile_k]
              running = running + torch.sum(chunk, -1)
              acc = acc + running[:, :, None]
          scale = torch.rsqrt(running[:, :, None] + 1.0)
          out[tile_b, tile_m, :] = (acc * scale).to(out.dtype)
      return out

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! Update the PR to fix this and added a test for this case

for node in outer_graph.nodes:
if node.op != "call_function" or node.name in all_pre_broadcast_outer:
continue
node_val = node.meta.get("val", None)
if not isinstance(node_val, torch.Tensor) or len(node_val.shape) < 2:
continue
last_dim = node_val.shape[-1]
if isinstance(last_dim, torch.SymInt):
continue
last_dim_int = int(last_dim)
if last_dim_int > PRE_BROADCAST_SIZE:
continue
has_pre_broadcast_arg = False
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg.name in all_pre_broadcast_outer:
arg_val = arg.meta.get("val", None)
if isinstance(arg_val, torch.Tensor) and len(arg_val.shape) >= 2:
arg_last = arg_val.shape[-1]
if isinstance(arg_last, int) and arg_last == PRE_BROADCAST_SIZE:
has_pre_broadcast_arg = True
break
if has_pre_broadcast_arg:
new_shape = [*node_val.shape[:-1], PRE_BROADCAST_SIZE]
node.meta["val"] = node_val.new_empty(new_shape)
all_pre_broadcast_outer.add(node.name)
reshaped.append(node)

new_nodes: list[torch.fx.Node] = []
for node in list(outer_graph.nodes):
if node.op != "call_function" or node.name in all_pre_broadcast_outer:
continue
node_val = node.meta.get("val", None)
if not isinstance(node_val, torch.Tensor) or len(node_val.shape) < 2:
continue
last_dim = node_val.shape[-1]
last_dim_is_sym = isinstance(last_dim, torch.SymInt)
if not last_dim_is_sym and int(last_dim) <= PRE_BROADCAST_SIZE:
continue
args_list = list(node.args)
changed = False
for ai, arg in enumerate(args_list):
if not isinstance(arg, torch.fx.Node):
continue
if arg.name not in all_pre_broadcast_outer:
continue
arg_val = arg.meta.get("val", None)
if not isinstance(arg_val, torch.Tensor):
continue
if not (
isinstance(arg_val.shape[-1], int)
and arg_val.shape[-1] == PRE_BROADCAST_SIZE
):
continue
with outer_graph.inserting_before(node):
tiled = outer_graph.call_function(
_pre_broadcast_tile,
args=(arg, last_dim),
)
tiled.meta = {
**arg.meta,
"val": arg_val.new_empty([*arg_val.shape[:-1], last_dim]),
}
new_nodes.append(tiled)
args_list[ai] = tiled
changed = True
if changed:
node.args = tuple(args_list)

# Re-prepare lowerings for modified outer nodes
if reshaped:
all_to_prepare = reshaped + new_nodes
if all_to_prepare:
with compile_lock:
graph_lowering = FakeGraphLowering()
with V.set_graph_handler(graph_lowering):
for node in reshaped:
for node in all_to_prepare:
if node.op == "call_function":
with node.meta["location"]:
prepare_node_lowering(graph_lowering, node)
Expand Down
105 changes: 105 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,111 @@ def test_attention_fori_loop_correctness(self) -> None:
).to(device=DEVICE)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_attention_emit_pipeline_correctness_head_dim_256(self) -> None:
"""Test emit_pipeline attention pre-broadcast with head_dim > PRE_BROADCAST_SIZE."""
query = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
key = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
val = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
code, result = code_and_output(
pallas_attention,
(query, key, val),
block_sizes=[4, 128, 128],
pallas_loop_type="emit_pipeline",
pallas_pre_broadcast=True,
)
# m_i and l_i scratches get pre-broadcast trailing dim 128;
# acc scratch keeps head_dim=256
self.assertIn(
"_scratch_shapes=["
"((4, 128, 128), 'jnp.float32', 'vmem'), "
"((4, 128, 256), 'jnp.float32', 'vmem'), "
"((4, 128, 128), 'jnp.float32', 'vmem')]",
code,
)
self.assertIn("jnp.tile(", code)
ref = torch.nn.functional.scaled_dot_product_attention(
query.float().cpu(), key.float().cpu(), val.float().cpu()
).to(device=DEVICE)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_attention_fori_loop_correctness_head_dim_256(self) -> None:
"""Test fori_loop attention pre-broadcast with head_dim > PRE_BROADCAST_SIZE."""
query = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
key = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
val = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE)
args = (query, key, val)
code, result = code_and_output(
pallas_attention,
args,
block_sizes=[4, 128, 128],
pallas_loop_type="fori_loop",
pallas_pre_broadcast=True,
)
self.assertIn("jax.lax.fori_loop", code)
# m_i and l_i scratches get pre-broadcast trailing dim 128;
# acc scratch keeps head_dim=256; extra entries are DMA buffers/semaphores
self.assertIn(
"_scratch_shapes=["
"((4, 128, 128), 'jnp.float32', 'vmem'), "
"((4, 128, 256), 'jnp.float32', 'vmem'), "
"((4, 128, 128), 'jnp.float32', 'vmem'), "
"((4, 256, 128), 'jnp.float32', 'vmem'), "
"((), None, 'dma_semaphore'), "
"((4, 128, 256), 'jnp.float32', 'vmem'), "
"((), None, 'dma_semaphore')]",
code,
)
self.assertIn("jnp.tile(", code)
ref = torch.nn.functional.scaled_dot_product_attention(
query.float().cpu(), key.float().cpu(), val.float().cpu()
).to(device=DEVICE)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_pre_broadcast_indirect_consumer(self) -> None:
"""Pre-broadcast tile must propagate through indirect consumers.

When a pre-broadcast node (2D, trailing dim 128) feeds an intermediate
op (e.g. running + 1.0, rsqrt) before reaching a wider-dim consumer
(e.g. acc * scale where acc has head_dim=256), the tile-insertion pass
must tile the intermediate result, not just direct pre-broadcast nodes.
"""

@helion.kernel(backend="pallas", static_shapes=True)
def outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
batch, m, k = a.size()
head_dim = hl.specialize(b.size(-1))
out = torch.empty([batch, m, head_dim], device=a.device, dtype=a.dtype)
for tile_b, tile_m in hl.tile([batch, m]):
running = hl.zeros([tile_b, tile_m], dtype=torch.float32)
acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
for tile_k in hl.tile(k):
chunk = a[tile_b, tile_m, tile_k]
running = running + torch.sum(chunk, -1)
acc = acc + running[:, :, None]
scale = torch.rsqrt(running[:, :, None] + 1.0)
out[tile_b, tile_m, :] = (acc * scale).to(out.dtype)
return out

def ref_outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# With k=128 and block_k=128, there's 1 tile iteration:
# running = sum(a, dim=-1), acc = running[:,:,None] (broadcast to 256)
running = a.sum(-1)
acc = running[:, :, None].expand(-1, -1, b.shape[-1]).clone()
scale = torch.rsqrt(running[:, :, None] + 1.0)
return (acc * scale).to(a.dtype)

a = torch.rand(4, 64, 128, dtype=torch.float32, device=DEVICE)
b = torch.rand(4, 64, 256, dtype=torch.float32, device=DEVICE)
code, result = code_and_output(
outer_chain_scale,
(a, b),
block_sizes=[4, 64, 128],
pallas_loop_type="fori_loop",
pallas_pre_broadcast=True,
)
ref = ref_outer_chain_scale(a, b)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_attention_emit_pipeline_non_divisible(self) -> None:
"""Test emit_pipeline with seq_kv not divisible by block_k.

Expand Down
Loading