Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,27 @@ def _setup_loop_carried_state(
return scratch_names, result_vars, carried


def _emit_nonlocal_scratch_declarations(
state: CodegenState,
body_stmts: list[ast.AST],
) -> None:
"""Insert ``nonlocal <scratch>`` at the top of the closure body.

Without ``nonlocal``, an assignment like ``scratch = scratch[...]`` inside
a fori_loop/emit_pipeline closure makes ``scratch`` local to the entire
function, causing an UnboundLocalError on the RHS read.

We emit nonlocal for *all* VMEM scratch args, not just the current loop's
carried state, because an outer loop body may contain ``scratch = scratch[...]``
from a nested inner loop's ``_read_final_loop_state``.
"""
names = [
s.name for s in state.device_function._scratch_args if s.scratch_type == "vmem"
]
if names:
body_stmts.insert(0, ast.Nonlocal(names=names))


def _remap_args_to_scratch(
args: list[ast.AST],
scratch_names: list[str],
Expand Down Expand Up @@ -1615,6 +1636,8 @@ def _make_hbm_slice(
if has_loop_state:
_write_back_loop_carried(state, scratch_names, carried, graph_results)

_emit_nonlocal_scratch_declarations(state, body_stmts)

all_body_params = body_params
# emit_pipeline passes indices as a single tuple argument; the prologue
# always references _pipeline_indices, so the body always takes it.
Expand Down Expand Up @@ -2016,6 +2039,8 @@ def _build_hbm_dma_slice(
)
state.codegen.add_statement(statement_from_string(f"{copy_out_var}.wait()"))

_emit_nonlocal_scratch_declarations(state, body_stmts)

# Emit nested fori_loop calls — one per dimension.
# Build inside-out: innermost function wraps body_stmts, each outer
# function wraps the inner fori_loop call.
Expand Down
48 changes: 48 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,54 @@ def jagged_sum_3d(
)
torch.testing.assert_close(result, ref, rtol=1e-3, atol=1e-3)

def test_nested_fori_loop_scratch_scoping(self) -> None:
"""Nested hl.tile(start, end) with inner accumulator"""

@helion.kernel(backend="pallas", static_shapes=True)
def nested_tile_sum(
x: torch.Tensor, y: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
A = hl.specialize(x.size(1))
B = hl.specialize(x.size(2))
num_segs = offsets.size(0) - 1
out = torch.zeros([num_segs, A, B], dtype=x.dtype, device=x.device)
for seg in hl.grid(num_segs):
start = offsets[seg]
end = offsets[seg + 1]
acc = hl.zeros([1, A, B], dtype=x.dtype)
for tile_i in hl.tile(start, end):
inner_acc = hl.zeros([1, A, B], dtype=x.dtype)
for tile_j in hl.tile(start, end):
inner_acc = inner_acc + (x[tile_i, :, :] * y[tile_j, :, :]).sum(
dim=0
).unsqueeze(0)
acc = acc + inner_acc
out[seg, :, :] = acc.squeeze(0)
return out

N, A, B = 128, 8, 256
x = torch.randn(N, A, B, device=DEVICE, dtype=torch.float32)
y = torch.randn(N, A, B, device=DEVICE, dtype=torch.float32)
offsets = torch.tensor([0, 64, 128], device=DEVICE, dtype=torch.int32)

_code, result = code_and_output(
nested_tile_sum,
(x, y, offsets),
block_sizes=[32, 32],
pallas_loop_type="fori_loop",
)

block = 32
ref = torch.zeros(offsets.size(0) - 1, A, B, device=DEVICE, dtype=x.dtype)
for seg in range(offsets.size(0) - 1):
s, e = int(offsets[seg]), int(offsets[seg + 1])
for i in range(0, e - s, block):
for j in range(0, e - s, block):
ref[seg] += (
x[s + i : s + i + block] * y[s + j : s + j + block]
).sum(dim=0)
torch.testing.assert_close(result, ref, rtol=1e-3, atol=1e-3)


@skipUnlessPallas("JAX/Pallas TPU not available")
class TestPallasIndirectGather(TestCase):
Expand Down
Loading