Skip to content

[Pallas] Make per-tile element cap a backend hook, disable for Pallas#2282

Merged
norx1991 merged 1 commit into
mainfrom
yifeixu/pallas-no-numel-cap
May 8, 2026
Merged

[Pallas] Make per-tile element cap a backend hook, disable for Pallas#2282
norx1991 merged 1 commit into
mainfrom
yifeixu/pallas-no-numel-cap

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented May 5, 2026

Summary

TRITON_MAX_TENSOR_NUMEL = 2**20 is a Triton-specific compile-time ceiling — past that, Triton's codegen pipeline rejects the kernel and register pressure on a single block becomes prohibitive. Helion applied this cap to all backends in _extract_tensor_numel_constraints and shrink_config, artificially narrowing the autotuner search space on Pallas where the real bound is VMEM bytes (already guarded at runtime by #2024).

Concretely, on layer_norm_bwd at (4096, 10240) the inner-tile constraint

block_m * 10240 <= 1_048_576   →   block_m <= 102

clashed with the TPU 128-alignment requirement on 1D fp32 mean/rstd loads (block_m >= 128). Result:

WARNING: tensor numel constraint unsatisfiable at minimum block sizes: 10240*u3 <= 1048576
helion.exc.NoConfigFound: No working config found from autotuning

This PR:

  • Adds Backend.max_tensor_numel (default TRITON_MAX_TENSOR_NUMEL) in helion/_compiler/backend.py.
  • Overrides it to None on PallasBackend. Documented rationale: TPU/Mosaic has no analogous compile-time cap; tile size is bounded by VMEM bytes which the runtime already enforces.
  • _extract_tensor_numel_constraints short-circuits when the cap is None (no constraints emitted).
  • shrink_config skips the ceiling clamp when the cap is None. Falls back to TRITON_MAX_TENSOR_NUMEL when no CompileEnvironment is active (unit-test contexts that construct ConfigGeneration directly).

Triton and CuTe paths are unchanged: their max_tensor_numel stays TRITON_MAX_TENSOR_NUMEL, so they keep applying both the constraint extraction and the shrink ceiling.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 5, 2026
@norx1991
Copy link
Copy Markdown
Contributor Author

norx1991 commented May 6, 2026

Hi @yarongmu-google , does this PR align with what you mentioned last time? I think I saw this blocker in layernorm_bwd.

@norx1991 norx1991 force-pushed the yifeixu/pallas-keep-concrete-dims branch from 22a3941 to 51c1183 Compare May 6, 2026 20:28
@norx1991 norx1991 force-pushed the yifeixu/pallas-no-numel-cap branch 3 times, most recently from 5d1370e to 6d766ec Compare May 6, 2026 21:22
@norx1991 norx1991 changed the base branch from yifeixu/pallas-keep-concrete-dims to yifeixu/pallas-skip-factory-padding May 6, 2026 21:23
@norx1991 norx1991 force-pushed the yifeixu/pallas-no-numel-cap branch from 6d766ec to 96810bc Compare May 6, 2026 22:07
@norx1991 norx1991 changed the base branch from yifeixu/pallas-skip-factory-padding to main May 6, 2026 22:07
@norx1991 norx1991 force-pushed the yifeixu/pallas-no-numel-cap branch 2 times, most recently from 2ef8eaa to 82f4262 Compare May 7, 2026 23:47
@norx1991 norx1991 marked this pull request as ready for review May 7, 2026 23:50
@norx1991 norx1991 requested review from AmesingFlank, jansel and oulgen May 7, 2026 23:50
Comment thread helion/_compiler/compile_environment.py Outdated
if max_numel is None:
# Backend (e.g. Pallas) has no compile-time per-tile element cap;
# VMEM byte budget is enforced separately at runtime.
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, return None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done.

`TRITON_MAX_TENSOR_NUMEL = 2**20` is a Triton-specific compile-time
ceiling: past that, Triton's codegen pipeline rejects the kernel and
register pressure becomes prohibitive. Helion applied this cap to all
backends in `_extract_tensor_numel_constraints` and `shrink_config`,
artificially narrowing the autotuner search space on Pallas where the
real bound is VMEM bytes (already guarded at runtime).

Concretely, on layer_norm_bwd at (4096, 10240) the inner tile constraint
`block_m * 10240 <= 1_048_576` forced `block_m <= 102`, but TPU
128-alignment for 1D fp32 mean/rstd loads required `block_m >= 128`.
Result: "tensor numel constraint unsatisfiable at minimum block sizes"
and the autotuner returned no working config.

Add `Backend.max_tensor_numel` (default `TRITON_MAX_TENSOR_NUMEL`) and
override to `None` on `PallasBackend`. `_extract_tensor_numel_constraints`
short-circuits when the cap is `None`; `shrink_config` skips the
ceiling clamp the same way. Triton/CuTe behavior is unchanged.
@norx1991 norx1991 force-pushed the yifeixu/pallas-no-numel-cap branch from 82f4262 to 5e2bc0d Compare May 8, 2026 00:01
@norx1991 norx1991 merged commit b88fd31 into main May 8, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants