diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 2db0dcaac9..2ac3cca95f 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -2566,12 +2566,15 @@ def _(state: CodegenState) -> ast.AST: mask_var := state.codegen.mask_var(index) ) is not None: expand = state.tile_strategy.expand_str(input_sizes, dim) - expr = f"({mask_var}{expand})" + # Cast bool mask to float before expanding — Mosaic cannot + # reshape bool vectors (e.g. vector<32xi1> → vector<32x1xi1>). + expr = f"({mask_var}.astype(jnp.float32){expand})" if expr not in mask_exprs: mask_exprs.append(expr) if not mask_exprs: return state.ast_arg(0) - mask_expr = "&".join(mask_exprs) + # Combine float masks via multiplication (equivalent to bool AND). + mask_expr = " * ".join(mask_exprs) if len(mask_exprs) < len(input_sizes): mask_expr = backend.broadcast_to_expr( mask_expr, state.tile_strategy.shape_str(input_sizes) diff --git a/test/test_examples.py b/test/test_examples.py index ed7def063d..5305deb1db 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -866,7 +866,6 @@ def test_attention_block_pointer(self): lambda: _get_backend() == "cute", "CuTe dynamic attention destabilizes later cute tests when it fails in-process", ) - @xfailIfPallas("JAX tracer error with dynamic shapes") def test_attention_dynamic(self): args = ( torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE), diff --git a/test/test_pallas.py b/test/test_pallas.py index 5d50824531..c3d9bbb089 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -2200,6 +2200,46 @@ def nested_tile_sum( ).sum(dim=0) torch.testing.assert_close(result, ref, rtol=1e-3, atol=1e-3) + def test_nested_tile_matmul_mask_cast(self) -> None: + """Two nested data-dependent tiles with matmul need float mask expansion.""" + + @helion.kernel(backend="pallas", static_shapes=True) + def jagged_kernel( + x: torch.Tensor, y: torch.Tensor, offsets: torch.Tensor + ) -> torch.Tensor: + num_segs = offsets.size(0) - 1 + out = torch.zeros([num_segs], dtype=x.dtype, device=x.device) + for seg in hl.grid(num_segs): + start = offsets[seg] + end = offsets[seg + 1] + acc = hl.zeros([1], dtype=x.dtype) + for tile_i in hl.tile(start, end): + for tile_j in hl.tile(start, end): + gram = torch.matmul( + x[tile_i, :], y[tile_j, :].transpose(-2, -1) + ) + acc = acc + gram.sum(dim=0).sum(dim=0).unsqueeze(0) + out[seg] = acc.squeeze(0) + return out + + N, D = 128, 128 + x = torch.randn(N, D, device=DEVICE, dtype=torch.float32) + y = torch.randn(N, D, device=DEVICE, dtype=torch.float32) + offsets = torch.tensor([0, 64, 128], device=DEVICE, dtype=torch.int32) + + _code, result = code_and_output( + jagged_kernel, + (x, y, offsets), + block_sizes=[32, 32], + pallas_loop_type="fori_loop", + ) + + ref = torch.zeros(offsets.size(0) - 1, device=DEVICE, dtype=x.dtype) + for i in range(offsets.size(0) - 1): + s, e = int(offsets[i]), int(offsets[i + 1]) + ref[i] = (x[s:e] @ y[s:e].T).sum() + torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) + @skipUnlessPallas("JAX/Pallas TPU not available") class TestPallasIndirectGather(TestCase):