[Pallas] Make per-tile element cap a backend hook, disable for Pallas#2282
Merged
Conversation
Contributor
Author
|
Hi @yarongmu-google , does this PR align with what you mentioned last time? I think I saw this blocker in layernorm_bwd. |
22a3941 to
51c1183
Compare
5d1370e to
6d766ec
Compare
6d766ec to
96810bc
Compare
2ef8eaa to
82f4262
Compare
AmesingFlank
approved these changes
May 7, 2026
| if max_numel is None: | ||
| # Backend (e.g. Pallas) has no compile-time per-tile element cap; | ||
| # VMEM byte budget is enforced separately at runtime. | ||
| return |
`TRITON_MAX_TENSOR_NUMEL = 2**20` is a Triton-specific compile-time ceiling: past that, Triton's codegen pipeline rejects the kernel and register pressure becomes prohibitive. Helion applied this cap to all backends in `_extract_tensor_numel_constraints` and `shrink_config`, artificially narrowing the autotuner search space on Pallas where the real bound is VMEM bytes (already guarded at runtime). Concretely, on layer_norm_bwd at (4096, 10240) the inner tile constraint `block_m * 10240 <= 1_048_576` forced `block_m <= 102`, but TPU 128-alignment for 1D fp32 mean/rstd loads required `block_m >= 128`. Result: "tensor numel constraint unsatisfiable at minimum block sizes" and the autotuner returned no working config. Add `Backend.max_tensor_numel` (default `TRITON_MAX_TENSOR_NUMEL`) and override to `None` on `PallasBackend`. `_extract_tensor_numel_constraints` short-circuits when the cap is `None`; `shrink_config` skips the ceiling clamp the same way. Triton/CuTe behavior is unchanged.
82f4262 to
5e2bc0d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
TRITON_MAX_TENSOR_NUMEL = 2**20is a Triton-specific compile-time ceiling — past that, Triton's codegen pipeline rejects the kernel and register pressure on a single block becomes prohibitive. Helion applied this cap to all backends in_extract_tensor_numel_constraintsandshrink_config, artificially narrowing the autotuner search space on Pallas where the real bound is VMEM bytes (already guarded at runtime by #2024).Concretely, on
layer_norm_bwdat(4096, 10240)the inner-tile constraintclashed with the TPU 128-alignment requirement on 1D fp32
mean/rstdloads (block_m >= 128). Result:This PR:
Backend.max_tensor_numel(defaultTRITON_MAX_TENSOR_NUMEL) inhelion/_compiler/backend.py.NoneonPallasBackend. Documented rationale: TPU/Mosaic has no analogous compile-time cap; tile size is bounded by VMEM bytes which the runtime already enforces._extract_tensor_numel_constraintsshort-circuits when the cap isNone(no constraints emitted).shrink_configskips the ceiling clamp when the cap isNone. Falls back toTRITON_MAX_TENSOR_NUMELwhen noCompileEnvironmentis active (unit-test contexts that constructConfigGenerationdirectly).Triton and CuTe paths are unchanged: their
max_tensor_numelstaysTRITON_MAX_TENSOR_NUMEL, so they keep applying both the constraint extraction and the shrink ceiling.