Skip to content

[Pallas] Cast bool masks to float before expanding in _mask_to codegen#2216

Merged
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/38
May 6, 2026
Merged

[Pallas] Cast bool masks to float before expanding in _mask_to codegen#2216
AmesingFlank merged 1 commit into
mainfrom
AmesingFlank/stack/38

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 3, 2026

Stacked PRs:


[Pallas] Cast bool masks to float before expanding in _mask_to codegen

Mosaic TPU compiler cannot reshape bool vectors (vector ->
vector). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/37 branch from 938bde3 to a1867d1 Compare May 3, 2026 21:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from e14cb44 to e823bb4 Compare May 3, 2026 21:54
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 3, 2026
AmesingFlank added a commit that referenced this pull request May 3, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 3, 2026 23:14
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from e823bb4 to 8e224a5 Compare May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 3, 2026 23:14
@AmesingFlank AmesingFlank marked this pull request as ready for review May 3, 2026 23:14
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 01:46
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from 8e224a5 to fb96c3b Compare May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 4, 2026 01:46
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 01:46
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 01:52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from fb96c3b to 539e8bd Compare May 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 4, 2026 01:52
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:22
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 03:32
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 03:32
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from 6cdd1f0 to 332cd72 Compare May 4, 2026 03:33
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 4, 2026 03:33
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:33
@AmesingFlank AmesingFlank requested review from norx1991 and oulgen May 4, 2026 15:09
expand = state.tile_strategy.expand_str(input_sizes, dim)
expr = f"({mask_var}{expand})"
# Cast bool mask to float before expanding — Mosaic cannot
# reshape bool vectors (e.g. vector<32xi1> → vector<32x1xi1>).
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue for Jax: jax-ml/jax#37370

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cast it back to bool after expanding so that we can keep "&".join(mask_exprs)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just gave that a try, I think casting to float, expand, and then casting back to bool, does get around this compiler error. However I'm not sure if its worth it to cast back, because we'll need to use floating point ops to mask the original float arrays anyways. (either a jnp.where as a ternary floating point op, or * as a binary floating point op)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to profile a bit, in this gist https://gist.github.com/AmesingFlank/2b3b43cc9c0ab268be0a3b14c8f8291c

not seeing any measurable difference between the two approaches. Using * is slightly less cluttering in the generated code though

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think this is a similar problem as in #2214. Let's keep it for now.

@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 16:44
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 4, 2026 16:45
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 16:45
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 17:55
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from 332cd72 to 464ce56 Compare May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/37 May 4, 2026 17:55
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 17:55
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
AmesingFlank added a commit that referenced this pull request May 4, 2026
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 18:54
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/37 to main May 4, 2026 18:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/38 branch from 464ce56 to 3d82c29 Compare May 4, 2026 18:54
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> ->
vector<Nx1xi1>). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2216, branch: AmesingFlank/stack/38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants