Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ def attention(
m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
l_i = torch.full_like(m_i, 1.0)
acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
q = q_view[tile_b, tile_m, :]
for tile_n in hl.tile(v_view.size(1)):
# scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
q_scaled = q * qk_scale
# load and scale Q in-loop on-demand reduces spillage
q_scaled = q_view[tile_b, tile_m, :] * qk_scale
k = k_view[tile_b, tile_n, :]
# Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
# same as hl.dot(q, k, out_dtype=torch.float32)
Expand Down
8 changes: 8 additions & 0 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,13 +1871,21 @@ def _classify_pipelined_tensors(
outer_access_tensor_ids.add(id(val))

pipelined_ids: set[int] = set()
inner_block_id_set = set(block_ids)
for (fake, _sub_meta, _direction), vmem_shape in zip(
all_tensor_info, vmem_shapes, strict=True
):
if not _check_dma_alignment(vmem_shape):
continue
if id(fake) in outer_access_tensor_ids:
continue
# Skip loop-invariant tensors: if none of the tensor's subscript
# dimensions correspond to an inner pipeline block_id, it doesn't
# vary with the pipeline iteration and should stay on its outer
# VMEM BlockSpec (single buffer, no redundant DMA).
dim_to_bid = _get_dim_block_ids(_sub_meta, env)
if not (set(dim_to_bid.values()) & inner_block_id_set):
continue
pipelined_ids.add(id(fake))
return all_tensor_info, vmem_shapes, pipelined_ids

Expand Down
5 changes: 2 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,9 @@ def pallas_attention(
m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
l_i = torch.full_like(m_i, 1.0)
acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
q = q_view[tile_b, tile_m, :]
for tile_n in hl.tile(v_view.size(1)):
# scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
q_scaled = q * qk_scale
# load and scale Q in-loop on-demand reduces spillage
q_scaled = q_view[tile_b, tile_m, :] * qk_scale
k = k_view[tile_b, tile_n, :]
# Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
# same as hl.dot(q, k, out_dtype=torch.float32)
Expand Down
Loading