Skip to content

[Pallas] Fix scratch Ref scoping bug in fori_loop/emit_pipeline codegen#2215

Merged
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/37
May 5, 2026
Merged

[Pallas] Fix scratch Ref scoping bug in fori_loop/emit_pipeline codegen#2215
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/37

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 3, 2026

Stacked PRs:


[Pallas] Fix scratch Ref scoping bug in fori_loop/emit_pipeline codegen

When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes scratch local to the entire function, so
earlier references like scratch[...] = v_1[...] fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing scratch = scratch[...] inside the
closure.

The fix: emit nonlocal <scratch_name> declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
_emit_nonlocal_scratch_declarations. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing scratch = scratch[...] to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
scratch = scratch[...] from a nested inner loop's readback.

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from 938bde3 to a1867d1 Compare May 3, 2026 21:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from 8dd075b to 9ea10cb Compare May 3, 2026 21:54
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 3, 2026
AmesingFlank added a commit that referenced this pull request May 3, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 3, 2026 23:14
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from a1867d1 to e23e600 Compare May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 3, 2026 23:14
@AmesingFlank AmesingFlank marked this pull request as ready for review May 3, 2026 23:14
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 01:46
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from e23e600 to 6fd7105 Compare May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 01:46
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 01:46
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 01:52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from 6fd7105 to 8ec698e Compare May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 03:22
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:22
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 03:32
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 03:32
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from eb8d483 to b2b2a2c Compare May 4, 2026 03:33
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 03:33
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:33
@AmesingFlank AmesingFlank requested review from jansel, norx1991 and oulgen May 4, 2026 15:08
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 16:45
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 16:45
@norx1991
Copy link
Copy Markdown
Contributor

norx1991 commented May 4, 2026

Smart fix!

AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 17:55
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from b2b2a2c to f485aaa Compare May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/36 May 4, 2026 17:55
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 17:55
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank added a commit that referenced this pull request May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 18:54
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/36 to main May 4, 2026 18:54
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:

    scratch = scratch[...]

inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.

The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.

The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.

The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2215, branch: AmesingFlank/stack/37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants