Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,11 +1657,15 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]:
# These must use VMEM BlockSpecs. Output-only tensors (written but
# never read) get HBM in_specs to avoid VMEM pressure.
inplace_indices: list[int] = []
# Atomic-write outputs. Launcher uses these for "arbitrary" grid dim
# semantics so split-K cells serialise.
atomic_indices: list[int] = []
if sorted_args is not None:
env = CompileEnvironment.current()
host_fn = HostFunction.current()
read_names, write_names = device_fn.get_tensor_read_write_names()
mutated_params = write_names & {a.arg for a in host_fn.args.args}
atomic_target_names = env.atomic_target_host_names
input_storages = {id(t.untyped_storage()) for t in env.input_sources}
# Only tensors allocated with torch.empty/empty_like/new_empty can be
# output-only — their initial values are undefined, so it's safe
Expand All @@ -1672,16 +1676,20 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]:
for i, arg in enumerate(sorted_args):
if not isinstance(arg, TensorArg):
continue
name = arg.host_str()
is_atomic_target = name in atomic_target_names
if id(arg.fake_value.untyped_storage()) not in input_storages:
# Tensor created inside the function body (output)
output_indices.append(i)
if arg.host_str() in read_names or arg.host_str() not in empty_vars:
if name in read_names or name not in empty_vars:
# Also read by the kernel (e.g. broadcast result)
inplace_indices.append(i)
elif arg.host_str() in mutated_params:
# Input tensor mutated in-place
elif name in mutated_params or is_atomic_target:
# Mutated in-place (subscript or atomic).
output_indices.append(i)
inplace_indices.append(i)
if is_atomic_target and i in output_indices:
atomic_indices.append(i)

# Collect output-only tensor names so codegen can retarget their
# allocations to ``device='meta'`` and capture the launcher return.
Expand All @@ -1697,6 +1705,8 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]:

launcher_args = [*args, f"_output_indices={output_indices}"]
launcher_args.append(f"_inplace_indices={inplace_indices}")
if atomic_indices:
launcher_args.append(f"_atomic_indices={atomic_indices}")

