Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object:
torch.float32: cutlass.Float32,
torch.float64: cutlass.Float64,
torch.bfloat16: cutlass.BFloat16,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e5m2: cutlass.Float8E5M2,
# CuTe does not support i1 global-memory tensors; torch.bool is stored
# as one byte, so pass bool tensor pointers as uint8 and let load
# lowering convert nonzero bytes back to cutlass.Boolean registers.
Expand Down
22 changes: 22 additions & 0 deletions test/test_cute_dtype_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import pytest
import torch

from helion.runtime import _torch_dtype_to_cutlass

cutlass = pytest.importorskip("cutlass")


def test_fp16_bf16_fp32_unchanged() -> None:
assert _torch_dtype_to_cutlass(torch.float16) is cutlass.Float16
assert _torch_dtype_to_cutlass(torch.bfloat16) is cutlass.BFloat16
assert _torch_dtype_to_cutlass(torch.float32) is cutlass.Float32


def test_fp8_e4m3fn_maps_to_cutlass() -> None:
assert _torch_dtype_to_cutlass(torch.float8_e4m3fn) is cutlass.Float8E4M3FN


def test_fp8_e5m2_maps_to_cutlass() -> None:
assert _torch_dtype_to_cutlass(torch.float8_e5m2) is cutlass.Float8E5M2
Loading