diff --git a/helion/_compiler/pallas/codegen.py b/helion/_compiler/pallas/codegen.py index c2ab9726d..aef8f8f52 100644 --- a/helion/_compiler/pallas/codegen.py +++ b/helion/_compiler/pallas/codegen.py @@ -106,6 +106,68 @@ def _load_mask_expr( return "*".join(mask_exprs) +def sliced_value_for_store( + state: CodegenState, + tensor: torch.Tensor, + subscript: list[object] | tuple[object, ...], + index_parts: list[str], + value: ast.AST, +) -> ast.AST: + """Slice the store value when the Pallas ref is smaller than the tile. + + The launcher clamps each BlockSpec dimension to + ``min(block_size, tensor.shape[d])``. When ``block_size > dim_size`` + the kernel ref is ``dim_size``-shaped but the computed value is + ``block_size``-shaped, so we must slice the value before storing. + + This only applies to grid-tiled dimensions that produce ``:`` in the + generated Pallas index. Dimensions indexed via ``pl.ds()`` are padded + instead of clamped, so they must keep their full block-size value. + """ + from helion._compiler.compile_environment import CompileEnvironment + from helion._compiler.pallas.plan_tiling import TilePattern + + assert state.fx_node is not None + patterns = state.fx_node.meta.get("indexing_patterns") + if patterns is None: + return value + + env = CompileEnvironment.current() + slices: list[str] = [] + needs_slice = False + tensor_dim = 0 + + index_part_idx = 0 + for idx, pattern in zip(subscript, patterns, strict=True): + if idx is None: + continue + + value_slice = ":" + index_part = index_parts[index_part_idx] + index_part_idx += 1 + if isinstance(pattern, TilePattern) and index_part == ":": + block_size = env.block_sizes[pattern.block_id].from_config(state.config) + dim_size = tensor.shape[tensor_dim] + if ( + isinstance(block_size, int) + and isinstance(dim_size, int) + and dim_size < block_size + ): + value_slice = f":{dim_size}" + needs_slice = True + + slices.append(value_slice) + tensor_dim += 1 + + if not needs_slice: + return value + + return expr_from_string( + f"{{value}}[{', '.join(slices)}]", + value=value, + ) + + def _tile_needs_mask( state: CodegenState, block_id: int, @@ -153,6 +215,15 @@ def index_str( subscript: list[object] | tuple[object, ...], tensor: torch.Tensor, ) -> tuple[str, list[int]]: + parts, none_dims = index_parts(state, subscript, tensor) + return ", ".join(parts), none_dims + + +def index_parts( + state: CodegenState, + subscript: list[object] | tuple[object, ...], + tensor: torch.Tensor, +) -> tuple[list[str], list[int]]: """Build a JAX/Pallas index string from a Helion subscript list. Uses ``pl.ds(offset, block_size)`` only for dimensions inside a looped @@ -171,7 +242,7 @@ def index_str( from helion._compiler.tile_strategy import ForiLoopState if not subscript: - return "...", [] + return ["..."], [] # Check if we're inside an emit_pipeline or fori_loop that pipelines # this specific tensor. Both loop types take a per-tensor decision: @@ -214,7 +285,7 @@ def index_str( out_pos += 1 tensor_dim += 1 - return ", ".join(parts), none_dims + return parts, none_dims def _get_indexing_patterns(state: CodegenState, tensor: torch.Tensor) -> list[object]: diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 790d70161..27b51f0e8 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -274,9 +274,13 @@ def _(state: CodegenState) -> None: device_fn = state.device_function device_fn.device_store_index += 1 device_fn.device_memory_op_index += 1 - index_str, _ = pallas_codegen.index_str(state, subscript, tensor) + parts, _ = pallas_codegen.index_parts(state, subscript, tensor) + value = pallas_codegen.sliced_value_for_store( + state, tensor, subscript, parts, value + ) + idx_str = ", ".join(parts) state.codegen.add_statement( - statement_from_string(f"{name}[{index_str}] = {{value}}", value=value) + statement_from_string(f"{name}[{idx_str}] = {{value}}", value=value) ) diff --git a/test/test_examples.py b/test/test_examples.py index ad0bce85d..48d9e8859 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1582,7 +1582,6 @@ def test_jsd(self): num_stages=3, ) - @xfailIfPallas("operation not supported on TPU") def test_kl_div(self): if _get_backend() == "cute" and "B200" in get_nvidia_gpu_model(): pytest.xfail("CuTe KL-div example still launches out of resources on B200") diff --git a/test/test_pallas.py b/test/test_pallas.py index 52454363f..69647aab1 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -496,6 +496,68 @@ def test_add_large(self) -> None: code, result = code_and_output(add_kernel, args, block_size=512) torch.testing.assert_close(result, args[0] + args[1]) + def test_store_slice_1d(self) -> None: + """Store value sliced when block_size > tensor dim (1D).""" + + @helion.kernel(backend="pallas", static_shapes=True) + def fill_kernel(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + out[tile] = hl.full([tile], 1.0, dtype=x.dtype) + return out + + x = torch.randn(1024, device=DEVICE, dtype=torch.float32) + code, result = code_and_output(fill_kernel, (x,), block_size=4096) + self.assertIn("[:1024]", code) + torch.testing.assert_close(result, torch.ones_like(x)) + + def test_store_slice_2d(self) -> None: + """Store value sliced on the dim where block_size > tensor dim (2D).""" + + @helion.kernel(backend="pallas", static_shapes=True) + def fill_2d(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty_like(x) + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = hl.full([tile_m, tile_n], 1.0, dtype=x.dtype) + return out + + # 100 < 128, 256 == 256 → only dim 0 needs slicing + x = torch.randn(100, 256, device=DEVICE, dtype=torch.float32) + code, result = code_and_output(fill_2d, (x,), block_size=[128, 256]) + self.assertIn("[:100, :]", code) + torch.testing.assert_close(result, torch.ones_like(x)) + + # 100 < 128, 200 < 256 → both dims need slicing + x2 = torch.randn(100, 200, device=DEVICE, dtype=torch.float32) + code2, result2 = code_and_output(fill_2d, (x2,), block_size=[128, 256]) + self.assertIn("[:100, :200]", code2) + torch.testing.assert_close(result2, torch.ones_like(x2)) + + def test_store_slice_skips_pl_ds_dim(self) -> None: + """Store value is not sliced on dimensions indexed with pl.ds().""" + + @helion.kernel(backend="pallas", static_shapes=True) + def fill_inner_loop(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty_like(x) + for tile_m in hl.tile(m): + for tile_n in hl.tile(n): + out[tile_m, tile_n] = hl.full([tile_m, tile_n], 1.0, dtype=x.dtype) + return out + + x = torch.randn(64, 32, device=DEVICE, dtype=torch.float32) + code, result = code_and_output( + fill_inner_loop, + (x,), + block_size=[128, 64], + pallas_loop_type="fori_loop", + ) + self.assertIn("pl.ds(", code) + self.assertIn("[:64, :]", code) + self.assertNotIn("[:64, :32]", code) + torch.testing.assert_close(result, torch.ones_like(x)) + def test_add_does_not_donate_inputs(self) -> None: """Verify that read-only inputs are not donated by the kernel.