diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 943692cdf6..9a17bd7dc9 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -22,7 +22,20 @@ # %% -@helion.kernel +def baseline_ln_fwd( + x: torch.Tensor, + normalized_shape: list[int], + weight: torch.Tensor, + bias: torch.Tensor | None = None, + eps: float = 1e-5, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean = x.to(torch.float32).mean(dim=-1) + var = x.to(torch.float32).var(dim=-1, unbiased=False) + rstd = torch.rsqrt(var + eps) + out = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) + return out, mean, rstd + +@helion.kernel(autotune_baseline_fn=baseline_ln_fwd) def layer_norm_fwd( x: torch.Tensor, normalized_shape: list[int], @@ -83,7 +96,30 @@ def layer_norm_fwd( # %% -@helion.kernel +def baseline_ln_bwd( + grad_out: torch.Tensor, + x: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + weight: torch.Tensor, + compute_bias_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + x_hat = (x - mean[:, None]) * rstd[:, None] + grad_weight = (grad_out * x_hat).sum(dim=0).to(weight.dtype) + grad_bias = grad_out.sum(dim=0).to(weight.dtype) if compute_bias_grad else None + + n = x.size(1) + wdy = weight * grad_out + c1 = (x_hat * wdy).sum(dim=-1, keepdim=True) / n + c2 = wdy.sum(dim=-1, keepdim=True) / n + grad_x = ((wdy - (x_hat * c1 + c2)) * rstd[:, None]).to(x.dtype) + + return grad_x, grad_weight, grad_bias + +@helion.kernel( + autotune_baseline_fn=baseline_ln_bwd, + config=helion.Config(block_sizes=[32, 1024]), +) def layer_norm_bwd( grad_out: torch.Tensor, x: torch.Tensor, @@ -259,8 +295,8 @@ def main() -> None: layer_norm, torch.nn.functional.layer_norm, (x, [dim], weight, b, eps), - rtol=1e-3, - atol=1e-3, + rtol=1e-2, + atol=1e-2, ) # Test forward + backward pass @@ -277,8 +313,8 @@ def main() -> None: layer_norm, torch.nn.functional.layer_norm, (x_grad, [dim], weight_grad, b, eps), - rtol=1e-3, - atol=1e-3, + rtol=1e-2, + atol=1e-2, bwd=True, )