Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,15 +1567,19 @@ def _compute_pad_info(
self,
sorted_args: list[Argument] | None,
config: Config,
) -> list[tuple[int, int, int]] | None:
) -> list[tuple[int, int, int, int]] | None:
"""Identify pl.ds() dims that may need padding and their block sizes.

Uses ``pallas_pad_info`` recorded during codegen to identify which
tensor dimensions use ``pl.ds()`` slicing.

Returns ``[(arg_index, tensor_dim, block_size), ...]`` or ``None``.
The launcher computes the actual pad amount at runtime as
``(-tensor.shape[dim]) % block_size``.
Returns ``[(arg_index, tensor_dim, block_size, extra_pad), ...]``
or ``None``. The launcher computes the actual pad amount at runtime
as ``(-tensor.shape[dim]) % block_size + extra_pad``.

``extra_pad`` is 0 when the tile loop starts at offset 0,
``begin % block_size`` for a constant begin offset, or
``block_size - 1`` for a data-dependent begin.
"""
if sorted_args is None:
return None
Expand All @@ -1589,17 +1593,17 @@ def _compute_pad_info(
if not device_fn.pallas_pad_info:
return None

result: list[tuple[int, int, int]] = []
result: list[tuple[int, int, int, int]] = []
for i, arg in enumerate(sorted_args):
if not isinstance(arg, TensorArg):
continue
dims_info = device_fn.pallas_pad_info.get(id(arg.fake_value))
if dims_info is not None:
for dim, block_id in dims_info.items():
for dim, (block_id, extra_pad) in dims_info.items():
bsi = env.block_sizes[block_id]
bs = bsi.from_config(config)
if isinstance(bs, int) and bs > 1:
result.append((i, dim, bs))
result.append((i, dim, bs, extra_pad))

return result or None

Expand Down
6 changes: 3 additions & 3 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,9 @@ def __init__(
# dict would then need to support multiple entries per tensor
# or the tensor would get distinct arg IDs per memory space.
self.pallas_memory_space: dict[int, PallasMemorySpace] = {}
# Pallas: id(fake_tensor) → {dim: block_id} for dims using pl.ds()
# that may need host-side padding when block size doesn't divide dim.
self.pallas_pad_info: dict[int, dict[int, int]] = {}
# Pallas: id(fake_tensor) → {dim: (block_id, extra_pad)} for dims
# using pl.ds() that may need host-side padding.
self.pallas_pad_info: dict[int, dict[int, tuple[int, int]]] = {}

def allocate_store_index(self) -> int:
"""Bump store counters and return the indexing strategy slot."""
Expand Down
36 changes: 35 additions & 1 deletion helion/_compiler/pallas/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ def _ds_expr(
if tensor is not None and tensor_dim is not None:
from helion.language.memory_ops import _record_pad_info

_record_pad_info(state, tensor, tensor_dim, block_id)
extra_pad = _loop_begin_extra_pad(block_id, state)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we can put the logic of extra_pad into _record_pad_info so we only need a single call here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just gave this a try,but it's not straightforward because _record_pad_info is called from three different contexts with different amounts of loop state available:

  1. _ds_expr in codegen.py — the DeviceLoopState is registered in active_device_loops and LoopDimInfo.begin_expr is set, so
    _loop_begin_extra_pad works correctly here.
  2. _make_block_spec in _codegen_emit_pipeline — called before the EmitPipelineLoopState is added to active_device_loops, and its
    LoopDimInfo doesn't set begin_expr.
  3. _build_hbm_dma_slice in _codegen_fori_loop — same issue as (2).

For (2) and (3), the begin info only exists as codegen-level string expressions (begin_exprs) in the enclosing scope, not in LoopDimInfo. To make _record_pad_info self-contained, we'd need to either propagate begin_expr into the LoopDimInfo for emit_pipeline/fori_loop AND register the loop state earlier, or pass the begin info through a different channel — both add more complexity than the current approach. So I'd prefer to leave this as is

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for giving it a try!

_record_pad_info(state, tensor, tensor_dim, block_id, extra_pad)

# Skip when tile_offset is set (e.g. offset + 64) — the shift
# means the full expression may not be a multiple of block_size.
Expand Down Expand Up @@ -356,6 +357,39 @@ def _ds_expr(
return f"pl.ds({offset}, {block_size})"


def _loop_begin_extra_pad(block_id: int, state: CodegenState) -> int:
"""Return extra padding needed for a non-zero loop begin.

A ``pl.ds(offset, block_size)`` read starting at a non-zero begin can
overshoot the tensor boundary by up to ``begin % block_size`` elements
beyond what ``(-N) % block_size`` accounts for. Returns 0 when the
loop starts at 0, ``begin % block_size`` for a provably constant begin,
or ``block_size - 1`` for a data-dependent begin.
"""
import sympy

from helion._compiler.compile_environment import CompileEnvironment

env = CompileEnvironment.current()
bs_value = env.block_sizes[block_id].from_config(state.device_function.config)
if not isinstance(bs_value, int):
return 0

loops = state.codegen.active_device_loops.get(block_id)
if not loops:
return 0

info = loops[-1].block_id_to_info.get(block_id)
if info is None or info.begin_expr is None:
return 0

begin = info.begin_expr
if isinstance(begin, (int, sympy.Integer)):
return int(begin) % bs_value

return bs_value - 1


def _loop_offset_alignment(
block_id: int,
state: CodegenState,
Expand Down
34 changes: 32 additions & 2 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,30 @@ def _pallas_loop_begin_and_step_exprs(
return begin_exprs, iter_step_exprs, slice_size_exprs


def _compute_pipeline_or_dma_extra_pad(
begin_expr: str,
bid: int,
env: CompileEnvironment,
state: CodegenState,
) -> int:
"""Return extra host-side padding for a pipeline/DMA dim with a non-zero begin.

When ``pl.ds(offset, block_size)`` reads from a tensor whose loop starts
at a non-zero begin, the last block can overshoot the tensor boundary
beyond what ``(-shape) % block_size`` accounts for. The worst case is
``block_size - 1`` extra elements when the begin is data-dependent.

# TODO(dunfanlu): if begin isn't "0" but is another constexpr int,
# we should be able to use a smaller padding than bs-1?
"""
if begin_expr == "0":
return 0
bs_val = env.block_sizes[bid].from_config(state.config)
if isinstance(bs_val, int):
return bs_val - 1
return 0


def _scratch_read(state: CodegenState, sname: str) -> str:
"""Read expression for a scratch buffer, slicing if padded for TPU."""
sl = state.device_function.scratch_read_slice(sname)
Expand Down Expand Up @@ -1378,7 +1402,10 @@ def _make_block_spec(fake: torch.Tensor, subscript_meta: list[object]) -> str:
block_shape_parts.append(slice_size_expr)
from .memory_ops import _record_pad_info

_record_pad_info(state, fake, dim_idx, bid)
extra_pad = _compute_pipeline_or_dma_extra_pad(
begin_expr, bid, env, state
)
_record_pad_info(state, fake, dim_idx, bid, extra_pad)
if begin_expr == "0" and iter_step_expr == slice_size_expr:
lambda_parts.append(lambda_params[bid_idx])
else:
Expand Down Expand Up @@ -1903,7 +1930,10 @@ def _build_hbm_dma_slice(
needs_slice = True
from .memory_ops import _record_pad_info

_record_pad_info(state, fake, dim_idx, bid)
extra_pad = _compute_pipeline_or_dma_extra_pad(
begin_expr, bid, env, state
)
_record_pad_info(state, fake, dim_idx, bid, extra_pad)
elif bid is not None and bid not in block_ids:
# Outer grid dim: use grid offset
grid_loops = state.codegen.active_device_loops.get(bid)
Expand Down
9 changes: 7 additions & 2 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,23 @@ def _record_pad_info(
tensor: torch.Tensor,
tensor_dim: int,
block_id: int,
extra_pad: int = 0,
) -> None:
"""Record that a tensor dimension uses pl.ds() and may need host-side padding.

Note: stores one block_id per (tensor, dim). If two inner loops tile the
*extra_pad* accounts for non-zero loop begins: 0 when the loop starts
at offset 0, ``begin % block_size`` for a constant begin, or
``block_size - 1`` for a data-dependent begin.

Note: stores one entry per (tensor, dim). If two inner loops tile the
same dim with different block_ids, the last one wins. This is fine when
both loops use the same block size (the common case).
"""
pad_info = state.device_function.pallas_pad_info
tensor_id = id(tensor)
if tensor_id not in pad_info:
pad_info[tensor_id] = {}
pad_info[tensor_id][tensor_dim] = block_id
pad_info[tensor_id][tensor_dim] = (block_id, extra_pad)


def _maybe_get_symbol_origin(idx: object) -> SymbolOrigin | None:
Expand Down
27 changes: 13 additions & 14 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _pallas_invoke_and_return(
tensor_arg_indices: list[int],
arg_to_tensor_pos: dict[int, int],
_output_indices: list[int],
_ds_pad_dims: list[tuple[int, int, int]] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_orig_output_tensors: dict[int, torch.Tensor] | None = None,
) -> object:
"""Run the JaxCallable and return output-only results.
Expand Down Expand Up @@ -792,10 +792,10 @@ def _pallas_invoke_and_return(

# Handle padding copy-back and result slicing
if _ds_pad_dims and _orig_output_tensors:
# _ds_pad_dims contains (arg_idx, dim, block_size).
# _ds_pad_dims contains (arg_idx, dim, block_size, extra_pad).
# Build a map from arg_idx → [(dim, ...)] for padded output args.
padded_dims_by_arg: dict[int, list[int]] = {}
for arg_idx, dim, _bs in _ds_pad_dims:
for arg_idx, dim, _bs, _extra in _ds_pad_dims:
if arg_idx in _orig_output_tensors:
padded_dims_by_arg.setdefault(arg_idx, []).append(dim)

Expand Down Expand Up @@ -842,30 +842,29 @@ def _pallas_invoke_and_return(
def _pallas_apply_ds_padding(
args: tuple[object, ...],
_output_indices: list[int],
_ds_pad_dims: list[tuple[int, int, int]],
_ds_pad_dims: list[tuple[int, int, int, int]],
) -> tuple[tuple[object, ...], dict[int, torch.Tensor]]:
"""Pad tensor args along non-divisible pl.ds() dimensions.
"""Pad tensor args so ``pl.ds(offset, block_size)`` never reads OOB.

``_ds_pad_dims`` contains ``(arg_index, dim, block_size)`` tuples.
The pad amount is computed at runtime as ``(-tensor.shape[dim]) % block_size``.
``_ds_pad_dims`` contains ``(arg_index, dim, block_size, extra_pad)``
tuples. The pad amount is ``(-tensor.shape[dim]) % block_size +
extra_pad``, where *extra_pad* accounts for non-zero loop begins.

Returns the padded args tuple and a dict mapping output arg indices
to their original (unpadded) tensors for post-call copy-back.
"""
args_list = list(args)
orig_output_tensors: dict[int, torch.Tensor] = {}
output_set = set(_output_indices)
for arg_idx, dim, block_size in _ds_pad_dims:
for arg_idx, dim, block_size, extra_pad in _ds_pad_dims:
a = args_list[arg_idx]
if not isinstance(a, torch.Tensor):
continue
pad_amount = (-a.shape[dim]) % block_size
pad_amount = (-a.shape[dim]) % block_size + extra_pad
if pad_amount == 0:
continue
if arg_idx in output_set and arg_idx not in orig_output_tensors:
orig_output_tensors[arg_idx] = a
# F.pad takes (last_dim_left, last_dim_right, ..., first_dim_left, first_dim_right).
# To right-pad dimension `dim`, set index 2*(ndim-1-dim) + 1.
pad_widths = [0] * (2 * a.ndim)
pad_widths[2 * (a.ndim - 1 - dim) + 1] = pad_amount
args_list[arg_idx] = torch.nn.functional.pad(a, pad_widths)
Expand All @@ -880,7 +879,7 @@ def default_pallas_launcher(
_inplace_indices: list[int] | None = None,
_block_spec_info: _BlockSpecInfo | None = None,
_smem_arg_indices: list[int] | None = None,
_ds_pad_dims: list[tuple[int, int, int]] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
**kwargs: object,
) -> object:
"""Default launcher for Pallas kernels on TPU (or CPU with interpret=True).
Expand Down Expand Up @@ -1021,7 +1020,7 @@ def default_pallas_pipeline_launcher(
_block_spec_info: _BlockSpecInfo | None = None,
_scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None,
_pipeline_arg_indices: list[int] | None = None,
_ds_pad_dims: list[tuple[int, int, int]] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_smem_arg_indices: list[int] | None = None,
**kwargs: object,
) -> object:
Expand Down Expand Up @@ -1191,7 +1190,7 @@ def default_pallas_fori_launcher(
_inplace_indices: list[int] | None = None,
_block_spec_info: _BlockSpecInfo | None = None,
_scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None,
_ds_pad_dims: list[tuple[int, int, int]] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_smem_arg_indices: list[int] | None = None,
**kwargs: object,
) -> object:
Expand Down
46 changes: 46 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,52 @@ def data_dependent_sum(
ref = torch.stack([data[: lengths[i]].sum() for i in range(B)])
torch.testing.assert_close(result, ref, rtol=1e-4, atol=1e-4)

def test_non_zero_tile_begin(self) -> None:
"""pl.ds() reads from a non-zero begin can overshoot the tensor boundary."""

@helion.kernel(backend="pallas", static_shapes=True)
def sum_with_constant_offset(
data: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
B = offsets.size(0) - 1
out = torch.zeros([B], dtype=data.dtype, device=data.device)
for seg in hl.grid(B):
acc = hl.zeros([1], dtype=data.dtype)
for tile in hl.tile(3, 128, block_size=16):
acc = acc + data[tile, :, :].sum(dim=0).sum(dim=0).sum(
dim=0
).unsqueeze(0)
out[seg] = acc.squeeze(0)
return out

@helion.kernel(backend="pallas", static_shapes=True)
def sum_with_dynamic_offset(
data: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
B = offsets.size(0) - 1
out = torch.zeros([B], dtype=data.dtype, device=data.device)
for seg in hl.grid(B):
start = offsets[seg]
end = offsets[seg + 1]
acc = hl.zeros([1], dtype=data.dtype)
for tile in hl.tile(start, end, block_size=16):
acc = acc + data[tile, :, :].sum(dim=0).sum(dim=0).sum(
dim=0
).unsqueeze(0)
out[seg] = acc.squeeze(0)
return out

N, A, B = 128, 8, 256
data = torch.randn(N, A, B, device=DEVICE, dtype=torch.float32)
offsets = torch.tensor([3, 128], device=DEVICE, dtype=torch.int32)
ref = data[3:128].sum().unsqueeze(0)

_code1, result1 = code_and_output(sum_with_constant_offset, (data, offsets))
torch.testing.assert_close(result1, ref, rtol=1e-3, atol=1e-3)

_code2, result2 = code_and_output(sum_with_dynamic_offset, (data, offsets))
torch.testing.assert_close(result2, ref, rtol=1e-3, atol=1e-3)


@skipUnlessPallas("JAX/Pallas TPU not available")
class TestPallasIndirectGather(TestCase):
Expand Down
Loading