Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,12 +2566,15 @@ def _(state: CodegenState) -> ast.AST:
mask_var := state.codegen.mask_var(index)
) is not None:
expand = state.tile_strategy.expand_str(input_sizes, dim)
expr = f"({mask_var}{expand})"
# Cast bool mask to float before expanding — Mosaic cannot
# reshape bool vectors (e.g. vector<32xi1> → vector<32x1xi1>).
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue for Jax: jax-ml/jax#37370

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cast it back to bool after expanding so that we can keep "&".join(mask_exprs)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just gave that a try, I think casting to float, expand, and then casting back to bool, does get around this compiler error. However I'm not sure if its worth it to cast back, because we'll need to use floating point ops to mask the original float arrays anyways. (either a jnp.where as a ternary floating point op, or * as a binary floating point op)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to profile a bit, in this gist https://gist.github.com/AmesingFlank/2b3b43cc9c0ab268be0a3b14c8f8291c

not seeing any measurable difference between the two approaches. Using * is slightly less cluttering in the generated code though

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think this is a similar problem as in #2214. Let's keep it for now.

expr = f"({mask_var}.astype(jnp.float32){expand})"
if expr not in mask_exprs:
mask_exprs.append(expr)
if not mask_exprs:
return state.ast_arg(0)
mask_expr = "&".join(mask_exprs)
# Combine float masks via multiplication (equivalent to bool AND).
mask_expr = " * ".join(mask_exprs)
if len(mask_exprs) < len(input_sizes):
mask_expr = backend.broadcast_to_expr(
mask_expr, state.tile_strategy.shape_str(input_sizes)
Expand Down
1 change: 0 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,6 @@ def test_attention_block_pointer(self):
lambda: _get_backend() == "cute",
"CuTe dynamic attention destabilizes later cute tests when it fails in-process",
)
@xfailIfPallas("JAX tracer error with dynamic shapes")
def test_attention_dynamic(self):
args = (
torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE),
Expand Down
40 changes: 40 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,46 @@ def nested_tile_sum(
).sum(dim=0)
torch.testing.assert_close(result, ref, rtol=1e-3, atol=1e-3)

def test_nested_tile_matmul_mask_cast(self) -> None:
"""Two nested data-dependent tiles with matmul need float mask expansion."""

@helion.kernel(backend="pallas", static_shapes=True)
def jagged_kernel(
x: torch.Tensor, y: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
num_segs = offsets.size(0) - 1
out = torch.zeros([num_segs], dtype=x.dtype, device=x.device)
for seg in hl.grid(num_segs):
start = offsets[seg]
end = offsets[seg + 1]
acc = hl.zeros([1], dtype=x.dtype)
for tile_i in hl.tile(start, end):
for tile_j in hl.tile(start, end):
gram = torch.matmul(
x[tile_i, :], y[tile_j, :].transpose(-2, -1)
)
acc = acc + gram.sum(dim=0).sum(dim=0).unsqueeze(0)
out[seg] = acc.squeeze(0)
return out

N, D = 128, 128
x = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
y = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
offsets = torch.tensor([0, 64, 128], device=DEVICE, dtype=torch.int32)

_code, result = code_and_output(
jagged_kernel,
(x, y, offsets),
block_sizes=[32, 32],
pallas_loop_type="fori_loop",
)

ref = torch.zeros(offsets.size(0) - 1, device=DEVICE, dtype=x.dtype)
for i in range(offsets.size(0) - 1):
s, e = int(offsets[i]), int(offsets[i + 1])
ref[i] = (x[s:e] @ y[s:e].T).sum()
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)


@skipUnlessPallas("JAX/Pallas TPU not available")
class TestPallasIndirectGather(TestCase):
Expand Down