diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index fed5c90c4..273408ab0 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1657,11 +1657,15 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: # These must use VMEM BlockSpecs. Output-only tensors (written but # never read) get HBM in_specs to avoid VMEM pressure. inplace_indices: list[int] = [] + # Atomic-write outputs. Launcher uses these for "arbitrary" grid dim + # semantics so split-K cells serialise. + atomic_indices: list[int] = [] if sorted_args is not None: env = CompileEnvironment.current() host_fn = HostFunction.current() read_names, write_names = device_fn.get_tensor_read_write_names() mutated_params = write_names & {a.arg for a in host_fn.args.args} + atomic_target_names = env.atomic_target_host_names input_storages = {id(t.untyped_storage()) for t in env.input_sources} # Only tensors allocated with torch.empty/empty_like/new_empty can be # output-only — their initial values are undefined, so it's safe @@ -1672,16 +1676,20 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: for i, arg in enumerate(sorted_args): if not isinstance(arg, TensorArg): continue + name = arg.host_str() + is_atomic_target = name in atomic_target_names if id(arg.fake_value.untyped_storage()) not in input_storages: # Tensor created inside the function body (output) output_indices.append(i) - if arg.host_str() in read_names or arg.host_str() not in empty_vars: + if name in read_names or name not in empty_vars: # Also read by the kernel (e.g. broadcast result) inplace_indices.append(i) - elif arg.host_str() in mutated_params: - # Input tensor mutated in-place + elif name in mutated_params or is_atomic_target: + # Mutated in-place (subscript or atomic). output_indices.append(i) inplace_indices.append(i) + if is_atomic_target and i in output_indices: + atomic_indices.append(i) # Collect output-only tensor names so codegen can retarget their # allocations to ``device='meta'`` and capture the launcher return. @@ -1697,6 +1705,8 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: launcher_args = [*args, f"_output_indices={output_indices}"] launcher_args.append(f"_inplace_indices={inplace_indices}") + if atomic_indices: + launcher_args.append(f"_atomic_indices={atomic_indices}") if has_rng_ops: launcher_args.insert(-1, "_rng_seed_buffer") diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 322fc77e9..c45bfa5b1 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -200,6 +200,9 @@ def __init__( self.config_spec.max_num_sm_multiplier = newmax self.has_barrier: bool = False + # Host names of ``hl.atomic_*`` first-arg targets; written by atomic + # codegen, read by the Pallas launcher prep. + self.atomic_target_host_names: set[str] = set() def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr: """Substitute any specialized vars with their concrete values.""" diff --git a/helion/language/atomic_ops.py b/helion/language/atomic_ops.py index 99a3dae72..f14ba6fc5 100644 --- a/helion/language/atomic_ops.py +++ b/helion/language/atomic_ops.py @@ -11,6 +11,7 @@ from .. import exc from .._compiler.ast_extension import expr_from_string +from .._compiler.compile_environment import CompileEnvironment from .._compiler.compile_environment import _symint_expr from .._compiler.host_function import HostFunction from .._compiler.indexing_strategy import SubscriptIndexing @@ -716,7 +717,9 @@ def _pallas_atomic_load_prev( if target not in host_function.tensor_to_origin: raise exc.AtomicOnDeviceTensor("pallas atomic") - name = state.device_function.tensor_arg(target).name + tensor_arg = state.device_function.tensor_arg(target) + name = tensor_arg.name + CompileEnvironment.current().atomic_target_host_names.add(tensor_arg.host_str()) index_str, _ = pallas_codegen.index_str(state, index, target) prev_var = state.device_function.new_var("_prev", dce=True) @@ -962,15 +965,16 @@ def _(state: CodegenState) -> ast.AST: name, index_str, prev_var = _pallas_atomic_load_prev(state) value_ast = _to_ast_values([state.ast_args[2]])[0] - target = state.proxy_arg(0) - assert isinstance(target, torch.Tensor) + # Cast value to the ref's dtype, not the target tensor's dtype, so the + # RMW stays in scratch dtype when ``runtime._apply_atomic_accumulator`` + # promotes a bf16/f16 target to an f32 VMEM scratch. backend = CompileEnvironment.current().backend - target_dtype = backend.dtype_str(target.dtype) - # Cast the sum to the target dtype so the store doesn't fail when - # the value dtype differs (e.g. float32 accumulator into bfloat16 ref). - cast = backend.cast_expr(f"{prev_var} + {{value}}", target_dtype) + cast = backend.cast_expr("{value}", f"{prev_var}.dtype") state.codegen.add_statement( - statement_from_string(f"{name}[{index_str}] = {cast}", value=value_ast) + statement_from_string( + f"{name}[{index_str}] = {prev_var} + {cast}", + value=value_ast, + ) ) return expr_from_string(prev_var) diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 85787a759..280cdff55 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -14,6 +14,8 @@ from .. import _compat as _compat # ensure Triton compatibility patches run from .. import exc from .._utils import triton_is_available +from ._pallas_rmw import _grid_rmw_apply +from ._pallas_rmw import _grid_rmw_plan from .config import Config as Config from .kernel import Kernel as Kernel from .kernel import kernel as kernel @@ -880,6 +882,7 @@ def default_pallas_launcher( _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, + _atomic_indices: list[int] | None = None, **kwargs: object, ) -> object: """Default launcher for Pallas kernels on TPU (or CPU with interpret=True). @@ -892,7 +895,9 @@ def default_pallas_launcher( Output-only tensors (in ``_output_indices`` but not in ``_inplace_indices``) are excluded from pallas_call inputs to save VMEM. Their results are - returned as torch tensors. + returned as torch tensors. ``_atomic_indices`` lists outputs written via + ``hl.atomic_*``; see :func:`_grid_rmw_plan` for the multi-cell VMEM + scratch path. """ if _output_indices is None: _output_indices = [] @@ -922,6 +927,7 @@ def default_pallas_launcher( inplace_positions, out_shapes, ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) + output_only_set: set[int] = set(output_only_indices) in_specs, out_specs = _pallas_build_block_specs( pl, @@ -936,6 +942,10 @@ def default_pallas_launcher( output_only_indices, ) + dim_sem, scratch_tiles = _grid_rmw_plan( + grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos + ) + reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, @@ -946,6 +956,19 @@ def default_pallas_launcher( inplace_positions, arg_to_tensor_pos, _smem_arg_indices=_smem_arg_indices, + skip_inplace_copy=output_only_set | set(scratch_tiles), + ) + + reordered_kernel, scratch_shapes = _grid_rmw_apply( + reordered_kernel, + scratch_tiles, + dim_sem, + args, + arg_to_tensor_pos, + _output_indices, + n_tensor_inputs, + grid, + pltpu, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -961,7 +984,7 @@ def default_pallas_launcher( pltpu, in_specs, out_specs, - None, + scratch_shapes, args, tensor_arg_indices, _output_indices, @@ -974,16 +997,34 @@ def default_pallas_launcher( f"Estimated {estimated_vmem / 1e6:.2f}MB exceeds {vmem_limit_bytes / 1e6:.2f}MB vmem capacity." ) - pallas_call_kwargs: dict[str, object] = { - "out_shape": out_shape_arg, - "input_output_aliases": pallas_aliases, - "grid": grid, - } + if scratch_shapes: + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=list(in_specs) if in_specs is not None else [], + out_specs=out_specs if out_specs is not None else [], + scratch_shapes=scratch_shapes, # pyrefly: ignore[bad-argument-type] + grid=grid, + ) + pallas_call_kwargs: dict[str, object] = { + "out_shape": out_shape_arg, + "input_output_aliases": pallas_aliases, + "grid_spec": grid_spec, + "compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] + dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type] + ), + } + else: + pallas_call_kwargs = { + "out_shape": out_shape_arg, + "input_output_aliases": pallas_aliases, + "grid": grid, + } + if in_specs is not None: + pallas_call_kwargs["in_specs"] = in_specs + pallas_call_kwargs["out_specs"] = out_specs + if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True - if in_specs is not None: - pallas_call_kwargs["in_specs"] = in_specs - pallas_call_kwargs["out_specs"] = out_specs jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] @@ -1022,13 +1063,14 @@ def default_pallas_pipeline_launcher( _pipeline_arg_indices: list[int] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, + _atomic_indices: list[int] | None = None, **kwargs: object, ) -> object: """Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory. - Used when ``pallas_loop_type='emit_pipeline'``. Pipeline-body tensors - (listed in ``_pipeline_arg_indices``) use HBM refs; all other tensors - get proper BlockSpecs for automatic VMEM prefetch. + Pipeline-body tensors (``_pipeline_arg_indices``) use HBM refs; the + rest get BlockSpecs for automatic VMEM prefetch. ``_atomic_indices``: + see :func:`default_pallas_launcher`. """ if _output_indices is None: _output_indices = [] @@ -1095,7 +1137,10 @@ def default_pallas_pipeline_launcher( smem_arg_indices=_smem_arg_indices, ) - _pipeline_set = set(_pipeline_arg_indices or []) + _pipeline_set: set[int] = set(_pipeline_arg_indices or []) + dim_sem, scratch_tiles = _grid_rmw_plan( + grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos + ) reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, @@ -1106,9 +1151,21 @@ def default_pallas_pipeline_launcher( inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), - skip_inplace_copy=_pipeline_set, + skip_inplace_copy=_pipeline_set | set(scratch_tiles), _smem_arg_indices=_smem_arg_indices, ) + reordered_kernel, rmw_extras = _grid_rmw_apply( + reordered_kernel, + scratch_tiles, + dim_sem, + args, + arg_to_tensor_pos, + _output_indices, + n_tensor_inputs, + grid, + pltpu, + ) + scratch_shapes.extend(rmw_extras) # pyrefly: ignore[bad-argument-type] out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -1149,7 +1206,7 @@ def default_pallas_pipeline_launcher( "input_output_aliases": pallas_aliases, "grid_spec": grid_spec, "compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] - dimension_semantics=tuple("parallel" for _ in grid), + dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type] ), } if _pallas_interpret_flag(): @@ -1192,15 +1249,15 @@ def default_pallas_fori_launcher( _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, _ds_pad_dims: list[tuple[int, int, int, int]] | None = None, _smem_arg_indices: list[int] | None = None, + _atomic_indices: list[int] | None = None, **kwargs: object, ) -> object: """Launcher for Pallas kernels using fori_loop with manual DMA. - Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as - ``memory_space=pl.ANY`` (HBM refs) and adds scratch buffers as - ``pltpu.VMEM`` shapes plus ``pltpu.SemaphoreType.DMA`` for async copies. - The kernel uses ``jax.lax.fori_loop`` with ``pltpu.make_async_copy`` - internally for DMA control. + All tensors are HBM refs; the kernel drives DMA via + ``pltpu.make_async_copy`` inside ``jax.lax.fori_loop``. Scratch + buffers are ``pltpu.VMEM`` plus ``pltpu.SemaphoreType.DMA`` for the + async copies. ``_atomic_indices``: see :func:`default_pallas_launcher`. """ if _output_indices is None: _output_indices = [] @@ -1249,6 +1306,7 @@ def default_pallas_fori_launcher( # Build in_specs/out_specs: proper BlockSpecs for outer grid dims, # HBM refs for tensors used in the fori_loop body (DMA handles tiling). _fori_pipeline_indices = kwargs.get("_pipeline_arg_indices") + _fori_pipeline_set: set[int] = set(_fori_pipeline_indices or []) # type: ignore[arg-type] assert _block_spec_info is not None, ( "fori_loop launcher requires _block_spec_info from codegen" ) @@ -1266,7 +1324,9 @@ def default_pallas_fori_launcher( smem_arg_indices=_smem_arg_indices, ) - _fori_pipeline_set = set(_fori_pipeline_indices or []) # type: ignore[arg-type] + dim_sem, scratch_tiles = _grid_rmw_plan( + grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos + ) reordered_kernel = _pallas_make_reordered_kernel( pallas_kernel, args, @@ -1277,9 +1337,21 @@ def default_pallas_fori_launcher( inplace_positions, arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), - skip_inplace_copy=_fori_pipeline_set, + skip_inplace_copy=_fori_pipeline_set | set(scratch_tiles), _smem_arg_indices=_smem_arg_indices, ) + reordered_kernel, rmw_extras = _grid_rmw_apply( + reordered_kernel, + scratch_tiles, + dim_sem, + args, + arg_to_tensor_pos, + _output_indices, + n_tensor_inputs, + grid, + pltpu, + ) + scratch_shapes.extend(rmw_extras) # pyrefly: ignore[bad-argument-type] out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -1320,7 +1392,7 @@ def default_pallas_fori_launcher( "input_output_aliases": pallas_aliases, "grid_spec": grid_spec, "compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation] - dimension_semantics=tuple("parallel" for _ in grid), + dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type] ), } if _pallas_interpret_flag(): diff --git a/helion/runtime/_pallas_rmw.py b/helion/runtime/_pallas_rmw.py new file mode 100644 index 000000000..7cb7e3e7b --- /dev/null +++ b/helion/runtime/_pallas_rmw.py @@ -0,0 +1,165 @@ +"""Multi-cell read-modify-write accumulator for Pallas atomic targets. + +`hl.atomic_*` on Pallas is not a hardware atomic; the TensorCore has no +concurrent writers to a given HBM cell. The semantics we need are +"accumulate across grid cells of the same output tile". This module +plans and wraps that pattern at the launcher level. + +The flow: + +1. :func:`_grid_rmw_plan` inspects the atomic target list and the + ``_BlockSpecInfo`` from compile time. For each target it finds grid + axes that are unmapped to the target's tile and have ``grid[i] > 1``; + those drive accumulation and become ``"arbitrary"``. Returns + ``(dim_sem, scratch_tiles)``. + +2. The launcher feeds ``set(scratch_tiles)`` into ``skip_inplace_copy`` + so the reordered kernel's HBM->VMEM refresh doesn't clobber the + accumulator. + +3. :func:`_grid_rmw_apply` allocates a persistent VMEM scratch per RMW + target (f32 for bf16/f16 to avoid per-cell rounding) and wraps the + kernel: on the first ``"arbitrary"`` cell, copy HBM into the scratch; + substitute the scratch ref for the out_ref inside the kernel body; + on the last cell, commit the scratch back to HBM. + +4. The launcher passes the wrap to ``pl.pallas_call`` with the returned + scratches in ``scratch_shapes`` and ``dim_sem`` in + ``compiler_params.dimension_semantics``. + +Single TensorCore only. ``pl.core_map`` would need a cross-core barrier. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any +from typing import cast + +import torch + +if TYPE_CHECKING: + from . import _BlockSpecInfo + + +def _grid_rmw_plan( + grid: tuple[int, ...], + args: tuple[object, ...], + rmw_indices: list[int] | None, + block_spec_info: _BlockSpecInfo | None, + arg_to_tensor_pos: dict[int, int], +) -> tuple[tuple[str, ...], dict[int, tuple[int, ...]]]: + """Plan VMEM scratches for RMW outputs hit by multiple grid cells. + + Returns ``(dim_sem, scratch_tiles)``. ``scratch_tiles`` maps each + ``orig_pos`` that needs a scratch to its tile shape; empty when no + output needs accumulation, in which case ``dim_sem`` is all + ``"parallel"``. + """ + n = len(grid) + parallel = tuple("parallel" for _ in grid) + if not rmw_indices or block_spec_info is None: + return parallel, {} + + arb: set[int] = set() + scratch_tiles: dict[int, tuple[int, ...]] = {} + for orig_pos in rmw_indices: + tpos = arg_to_tensor_pos.get(orig_pos) + if tpos is None or tpos >= len(block_spec_info): + continue + info = block_spec_info[tpos] + if info is None: + continue + bshape, grid_mapping = info + mapped = {gd for gd in grid_mapping if isinstance(gd, int)} + unmapped = {i for i in range(n) if i not in mapped and grid[i] > 1} + if not unmapped: + continue + arb |= unmapped + t = args[orig_pos] + assert isinstance(t, torch.Tensor) + scratch_tiles[orig_pos] = tuple( + bs if bs is not None else t.shape[d] for d, bs in enumerate(bshape) + ) + if not scratch_tiles: + return parallel, {} + dim_sem = tuple("arbitrary" if i in arb else "parallel" for i in range(n)) + return dim_sem, scratch_tiles + + +def _grid_rmw_apply( + inner: object, + scratch_tiles: dict[int, tuple[int, ...]], + dim_sem: tuple[str, ...], + args: tuple[object, ...], + arg_to_tensor_pos: dict[int, int], + output_indices: list[int], + n_tensor_inputs: int, + grid: tuple[int, ...], + pltpu: object, +) -> tuple[object, list[object]]: + """Wrap *inner* with cross-cell scratch substitution. + + Returns ``(wrapped_kernel, extras)``. No-op when ``scratch_tiles`` + is empty: returns ``(inner, [])``. + + bf16/f16 targets get an f32 scratch so per-cell sums don't round. + The atomic codegen casts the value to ``prev.dtype`` so the in-kernel + RMW stays in scratch dtype. ``convert_element_type`` is a no-op + when dtypes match. + """ + if not scratch_tiles: + return inner, [] + + from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime + + extras: list[object] = [] + out_to_extra: dict[int, int] = {} + out_to_in: dict[int, int] = {} + for out_idx, orig_pos in enumerate(output_indices): + tile = scratch_tiles.get(orig_pos) + if tile is None: + continue + t = cast("torch.Tensor", args[orig_pos]) + sd = torch.float32 if t.dtype in (torch.bfloat16, torch.float16) else t.dtype + extras.append(pltpu.VMEM(tile, torch_dtype_to_jax_runtime(sd))) # type: ignore[union-attr] + out_to_extra[out_idx] = len(out_to_extra) + out_to_in[out_idx] = arg_to_tensor_pos[orig_pos] + + arb_dims = [i for i, s in enumerate(dim_sem) if s == "arbitrary"] + n_extras = len(extras) + + def wrapped(*refs: object) -> None: + from jax import lax + from jax.experimental import pallas as pl + import jax.numpy as jnp + + inner_end = len(refs) - n_extras + scratches = {oi: refs[inner_end + off] for oi, off in out_to_extra.items()} + patched = list(refs[:inner_end]) + originals: dict[int, object] = {} + for oi, s in scratches.items(): + slot = n_tensor_inputs + oi + originals[oi] = patched[slot] + patched[slot] = s + + is_first: Any = jnp.bool_(True) + is_last: Any = jnp.bool_(True) + for d in arb_dims: + is_first = is_first & (pl.program_id(d) == 0) + is_last = is_last & (pl.program_id(d) == (grid[d] - 1)) + + @pl.when(is_first) # type: ignore[arg-type] + def _init() -> None: + for oi, s in scratches.items(): + in_ref = refs[out_to_in[oi]] + s[...] = lax.convert_element_type(in_ref[...], s.dtype) # type: ignore[index,attr-defined] + + inner(*patched) # type: ignore[operator] + + @pl.when(is_last) # type: ignore[arg-type] + def _commit() -> None: + for oi, s in scratches.items(): + originals[oi][...] = lax.convert_element_type(s[...], originals[oi].dtype) # type: ignore[index,attr-defined] + + return wrapped, extras diff --git a/test/test_atomic_ops.py b/test/test_atomic_ops.py index 8c93c5def..c0de86019 100644 --- a/test/test_atomic_ops.py +++ b/test/test_atomic_ops.py @@ -69,6 +69,33 @@ def split_k_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +@helion.kernel(static_shapes=True) +def split_k_atomic_max_kernel(x: torch.Tensor) -> torch.Tensor: + """Split-K reduction where each K-tile contributes its max via atomic_max.""" + m, k = x.size() + out = torch.full([m], -1e30, dtype=x.dtype, device=x.device) + for tile_m, tile_k in hl.tile([m, k]): + block = x[tile_m, tile_k] + per_row = torch.amax(block, dim=1) + hl.atomic_max(out, [tile_m], per_row) + return out + + +@helion.kernel(static_shapes=True) +def split_k_multi_atomic_kernel( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Two atomic outputs in one kernel, both needing the scratch accumulator.""" + m, k = x.size() + sum_out = torch.zeros([m], dtype=x.dtype, device=x.device) + sumsq_out = torch.zeros([m], dtype=x.dtype, device=x.device) + for tile_m, tile_k in hl.tile([m, k]): + block = x[tile_m, tile_k] + hl.atomic_add(sum_out, [tile_m], torch.sum(block, dim=1)) + hl.atomic_add(sumsq_out, [tile_m], torch.sum(block * block, dim=1)) + return sum_out, sumsq_out + + @helion.kernel() def atomic_add_f32_into_bf16_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Test atomic_add where value dtype (float32) differs from output (bfloat16).""" @@ -454,6 +481,149 @@ def test_split_k_atomic_add_vmem_preload(self): ) torch.testing.assert_close(result, expected, atol=0.1, rtol=0.05) + def _run_split_k_atomic_max(self, loop_type: str | None) -> None: + m, k = 128, 512 + x = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + kwargs = {"pallas_loop_type": loop_type} if loop_type else {} + _, result = code_and_output( + split_k_atomic_max_kernel, + (x,), + block_sizes=[32, 128], + **kwargs, + ) + torch.testing.assert_close(result, torch.amax(x, dim=1)) + + @onlyBackends(["pallas"]) + def test_split_k_atomic_max_fori(self): + self._run_split_k_atomic_max("fori_loop") + + @onlyBackends(["pallas"]) + def test_split_k_atomic_max_pipeline(self): + self._run_split_k_atomic_max("emit_pipeline") + + @onlyBackends(["pallas"]) + def test_split_k_atomic_max_default(self): + self._run_split_k_atomic_max(None) + + def _run_split_k_multi_atomic(self, loop_type: str | None) -> None: + m, k = 128, 512 + x = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + kwargs = {"pallas_loop_type": loop_type} if loop_type else {} + _, (sum_out, sumsq_out) = code_and_output( + split_k_multi_atomic_kernel, + (x,), + block_sizes=[32, 128], + **kwargs, + ) + torch.testing.assert_close(sum_out, torch.sum(x, dim=1)) + torch.testing.assert_close(sumsq_out, torch.sum(x * x, dim=1)) + + @onlyBackends(["pallas"]) + def test_split_k_multi_atomic_outputs_fori(self): + self._run_split_k_multi_atomic("fori_loop") + + @onlyBackends(["pallas"]) + def test_split_k_multi_atomic_outputs_pipeline(self): + self._run_split_k_multi_atomic("emit_pipeline") + + @onlyBackends(["pallas"]) + def test_split_k_multi_atomic_outputs_default(self): + self._run_split_k_multi_atomic(None) + + def _run_split_k_add( + self, + loop_type: str | None, + m: int = 128, + k: int = 512, + n: int = 128, + block_m: int = 32, + block_n: int = 128, + block_k: int = 128, + inner_k: int = 128, + ) -> None: + """Split-K matmul via atomic_add, verify on all launchers.""" + x = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + y = torch.randn(k, n, device=DEVICE, dtype=torch.float32) + kwargs = {"pallas_loop_type": loop_type} if loop_type else {} + _, result = code_and_output( + split_k_atomic_add_kernel, + (x, y), + block_sizes=[block_m, block_n, block_k, inner_k], + **kwargs, + ) + # out starts as ones (see kernel), plus matmul result + expected = torch.ones(m, n, device=DEVICE, dtype=torch.float32) + x @ y + torch.testing.assert_close(result, expected, atol=0.5, rtol=0.05) + + @onlyBackends(["pallas"]) + def test_split_k_add_fori(self): + self._run_split_k_add("fori_loop") + + @onlyBackends(["pallas"]) + def test_split_k_add_pipeline(self): + self._run_split_k_add("emit_pipeline") + + @onlyBackends(["pallas"]) + def test_split_k_add_default(self): + self._run_split_k_add(None) + + def _run_split_k_add_various_splits(self, loop_type: str) -> None: + m, k, n = 128, 1024, 128 + x = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + y = torch.randn(k, n, device=DEVICE, dtype=torch.float32) + expected = torch.ones(m, n, device=DEVICE, dtype=torch.float32) + x @ y + for split_k in (1, 2, 4, 8): + k_block = max(128, k // split_k) + _, result = code_and_output( + split_k_atomic_add_kernel, + (x, y), + block_sizes=[32, 128, k_block, 128], + pallas_loop_type=loop_type, + ) + torch.testing.assert_close( + result, + expected, + atol=0.5, + rtol=0.05, + msg=f"Failed at split_k={split_k} loop_type={loop_type}", + ) + + @onlyBackends(["pallas"]) + def test_split_k_add_various_splits_fori(self): + self._run_split_k_add_various_splits("fori_loop") + + @onlyBackends(["pallas"]) + def test_split_k_add_various_splits_pipeline(self): + self._run_split_k_add_various_splits("emit_pipeline") + + # TODO(thcmbs): default launcher can't handle SymInt loop bounds + # (TracerIntegerConversionError). Add default variant once fixed. + + def _run_split_k_add_bf16(self, loop_type: str) -> None: + m, k, n = 128, 512, 128 + x = torch.randn(m, k, device=DEVICE).to(torch.bfloat16) + y = torch.randn(k, n, device=DEVICE).to(torch.bfloat16) + _, result = code_and_output( + split_k_atomic_add_kernel, + (x, y), + block_sizes=[128, 128, 128, 128], + pallas_loop_type=loop_type, + ) + expected = torch.ones(m, n, device=DEVICE, dtype=torch.bfloat16) + (x @ y).to( + torch.bfloat16 + ) + torch.testing.assert_close(result, expected, atol=2.0, rtol=0.1) + + @onlyBackends(["pallas"]) + def test_split_k_add_bf16_fori(self): + self._run_split_k_add_bf16("fori_loop") + + @onlyBackends(["pallas"]) + def test_split_k_add_bf16_pipeline(self): + self._run_split_k_add_bf16("emit_pipeline") + + # TODO(thcmbs): same default-launcher SymInt limitation as above. + def test_atomic_add_code_generation(self): """Test that the generated code contains atomic_add.""" x = torch.zeros(10, device=DEVICE) diff --git a/test/test_examples.py b/test/test_examples.py index 5305deb1d..44870415c 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -965,7 +965,6 @@ def test_moe_matmul_ogs(self): skip_accuracy=True, # TODO(yf225): fix unstable numerics ) - @xfailIfPallas("InductorLoweringError") @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) def test_matmul_split_k(self): args = (