diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index ab8ff658d..800a04545 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -5,7 +5,6 @@ import dataclasses import functools import itertools -import math import operator from typing import TYPE_CHECKING from typing import NamedTuple @@ -297,31 +296,6 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]: range_num_stages = env.config_spec.range_num_stages.config_get( config.range_num_stages, block_idx, 0 ) - num_stages = config.num_stages - - if "tensor_descriptor" in config.indexing: - # Tensor descriptor + multi-stage pipelines in addition to unrolling tend to cause - # CUDA "misaligned address" or "unspecified launch failure" errors. - if range_num_stages > 0: - range_num_stages = 0 - if range_unroll_factor > 0 and num_stages > 1: - range_unroll_factor = 0 - elif ( - range_num_stages > 1 - and range_unroll_factor > 1 - and env.block_sizes[block_idx].size - and env.block_sizes[block_idx].numel.is_number - ): - # Unrolling can cause CUDA IMA with pipelining - # We want to ensure new step size + pipeline is within bounds - loop_numel = int(env.block_sizes[block_idx].numel) - block_size = int(env.block_sizes[block_idx].from_config_assert(config)) - step = range_unroll_factor * block_size - last_offset = ((loop_numel - 1) // block_size) * block_size - remainder = loop_numel - last_offset - range_num_stages = min( - max(1, int(math.ceil(remainder / step))), range_num_stages - ) if range_unroll_factor > 0: kwargs.append(f"loop_unroll_factor={range_unroll_factor}") diff --git a/test/test_loops.py b/test/test_loops.py index 0da72f7db..5d3f49de8 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -1397,50 +1397,6 @@ def three_pass_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) - @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) - @skipIfTileIR("tileir backend will ignore `range_unroll_factors` hint") - @skipIfNotTriton("range loop hints are Triton-specific") - @skipIfXPU("Accuracy issue on XPU backend") - def test_unroll_with_pipelining(self): - @helion.kernel(static_shapes=True) - def matmul( - x: torch.Tensor, - y: torch.Tensor, - ) -> torch.Tensor: - m, k = x.size() - k2, n = y.size() - assert k == k2, f"size mismatch {k} != {k2}" - out = torch.empty( - [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device - ) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = acc - return out - - a = torch.randn(256, 256, device=DEVICE, dtype=torch.bfloat16) - b = torch.randn(256, 256, device=DEVICE, dtype=torch.bfloat16) - - code, result = code_and_output( - matmul, - (a, b), - block_sizes=[64, 16, 16], - indexing="block_ptr", - loop_orders=[[1, 0]], - pid_type="persistent_blocked", - range_num_stages=[4, 2], - range_unroll_factors=[4, 4], - ) - - expected = torch.matmul(a, b) - torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) - - # Logic for modifying num_stages and loop unrolling factors should - # change num_stages=1 - self.assertIn("num_stages=1", code) - def test_loop_with_symbolic_bounds(self): @helion.kernel( config=helion.Config( diff --git a/test/test_tensor_descriptor.py b/test/test_tensor_descriptor.py index 303753071..7b95dd065 100644 --- a/test/test_tensor_descriptor.py +++ b/test/test_tensor_descriptor.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import unittest import torch @@ -303,14 +302,6 @@ def jsd_forward_kernel( torch.testing.assert_close(loss, baseline_loss, rtol=5e-2, atol=5e-3) self.assertIn(get_tensor_descriptor_fn_name(), code) - range_stage_values = [ - int(match) - for line in code.splitlines() - if "tl.range" in line - for match in re.findall(r"num_stages=(\d+)", line) - ] - # range_num_stages=4 is clamped to 0, so doesn't show up as num_stages in the tl.range call - self.assertEqual(len(range_stage_values), 0) @skipUnlessTensorDescriptor("Tensor descriptor support is required") def test_tiny_matmul_tile_fallback(self) -> None: