diff --git a/examples/attention.py b/examples/attention.py index 4ca3c98c4..91225bffd 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -69,10 +69,9 @@ def attention( m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) l_i = torch.full_like(m_i, 1.0) acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32) - q = q_view[tile_b, tile_m, :] for tile_n in hl.tile(v_view.size(1)): - # scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q - q_scaled = q * qk_scale + # load and scale Q in-loop on-demand reduces spillage + q_scaled = q_view[tile_b, tile_m, :] * qk_scale k = k_view[tile_b, tile_n, :] # Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs. # same as hl.dot(q, k, out_dtype=torch.float32) diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 3ead55f51..e94b33825 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -1871,6 +1871,7 @@ def _classify_pipelined_tensors( outer_access_tensor_ids.add(id(val)) pipelined_ids: set[int] = set() + inner_block_id_set = set(block_ids) for (fake, _sub_meta, _direction), vmem_shape in zip( all_tensor_info, vmem_shapes, strict=True ): @@ -1878,6 +1879,13 @@ def _classify_pipelined_tensors( continue if id(fake) in outer_access_tensor_ids: continue + # Skip loop-invariant tensors: if none of the tensor's subscript + # dimensions correspond to an inner pipeline block_id, it doesn't + # vary with the pipeline iteration and should stay on its outer + # VMEM BlockSpec (single buffer, no redundant DMA). + dim_to_bid = _get_dim_block_ids(_sub_meta, env) + if not (set(dim_to_bid.values()) & inner_block_id_set): + continue pipelined_ids.add(id(fake)) return all_tensor_info, vmem_shapes, pipelined_ids diff --git a/test/test_pallas.py b/test/test_pallas.py index 52454363f..68804edad 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -315,10 +315,9 @@ def pallas_attention( m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) l_i = torch.full_like(m_i, 1.0) acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32) - q = q_view[tile_b, tile_m, :] for tile_n in hl.tile(v_view.size(1)): - # scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q - q_scaled = q * qk_scale + # load and scale Q in-loop on-demand reduces spillage + q_scaled = q_view[tile_b, tile_m, :] * qk_scale k = k_view[tile_b, tile_n, :] # Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs. # same as hl.dot(q, k, out_dtype=torch.float32)