if has_rng_ops:
launcher_args.insert(-1, "_rng_seed_buffer")
Expand Down
3 changes: 3 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def __init__(
self.config_spec.max_num_sm_multiplier = newmax

self.has_barrier: bool = False
# Host names of ``hl.atomic_*`` first-arg targets; written by atomic
# codegen, read by the Pallas launcher prep.
self.atomic_target_host_names: set[str] = set()

def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
"""Substitute any specialized vars with their concrete values."""
Expand Down
20 changes: 12 additions & 8 deletions helion/language/atomic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import _symint_expr
from .._compiler.host_function import HostFunction
from .._compiler.indexing_strategy import SubscriptIndexing
Expand Down Expand Up @@ -716,7 +717,9 @@ def _pallas_atomic_load_prev(
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor("pallas atomic")

name = state.device_function.tensor_arg(target).name
tensor_arg = state.device_function.tensor_arg(target)
name = tensor_arg.name
CompileEnvironment.current().atomic_target_host_names.add(tensor_arg.host_str())
index_str, _ = pallas_codegen.index_str(state, index, target)

prev_var = state.device_function.new_var("_prev", dce=True)
Expand Down Expand Up @@ -962,15 +965,16 @@ def _(state: CodegenState) -> ast.AST:

name, index_str, prev_var = _pallas_atomic_load_prev(state)
value_ast = _to_ast_values([state.ast_args[2]])[0]
target = state.proxy_arg(0)
assert isinstance(target, torch.Tensor)
# Cast value to the ref's dtype, not the target tensor's dtype, so the
# RMW stays in scratch dtype when ``runtime._apply_atomic_accumulator``
# promotes a bf16/f16 target to an f32 VMEM scratch.
backend = CompileEnvironment.current().backend
target_dtype = backend.dtype_str(target.dtype)
# Cast the sum to the target dtype so the store doesn't fail when
# the value dtype differs (e.g. float32 accumulator into bfloat16 ref).
cast = backend.cast_expr(f"{prev_var} + {{value}}", target_dtype)
cast = backend.cast_expr("{value}", f"{prev_var}.dtype")
state.codegen.add_statement(
statement_from_string(f"{name}[{index_str}] = {cast}", value=value_ast)
statement_from_string(
f"{name}[{index_str}] = {prev_var} + {cast}",
value=value_ast,
)
)
return expr_from_string(prev_var)

Expand Down
120 changes: 96 additions & 24 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .. import _compat as _compat # ensure Triton compatibility patches run
from .. import exc
from .._utils import triton_is_available
from ._pallas_rmw import _grid_rmw_apply
from ._pallas_rmw import _grid_rmw_plan
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
Expand Down Expand Up @@ -880,6 +882,7 @@ def default_pallas_launcher(
_block_spec_info: _BlockSpecInfo | None = None,
_smem_arg_indices: list[int] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_atomic_indices: list[int] | None = None,
**kwargs: object,
) -> object:
"""Default launcher for Pallas kernels on TPU (or CPU with interpret=True).
Expand All @@ -892,7 +895,9 @@ def default_pallas_launcher(

Output-only tensors (in ``_output_indices`` but not in ``_inplace_indices``)
are excluded from pallas_call inputs to save VMEM. Their results are
returned as torch tensors.
returned as torch tensors. ``_atomic_indices`` lists outputs written via
``hl.atomic_*``; see :func:`_grid_rmw_plan` for the multi-cell VMEM
scratch path.
"""
if _output_indices is None:
_output_indices = []
Expand Down Expand Up @@ -922,6 +927,7 @@ def default_pallas_launcher(
inplace_positions,
out_shapes,
) = _pallas_prepare_args(args, _output_indices, _inplace_indices)
output_only_set: set[int] = set(output_only_indices)

in_specs, out_specs = _pallas_build_block_specs(
pl,
Expand All @@ -936,6 +942,10 @@ def default_pallas_launcher(
output_only_indices,
)

dim_sem, scratch_tiles = _grid_rmw_plan(
grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos
)

reordered_kernel = _pallas_make_reordered_kernel(
pallas_kernel,
args,
Expand All @@ -946,6 +956,19 @@ def default_pallas_launcher(
inplace_positions,
arg_to_tensor_pos,
_smem_arg_indices=_smem_arg_indices,
skip_inplace_copy=output_only_set | set(scratch_tiles),
)

reordered_kernel, scratch_shapes = _grid_rmw_apply(
reordered_kernel,
scratch_tiles,
dim_sem,
args,
arg_to_tensor_pos,
_output_indices,
n_tensor_inputs,
grid,
pltpu,
)

out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0]
Expand All @@ -961,7 +984,7 @@ def default_pallas_launcher(
pltpu,
in_specs,
out_specs,
None,
scratch_shapes,
args,
tensor_arg_indices,
_output_indices,
Expand All @@ -974,16 +997,34 @@ def default_pallas_launcher(
f"Estimated {estimated_vmem / 1e6:.2f}MB exceeds {vmem_limit_bytes / 1e6:.2f}MB vmem capacity."
)

pallas_call_kwargs: dict[str, object] = {
"out_shape": out_shape_arg,
"input_output_aliases": pallas_aliases,
"grid": grid,
}
if scratch_shapes:
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=list(in_specs) if in_specs is not None else [],
out_specs=out_specs if out_specs is not None else [],
scratch_shapes=scratch_shapes, # pyrefly: ignore[bad-argument-type]
grid=grid,
)
pallas_call_kwargs: dict[str, object] = {
"out_shape": out_shape_arg,
"input_output_aliases": pallas_aliases,
"grid_spec": grid_spec,
"compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation]
dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type]
),
}
else:
pallas_call_kwargs = {
"out_shape": out_shape_arg,
"input_output_aliases": pallas_aliases,
"grid": grid,
}
if in_specs is not None:
pallas_call_kwargs["in_specs"] = in_specs
pallas_call_kwargs["out_specs"] = out_specs

if _pallas_interpret_flag():
pallas_call_kwargs["interpret"] = True
if in_specs is not None:
pallas_call_kwargs["in_specs"] = in_specs
pallas_call_kwargs["out_specs"] = out_specs

jit_fn = pl.pallas_call(
reordered_kernel, # pyrefly: ignore[bad-argument-type]
Expand Down Expand Up @@ -1022,13 +1063,14 @@ def default_pallas_pipeline_launcher(
_pipeline_arg_indices: list[int] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_smem_arg_indices: list[int] | None = None,
_atomic_indices: list[int] | None = None,
**kwargs: object,
) -> object:
"""Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory.

Used when ``pallas_loop_type='emit_pipeline'``. Pipeline-body tensors
(listed in ``_pipeline_arg_indices``) use HBM refs; all other tensors
get proper BlockSpecs for automatic VMEM prefetch.
Pipeline-body tensors (``_pipeline_arg_indices``) use HBM refs; the
rest get BlockSpecs for automatic VMEM prefetch. ``_atomic_indices``:
see :func:`default_pallas_launcher`.
"""
if _output_indices is None:
_output_indices = []
Expand Down Expand Up @@ -1095,7 +1137,10 @@ def default_pallas_pipeline_launcher(
smem_arg_indices=_smem_arg_indices,
)

_pipeline_set = set(_pipeline_arg_indices or [])
_pipeline_set: set[int] = set(_pipeline_arg_indices or [])
dim_sem, scratch_tiles = _grid_rmw_plan(
grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos
)
reordered_kernel = _pallas_make_reordered_kernel(
pallas_kernel,
args,
Expand All @@ -1106,9 +1151,21 @@ def default_pallas_pipeline_launcher(
inplace_positions,
arg_to_tensor_pos,
n_extra_refs=len(scratch_shapes),
skip_inplace_copy=_pipeline_set,
skip_inplace_copy=_pipeline_set | set(scratch_tiles),
_smem_arg_indices=_smem_arg_indices,
)
reordered_kernel, rmw_extras = _grid_rmw_apply(
reordered_kernel,
scratch_tiles,
dim_sem,
args,
arg_to_tensor_pos,
_output_indices,
n_tensor_inputs,
grid,
pltpu,
)
scratch_shapes.extend(rmw_extras) # pyrefly: ignore[bad-argument-type]

out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0]

Expand Down Expand Up @@ -1149,7 +1206,7 @@ def default_pallas_pipeline_launcher(
"input_output_aliases": pallas_aliases,
"grid_spec": grid_spec,
"compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation]
dimension_semantics=tuple("parallel" for _ in grid),
dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type]
),
}
if _pallas_interpret_flag():
Expand Down Expand Up @@ -1192,15 +1249,15 @@ def default_pallas_fori_launcher(
_scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None,
_ds_pad_dims: list[tuple[int, int, int, int]] | None = None,
_smem_arg_indices: list[int] | None = None,
_atomic_indices: list[int] | None = None,
**kwargs: object,
) -> object:
"""Launcher for Pallas kernels using fori_loop with manual DMA.

Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as
``memory_space=pl.ANY`` (HBM refs) and adds scratch buffers as
``pltpu.VMEM`` shapes plus ``pltpu.SemaphoreType.DMA`` for async copies.
The kernel uses ``jax.lax.fori_loop`` with ``pltpu.make_async_copy``
internally for DMA control.
All tensors are HBM refs; the kernel drives DMA via
``pltpu.make_async_copy`` inside ``jax.lax.fori_loop``. Scratch
buffers are ``pltpu.VMEM`` plus ``pltpu.SemaphoreType.DMA`` for the
async copies. ``_atomic_indices``: see :func:`default_pallas_launcher`.
"""
if _output_indices is None:
_output_indices = []
Expand Down Expand Up @@ -1249,6 +1306,7 @@ def default_pallas_fori_launcher(
# Build in_specs/out_specs: proper BlockSpecs for outer grid dims,
# HBM refs for tensors used in the fori_loop body (DMA handles tiling).
_fori_pipeline_indices = kwargs.get("_pipeline_arg_indices")
_fori_pipeline_set: set[int] = set(_fori_pipeline_indices or []) # type: ignore[arg-type]
assert _block_spec_info is not None, (
"fori_loop launcher requires _block_spec_info from codegen"
)
Expand All @@ -1266,7 +1324,9 @@ def default_pallas_fori_launcher(
smem_arg_indices=_smem_arg_indices,
)

_fori_pipeline_set = set(_fori_pipeline_indices or []) # type: ignore[arg-type]
dim_sem, scratch_tiles = _grid_rmw_plan(
grid, args, _atomic_indices, _block_spec_info, arg_to_tensor_pos
)
reordered_kernel = _pallas_make_reordered_kernel(
pallas_kernel,
args,
Expand All @@ -1277,9 +1337,21 @@ def default_pallas_fori_launcher(
inplace_positions,
arg_to_tensor_pos,
n_extra_refs=len(scratch_shapes),
skip_inplace_copy=_fori_pipeline_set,
skip_inplace_copy=_fori_pipeline_set | set(scratch_tiles),
_smem_arg_indices=_smem_arg_indices,
)
reordered_kernel, rmw_extras = _grid_rmw_apply(
reordered_kernel,
scratch_tiles,
dim_sem,
args,
arg_to_tensor_pos,
_output_indices,
n_tensor_inputs,
grid,
pltpu,
)
scratch_shapes.extend(rmw_extras) # pyrefly: ignore[bad-argument-type]

out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0]

Expand Down Expand Up @@ -1320,7 +1392,7 @@ def default_pallas_fori_launcher(
"input_output_aliases": pallas_aliases,
"grid_spec": grid_spec,
"compiler_params": pltpu.CompilerParams( # pyrefly: ignore[bad-instantiation]
dimension_semantics=tuple("parallel" for _ in grid),
dimension_semantics=dim_sem, # pyrefly: ignore[bad-argument-type]
),
}
if _pallas_interpret_flag():
Expand Down
Loading
Loading