From 2433845d2869fb2f21335d8969413cdc7234770d Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 13 May 2026 17:18:36 -0700 Subject: [PATCH] [cute] add fp8 dtypes to _torch_dtype_to_cutlass mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow fp8 tensors (torch.float8_e4m3fn, torch.float8_e5m2) to flow through the cute-backend kernel binding layer. Before this change, passing fp8 tensors to a helion.kernel(backend='cute') call would hit BackendUnsupported at type-mapping time, before any MMA/codegen could be attempted. This is pure plumbing — no MMA, lowering, or codegen behavior changes. fp8 kernels will still fail downstream until follow-up PRs wire fp8 through mma_support and cute_mma codegen. The point of this PR is to move the failure boundary forward so the next PR can be reviewed in isolation. Authored with Claude. [ghstack-poisoned] --- helion/runtime/__init__.py | 2 ++ test/test_cute_dtype_mapping.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 test/test_cute_dtype_mapping.py diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 0b4ffe366..5b11d5a12 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -1371,6 +1371,8 @@ def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object: torch.float32: cutlass.Float32, torch.float64: cutlass.Float64, torch.bfloat16: cutlass.BFloat16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, # CuTe does not support i1 global-memory tensors; torch.bool is stored # as one byte, so pass bool tensor pointers as uint8 and let load # lowering convert nonzero bytes back to cutlass.Boolean registers. diff --git a/test/test_cute_dtype_mapping.py b/test/test_cute_dtype_mapping.py new file mode 100644 index 000000000..836ec5e14 --- /dev/null +++ b/test/test_cute_dtype_mapping.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest +import torch + +from helion.runtime import _torch_dtype_to_cutlass + +cutlass = pytest.importorskip("cutlass") + + +def test_fp16_bf16_fp32_unchanged() -> None: + assert _torch_dtype_to_cutlass(torch.float16) is cutlass.Float16 + assert _torch_dtype_to_cutlass(torch.bfloat16) is cutlass.BFloat16 + assert _torch_dtype_to_cutlass(torch.float32) is cutlass.Float32 + + +def test_fp8_e4m3fn_maps_to_cutlass() -> None: + assert _torch_dtype_to_cutlass(torch.float8_e4m3fn) is cutlass.Float8E4M3FN + + +def test_fp8_e5m2_maps_to_cutlass() -> None: + assert _torch_dtype_to_cutlass(torch.float8_e5m2) is cutlass.Float8E5M2