Skip to content

Add ragged sort kernel fallback mechanism and version guard#4187

Open
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-ragged-guard
Open

Add ragged sort kernel fallback mechanism and version guard#4187
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-ragged-guard

Conversation

@NuojCheng

@NuojCheng NuojCheng commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR

  • Introduces two flags, ragged_gather_fallback and ragged_gather_reduce_fallback. When they are true, a fully JAX, non-ragged version instead of the kernel version gonna be used
  • Add version guard protecting ragged gather kernel.

Tests

Based on xprofs, both flags work effectively.

Fall back ragged gather reduce only

smoke_train model_name=deepseek2-16b ici_expert_parallelism=4 per_device_batch_size=1 max_target_length=4096 use_random_routing=true use_ring_of_experts=true use_ragged_sort=true ragged_gather_reduce_fallback=true debug_sharding=false profiler=xplane ragged_gather_fallback=false enable_tpu_profiling_options=true

xprof

Fall back ragged gather only

smoke_train model_name=deepseek2-16b ici_expert_parallelism=4 per_device_batch_size=1 max_target_length=4096 use_random_routing=true use_ring_of_experts=true use_ragged_sort=true ragged_gather_reduce_fallback=false debug_sharding=false profiler=xplane ragged_gather_fallback=true enable_tpu_profiling_options=true

xprof

Fall back both kernels

smoke_train model_name=deepseek2-16b ici_expert_parallelism=4 per_device_batch_size=1 max_target_length=4096 use_random_routing=true use_ring_of_experts=true use_ragged_sort=true ragged_gather_reduce_fallback=true debug_sharding=false profiler=xplane ragged_gather_fallback=true enable_tpu_profiling_options=true

xprof

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 93.33333% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/kernels/ragged/ragged_gather_reduce.py 66.66% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

return out
if sc_info is None or enforce_fallback:
# Sparse core is not available or fallback is enforced. Use JAX reference.
return _fallback_implementation(x, indices, weights, has_weights)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit surprising to see a non ragged fallback inside of a function called ragged_gather - should the decision to use ragged_gather vs a fallback happen at a higher level function (e.g. somewhere higher in the call stack?)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ragged_gather_fallback and ragged_gather_reduce_fallback flags control kernel-level fallback logic. We avoid using these flags directly in moe.py because we may need the native JAX implementation for the forward pass and the custom kernel for the backward pass. While we could theoretically use them in ragged_sort.py, a fallback mechanism already exists in the primary kernel wrapper. We leverage that existing structure instead of introducing redundant if-else conditions.

Comment thread tests/unit/moe_test.py
self._run_ragged_sort_loss_and_grad(use_ring_of_experts=True, ragged_buffer_factor=1.5)

@pytest.mark.tpu_only
@pytest.mark.skip_on_tpu7x

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why skip on tpu7x?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there were some issues using ragged sort on bloom. @darisoy do we have a buganizer tracking this?

weights: jax.Array | None = None,
has_weights: bool = False,
) -> jax.Array:
"""Fallback to JAX implementation for ragged gather."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a ragged implementation? Doesn't this grow with the full buffer size or no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are only two options:

  1. Jax sort, no raggedness
  2. sparse core kernels, ragged
    Update the comments to better reflect JAX is non-ragged.

# without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run
# inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute.
use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops
ragged_gather_fallback: false # when true, unconditionally use the JAX reference implementation instead of the

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this just be:
ragged_gather: true --> use ragged kernel
ragged_gather:false --> use non ragged fallback

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two operations:
1: ragged permute (fwd: ragged gather; bwd: ragged gather reduce)
2. ragged unpermute (fwd: ragged gather reduce, bwd: ragged gather)

Technically we can introduce 4 flags, controlling whether we want kernel version/JAX version respectively. However, it is probably not necessary. Instead, we have one flag use_ragged_sort indicating whether we want to use kernel version for all, and ragged_gather_fallback and ragged_gather_reduce_fallback controlling kernels respectively.

I think you are suggesting re-naming these two flags, with opposite meanings? I don't have strong opinion on this, but I can update if you think they are much better...

@NuojCheng NuojCheng force-pushed the chengnuojin-ragged-guard branch from a0a7096 to e76b469 Compare June 19, 2026 00:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants