Add ragged sort kernel fallback mechanism and version guard#4187
Add ragged sort kernel fallback mechanism and version guard#4187NuojCheng wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
74e1341 to
eb7f846
Compare
eb7f846 to
a0a7096
Compare
| return out | ||
| if sc_info is None or enforce_fallback: | ||
| # Sparse core is not available or fallback is enforced. Use JAX reference. | ||
| return _fallback_implementation(x, indices, weights, has_weights) |
There was a problem hiding this comment.
this is a bit surprising to see a non ragged fallback inside of a function called ragged_gather - should the decision to use ragged_gather vs a fallback happen at a higher level function (e.g. somewhere higher in the call stack?)
There was a problem hiding this comment.
The new ragged_gather_fallback and ragged_gather_reduce_fallback flags control kernel-level fallback logic. We avoid using these flags directly in moe.py because we may need the native JAX implementation for the forward pass and the custom kernel for the backward pass. While we could theoretically use them in ragged_sort.py, a fallback mechanism already exists in the primary kernel wrapper. We leverage that existing structure instead of introducing redundant if-else conditions.
| self._run_ragged_sort_loss_and_grad(use_ring_of_experts=True, ragged_buffer_factor=1.5) | ||
|
|
||
| @pytest.mark.tpu_only | ||
| @pytest.mark.skip_on_tpu7x |
There was a problem hiding this comment.
I think there were some issues using ragged sort on bloom. @darisoy do we have a buganizer tracking this?
| weights: jax.Array | None = None, | ||
| has_weights: bool = False, | ||
| ) -> jax.Array: | ||
| """Fallback to JAX implementation for ragged gather.""" |
There was a problem hiding this comment.
is this a ragged implementation? Doesn't this grow with the full buffer size or no?
There was a problem hiding this comment.
There are only two options:
- Jax sort, no raggedness
- sparse core kernels, ragged
Update the comments to better reflect JAX is non-ragged.
| # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run | ||
| # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. | ||
| use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops | ||
| ragged_gather_fallback: false # when true, unconditionally use the JAX reference implementation instead of the |
There was a problem hiding this comment.
could this just be:
ragged_gather: true --> use ragged kernel
ragged_gather:false --> use non ragged fallback
There was a problem hiding this comment.
There are two operations:
1: ragged permute (fwd: ragged gather; bwd: ragged gather reduce)
2. ragged unpermute (fwd: ragged gather reduce, bwd: ragged gather)
Technically we can introduce 4 flags, controlling whether we want kernel version/JAX version respectively. However, it is probably not necessary. Instead, we have one flag use_ragged_sort indicating whether we want to use kernel version for all, and ragged_gather_fallback and ragged_gather_reduce_fallback controlling kernels respectively.
I think you are suggesting re-naming these two flags, with opposite meanings? I don't have strong opinion on this, but I can update if you think they are much better...
a0a7096 to
e76b469
Compare
Description
This PR
ragged_gather_fallbackandragged_gather_reduce_fallback. When they are true, a fully JAX, non-ragged version instead of the kernel version gonna be usedTests
Based on xprofs, both flags work effectively.
Fall back ragged gather reduce only
xprof
Fall back ragged gather only
xprof
Fall back both kernels
xprof
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.