Skip to content

WIP Pallas atomic all launchers v2#2346

Draft
thcmbs wants to merge 2 commits into
pytorch:mainfrom
thcmbs:pallas-atomic-all-launchers-v2
Draft

WIP Pallas atomic all launchers v2#2346
thcmbs wants to merge 2 commits into
pytorch:mainfrom
thcmbs:pallas-atomic-all-launchers-v2

Conversation

@thcmbs
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs commented May 7, 2026

No description provided.

thcmbs added 2 commits May 6, 2026 11:41
Two bugs made hl.atomic_* (and split-K matmul) produce garbage on Pallas:

1. Pallas pipelining does not populate output VMEM from HBM even when
   input_output_aliases is set, so the atomic_add codegen reads garbage
   from the output VMEM ref before adding the new value.
2. The reordered_kernel copies in_ref -> out_ref for inplace outputs on
   every grid cell, overwriting accumulated state on multi-cell writes.

Fix: route atomic outputs through a VMEM scratch buffer.

  - Detect atomic targets at the FX-graph level: backend.py walks every
    DeviceIR graph for call_function nodes whose target is in
    helion.language.atomic_ops, and maps the first-arg fake tensor back
    to the launch arg index. Emitted as _atomic_indices=[...] kwarg.
  - For each atomic output, build a pltpu.VMEM scratch sized to the
    BlockSpec tile. bf16/f16 targets get an f32 scratch so per-cell sums
    don't round.
  - On the first cell of any "arbitrary" dim (unmapped grid dims with
    size > 1), preload the scratch from HBM via convert_element_type.
    On the last cell, commit it back. dimension_semantics is set to
    "arbitrary" on those dims so cells serialise.
  - The atomic-add codegen casts the value to the ref's dtype rather
    than the target tensor's dtype, so f32 accumulation into a bf16
    scratch stays in f32 until the final write-back.
  - Wired into all three launchers (default, emit_pipeline, fori_loop)
    via shared _wrap_atomic_accumulator / _apply_atomic_accumulator
    helpers. Atomic outputs are added to skip_inplace_copy so the
    reordered_kernel's HBM copy doesn't clobber the scratch.

Tests:

  - test_atomic_ops.py: split_k_add / split_k_max / multi-output kernels,
    parametrised across all three launchers, plus bf16 and various
    split_k counts. Two SymInt-related TODOs left for the default
    launcher on the bf16 + various-splits variants.
  - test_examples.py: test_matmul_split_k unxfailed; previously failed
    with InductorLoweringError, now passes.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant