Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions helion/_compiler/pallas/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,68 @@ def _load_mask_expr(
return "*".join(mask_exprs)


def sliced_value_for_store(
state: CodegenState,
tensor: torch.Tensor,
subscript: list[object] | tuple[object, ...],
index_parts: list[str],
value: ast.AST,
) -> ast.AST:
"""Slice the store value when the Pallas ref is smaller than the tile.

The launcher clamps each BlockSpec dimension to
``min(block_size, tensor.shape[d])``. When ``block_size > dim_size``
the kernel ref is ``dim_size``-shaped but the computed value is
``block_size``-shaped, so we must slice the value before storing.

This only applies to grid-tiled dimensions that produce ``:`` in the
generated Pallas index. Dimensions indexed via ``pl.ds()`` are padded
instead of clamped, so they must keep their full block-size value.
"""
from helion._compiler.compile_environment import CompileEnvironment
from helion._compiler.pallas.plan_tiling import TilePattern

assert state.fx_node is not None
patterns = state.fx_node.meta.get("indexing_patterns")
if patterns is None:
return value

env = CompileEnvironment.current()
slices: list[str] = []
needs_slice = False
tensor_dim = 0

index_part_idx = 0
for idx, pattern in zip(subscript, patterns, strict=True):
if idx is None:
continue

value_slice = ":"
index_part = index_parts[index_part_idx]
index_part_idx += 1
if isinstance(pattern, TilePattern) and index_part == ":":
block_size = env.block_sizes[pattern.block_id].from_config(state.config)
dim_size = tensor.shape[tensor_dim]
if (
isinstance(block_size, int)
and isinstance(dim_size, int)
and dim_size < block_size
):
value_slice = f":{dim_size}"
needs_slice = True

slices.append(value_slice)
tensor_dim += 1

if not needs_slice:
return value

return expr_from_string(
f"{{value}}[{', '.join(slices)}]",
value=value,
)


def _tile_needs_mask(
state: CodegenState,
block_id: int,
Expand Down Expand Up @@ -153,6 +215,15 @@ def index_str(
subscript: list[object] | tuple[object, ...],
tensor: torch.Tensor,
) -> tuple[str, list[int]]:
parts, none_dims = index_parts(state, subscript, tensor)
return ", ".join(parts), none_dims


def index_parts(
state: CodegenState,
subscript: list[object] | tuple[object, ...],
tensor: torch.Tensor,
) -> tuple[list[str], list[int]]:
"""Build a JAX/Pallas index string from a Helion subscript list.

Uses ``pl.ds(offset, block_size)`` only for dimensions inside a looped
Expand All @@ -171,7 +242,7 @@ def index_str(
from helion._compiler.tile_strategy import ForiLoopState

if not subscript:
return "...", []
return ["..."], []

# Check if we're inside an emit_pipeline or fori_loop that pipelines
# this specific tensor. Both loop types take a per-tensor decision:
Expand Down Expand Up @@ -214,7 +285,7 @@ def index_str(
out_pos += 1
tensor_dim += 1

return ", ".join(parts), none_dims
return parts, none_dims


def _get_indexing_patterns(state: CodegenState, tensor: torch.Tensor) -> list[object]:
Expand Down
8 changes: 6 additions & 2 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,13 @@ def _(state: CodegenState) -> None:
device_fn = state.device_function
device_fn.device_store_index += 1
device_fn.device_memory_op_index += 1
index_str, _ = pallas_codegen.index_str(state, subscript, tensor)
parts, _ = pallas_codegen.index_parts(state, subscript, tensor)
value = pallas_codegen.sliced_value_for_store(
state, tensor, subscript, parts, value
)
idx_str = ", ".join(parts)
state.codegen.add_statement(
statement_from_string(f"{name}[{index_str}] = {{value}}", value=value)
statement_from_string(f"{name}[{idx_str}] = {{value}}", value=value)
)


Expand Down
1 change: 0 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,6 @@ def test_jsd(self):
num_stages=3,
)

@xfailIfPallas("operation not supported on TPU")
def test_kl_div(self):
if _get_backend() == "cute" and "B200" in get_nvidia_gpu_model():
pytest.xfail("CuTe KL-div example still launches out of resources on B200")
Expand Down
62 changes: 62 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,68 @@ def test_add_large(self) -> None:
code, result = code_and_output(add_kernel, args, block_size=512)
torch.testing.assert_close(result, args[0] + args[1])

def test_store_slice_1d(self) -> None:
"""Store value sliced when block_size > tensor dim (1D)."""

@helion.kernel(backend="pallas", static_shapes=True)
def fill_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
out[tile] = hl.full([tile], 1.0, dtype=x.dtype)
return out

x = torch.randn(1024, device=DEVICE, dtype=torch.float32)
code, result = code_and_output(fill_kernel, (x,), block_size=4096)
self.assertIn("[:1024]", code)
torch.testing.assert_close(result, torch.ones_like(x))

def test_store_slice_2d(self) -> None:
"""Store value sliced on the dim where block_size > tensor dim (2D)."""

@helion.kernel(backend="pallas", static_shapes=True)
def fill_2d(x: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty_like(x)
for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = hl.full([tile_m, tile_n], 1.0, dtype=x.dtype)
return out

# 100 < 128, 256 == 256 → only dim 0 needs slicing
x = torch.randn(100, 256, device=DEVICE, dtype=torch.float32)
code, result = code_and_output(fill_2d, (x,), block_size=[128, 256])
self.assertIn("[:100, :]", code)
torch.testing.assert_close(result, torch.ones_like(x))

# 100 < 128, 200 < 256 → both dims need slicing
x2 = torch.randn(100, 200, device=DEVICE, dtype=torch.float32)
code2, result2 = code_and_output(fill_2d, (x2,), block_size=[128, 256])
self.assertIn("[:100, :200]", code2)
torch.testing.assert_close(result2, torch.ones_like(x2))

def test_store_slice_skips_pl_ds_dim(self) -> None:
"""Store value is not sliced on dimensions indexed with pl.ds()."""

@helion.kernel(backend="pallas", static_shapes=True)
def fill_inner_loop(x: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty_like(x)
for tile_m in hl.tile(m):
for tile_n in hl.tile(n):
out[tile_m, tile_n] = hl.full([tile_m, tile_n], 1.0, dtype=x.dtype)
return out

x = torch.randn(64, 32, device=DEVICE, dtype=torch.float32)
code, result = code_and_output(
fill_inner_loop,
(x,),
block_size=[128, 64],
pallas_loop_type="fori_loop",
)
self.assertIn("pl.ds(", code)
self.assertIn("[:64, :]", code)
self.assertNotIn("[:64, :32]", code)
torch.testing.assert_close(result, torch.ones_like(x))

def test_add_does_not_donate_inputs(self) -> None:
"""Verify that read-only inputs are not donated by the kernel.

Expand Down
Loading