From 10123e9d6e13b4439dd81da7531811537913afe0 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Wed, 6 May 2026 05:08:38 +0000 Subject: [PATCH] [Pallas] Fix failing scratch shapes asserts due to land-time race when #2278 caused scratch shapes to be re-ordered The PRs #2278 and #2223 had a land-time race. #2223 introduced new scratch shape asserts, while #2278 caused scratch shapes to be re-ordered. No real breakage, but the asserted shapes needs to be re-ordered. stack-info: PR: https://github.com/pytorch/helion/pull/2302, branch: AmesingFlank/stack/44 --- test/test_pallas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 9c41875e44..43ff071b26 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1027,8 +1027,8 @@ def test_attention_emit_pipeline_correctness_head_dim_256(self) -> None: self.assertIn( "_scratch_shapes=[" "((4, 128, 128), 'jnp.float32', 'vmem'), " - "((4, 128, 256), 'jnp.float32', 'vmem'), " - "((4, 128, 128), 'jnp.float32', 'vmem')]", + "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem')]", code, ) self.assertIn("jnp.tile(", code) @@ -1056,8 +1056,8 @@ def test_attention_fori_loop_correctness_head_dim_256(self) -> None: self.assertIn( "_scratch_shapes=[" "((4, 128, 128), 'jnp.float32', 'vmem'), " - "((4, 128, 256), 'jnp.float32', 'vmem'), " "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem'), " "((4, 256, 128), 'jnp.float32', 'vmem'), " "((), None, 'dma_semaphore'), " "((4, 128, 256), 'jnp.float32', 'vmem'), "