diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 0b4ffe3667..5b11d5a124 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -1371,6 +1371,8 @@ def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object: torch.float32: cutlass.Float32, torch.float64: cutlass.Float64, torch.bfloat16: cutlass.BFloat16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, # CuTe does not support i1 global-memory tensors; torch.bool is stored # as one byte, so pass bool tensor pointers as uint8 and let load # lowering convert nonzero bytes back to cutlass.Boolean registers. diff --git a/test/test_cute_dtype_mapping.py b/test/test_cute_dtype_mapping.py new file mode 100644 index 0000000000..836ec5e14e --- /dev/null +++ b/test/test_cute_dtype_mapping.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest +import torch + +from helion.runtime import _torch_dtype_to_cutlass + +cutlass = pytest.importorskip("cutlass") + + +def test_fp16_bf16_fp32_unchanged() -> None: + assert _torch_dtype_to_cutlass(torch.float16) is cutlass.Float16 + assert _torch_dtype_to_cutlass(torch.bfloat16) is cutlass.BFloat16 + assert _torch_dtype_to_cutlass(torch.float32) is cutlass.Float32 + + +def test_fp8_e4m3fn_maps_to_cutlass() -> None: + assert _torch_dtype_to_cutlass(torch.float8_e4m3fn) is cutlass.Float8E4M3FN + + +def test_fp8_e5m2_maps_to_cutlass() -> None: + assert _torch_dtype_to_cutlass(torch.float8_e5m2) is cutlass.Float8E5M2