Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions benchmarks/run_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,129 @@ def _long_sum_shapes(
]


def _grpo_loss_baseline(
logits: torch.Tensor,
old_logp: torch.Tensor | None,
ref_logp: torch.Tensor | None,
completion_ids: torch.Tensor,
advantages: torch.Tensor,
completion_mask: torch.Tensor | None,
temperature: float,
beta: float,
eps_low: float,
eps_high: float,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
# Upcast logits to fp32 to match the kernel's internal precision — the
# kernel does `logits.to(torch.float32) / temperature` inside the tile
# loop, so the bf16 -> fp32 cast must happen before softmax/lse to make
# outputs comparable.
# Import via examples.grpo_loss (package) so torch.compile can trace this
# baseline without hitting importlib spec_from_file_location.
from examples.grpo_loss import torch_grpo_loss

return torch_grpo_loss(
logits.float(),
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
temperature,
beta,
eps_low,
eps_high,
)


def _fused_linear_jsd_baseline(
beta: float,
ignore_index: int,
temperature: float,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> torch.Tensor:
# PyTorch reference matching fused_linear_jsd_kernel's JSD-only math
# (the JSD computation on pre-computed logits — no linear projection).
# The example file's `fused_linear_jsd_pytorch` does linear+JSD with a
# different signature, so it doesn't match this kernel's args.
s = student_logits / temperature
t = teacher_logits / temperature
sp = torch.softmax(s.float(), dim=-1)
tp = torch.softmax(t.float(), dim=-1)
slp = torch.log_softmax(s.float(), dim=-1)
tlp = torch.log_softmax(t.float(), dim=-1)
m = (1 - beta) * sp + beta * tp
lm = torch.log(m)
skl = (sp * (slp - lm)).sum(dim=-1)
tkl = (tp * (tlp - lm)).sum(dim=-1)
loss = (1 - beta) * skl + beta * tkl
return (loss / loss.size(0)).sum()


def _fused_linear_jsd_shapes(
num_shapes: int | None = None,
) -> list[tuple[str, tuple[Any, ...]]]:
# (m, k) is the per-side logits tensor shape (batch, vocab).
# The kernel materializes ~8 fp32 intermediates of shape [tile_b, vocab];
# vocab=8192 with default tile=128 overflows VMEM (~32 MB needed), so keep
# vocab <= 4096 and scale batch instead.
configs = [(1024, 4096), (4096, 4096)]
if num_shapes is not None:
configs = configs[:num_shapes]
out: list[tuple[str, tuple[Any, ...]]] = []
for m, k in configs:
student_logits = torch.randn(m, k, device=DEVICE, dtype=torch.float32)
teacher_logits = torch.randn(m, k, device=DEVICE, dtype=torch.float32)
out.append(
(
f"[{m},{k}]",
(0.5, -100, 1.0, student_logits, teacher_logits),
)
)
return out


def _grpo_loss_shapes(
num_shapes: int | None = None,
) -> list[tuple[str, tuple[Any, ...]]]:
# (B, L, V) — kernel takes logits of shape [B, L+1, V].
configs = [(4, 512, 2048), (4, 1024, 32000)]
if num_shapes is not None:
configs = configs[:num_shapes]
temperature = 0.9
beta = 0.04
eps_low = 0.2
eps_high = 0.4
out: list[tuple[str, tuple[Any, ...]]] = []
for b, length, v in configs:
logits = torch.randn(b, length + 1, v, device=DEVICE, dtype=torch.bfloat16)
completion_ids = torch.randint(
0, v, (b, length), device=DEVICE, dtype=torch.int64
)
old_logp = torch.randn(b, length, device=DEVICE, dtype=torch.float32)
ref_logp = torch.randn(b, length, device=DEVICE, dtype=torch.float32)
advantages = torch.randn(b, device=DEVICE, dtype=torch.float32)
completion_mask = torch.ones(b, length, device=DEVICE, dtype=torch.float32)
out.append(
(
f"[{b},{length},{v}]",
(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
temperature,
beta,
eps_low,
eps_high,
),
)
)
return out


# Kernel mappings for TPU/Pallas benchmarks.
# Format: kernel_name -> (module_file, kernel_fn_name, baseline_fn, shapes_fn,
# max_mismatch_pct)
Expand Down Expand Up @@ -790,6 +913,23 @@ def _long_sum_shapes(
_squeeze_and_excitation_net_shapes,
None,
),
# Use the JSD-only inner kernel (test-covered on Pallas). The example's
# `fused_linear_jsd_fwd` wrapper that pairs with `fused_linear_jsd_pytorch`
# is not Pallas-tested and currently fails autotune accuracy here.
"fused_linear_jsd": (
"fused_linear_jsd",
"fused_linear_jsd_kernel",
_fused_linear_jsd_baseline,
_fused_linear_jsd_shapes,
None,
),
"grpo_loss": (
"grpo_loss",
"helion_grpo_loss",
_grpo_loss_baseline,
_grpo_loss_shapes,
None,
),
}


Expand Down