-
Notifications
You must be signed in to change notification settings - Fork 539
Add ragged sort kernel fallback mechanism and version guard #4187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,23 +16,30 @@ | |
| # Source from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/sparse_core/ragged_gather.py | ||
|
|
||
| import functools | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax.experimental import pallas as pl | ||
| from jax.experimental.pallas import tpu as pltpu | ||
| from jax.experimental.pallas import tpu_sc as plsc | ||
| from packaging.version import Version | ||
|
|
||
|
|
||
| # JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later | ||
| # versions renamed them to `out_type`/`scratch_types`. | ||
| if Version(jax.__version__) <= Version("0.10.0"): | ||
| _OUT_KW = "out_shape" | ||
| _SCRATCH_KW = "scratch_shapes" | ||
| _COMPILER_PARAMS = { | ||
| "use_tc_tiling_on_sc": True, | ||
| "disable_bounds_checks": True, | ||
| } | ||
| else: | ||
| _OUT_KW = "out_type" | ||
| _SCRATCH_KW = "scratch_types" | ||
| _COMPILER_PARAMS = { | ||
| "use_tc_tiling_on_sc": True, | ||
| "disable_bounds_checks": True, | ||
| "needs_layout_passes": False, | ||
| } | ||
|
|
||
|
|
||
| def main_kernel( | ||
|
|
@@ -264,6 +271,19 @@ def dma_write_loop(col_vmem_start): | |
| inner_kernel() | ||
|
|
||
|
|
||
| def _fallback_implementation( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
| weights: jax.Array | None = None, | ||
| has_weights: bool = False, | ||
| ) -> jax.Array: | ||
| """Fallback to (non-ragged) JAX implementation for ragged gather.""" | ||
| out = x[indices] | ||
| if has_weights: | ||
| out = out * weights[:, None] | ||
| return out | ||
|
|
||
|
|
||
| def calculate_col_size(hidden_size: int) -> int: | ||
| """Calculate col size for ragged gather kernel.""" | ||
| tpu_info = pltpu.get_tpu_info() | ||
|
|
@@ -288,14 +308,15 @@ def calculate_col_size(hidden_size: int) -> int: | |
| return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes | ||
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=("has_weights",)) | ||
| @functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback")) | ||
| def ragged_gather( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
| start: jax.Array, | ||
| end: jax.Array, | ||
| weights: jax.Array | None = None, | ||
| has_weights: bool = False, | ||
| enforce_fallback: bool = False, | ||
| ) -> jax.Array: | ||
| """Perform gather on indices within dynamic array start and end. | ||
|
|
||
|
|
@@ -309,6 +330,9 @@ def ragged_gather( | |
| kernel, avoiding an extra HBM read-write pass. | ||
| has_weights: Static bool flag indicating whether ``weights`` should be | ||
| applied. Must be ``True`` when ``weights`` is not ``None``. | ||
| enforce_fallback: Static bool flag. When ``True``, unconditionally use the | ||
| JAX reference implementation instead of the SparseCore kernel. | ||
| When ``False`` (default), use the SparseCore kernel and raise any error. | ||
|
|
||
| Returns: | ||
| Gathered output of shape ``(indices_size, hidden_size)``. | ||
|
|
@@ -331,12 +355,9 @@ def ragged_gather( | |
| dtype = x.dtype | ||
|
|
||
| sc_info = pltpu.get_tpu_info().sparse_core | ||
| if sc_info is None: | ||
| # Sparse core is not available. Fallback to regular gather. | ||
| out = x[indices] | ||
| if has_weights: | ||
| out = out * weights[:, None] | ||
| return out | ||
| if sc_info is None or enforce_fallback: | ||
| # Sparse core is not available or fallback is enforced. Use JAX reference. | ||
| return _fallback_implementation(x, indices, weights, has_weights) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit surprising to see a non ragged fallback inside of a function called ragged_gather - should the decision to use ragged_gather vs a fallback happen at a higher level function (e.g. somewhere higher in the call stack?)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new |
||
|
|
||
| hidden_size = x.shape[-1] | ||
| out_size = indices.size | ||
|
|
@@ -371,9 +392,8 @@ def ragged_gather( | |
| subcore_axis_name=vector_mesh.subcore_axis_name, | ||
| has_weights=has_weights, | ||
| ), | ||
| compiler_params=pltpu.CompilerParams( | ||
| use_tc_tiling_on_sc=True, | ||
| disable_bounds_checks=True, | ||
| compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args | ||
| **_COMPILER_PARAMS, | ||
| ), | ||
| mesh=vector_mesh, | ||
| name="sc_ragged_gather", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this just be:
ragged_gather: true --> use ragged kernel
ragged_gather:false --> use non ragged fallback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two operations:
1: ragged permute (fwd: ragged gather; bwd: ragged gather reduce)
2. ragged unpermute (fwd: ragged gather reduce, bwd: ragged gather)
Technically we can introduce 4 flags, controlling whether we want kernel version/JAX version respectively. However, it is probably not necessary. Instead, we have one flag
use_ragged_sortindicating whether we want to use kernel version for all, andragged_gather_fallbackandragged_gather_reduce_fallbackcontrolling kernels respectively.I think you are suggesting re-naming these two flags, with opposite meanings? I don't have strong opinion on this, but I can update if you think they are much better...