Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions helion/_compiler/cute/mma_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class CuteMmaSupport:
warp_f16bf16: bool
warpgroup_f16bf16: bool
tcgen05_f16bf16: bool
tcgen05_f8f6f4: bool = False
warp_error: str | None = None
warpgroup_error: str | None = None
tcgen05_error: str | None = None
tcgen05_f8f6f4_error: str | None = None

@property
def supported_impls(self) -> tuple[str, ...]:
Expand Down Expand Up @@ -105,6 +107,28 @@ def _probe_tcgen05_f16bf16() -> tuple[bool, str | None]:
return False, f"{type(exc).__name__}: {exc}"


def _probe_tcgen05_f8f6f4() -> tuple[bool, str | None]:
# MmaF8F6F4Op is the canonical Blackwell fp8/fp6/fp4 atom. The older
# MmaFP8Op is deprecated in cutlass-dsl in favor of this one.
try:
import cutlass
from cutlass.cute.nvgpu import tcgen05

tcgen05.MmaF8F6F4Op(
cutlass.Float8E4M3FN,
cutlass.Float8E4M3FN,
cutlass.Float32,
(128, 8, 32),
tcgen05.CtaGroup.ONE,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
)
return True, None
except Exception as exc:
return False, f"{type(exc).__name__}: {exc}"


def get_cute_mma_support() -> CuteMmaSupport:
device = _current_cuda_device()
if device is None:
Expand All @@ -116,9 +140,11 @@ def get_cute_mma_support() -> CuteMmaSupport:
warp_f16bf16=False,
warpgroup_f16bf16=False,
tcgen05_f16bf16=False,
tcgen05_f8f6f4=False,
warp_error="CUDA unavailable",
warpgroup_error="CUDA unavailable",
tcgen05_error="CUDA unavailable",
tcgen05_f8f6f4_error="CUDA unavailable",
)

device_name = torch.cuda.get_device_name(device)
Expand All @@ -130,6 +156,7 @@ def get_cute_mma_support() -> CuteMmaSupport:
warp_ok, warp_error = _probe_warp_f16bf16()
warpgroup_ok, warpgroup_error = _probe_warpgroup_f16bf16()
tcgen05_ok, tcgen05_error = _probe_tcgen05_f16bf16()
tcgen05_f8_ok, tcgen05_f8_error = _probe_tcgen05_f8f6f4()

return CuteMmaSupport(
device_name=device_name,
Expand All @@ -139,9 +166,11 @@ def get_cute_mma_support() -> CuteMmaSupport:
warp_f16bf16=warp_ok,
warpgroup_f16bf16=warpgroup_ok,
tcgen05_f16bf16=tcgen05_ok,
tcgen05_f8f6f4=tcgen05_f8_ok,
warp_error=warp_error,
warpgroup_error=warpgroup_error,
tcgen05_error=tcgen05_error,
tcgen05_f8f6f4_error=tcgen05_f8_error,
)


Expand Down
2 changes: 2 additions & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object:
torch.float32: cutlass.Float32,
torch.float64: cutlass.Float64,
torch.bfloat16: cutlass.BFloat16,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e5m2: cutlass.Float8E5M2,
# CuTe does not support i1 global-memory tensors; torch.bool is stored
# as one byte, so pass bool tensor pointers as uint8 and let load
# lowering convert nonzero bytes back to cutlass.Boolean registers.
Expand Down
129 changes: 129 additions & 0 deletions test/test_cute_fp8_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

from unittest.mock import patch

import pytest
import torch

from helion._compiler.cute.mma_support import CuteMmaSupport
from helion._compiler.cute.mma_support import _probe_tcgen05_f8f6f4
from helion._compiler.cute.mma_support import get_cute_mma_support
from helion.runtime import _torch_dtype_to_cutlass


def test_dtype_mapping_fp8() -> None:
cutlass = pytest.importorskip("cutlass")

assert _torch_dtype_to_cutlass(torch.float8_e4m3fn) is cutlass.Float8E4M3FN
assert _torch_dtype_to_cutlass(torch.float8_e5m2) is cutlass.Float8E5M2


def test_dtype_mapping_existing_unchanged() -> None:
cutlass = pytest.importorskip("cutlass")

assert _torch_dtype_to_cutlass(torch.float16) is cutlass.Float16
assert _torch_dtype_to_cutlass(torch.bfloat16) is cutlass.BFloat16
assert _torch_dtype_to_cutlass(torch.float32) is cutlass.Float32


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="probe requires CUDA",
)
def test_tcgen05_f8f6f4_probe_on_b200() -> None:
cap = torch.cuda.get_device_capability(0)
ok, err = _probe_tcgen05_f8f6f4()
if cap >= (10, 0):
assert ok, f"expected fp8/f6/f4 MMA on capability {cap}, got error: {err}"
assert err is None
else:
# Older arches may or may not support the atom depending on cutlass-dsl
# backend; either way the probe must not raise.
assert isinstance(ok, bool)
if not ok:
assert isinstance(err, str) and err


def test_capability_dataclass_no_cuda_fields() -> None:
# When CUDA is unavailable, all probe fields are False and error fields populated.
support = CuteMmaSupport(
device_name=None,
capability=None,
cutlass_arch=None,
universal=False,
warp_f16bf16=False,
warpgroup_f16bf16=False,
tcgen05_f16bf16=False,
tcgen05_f8f6f4=False,
warp_error="x",
warpgroup_error="x",
tcgen05_error="x",
tcgen05_f8f6f4_error="x",
)
assert support.tcgen05_f8f6f4 is False
assert support.tcgen05_f8f6f4_error == "x"
# supported_impls should not list tcgen05 if nothing under it is supported.
assert "tcgen05" not in support.supported_impls


def test_supported_impls_does_not_list_tcgen05_when_only_fp8_supported() -> None:
# The generic "tcgen05" impl is the current f16/bf16 path. The fp8 probe is
# a separate capability bit until fp8 codegen consumes it directly.
support = CuteMmaSupport(
device_name="X",
capability=(10, 0),
cutlass_arch="SM_100",
universal=True,
warp_f16bf16=False,
warpgroup_f16bf16=False,
tcgen05_f16bf16=False,
tcgen05_f8f6f4=True,
)
assert "tcgen05" not in support.supported_impls


def test_supported_impls_lists_tcgen05_when_f16bf16_supported() -> None:
support = CuteMmaSupport(
device_name="X",
capability=(10, 0),
cutlass_arch="SM_100",
universal=True,
warp_f16bf16=False,
warpgroup_f16bf16=False,
tcgen05_f16bf16=True,
tcgen05_f8f6f4=False,
)
assert "tcgen05" in support.supported_impls


def test_probe_does_not_raise_when_cutlass_missing() -> None:
# If cutlass.cute.nvgpu.tcgen05 cannot be imported, the probe must return
# (False, error_str) rather than propagating.
import builtins

real_import = builtins.__import__

def fake_import(name, *args, **kwargs):
if name.startswith("cutlass.cute.nvgpu") or name == "cutlass":
raise ImportError(f"simulated missing module: {name}")
return real_import(name, *args, **kwargs)

with patch("builtins.__import__", side_effect=fake_import):
ok, err = _probe_tcgen05_f8f6f4()

assert ok is False
assert isinstance(err, str) and "ImportError" in err


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="aggregate probe requires CUDA",
)
def test_get_cute_mma_support_populates_fp8_field() -> None:
support = get_cute_mma_support()
cap = torch.cuda.get_device_capability(0)
# Field is always present (no AttributeError) on supported builds.
assert hasattr(support, "tcgen05_f8f6f4")
assert hasattr(support, "tcgen05_f8f6f4_error")
if cap >= (10, 0) and support.cutlass_arch is not None:
assert support.tcgen05_f8f6f4 is True
Loading