diff --git a/test/test_pallas.py b/test/test_pallas.py index 9c41875e44..43ff071b26 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1027,8 +1027,8 @@ def test_attention_emit_pipeline_correctness_head_dim_256(self) -> None: self.assertIn( "_scratch_shapes=[" "((4, 128, 128), 'jnp.float32', 'vmem'), " - "((4, 128, 256), 'jnp.float32', 'vmem'), " - "((4, 128, 128), 'jnp.float32', 'vmem')]", + "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem')]", code, ) self.assertIn("jnp.tile(", code) @@ -1056,8 +1056,8 @@ def test_attention_fori_loop_correctness_head_dim_256(self) -> None: self.assertIn( "_scratch_shapes=[" "((4, 128, 128), 'jnp.float32', 'vmem'), " - "((4, 128, 256), 'jnp.float32', 'vmem'), " "((4, 128, 128), 'jnp.float32', 'vmem'), " + "((4, 128, 256), 'jnp.float32', 'vmem'), " "((4, 256, 128), 'jnp.float32', 'vmem'), " "((), None, 'dma_semaphore'), " "((4, 128, 256), 'jnp.float32', 'vmem'), "