diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index 31e6c70b1..363dfeaea 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -608,6 +608,129 @@ def _long_sum_shapes( ] +def _grpo_loss_baseline( + logits: torch.Tensor, + old_logp: torch.Tensor | None, + ref_logp: torch.Tensor | None, + completion_ids: torch.Tensor, + advantages: torch.Tensor, + completion_mask: torch.Tensor | None, + temperature: float, + beta: float, + eps_low: float, + eps_high: float, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + # Upcast logits to fp32 to match the kernel's internal precision — the + # kernel does `logits.to(torch.float32) / temperature` inside the tile + # loop, so the bf16 -> fp32 cast must happen before softmax/lse to make + # outputs comparable. + # Import via examples.grpo_loss (package) so torch.compile can trace this + # baseline without hitting importlib spec_from_file_location. + from examples.grpo_loss import torch_grpo_loss + + return torch_grpo_loss( + logits.float(), + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + +def _fused_linear_jsd_baseline( + beta: float, + ignore_index: int, + temperature: float, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, +) -> torch.Tensor: + # PyTorch reference matching fused_linear_jsd_kernel's JSD-only math + # (the JSD computation on pre-computed logits — no linear projection). + # The example file's `fused_linear_jsd_pytorch` does linear+JSD with a + # different signature, so it doesn't match this kernel's args. + s = student_logits / temperature + t = teacher_logits / temperature + sp = torch.softmax(s.float(), dim=-1) + tp = torch.softmax(t.float(), dim=-1) + slp = torch.log_softmax(s.float(), dim=-1) + tlp = torch.log_softmax(t.float(), dim=-1) + m = (1 - beta) * sp + beta * tp + lm = torch.log(m) + skl = (sp * (slp - lm)).sum(dim=-1) + tkl = (tp * (tlp - lm)).sum(dim=-1) + loss = (1 - beta) * skl + beta * tkl + return (loss / loss.size(0)).sum() + + +def _fused_linear_jsd_shapes( + num_shapes: int | None = None, +) -> list[tuple[str, tuple[Any, ...]]]: + # (m, k) is the per-side logits tensor shape (batch, vocab). + # The kernel materializes ~8 fp32 intermediates of shape [tile_b, vocab]; + # vocab=8192 with default tile=128 overflows VMEM (~32 MB needed), so keep + # vocab <= 4096 and scale batch instead. + configs = [(1024, 4096), (4096, 4096)] + if num_shapes is not None: + configs = configs[:num_shapes] + out: list[tuple[str, tuple[Any, ...]]] = [] + for m, k in configs: + student_logits = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + teacher_logits = torch.randn(m, k, device=DEVICE, dtype=torch.float32) + out.append( + ( + f"[{m},{k}]", + (0.5, -100, 1.0, student_logits, teacher_logits), + ) + ) + return out + + +def _grpo_loss_shapes( + num_shapes: int | None = None, +) -> list[tuple[str, tuple[Any, ...]]]: + # (B, L, V) — kernel takes logits of shape [B, L+1, V]. + configs = [(4, 512, 2048), (4, 1024, 32000)] + if num_shapes is not None: + configs = configs[:num_shapes] + temperature = 0.9 + beta = 0.04 + eps_low = 0.2 + eps_high = 0.4 + out: list[tuple[str, tuple[Any, ...]]] = [] + for b, length, v in configs: + logits = torch.randn(b, length + 1, v, device=DEVICE, dtype=torch.bfloat16) + completion_ids = torch.randint( + 0, v, (b, length), device=DEVICE, dtype=torch.int64 + ) + old_logp = torch.randn(b, length, device=DEVICE, dtype=torch.float32) + ref_logp = torch.randn(b, length, device=DEVICE, dtype=torch.float32) + advantages = torch.randn(b, device=DEVICE, dtype=torch.float32) + completion_mask = torch.ones(b, length, device=DEVICE, dtype=torch.float32) + out.append( + ( + f"[{b},{length},{v}]", + ( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ), + ) + ) + return out + + # Kernel mappings for TPU/Pallas benchmarks. # Format: kernel_name -> (module_file, kernel_fn_name, baseline_fn, shapes_fn, # max_mismatch_pct) @@ -790,6 +913,23 @@ def _long_sum_shapes( _squeeze_and_excitation_net_shapes, None, ), + # Use the JSD-only inner kernel (test-covered on Pallas). The example's + # `fused_linear_jsd_fwd` wrapper that pairs with `fused_linear_jsd_pytorch` + # is not Pallas-tested and currently fails autotune accuracy here. + "fused_linear_jsd": ( + "fused_linear_jsd", + "fused_linear_jsd_kernel", + _fused_linear_jsd_baseline, + _fused_linear_jsd_shapes, + None, + ), + "grpo_loss": ( + "grpo_loss", + "helion_grpo_loss", + _grpo_loss_baseline, + _grpo_loss_shapes, + None, + ), }