diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 2ac3cca95f..bc9f2a4db8 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -1093,12 +1093,89 @@ def _rewrite_outer_subscripts_for_pre_broadcast( ) reshaped.append(node) + # Insert _pre_broadcast_tile where pre-broadcast outer nodes feed wider-dim ops. + # First, propagate pre-broadcast status transitively through indirect consumers. + # After rewriting subscript[:, :, None] → subscript[:, :], downstream nodes + # (e.g. add, rsqrt) may still have stale meta shapes (u0, u1, 1) from trace + # time. We identify them by checking if any arg is pre-broadcast — if so, + # the node is also pre-broadcast (its real last dim is PRE_BROADCAST_SIZE). + all_pre_broadcast_outer: set[str] = set(pre_broadcast_outer_nodes) + all_pre_broadcast_outer.update(node.name for node in reshaped) + for node in outer_graph.nodes: + if node.op != "call_function" or node.name in all_pre_broadcast_outer: + continue + node_val = node.meta.get("val", None) + if not isinstance(node_val, torch.Tensor) or len(node_val.shape) < 2: + continue + last_dim = node_val.shape[-1] + if isinstance(last_dim, torch.SymInt): + continue + last_dim_int = int(last_dim) + if last_dim_int > PRE_BROADCAST_SIZE: + continue + has_pre_broadcast_arg = False + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg.name in all_pre_broadcast_outer: + arg_val = arg.meta.get("val", None) + if isinstance(arg_val, torch.Tensor) and len(arg_val.shape) >= 2: + arg_last = arg_val.shape[-1] + if isinstance(arg_last, int) and arg_last == PRE_BROADCAST_SIZE: + has_pre_broadcast_arg = True + break + if has_pre_broadcast_arg: + new_shape = [*node_val.shape[:-1], PRE_BROADCAST_SIZE] + node.meta["val"] = node_val.new_empty(new_shape) + all_pre_broadcast_outer.add(node.name) + reshaped.append(node) + + new_nodes: list[torch.fx.Node] = [] + for node in list(outer_graph.nodes): + if node.op != "call_function" or node.name in all_pre_broadcast_outer: + continue + node_val = node.meta.get("val", None) + if not isinstance(node_val, torch.Tensor) or len(node_val.shape) < 2: + continue + last_dim = node_val.shape[-1] + last_dim_is_sym = isinstance(last_dim, torch.SymInt) + if not last_dim_is_sym and int(last_dim) <= PRE_BROADCAST_SIZE: + continue + args_list = list(node.args) + changed = False + for ai, arg in enumerate(args_list): + if not isinstance(arg, torch.fx.Node): + continue + if arg.name not in all_pre_broadcast_outer: + continue + arg_val = arg.meta.get("val", None) + if not isinstance(arg_val, torch.Tensor): + continue + if not ( + isinstance(arg_val.shape[-1], int) + and arg_val.shape[-1] == PRE_BROADCAST_SIZE + ): + continue + with outer_graph.inserting_before(node): + tiled = outer_graph.call_function( + _pre_broadcast_tile, + args=(arg, last_dim), + ) + tiled.meta = { + **arg.meta, + "val": arg_val.new_empty([*arg_val.shape[:-1], last_dim]), + } + new_nodes.append(tiled) + args_list[ai] = tiled + changed = True + if changed: + node.args = tuple(args_list) + # Re-prepare lowerings for modified outer nodes - if reshaped: + all_to_prepare = reshaped + new_nodes + if all_to_prepare: with compile_lock: graph_lowering = FakeGraphLowering() with V.set_graph_handler(graph_lowering): - for node in reshaped: + for node in all_to_prepare: if node.op == "call_function": with node.meta["location"]: prepare_node_lowering(graph_lowering, node) diff --git a/test/test_pallas.py b/test/test_pallas.py index f64de2d5bd..0128379115 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -971,6 +971,111 @@ def test_attention_fori_loop_correctness(self) -> None: ).to(device=DEVICE) torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + def test_attention_emit_pipeline_correctness_head_dim_256(self) -> None: + """Test emit_pipeline attention pre-broadcast with head_dim > PRE_BROADCAST_SIZE.""" + query = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + key = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + val = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + code, result = code_and_output( + pallas_attention, + (query, key, val), + block_sizes=[4, 128, 128], + pallas_loop_type="emit_pipeline", + pallas_pre_broadcast=True, + ) + # m_i and l_i scratches get pre-broadcast trailing dim 128; + # acc scratch keeps head_dim=256 + self.assertIn( + "_scratch_shapes=[" + "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem'), " + "((4, 128, 128), 'jnp.float32', 'vmem')]", + code, + ) + self.assertIn("jnp.tile(", code) + ref = torch.nn.functional.scaled_dot_product_attention( + query.float().cpu(), key.float().cpu(), val.float().cpu() + ).to(device=DEVICE) + torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + + def test_attention_fori_loop_correctness_head_dim_256(self) -> None: + """Test fori_loop attention pre-broadcast with head_dim > PRE_BROADCAST_SIZE.""" + query = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + key = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + val = torch.randn(2, 2, 128, 256, dtype=torch.float32, device=DEVICE) + args = (query, key, val) + code, result = code_and_output( + pallas_attention, + args, + block_sizes=[4, 128, 128], + pallas_loop_type="fori_loop", + pallas_pre_broadcast=True, + ) + self.assertIn("jax.lax.fori_loop", code) + # m_i and l_i scratches get pre-broadcast trailing dim 128; + # acc scratch keeps head_dim=256; extra entries are DMA buffers/semaphores + self.assertIn( + "_scratch_shapes=[" + "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem'), " + "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 256, 128), 'jnp.float32', 'vmem'), " + "((), None, 'dma_semaphore'), " + "((4, 128, 256), 'jnp.float32', 'vmem'), " + "((), None, 'dma_semaphore')]", + code, + ) + self.assertIn("jnp.tile(", code) + ref = torch.nn.functional.scaled_dot_product_attention( + query.float().cpu(), key.float().cpu(), val.float().cpu() + ).to(device=DEVICE) + torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + + def test_pre_broadcast_indirect_consumer(self) -> None: + """Pre-broadcast tile must propagate through indirect consumers. + + When a pre-broadcast node (2D, trailing dim 128) feeds an intermediate + op (e.g. running + 1.0, rsqrt) before reaching a wider-dim consumer + (e.g. acc * scale where acc has head_dim=256), the tile-insertion pass + must tile the intermediate result, not just direct pre-broadcast nodes. + """ + + @helion.kernel(backend="pallas", static_shapes=True) + def outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + batch, m, k = a.size() + head_dim = hl.specialize(b.size(-1)) + out = torch.empty([batch, m, head_dim], device=a.device, dtype=a.dtype) + for tile_b, tile_m in hl.tile([batch, m]): + running = hl.zeros([tile_b, tile_m], dtype=torch.float32) + acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32) + for tile_k in hl.tile(k): + chunk = a[tile_b, tile_m, tile_k] + running = running + torch.sum(chunk, -1) + acc = acc + running[:, :, None] + scale = torch.rsqrt(running[:, :, None] + 1.0) + out[tile_b, tile_m, :] = (acc * scale).to(out.dtype) + return out + + def ref_outer_chain_scale(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + # With k=128 and block_k=128, there's 1 tile iteration: + # running = sum(a, dim=-1), acc = running[:,:,None] (broadcast to 256) + running = a.sum(-1) + acc = running[:, :, None].expand(-1, -1, b.shape[-1]).clone() + scale = torch.rsqrt(running[:, :, None] + 1.0) + return (acc * scale).to(a.dtype) + + a = torch.rand(4, 64, 128, dtype=torch.float32, device=DEVICE) + b = torch.rand(4, 64, 256, dtype=torch.float32, device=DEVICE) + code, result = code_and_output( + outer_chain_scale, + (a, b), + block_sizes=[4, 64, 128], + pallas_loop_type="fori_loop", + pallas_pre_broadcast=True, + ) + ref = ref_outer_chain_scale(a, b) + torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + def test_attention_emit_pipeline_non_divisible(self) -> None: """Test emit_pipeline with seq_kv not divisible by block_k.