Skip to content

[Pallas] Apply tile masks at load time to zero out-of-bounds data#2214

Merged
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/36
May 4, 2026
Merged

[Pallas] Apply tile masks at load time to zero out-of-bounds data#2214
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/36

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 3, 2026

Stacked PRs:


[Pallas] Apply tile masks at load time to zero out-of-bounds data

Consider this jagged-sum kernel where each segment has a different
length:

for seq_index in hl.grid(num_rows):
    start = x_offsets[seq_index]
    end = x_offsets[seq_index + 1]
    row_sums = hl.zeros([1], dtype=x_data.dtype)
    for tile in hl.tile(start, end):
        vals = x_data[tile, :, :]
        row_sums += vals.sum()

If segment 0 has length 35, hl.tile(0, 35) with block_size=64
iterates once over tile range [0, 64). The host pads x_data along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage). Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's tl.load(mask=, other=0), Pallas loads via
ref[pl.ds()] return raw memory with no built-in masking. Fix this
by generating ref[idx] * mask.astype(dtype) directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size. If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied. Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/35 branch from a62d150 to a078f65 Compare May 3, 2026 21:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from 8dd075b to 9ea10cb Compare May 3, 2026 21:54
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 3, 2026
AmesingFlank added a commit that referenced this pull request May 3, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 3, 2026 23:14
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from 9ea10cb to f91ffeb Compare May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 3, 2026 23:14
@AmesingFlank AmesingFlank marked this pull request as ready for review May 3, 2026 23:14
AmesingFlank added a commit that referenced this pull request May 4, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
AmesingFlank added a commit that referenced this pull request May 4, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
AmesingFlank added a commit that referenced this pull request May 4, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
AmesingFlank added a commit that referenced this pull request May 4, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 01:46
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from f91ffeb to 38d6522 Compare May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 4, 2026 01:46
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 01:46
AmesingFlank added a commit that referenced this pull request May 4, 2026
Pallas loads (`ref[pl.ds()]`) return raw memory with no built-in
masking, unlike Triton's `tl.load(mask=, other=0)`. When a tile
overshoots a segment boundary, trailing rows contain adjacent-segment
data that corrupts reductions.

Generate `ref[idx] * mask.astype(dtype)` directly in load codegen so
loaded data is zeroed for out-of-tile positions. This lets the existing
`remove_unnecessary_masking` pass correctly elide the redundant
`_mask_to` nodes. The float cast is applied before unsqueeze to avoid
a Mosaic boolean reshape limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 01:52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from 38d6522 to f2dc094 Compare May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 4, 2026 01:52
AmesingFlank added a commit that referenced this pull request May 4, 2026
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 03:32
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 03:32
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from eaf93e8 to 4da9e2d Compare May 4, 2026 03:33
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 4, 2026 03:33
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:33
@AmesingFlank AmesingFlank requested review from jansel, norx1991 and oulgen May 4, 2026 15:08
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 4, 2026 16:45
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 16:45
AmesingFlank added a commit that referenced this pull request May 4, 2026
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
AmesingFlank added a commit that referenced this pull request May 4, 2026
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 17:54
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 17:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/36 branch from 4da9e2d to c7d2977 Compare May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/35 May 4, 2026 17:55
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 17:55
Comment thread test/test_pallas.py
row_sums = hl.zeros([1], dtype=x_data.dtype)
for tile in hl.tile(start, end):
vals = x_data[tile, :, :]
row_sums = row_sums + vals.sum(dim=0).sum(dim=0).sum(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will a single sum() causes some issue?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Helion's sum only allows reduction across a single dimension at a time, otherwise we throw here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think we can update the TODO there as we have found a concrete case that will hit this.

result = expr_from_string(f"{name}[{idx_str}]")
mask_expr = _load_mask_expr(state, subscript, tensor)
if mask_expr is not None:
result = expr_from_string(f"{name}[{idx_str}] * ({mask_expr})")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we considered using jnp.where?

Copy link
Copy Markdown
Contributor Author

@AmesingFlank AmesingFlank May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue with jnp.where + bool array is that, Mosaic TPU lowering doesn't like reshaping bool arrays. See this related PR (#2216). When there is a bool array and we try to do bool_array[:, None], it fails, so we have to resort to casting to dtype and then multiplying.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause issue for other types of reduction? Like max, min?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think helion's contract is that if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0. (On triton we have tl.load with masking which does that, but pltpu.load + masking doesn't work within fori_loop, which is why we do this). The user should expect to make sure there kernel works with 0-masked values

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With #2216, does the jnp.where method work? It seems like that will allow us to have one less assumption like this: "if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0".

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With #2216, does the jnp.where method work?

Not quite sure if I understood the question. 2216 is for a different place where we need masking, but it has the same issue as we do here, where reshaping bool arrays is not allowed. And jnp.where can only be applied to bool arrays.

It seems like that will allow us to have one less assumption like this: "if you have a tile that goes out-of-bound, then the out-of-bound sections will be masked to have value 0".

Why would using jnp.where help remove this assumption? jnp.where is just a different way to apply the masking, we are not truncating the loaded the tensor to have a different shape.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2216 is for a different place where we need masking, but it has the same issue as we do here

IIUC, #2216 demonstrates that we can use casting to manipulate the boolean array. Can this technique be used here as well to enable jnp.where?

we are not truncating the loaded the tensor to have a different shape
Right, but it makes it expandable as we can decide the fill value. With multiplying 0, it only works when 0 is the fill value.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jnp.where only accepts bool array as condition, so we'll have to generate a float mask and cast back to bool, which is extra work. In addition, I don't know if jnp.where will be faster than applying a * directly, because it is a tenary floating point op, whereas * is a binary floating point op. So I'm not sure if it has any benefits

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to profile a bit, in this gist https://gist.github.com/AmesingFlank/2b3b43cc9c0ab268be0a3b14c8f8291c

not seeing any measurable difference between the two approaches. Using * is slightly less cluttering in the generated code though

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we can proceed with this as it then. Suggest add a comment to documentation the assumption about the fill value though (I guess it will produce weird wrong result for other types of reduction).

Comment thread helion/_compiler/pallas/codegen.py Outdated
dtype_str = env.backend.dtype_str(tensor.dtype)
expand = state.tile_strategy.expand_str(output_sizes, out_dim)
expr = f"({mask_var}.astype({dtype_str}){expand})"
if expr not in mask_exprs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this happen?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This conditional isn't necessary, updated the PR

AmesingFlank added a commit that referenced this pull request May 4, 2026
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
AmesingFlank added a commit that referenced this pull request May 4, 2026
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 18:54
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/35 to main May 4, 2026 18:54
Consider this jagged-sum kernel where each segment has a different
length:

    for seq_index in hl.grid(num_rows):
        start = x_offsets[seq_index]
        end = x_offsets[seq_index + 1]
        row_sums = hl.zeros([1], dtype=x_data.dtype)
        for tile in hl.tile(start, end):
            vals = x_data[tile, :, :]
            row_sums += vals.sum()

If segment 0 has length 35, `hl.tile(0, 35)` with block_size=64
iterates once over tile range [0, 64).  The host pads `x_data` along
dim 0 so the Pallas ref has at least 64 rows, but rows 35-63 contain
the next segment's data (or padding garbage).  Without masking, those
29 stale rows are included in the reduction, silently producing wrong
results.

Unlike Triton's `tl.load(mask=, other=0)`, Pallas loads via
`ref[pl.ds()]` return raw memory with no built-in masking.  Fix this
by generating `ref[idx] * mask.astype(dtype)` directly in load
codegen so loaded data is zeroed for out-of-tile positions.

Masking is driven by the indexing patterns: for each TilePattern
dimension, we check whether the loop's iteration range matches the
tensor's symbolic size.  If not (data-dependent bounds, constexpr
sub-ranges, non-zero begin), a mask is applied.  Grid/tile dimensions
where BlockSpecs size the ref to the actual remainder are not masked —
a block-sized mask would cause a shape mismatch against the smaller
ref.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2214, branch: AmesingFlank/stack/36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants