Skip to content

[TPU][Pallas] relax tolerances and fix Pallas autotuning OOM in layer_norm#2272

Draft
yarongmu-google wants to merge 1 commit into
pytorch:mainfrom
yarongmu-google:fix_layer_norm_tol
Draft

[TPU][Pallas] relax tolerances and fix Pallas autotuning OOM in layer_norm#2272
yarongmu-google wants to merge 1 commit into
pytorch:mainfrom
yarongmu-google:fix_layer_norm_tol

Conversation

@yarongmu-google
Copy link
Copy Markdown
Collaborator

  1. Fixes bfloat16 accuracy validation: Relaxes the rtol and atol from 1e-3 to 1e-2 in examples/layer_norm.py to prevent spurious validation failures caused by bfloat16 precision variations on TPUs.
  2. Fixes VMEM Out-of-Memory during autotuning: The autotuner's default baseline config generation (block_sizes=[32, 128]) was attempting to allocate overly large chunks of the 10,240-element feature dimension into VMEM, causing the process to crash before tuning could even begin. Bypassed this by providing autotune_baseline_fn for both the forward and backward kernels.
  3. Adds a safe fallback config: Added config=helion.Config(block_sizes=[32, 1024]) to the backward pass kernel to ensure a memory-safe configuration is available even if the search space fails to converge.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 5, 2026
@yarongmu-google yarongmu-google marked this pull request as draft May 5, 2026 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant