[Pallas] Fix scratch Ref scoping bug in fori_loop/emit_pipeline codegen#2215
Merged
Conversation
938bde3 to
a1867d1
Compare
8dd075b to
9ea10cb
Compare
This was referenced May 3, 2026
Merged
AmesingFlank
added a commit
that referenced
this pull request
May 3, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
a1867d1 to
e23e600
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
e23e600 to
6fd7105
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
6fd7105 to
8ec698e
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
eb8d483 to
b2b2a2c
Compare
Contributor
|
Smart fix! |
norx1991
approved these changes
May 4, 2026
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
b2b2a2c to
f485aaa
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Also simplifies _setup_loop_carried_state to store scratch names as
plain strings in result_vars instead of (result_name, scratch_name)
tuples, and derives the result variable name in _read_final_loop_state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
scratch = scratch[...]
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes `scratch` local to the entire function, so
earlier references like `scratch[...] = v_1[...]` fail with
UnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing `scratch = scratch[...]` inside the
closure.
The fix: emit `nonlocal <scratch_name>` declarations at the top of
fori_loop and emit_pipeline body closures via a new helper
`_emit_nonlocal_scratch_declarations`. This tells Python that scratch
variables refer to the enclosing scope, preventing them from becoming
local and allowing `scratch = scratch[...]` to work correctly.
The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
`scratch = scratch[...]` from a nested inner loop's readback.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2215, branch: AmesingFlank/stack/37
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[Pallas] Fix scratch Ref scoping bug in fori_loop/emit_pipeline codegen
When a Pallas kernel has hl.tile(start, end) loops with loop-carried
state (e.g. an accumulator), the generated code contained a pattern
like:
inside a fori_loop body closure. This is a Python closure scoping bug:
the assignment makes
scratchlocal to the entire function, soearlier references like
scratch[...] = v_1[...]fail withUnboundLocalError because the local hasn't been assigned yet.
The root cause: the phi merge pass (merge_variable_names + ast_rename)
collapses the initial scratch Ref name with the read-back variable
(e.g. scratch_val), producing
scratch = scratch[...]inside theclosure.
The fix: emit
nonlocal <scratch_name>declarations at the top offori_loop and emit_pipeline body closures via a new helper
_emit_nonlocal_scratch_declarations. This tells Python that scratchvariables refer to the enclosing scope, preventing them from becoming
local and allowing
scratch = scratch[...]to work correctly.The declarations cover all VMEM scratch args (not just the current
loop's carried state), because an outer loop body may contain
scratch = scratch[...]from a nested inner loop's readback.Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com