[Pallas] Cast bool masks to float before expanding in _mask_to codegen#2216
Conversation
938bde3 to
a1867d1
Compare
e14cb44 to
e823bb4
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
e823bb4 to
8e224a5
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
8e224a5 to
fb96c3b
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
fb96c3b to
539e8bd
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
6cdd1f0 to
332cd72
Compare
| expand = state.tile_strategy.expand_str(input_sizes, dim) | ||
| expr = f"({mask_var}{expand})" | ||
| # Cast bool mask to float before expanding — Mosaic cannot | ||
| # reshape bool vectors (e.g. vector<32xi1> → vector<32x1xi1>). |
There was a problem hiding this comment.
Can you cast it back to bool after expanding so that we can keep "&".join(mask_exprs)?
There was a problem hiding this comment.
Just gave that a try, I think casting to float, expand, and then casting back to bool, does get around this compiler error. However I'm not sure if its worth it to cast back, because we'll need to use floating point ops to mask the original float arrays anyways. (either a jnp.where as a ternary floating point op, or * as a binary floating point op)
There was a problem hiding this comment.
Tried to profile a bit, in this gist https://gist.github.com/AmesingFlank/2b3b43cc9c0ab268be0a3b14c8f8291c
not seeing any measurable difference between the two approaches. Using * is slightly less cluttering in the generated code though
There was a problem hiding this comment.
Thanks! I think this is a similar problem as in #2214. Let's keep it for now.
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
332cd72 to
464ce56
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
464ce56 to
3d82c29
Compare
Mosaic TPU compiler cannot reshape bool vectors (vector<Nxi1> -> vector<Nx1xi1>). When _mask_to generates combined masks like mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32 before dimension expansion and combined with * instead of &. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2216, branch: AmesingFlank/stack/38
Stacked PRs:
[Pallas] Cast bool masks to float before expanding in _mask_to codegen
Mosaic TPU compiler cannot reshape bool vectors (vector ->
vector). When _mask_to generates combined masks like
mask_q[:, None] & mask_kv[None, :], each mask is now cast to float32
before dimension expansion and combined with * instead of &.
Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com