diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 6e26136ec9..d79ebcdc88 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1567,15 +1567,19 @@ def _compute_pad_info( self, sorted_args: list[Argument] | None, config: Config, - ) -> list[tuple[int, int, int]] | None: + ) -> list[tuple[int, int, int, int]] | None: """Identify pl.ds() dims that may need padding and their block sizes. Uses ``pallas_pad_info`` recorded during codegen to identify which tensor dimensions use ``pl.ds()`` slicing. - Returns ``[(arg_index, tensor_dim, block_size), ...]`` or ``None``. - The launcher computes the actual pad amount at runtime as - ``(-tensor.shape[dim]) % block_size``. + Returns ``[(arg_index, tensor_dim, block_size, extra_pad), ...]`` + or ``None``. The launcher computes the actual pad amount at runtime + as ``(-tensor.shape[dim]) % block_size + extra_pad``. + + ``extra_pad`` is 0 when the tile loop starts at offset 0, + ``begin % block_size`` for a constant begin offset, or + ``block_size - 1`` for a data-dependent begin. """ if sorted_args is None: return None @@ -1589,17 +1593,17 @@ def _compute_pad_info( if not device_fn.pallas_pad_info: return None - result: list[tuple[int, int, int]] = [] + result: list[tuple[int, int, int, int]] = [] for i, arg in enumerate(sorted_args): if not isinstance(arg, TensorArg): continue dims_info = device_fn.pallas_pad_info.get(id(arg.fake_value)) if dims_info is not None: - for dim, block_id in dims_info.items(): + for dim, (block_id, extra_pad) in dims_info.items(): bsi = env.block_sizes[block_id] bs = bsi.from_config(config) if isinstance(bs, int) and bs > 1: - result.append((i, dim, bs)) + result.append((i, dim, bs, extra_pad)) return result or None diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 2f1b8ca103..e62224dc0b 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -446,9 +446,9 @@ def __init__( # dict would then need to support multiple entries per tensor # or the tensor would get distinct arg IDs per memory space. self.pallas_memory_space: dict[int, PallasMemorySpace] = {} - # Pallas: id(fake_tensor) → {dim: block_id} for dims using pl.ds() - # that may need host-side padding when block size doesn't divide dim. - self.pallas_pad_info: dict[int, dict[int, int]] = {} + # Pallas: id(fake_tensor) → {dim: (block_id, extra_pad)} for dims + # using pl.ds() that may need host-side padding. + self.pallas_pad_info: dict[int, dict[int, tuple[int, int]]] = {} def allocate_store_index(self) -> int: """Bump store counters and return the indexing strategy slot.""" diff --git a/helion/_compiler/pallas/codegen.py b/helion/_compiler/pallas/codegen.py index 43c988ad5f..e1439e7bed 100644 --- a/helion/_compiler/pallas/codegen.py +++ b/helion/_compiler/pallas/codegen.py @@ -327,7 +327,8 @@ def _ds_expr( if tensor is not None and tensor_dim is not None: from helion.language.memory_ops import _record_pad_info - _record_pad_info(state, tensor, tensor_dim, block_id) + extra_pad = _loop_begin_extra_pad(block_id, state) + _record_pad_info(state, tensor, tensor_dim, block_id, extra_pad) # Skip when tile_offset is set (e.g. offset + 64) — the shift # means the full expression may not be a multiple of block_size. @@ -356,6 +357,39 @@ def _ds_expr( return f"pl.ds({offset}, {block_size})" +def _loop_begin_extra_pad(block_id: int, state: CodegenState) -> int: + """Return extra padding needed for a non-zero loop begin. + + A ``pl.ds(offset, block_size)`` read starting at a non-zero begin can + overshoot the tensor boundary by up to ``begin % block_size`` elements + beyond what ``(-N) % block_size`` accounts for. Returns 0 when the + loop starts at 0, ``begin % block_size`` for a provably constant begin, + or ``block_size - 1`` for a data-dependent begin. + """ + import sympy + + from helion._compiler.compile_environment import CompileEnvironment + + env = CompileEnvironment.current() + bs_value = env.block_sizes[block_id].from_config(state.device_function.config) + if not isinstance(bs_value, int): + return 0 + + loops = state.codegen.active_device_loops.get(block_id) + if not loops: + return 0 + + info = loops[-1].block_id_to_info.get(block_id) + if info is None or info.begin_expr is None: + return 0 + + begin = info.begin_expr + if isinstance(begin, (int, sympy.Integer)): + return int(begin) % bs_value + + return bs_value - 1 + + def _loop_offset_alignment( block_id: int, state: CodegenState, diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 9a3315346c..4f52c99308 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -433,6 +433,30 @@ def _pallas_loop_begin_and_step_exprs( return begin_exprs, iter_step_exprs, slice_size_exprs +def _compute_pipeline_or_dma_extra_pad( + begin_expr: str, + bid: int, + env: CompileEnvironment, + state: CodegenState, +) -> int: + """Return extra host-side padding for a pipeline/DMA dim with a non-zero begin. + + When ``pl.ds(offset, block_size)`` reads from a tensor whose loop starts + at a non-zero begin, the last block can overshoot the tensor boundary + beyond what ``(-shape) % block_size`` accounts for. The worst case is + ``block_size - 1`` extra elements when the begin is data-dependent. + + # TODO(dunfanlu): if begin isn't "0" but is another constexpr int, + # we should be able to use a smaller padding than bs-1? + """ + if begin_expr == "0": + return 0 + bs_val = env.block_sizes[bid].from_config(state.config) + if isinstance(bs_val, int): + return bs_val - 1 + return 0 + + def _scratch_read(state: CodegenState, sname: str) -> str: """Read expression for a scratch buffer, slicing if padded for TPU.""" sl = state.device_function.scratch_read_slice(sname) @@ -1378,7 +1402,10 @@ def _make_block_spec(fake: torch.Tensor, subscript_meta: list[object]) -> str: block_shape_parts.append(slice_size_expr) from .memory_ops import _record_pad_info - _record_pad_info(state, fake, dim_idx, bid) + extra_pad = _compute_pipeline_or_dma_extra_pad( + begin_expr, bid, env, state + ) + _record_pad_info(state, fake, dim_idx, bid, extra_pad) if begin_expr == "0" and iter_step_expr == slice_size_expr: lambda_parts.append(lambda_params[bid_idx]) else: @@ -1903,7 +1930,10 @@ def _build_hbm_dma_slice( needs_slice = True from .memory_ops import _record_pad_info - _record_pad_info(state, fake, dim_idx, bid) + extra_pad = _compute_pipeline_or_dma_extra_pad( + begin_expr, bid, env, state + ) + _record_pad_info(state, fake, dim_idx, bid, extra_pad) elif bid is not None and bid not in block_ids: # Outer grid dim: use grid offset grid_loops = state.codegen.active_device_loops.get(bid) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 8a626ca5b5..6e5602df77 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -172,10 +172,15 @@ def _record_pad_info( tensor: torch.Tensor, tensor_dim: int, block_id: int, + extra_pad: int = 0, ) -> None: """Record that a tensor dimension uses pl.ds() and may need host-side padding. - Note: stores one block_id per (tensor, dim). If two inner loops tile the + *extra_pad* accounts for non-zero loop begins: 0 when the loop starts + at offset 0, ``begin % block_size`` for a constant begin, or + ``block_size - 1`` for a data-dependent begin. + + Note: stores one entry per (tensor, dim). If two inner loops tile the same dim with different block_ids, the last one wins. This is fine when both loops use the same block size (the common case). """ @@ -183,7 +188,7 @@ def _record_pad_info( tensor_id = id(tensor) if tensor_id not in pad_info: pad_info[tensor_id] = {} - pad_info[tensor_id][tensor_dim] = block_id + pad_info[tensor_id][tensor_dim] = (block_id, extra_pad) def _maybe_get_symbol_origin(idx: object) -> SymbolOrigin | None: diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 14e8fc0441..ad75bcd84d 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -748,7 +748,7 @@ def _pallas_invoke_and_return( tensor_arg_indices: list[int], arg_to_tensor_pos: dict[int, int], _output_indices: list[int], - _ds_pad_dims: list[tuple[int, int, int]] | None = None, + _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _orig_output_tensors: dict[int, torch.Tensor] | None = None, ) -> object: """Run the JaxCallable and return output-only results. @@ -792,10 +792,10 @@ def _pallas_invoke_and_return( # Handle padding copy-back and result slicing if _ds_pad_dims and _orig_output_tensors: - # _ds_pad_dims contains (arg_idx, dim, block_size). + # _ds_pad_dims contains (arg_idx, dim, block_size, extra_pad). # Build a map from arg_idx → [(dim, ...)] for padded output args. padded_dims_by_arg: dict[int, list[int]] = {} - for arg_idx, dim, _bs in _ds_pad_dims: + for arg_idx, dim, _bs, _extra in _ds_pad_dims: if arg_idx in _orig_output_tensors: padded_dims_by_arg.setdefault(arg_idx, []).append(dim) @@ -842,12 +842,13 @@ def _pallas_invoke_and_return( def _pallas_apply_ds_padding( args: tuple[object, ...], _output_indices: list[int], - _ds_pad_dims: list[tuple[int, int, int]], + _ds_pad_dims: list[tuple[int, int, int, int]], ) -> tuple[tuple[object, ...], dict[int, torch.Tensor]]: - """Pad tensor args along non-divisible pl.ds() dimensions. + """Pad tensor args so ``pl.ds(offset, block_size)`` never reads OOB. - ``_ds_pad_dims`` contains ``(arg_index, dim, block_size)`` tuples. - The pad amount is computed at runtime as ``(-tensor.shape[dim]) % block_size``. + ``_ds_pad_dims`` contains ``(arg_index, dim, block_size, extra_pad)`` + tuples. The pad amount is ``(-tensor.shape[dim]) % block_size + + extra_pad``, where *extra_pad* accounts for non-zero loop begins. Returns the padded args tuple and a dict mapping output arg indices to their original (unpadded) tensors for post-call copy-back. @@ -855,17 +856,15 @@ def _pallas_apply_ds_padding( args_list = list(args) orig_output_tensors: dict[int, torch.Tensor] = {} output_set = set(_output_indices) - for arg_idx, dim, block_size in _ds_pad_dims: + for arg_idx, dim, block_size, extra_pad in _ds_pad_dims: a = args_list[arg_idx] if not isinstance(a, torch.Tensor): continue - pad_amount = (-a.shape[dim]) % block_size + pad_amount = (-a.shape[dim]) % block_size + extra_pad if pad_amount == 0: continue if arg_idx in output_set and arg_idx not in orig_output_tensors: orig_output_tensors[arg_idx] = a - # F.pad takes (last_dim_left, last_dim_right, ..., first_dim_left, first_dim_right). - # To right-pad dimension `dim`, set index 2*(ndim-1-dim) + 1. pad_widths = [0] * (2 * a.ndim) pad_widths[2 * (a.ndim - 1 - dim) + 1] = pad_amount args_list[arg_idx] = torch.nn.functional.pad(a, pad_widths) @@ -880,7 +879,7 @@ def default_pallas_launcher( _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, - _ds_pad_dims: list[tuple[int, int, int]] | None = None, + _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, **kwargs: object, ) -> object: """Default launcher for Pallas kernels on TPU (or CPU with interpret=True). @@ -1021,7 +1020,7 @@ def default_pallas_pipeline_launcher( _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, _pipeline_arg_indices: list[int] | None = None, - _ds_pad_dims: list[tuple[int, int, int]] | None = None, + _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, ) -> object: @@ -1191,7 +1190,7 @@ def default_pallas_fori_launcher( _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, - _ds_pad_dims: list[tuple[int, int, int]] | None = None, + _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, ) -> object: diff --git a/test/test_pallas.py b/test/test_pallas.py index dd399e904f..d2b706a2b5 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -2015,6 +2015,52 @@ def data_dependent_sum( ref = torch.stack([data[: lengths[i]].sum() for i in range(B)]) torch.testing.assert_close(result, ref, rtol=1e-4, atol=1e-4) + def test_non_zero_tile_begin(self) -> None: + """pl.ds() reads from a non-zero begin can overshoot the tensor boundary.""" + + @helion.kernel(backend="pallas", static_shapes=True) + def sum_with_constant_offset( + data: torch.Tensor, offsets: torch.Tensor + ) -> torch.Tensor: + B = offsets.size(0) - 1 + out = torch.zeros([B], dtype=data.dtype, device=data.device) + for seg in hl.grid(B): + acc = hl.zeros([1], dtype=data.dtype) + for tile in hl.tile(3, 128, block_size=16): + acc = acc + data[tile, :, :].sum(dim=0).sum(dim=0).sum( + dim=0 + ).unsqueeze(0) + out[seg] = acc.squeeze(0) + return out + + @helion.kernel(backend="pallas", static_shapes=True) + def sum_with_dynamic_offset( + data: torch.Tensor, offsets: torch.Tensor + ) -> torch.Tensor: + B = offsets.size(0) - 1 + out = torch.zeros([B], dtype=data.dtype, device=data.device) + for seg in hl.grid(B): + start = offsets[seg] + end = offsets[seg + 1] + acc = hl.zeros([1], dtype=data.dtype) + for tile in hl.tile(start, end, block_size=16): + acc = acc + data[tile, :, :].sum(dim=0).sum(dim=0).sum( + dim=0 + ).unsqueeze(0) + out[seg] = acc.squeeze(0) + return out + + N, A, B = 128, 8, 256 + data = torch.randn(N, A, B, device=DEVICE, dtype=torch.float32) + offsets = torch.tensor([3, 128], device=DEVICE, dtype=torch.int32) + ref = data[3:128].sum().unsqueeze(0) + + _code1, result1 = code_and_output(sum_with_constant_offset, (data, offsets)) + torch.testing.assert_close(result1, ref, rtol=1e-3, atol=1e-3) + + _code2, result2 = code_and_output(sum_with_dynamic_offset, (data, offsets)) + torch.testing.assert_close(result2, ref, rtol=1e-3, atol=1e-3) + @skipUnlessPallas("JAX/Pallas TPU not available") class TestPallasIndirectGather(TestCase):