[Pallas] Apply tile masks at load time to zero out-of-bounds data#2214
Conversation
a62d150 to
a078f65
Compare
8dd075b to
9ea10cb
Compare
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
9ea10cb to
f91ffeb
Compare
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
f91ffeb to
38d6522
Compare
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile overshoots a segment boundary, trailing rows contain adjacent-segment data that corrupts reductions. Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so loaded data is zeroed for out-of-tile positions. This lets the existing `remove_unnecessary_masking` pass correctly elide the redundant `_mask_to` nodes. The float cast is applied before unsqueeze to avoid a Mosaic boolean reshape limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2214, branch: AmesingFlank/stack/36
38d6522 to
f2dc094
Compare
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
eaf93e8 to
4da9e2d
Compare
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
4da9e2d to
c7d2977
Compare
| row_sums = hl.zeros([1], dtype=x_data.dtype) | ||
| for tile in hl.tile(start, end): | ||
| vals = x_data[tile, :, :] | ||
| row_sums = row_sums + vals.sum(dim=0).sum(dim=0).sum( |
There was a problem hiding this comment.
Will a single sum() causes some issue?
There was a problem hiding this comment.
Yes, Helion's sum only allows reduction across a single dimension at a time, otherwise we throw here
There was a problem hiding this comment.
Got it. I think we can update the TODO there as we have found a concrete case that will hit this.
| result = expr_from_string(f"{name}[{idx_str}]") | ||
| mask_expr = _load_mask_expr(state, subscript, tensor) | ||
| if mask_expr is not None: | ||
| result = expr_from_string(f"{name}[{idx_str}] * ({mask_expr})") |
There was a problem hiding this comment.
Have we considered using jnp.where?
There was a problem hiding this comment.
One issue with jnp.where + bool array is that, Mosaic TPU lowering doesn't like reshaping bool arrays. See this related PR (#2216). When there is a bool array and we try to do bool_array[:, None], it fails, so we have to resort to casting to dtype and then multiplying.
There was a problem hiding this comment.
Will this cause issue for other types of reduction? Like max, min?
There was a problem hiding this comment.
I think helion's contract is that if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0. (On triton we have tl.load with masking which does that, but pltpu.load + masking doesn't work within fori_loop, which is why we do this). The user should expect to make sure there kernel works with 0-masked values
There was a problem hiding this comment.
With #2216, does the jnp.where method work? It seems like that will allow us to have one less assumption like this: "if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0".
There was a problem hiding this comment.
With #2216, does the jnp.where method work?
Not quite sure if I understood the question. 2216 is for a different place where we need masking, but it has the same issue as we do here, where reshaping bool arrays is not allowed. And jnp.where can only be applied to bool arrays.
It seems like that will allow us to have one less assumption like this: "if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0".
Why would using jnp.where help remove this assumption? jnp.where is just a different way to apply the masking, we are not truncating the loaded the tensor to have a different shape.
There was a problem hiding this comment.
2216 is for a different place where we need masking, but it has the same issue as we do here
IIUC, #2216 demonstrates that we can use casting to manipulate the boolean array. Can this technique be used here as well to enable jnp.where?
we are not truncating the loaded the tensor to have a different shape
Right, but it makes it expandable as we can decide the fill value. With multiplying 0, it only works when 0 is the fill value.
There was a problem hiding this comment.
jnp.where only accepts bool array as condition, so we'll have to generate a float mask and cast back to bool, which is extra work. In addition, I don't know if jnp.where will be faster than applying a * directly, because it is a tenary floating point op, whereas * is a binary floating point op. So I'm not sure if it has any benefits
There was a problem hiding this comment.
Tried to profile a bit, in this gist https://gist.github.com/AmesingFlank/2b3b43cc9c0ab268be0a3b14c8f8291c
not seeing any measurable difference between the two approaches. Using * is slightly less cluttering in the generated code though
There was a problem hiding this comment.
Thanks! I think we can proceed with this as it then. Suggest add a comment to documentation the assumption about the fill value though (I guess it will produce weird wrong result for other types of reduction).
| dtype_str = env.backend.dtype_str(tensor.dtype) | ||
| expand = state.tile_strategy.expand_str(output_sizes, out_dim) | ||
| expr = f"({mask_var}.astype({dtype_str}){expand})" | ||
| if expr not in mask_exprs: |
There was a problem hiding this comment.
Good catch! This conditional isn't necessary, updated the PR
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
Consider this jagged-sum kernel where each segment has a different
length:
for seq_index in hl.grid(num_rows):
start = x_offsets[seq_index]
end = x_offsets[seq_index + 1]
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums += vals.sum()
If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64). The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking. Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
stack-info: PR: #2214, branch: AmesingFlank/stack/36
Stacked PRs:
[Pallas] Apply tile masks at load time to zero out-of-bounds data
Consider this jagged-sum kernel where each segment has a different
length:
If segment 0 has length 35,
hl.tile(0, 35)with block_size=64iterates once over tile range [0, 64). The host pads
x_dataalongdim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.
Unlike Triton's
tl.load(mask=, other=0), Pallas loads viaref[pl.ds()]return raw memory with no built-in masking. Fix thisby generating
ref[idx] * mask.astype(dtype)directly in loadcodegen so loaded data is zeroed for out-of-tile positions.
Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.
Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com