diff --git a/helion/_compiler/cute/mma_support.py b/helion/_compiler/cute/mma_support.py index 48289ead7..1b2a47d02 100644 --- a/helion/_compiler/cute/mma_support.py +++ b/helion/_compiler/cute/mma_support.py @@ -17,9 +17,11 @@ class CuteMmaSupport: warp_f16bf16: bool warpgroup_f16bf16: bool tcgen05_f16bf16: bool + tcgen05_f8f6f4: bool = False warp_error: str | None = None warpgroup_error: str | None = None tcgen05_error: str | None = None + tcgen05_f8f6f4_error: str | None = None @property def supported_impls(self) -> tuple[str, ...]: @@ -105,6 +107,28 @@ def _probe_tcgen05_f16bf16() -> tuple[bool, str | None]: return False, f"{type(exc).__name__}: {exc}" +def _probe_tcgen05_f8f6f4() -> tuple[bool, str | None]: + # MmaF8F6F4Op is the canonical Blackwell fp8/fp6/fp4 atom. The older + # MmaFP8Op is deprecated in cutlass-dsl in favor of this one. + try: + import cutlass + from cutlass.cute.nvgpu import tcgen05 + + tcgen05.MmaF8F6F4Op( + cutlass.Float8E4M3FN, + cutlass.Float8E4M3FN, + cutlass.Float32, + (128, 8, 32), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + ) + return True, None + except Exception as exc: + return False, f"{type(exc).__name__}: {exc}" + + def get_cute_mma_support() -> CuteMmaSupport: device = _current_cuda_device() if device is None: @@ -116,9 +140,11 @@ def get_cute_mma_support() -> CuteMmaSupport: warp_f16bf16=False, warpgroup_f16bf16=False, tcgen05_f16bf16=False, + tcgen05_f8f6f4=False, warp_error="CUDA unavailable", warpgroup_error="CUDA unavailable", tcgen05_error="CUDA unavailable", + tcgen05_f8f6f4_error="CUDA unavailable", ) device_name = torch.cuda.get_device_name(device) @@ -130,6 +156,7 @@ def get_cute_mma_support() -> CuteMmaSupport: warp_ok, warp_error = _probe_warp_f16bf16() warpgroup_ok, warpgroup_error = _probe_warpgroup_f16bf16() tcgen05_ok, tcgen05_error = _probe_tcgen05_f16bf16() + tcgen05_f8_ok, tcgen05_f8_error = _probe_tcgen05_f8f6f4() return CuteMmaSupport( device_name=device_name, @@ -139,9 +166,11 @@ def get_cute_mma_support() -> CuteMmaSupport: warp_f16bf16=warp_ok, warpgroup_f16bf16=warpgroup_ok, tcgen05_f16bf16=tcgen05_ok, + tcgen05_f8f6f4=tcgen05_f8_ok, warp_error=warp_error, warpgroup_error=warpgroup_error, tcgen05_error=tcgen05_error, + tcgen05_f8f6f4_error=tcgen05_f8_error, ) diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 0b4ffe366..5b11d5a12 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -1371,6 +1371,8 @@ def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object: torch.float32: cutlass.Float32, torch.float64: cutlass.Float64, torch.bfloat16: cutlass.BFloat16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, # CuTe does not support i1 global-memory tensors; torch.bool is stored # as one byte, so pass bool tensor pointers as uint8 and let load # lowering convert nonzero bytes back to cutlass.Boolean registers. diff --git a/test/test_cute_fp8_support.py b/test/test_cute_fp8_support.py new file mode 100644 index 000000000..d4cd72569 --- /dev/null +++ b/test/test_cute_fp8_support.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest +import torch + +from helion._compiler.cute.mma_support import CuteMmaSupport +from helion._compiler.cute.mma_support import _probe_tcgen05_f8f6f4 +from helion._compiler.cute.mma_support import get_cute_mma_support +from helion.runtime import _torch_dtype_to_cutlass + + +def test_dtype_mapping_fp8() -> None: + cutlass = pytest.importorskip("cutlass") + + assert _torch_dtype_to_cutlass(torch.float8_e4m3fn) is cutlass.Float8E4M3FN + assert _torch_dtype_to_cutlass(torch.float8_e5m2) is cutlass.Float8E5M2 + + +def test_dtype_mapping_existing_unchanged() -> None: + cutlass = pytest.importorskip("cutlass") + + assert _torch_dtype_to_cutlass(torch.float16) is cutlass.Float16 + assert _torch_dtype_to_cutlass(torch.bfloat16) is cutlass.BFloat16 + assert _torch_dtype_to_cutlass(torch.float32) is cutlass.Float32 + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="probe requires CUDA", +) +def test_tcgen05_f8f6f4_probe_on_b200() -> None: + cap = torch.cuda.get_device_capability(0) + ok, err = _probe_tcgen05_f8f6f4() + if cap >= (10, 0): + assert ok, f"expected fp8/f6/f4 MMA on capability {cap}, got error: {err}" + assert err is None + else: + # Older arches may or may not support the atom depending on cutlass-dsl + # backend; either way the probe must not raise. + assert isinstance(ok, bool) + if not ok: + assert isinstance(err, str) and err + + +def test_capability_dataclass_no_cuda_fields() -> None: + # When CUDA is unavailable, all probe fields are False and error fields populated. + support = CuteMmaSupport( + device_name=None, + capability=None, + cutlass_arch=None, + universal=False, + warp_f16bf16=False, + warpgroup_f16bf16=False, + tcgen05_f16bf16=False, + tcgen05_f8f6f4=False, + warp_error="x", + warpgroup_error="x", + tcgen05_error="x", + tcgen05_f8f6f4_error="x", + ) + assert support.tcgen05_f8f6f4 is False + assert support.tcgen05_f8f6f4_error == "x" + # supported_impls should not list tcgen05 if nothing under it is supported. + assert "tcgen05" not in support.supported_impls + + +def test_supported_impls_does_not_list_tcgen05_when_only_fp8_supported() -> None: + # The generic "tcgen05" impl is the current f16/bf16 path. The fp8 probe is + # a separate capability bit until fp8 codegen consumes it directly. + support = CuteMmaSupport( + device_name="X", + capability=(10, 0), + cutlass_arch="SM_100", + universal=True, + warp_f16bf16=False, + warpgroup_f16bf16=False, + tcgen05_f16bf16=False, + tcgen05_f8f6f4=True, + ) + assert "tcgen05" not in support.supported_impls + + +def test_supported_impls_lists_tcgen05_when_f16bf16_supported() -> None: + support = CuteMmaSupport( + device_name="X", + capability=(10, 0), + cutlass_arch="SM_100", + universal=True, + warp_f16bf16=False, + warpgroup_f16bf16=False, + tcgen05_f16bf16=True, + tcgen05_f8f6f4=False, + ) + assert "tcgen05" in support.supported_impls + + +def test_probe_does_not_raise_when_cutlass_missing() -> None: + # If cutlass.cute.nvgpu.tcgen05 cannot be imported, the probe must return + # (False, error_str) rather than propagating. + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("cutlass.cute.nvgpu") or name == "cutlass": + raise ImportError(f"simulated missing module: {name}") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + ok, err = _probe_tcgen05_f8f6f4() + + assert ok is False + assert isinstance(err, str) and "ImportError" in err + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="aggregate probe requires CUDA", +) +def test_get_cute_mma_support_populates_fp8_field() -> None: + support = get_cute_mma_support() + cap = torch.cuda.get_device_capability(0) + # Field is always present (no AttributeError) on supported builds. + assert hasattr(support, "tcgen05_f8f6f4") + assert hasattr(support, "tcgen05_f8f6f4_error") + if cap >= (10, 0) and support.cutlass_arch is not None: + assert support.tcgen05_f8f6f4 is True