diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 1b4d417520..26a7059993 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -96,6 +96,24 @@ def _lerp_scalar_decomp( return start + weight * (end - start) +def _fixup_addmm_result_dtype( + result: torch.Tensor, target_dtype: torch.dtype +) -> torch.Tensor: + # Rewrite this FX node's meta["val"] dtype in place. Helion's addmm + # lowering (aten_lowering.reduce_3d_dot) reads node.meta["val"].dtype + # to choose the tl.dot out_dtype, so a trailing .to() cast would leave + # the dot output truncated and the cast unable to recover precision. + fixed = result.new_empty(result.shape, dtype=target_dtype) + mode = proxy_tensor.get_proxy_mode() + if isinstance(mode, proxy_tensor.ProxyTorchDispatchMode): + tracer = mode.tracer + slot = proxy_tensor.get_proxy_slot(result, tracer, None) + if slot is not None: + slot.proxy.node.meta["val"] = fixed + proxy_tensor.set_proxy_slot(fixed, tracer, slot) + return fixed + + def _get_custom_decomp_table() -> dict[torch._ops.OpOverload, Callable[..., object]]: from ..language._gelu_tanh_approx import install_gelu_decomp @@ -1672,7 +1690,17 @@ def visit_Call(self, node: ast.Call) -> object: func = self.visit(node.func) # pyrefly: ignore [bad-argument-type] - return _CheckForIndexCalls.retry_call(func, args, kwargs) + result = _CheckForIndexCalls.retry_call(func, args, kwargs) + # Pin addmm/baddbmm output dtype to `self`: see type_propagation. + if ( + func in (torch.addmm, torch.baddbmm) + and isinstance(result, torch.Tensor) + and args + and isinstance(args[0], torch.Tensor) + and result.dtype != args[0].dtype + ): + result = _fixup_addmm_result_dtype(result, args[0].dtype) + return result def visit_Attribute(self, node: ast.Attribute) -> object: return getattr(self.visit(node.value), node.attr) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 8d5c03a623..e89b3c78c5 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -881,6 +881,20 @@ def to_proxy(arg: TypeInfo) -> object: ), origin, ) + # Pin addmm/baddbmm output dtype to `self`: Helion's lowering + # treats it as authoritative, and PyTorch's FakeTensor meta has + # flip-flopped on whether to follow `self` or `mat1`. + if ( + self.value in (torch.addmm, torch.baddbmm) + and isinstance(output_type, TensorType) + and args + and isinstance(args[0], TensorType) + and output_type.fake_value.dtype != args[0].fake_value.dtype + ): + output_type = TensorType( + output_type.origin, + output_type.fake_value.to(args[0].fake_value.dtype), + ) output_type.tree_map(warn_wrong_device) if ( origin.is_host() diff --git a/test/test_dot.py b/test/test_dot.py index d4077ce5da..dbf30416d4 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -501,6 +501,29 @@ def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here + def test_addmm_mixed_dtype_acc_fp32_mats_fp16(self): + # Regression: torch.addmm(fp32_acc, fp16_x, fp16_y) is the canonical + # Helion matmul pattern (see examples/matmul.py). Some PyTorch + # nightlies regressed the FakeTensor meta to return mat1's dtype; + # Helion's compiler pins the output to `self` instead. + @helion.kernel(autotune_effort="none", static_shapes=True) + def matmul_fp32_acc(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + x = torch.randn(64, 64, device=DEVICE, dtype=torch.float16) + y = torch.randn(64, 64, device=DEVICE, dtype=torch.float16) + _, result = code_and_output(matmul_fp32_acc, (x, y)) + expected = (x.float() @ y.float()).to(torch.float16) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + @skipIfNotTriton("triton-specific codegen assertions") @skipIfRefEager("Debug dtype codegen checks rely on compiled code") @skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772")