Skip to content

Introduce SC dense gather reduce kernel for MoE layers#2979

Merged
BirdsOfAFthr merged 1 commit into
vllm-project:mainfrom
BirdsOfAFthr:dense-gr
Jun 24, 2026
Merged

Introduce SC dense gather reduce kernel for MoE layers#2979
BirdsOfAFthr merged 1 commit into
vllm-project:mainfrom
BirdsOfAFthr:dense-gr

Conversation

@BirdsOfAFthr

Copy link
Copy Markdown
Collaborator

Description

This PR introduces a SparseCore Pallas kernel implementation for a dense gather-reduce in MoE.

Details & Context

  • Problem: Under pure TP, the gather-reduce step (which collects outputs from active experts and performs the top-k weighted reduction) was executed using a native JAX implementation. This native path is highly unperformant on TPU.
  • Solution: Similar to what we did for EP, where we have ragged gather reduce kernel to offload the ops to SC, we introduce dense gather reduce kernel for TP.
  • Limitations / Future Work:
    • Currently limited to hardware platforms equipped with TPU SparseCore.
    • Fusing the gather reduce ops into gmm kernel, which is agnostic to sharding.
    • Remove the sequential dependency chain of vector addtions.

Tests

  • Verified correctness and NaN-handling behavior against a JAX reference implementation across multiple shapes, dtypes (bfloat16 and float32), and reduction group sizes (top-k).
  • Ensured the unit tests gracefully fallback.

Commands to Run:

PYTHONPATH=. .venv/bin/python -m unittest tests.kernels.dense_gather_reduce_test

@BirdsOfAFthr BirdsOfAFthr requested a review from guowei-dev June 24, 2026 05:06
@BirdsOfAFthr BirdsOfAFthr added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 24, 2026
@BirdsOfAFthr BirdsOfAFthr changed the title Introduce Sparse Core dense gather reduce kernel for MoE layers Introduce SC dense gather reduce kernel for MoE layers Jun 24, 2026
@BirdsOfAFthr BirdsOfAFthr force-pushed the dense-gr branch 3 times, most recently from 27ca34b to c148465 Compare June 24, 2026 16:35
Signed-off-by: Amanda Liang <amandaliang@google.com>
@BirdsOfAFthr BirdsOfAFthr merged commit 8f92c30 into vllm-project:main Jun 24, 2026
54 checks passed
amanseervi pushed a commit to amanseervi/tpu-inference that referenced this pull request Jun 25, 2026
…2979)

Signed-off-by: Amanda Liang <amandaliang@google.com>
Signed-off-by: Aman Seervi <amanseervi@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants