Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def _lerp_scalar_decomp(
return start + weight * (end - start)


def _fixup_addmm_result_dtype(
result: torch.Tensor, target_dtype: torch.dtype
) -> torch.Tensor:
# Rewrite this FX node's meta["val"] dtype in place. Helion's addmm
# lowering (aten_lowering.reduce_3d_dot) reads node.meta["val"].dtype
# to choose the tl.dot out_dtype, so a trailing .to() cast would leave
# the dot output truncated and the cast unable to recover precision.
fixed = result.new_empty(result.shape, dtype=target_dtype)
mode = proxy_tensor.get_proxy_mode()
if isinstance(mode, proxy_tensor.ProxyTorchDispatchMode):
tracer = mode.tracer
slot = proxy_tensor.get_proxy_slot(result, tracer, None)
if slot is not None:
slot.proxy.node.meta["val"] = fixed
proxy_tensor.set_proxy_slot(fixed, tracer, slot)
return fixed


def _get_custom_decomp_table() -> dict[torch._ops.OpOverload, Callable[..., object]]:
from ..language._gelu_tanh_approx import install_gelu_decomp

Expand Down Expand Up @@ -1672,7 +1690,17 @@ def visit_Call(self, node: ast.Call) -> object:
func = self.visit(node.func)

# pyrefly: ignore [bad-argument-type]
return _CheckForIndexCalls.retry_call(func, args, kwargs)
result = _CheckForIndexCalls.retry_call(func, args, kwargs)
# Pin addmm/baddbmm output dtype to `self`: see type_propagation.
if (
func in (torch.addmm, torch.baddbmm)
and isinstance(result, torch.Tensor)
and args
and isinstance(args[0], torch.Tensor)
and result.dtype != args[0].dtype
):
result = _fixup_addmm_result_dtype(result, args[0].dtype)
return result

def visit_Attribute(self, node: ast.Attribute) -> object:
return getattr(self.visit(node.value), node.attr)
Expand Down
14 changes: 14 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,20 @@ def to_proxy(arg: TypeInfo) -> object:
),
origin,
)
# Pin addmm/baddbmm output dtype to `self`: Helion's lowering
# treats it as authoritative, and PyTorch's FakeTensor meta has
# flip-flopped on whether to follow `self` or `mat1`.
if (
self.value in (torch.addmm, torch.baddbmm)
and isinstance(output_type, TensorType)
and args
and isinstance(args[0], TensorType)
and output_type.fake_value.dtype != args[0].fake_value.dtype
):
output_type = TensorType(
output_type.origin,
output_type.fake_value.to(args[0].fake_value.dtype),
)
output_type.tree_map(warn_wrong_device)
if (
origin.is_host()
Expand Down
23 changes: 23 additions & 0 deletions test/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,29 @@ def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

# torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here

def test_addmm_mixed_dtype_acc_fp32_mats_fp16(self):
# Regression: torch.addmm(fp32_acc, fp16_x, fp16_y) is the canonical
# Helion matmul pattern (see examples/matmul.py). Some PyTorch
# nightlies regressed the FakeTensor meta to return mat1's dtype;
# Helion's compiler pins the output to `self` instead.
@helion.kernel(autotune_effort="none", static_shapes=True)
def matmul_fp32_acc(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
_, n = y.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc.to(x.dtype)
return out

x = torch.randn(64, 64, device=DEVICE, dtype=torch.float16)
y = torch.randn(64, 64, device=DEVICE, dtype=torch.float16)
_, result = code_and_output(matmul_fp32_acc, (x, y))
expected = (x.float() @ y.float()).to(torch.float16)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)

@skipIfNotTriton("triton-specific codegen assertions")
@skipIfRefEager("Debug dtype codegen checks rely on compiled code")
@skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772")
Expand Down
Loading