Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the Mo
# without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run
# inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute.
use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops
ragged_gather_fallback: false # when true, unconditionally use the JAX reference implementation instead of the

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this just be:
ragged_gather: true --> use ragged kernel
ragged_gather:false --> use non ragged fallback

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two operations:
1: ragged permute (fwd: ragged gather; bwd: ragged gather reduce)
2. ragged unpermute (fwd: ragged gather reduce, bwd: ragged gather)

Technically we can introduce 4 flags, controlling whether we want kernel version/JAX version respectively. However, it is probably not necessary. Instead, we have one flag use_ragged_sort indicating whether we want to use kernel version for all, and ragged_gather_fallback and ragged_gather_reduce_fallback controlling kernels respectively.

I think you are suggesting re-naming these two flags, with opposite meanings? I don't have strong opinion on this, but I can update if you think they are much better...

# ragged gather SparseCore kernel. When false (default), use the SparseCore kernel.
ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the
# ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel.
# tunable tiling dimensions used for mlp gmm
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
# tokamax ragged dot - supports all 18 configs
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,16 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use a custom mosaic kernel for token gather ops.",
)
ragged_gather_fallback: bool = Field(
False,
description="When true, unconditionally use the JAX reference implementation instead of the ragged gather "
"SparseCore kernel. When false (default), use the SparseCore kernel.",
)
ragged_gather_reduce_fallback: bool = Field(
False,
description="When true, unconditionally use the JAX reference implementation instead of the ragged gather "
"reduce SparseCore kernel. When false (default), use the SparseCore kernel.",
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
moe_fsdp_use_two_stage_all_gather: bool = Field(
Expand Down
44 changes: 32 additions & 12 deletions src/maxtext/kernels/ragged/ragged_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,30 @@
# Source from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/sparse_core/ragged_gather.py

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas import tpu_sc as plsc
from packaging.version import Version


# JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later
# versions renamed them to `out_type`/`scratch_types`.
if Version(jax.__version__) <= Version("0.10.0"):
_OUT_KW = "out_shape"
_SCRATCH_KW = "scratch_shapes"
_COMPILER_PARAMS = {
"use_tc_tiling_on_sc": True,
"disable_bounds_checks": True,
}
else:
_OUT_KW = "out_type"
_SCRATCH_KW = "scratch_types"
_COMPILER_PARAMS = {
"use_tc_tiling_on_sc": True,
"disable_bounds_checks": True,
"needs_layout_passes": False,
}


def main_kernel(
Expand Down Expand Up @@ -264,6 +271,19 @@ def dma_write_loop(col_vmem_start):
inner_kernel()


def _fallback_implementation(
x: jax.Array,
indices: jax.Array,
weights: jax.Array | None = None,
has_weights: bool = False,
) -> jax.Array:
"""Fallback to (non-ragged) JAX implementation for ragged gather."""
out = x[indices]
if has_weights:
out = out * weights[:, None]
return out


def calculate_col_size(hidden_size: int) -> int:
"""Calculate col size for ragged gather kernel."""
tpu_info = pltpu.get_tpu_info()
Expand All @@ -288,14 +308,15 @@ def calculate_col_size(hidden_size: int) -> int:
return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes


@functools.partial(jax.jit, static_argnames=("has_weights",))
@functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback"))
def ragged_gather(
x: jax.Array,
indices: jax.Array,
start: jax.Array,
end: jax.Array,
weights: jax.Array | None = None,
has_weights: bool = False,
enforce_fallback: bool = False,
) -> jax.Array:
"""Perform gather on indices within dynamic array start and end.

Expand All @@ -309,6 +330,9 @@ def ragged_gather(
kernel, avoiding an extra HBM read-write pass.
has_weights: Static bool flag indicating whether ``weights`` should be
applied. Must be ``True`` when ``weights`` is not ``None``.
enforce_fallback: Static bool flag. When ``True``, unconditionally use the
JAX reference implementation instead of the SparseCore kernel.
When ``False`` (default), use the SparseCore kernel and raise any error.

Returns:
Gathered output of shape ``(indices_size, hidden_size)``.
Expand All @@ -331,12 +355,9 @@ def ragged_gather(
dtype = x.dtype

sc_info = pltpu.get_tpu_info().sparse_core
if sc_info is None:
# Sparse core is not available. Fallback to regular gather.
out = x[indices]
if has_weights:
out = out * weights[:, None]
return out
if sc_info is None or enforce_fallback:
# Sparse core is not available or fallback is enforced. Use JAX reference.
return _fallback_implementation(x, indices, weights, has_weights)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit surprising to see a non ragged fallback inside of a function called ragged_gather - should the decision to use ragged_gather vs a fallback happen at a higher level function (e.g. somewhere higher in the call stack?)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ragged_gather_fallback and ragged_gather_reduce_fallback flags control kernel-level fallback logic. We avoid using these flags directly in moe.py because we may need the native JAX implementation for the forward pass and the custom kernel for the backward pass. While we could theoretically use them in ragged_sort.py, a fallback mechanism already exists in the primary kernel wrapper. We leverage that existing structure instead of introducing redundant if-else conditions.


hidden_size = x.shape[-1]
out_size = indices.size
Expand Down Expand Up @@ -371,9 +392,8 @@ def ragged_gather(
subcore_axis_name=vector_mesh.subcore_axis_name,
has_weights=has_weights,
),
compiler_params=pltpu.CompilerParams(
use_tc_tiling_on_sc=True,
disable_bounds_checks=True,
compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args
**_COMPILER_PARAMS,
),
mesh=vector_mesh,
name="sc_ragged_gather",
Expand Down
40 changes: 18 additions & 22 deletions src/maxtext/kernels/ragged/ragged_gather_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax.numpy as jnp
from packaging.version import Version


# JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later
# versions renamed them to `out_type`/`scratch_types`.
if Version(jax.__version__) <= Version("0.10.0"):
Expand Down Expand Up @@ -360,26 +361,23 @@ def _preprocess(
num_src_rows_per_row_partition.astype(jnp.int32),
(0, num_simd_lanes - num_row_partitions),
)
# If there is no valid source row in a reduce group, we set the mask to
# False, so that the output for that group is set to zero.
mask = jnp.any(valid_rows_mask.reshape(-1, reduce_group_size), axis=-1)

return (
src_indices,
dst_indices,
topk_weights,
num_src_rows_per_row_partition,
mask,
)


@functools.partial(jax.jit, static_argnames=("reduce_group_size",))
@functools.partial(jax.jit, static_argnames=("reduce_group_size", "enforce_fallback"))
def ragged_gather_reduce(
x: jax.Array,
indices: jax.Array,
topk_weights: jax.Array,
valid_rows_mask: jax.Array,
reduce_group_size: int,
enforce_fallback: bool = False,
) -> jax.Array:
"""Gathers `x` according to `indices`, applies weights and masks, and reduces.

Expand All @@ -402,6 +400,9 @@ def ragged_gather_reduce(
valid, with shape `(input_size,)`.
reduce_group_size: An integer representing the number of consecutive rows to
reduce (sum) together.
enforce_fallback: Static bool flag. When ``True``, unconditionally use the
JAX reference implementation instead of the SparseCore kernel.
When ``False`` (default), use the SparseCore kernel and raise any error.

Returns:
A 2D JAX array of reduced data with shape
Expand All @@ -414,7 +415,8 @@ def ragged_gather_reduce(
assert valid_rows_mask.ndim == 1, "ragged_gather_reduce only supports 1d valid_rows_mask."

sc_info = pltpu.get_tpu_info().sparse_core
if sc_info is None:
if sc_info is None or enforce_fallback:
# Sparse core is not available or fallback is enforced. Use JAX reference.
return _fallback_implementation(x, indices, topk_weights, valid_rows_mask, reduce_group_size)

# Heuristic threshold on whether to fallback for small inputs.
Expand All @@ -431,19 +433,20 @@ def ragged_gather_reduce(
num_simd_lanes = sc_info.num_lanes
num_cores = sc_info.num_cores * sc_info.num_subcores

# This kernel partitions the output's columns into `num_column_partitions` and
# partition the output's rows into `num_row_partitions` and run each
# This kernel partitions the output's columns into `num_column_partitions`
# and partition the output's rows into `num_row_partitions` and run each
# {row_partition} x {column_partition} combination on a separate SC subcore
# for parallelism. With such work partitioning, we guarantee that there won't
# be write collision (from different subcores) to the any output row X column.
# for parallelism. With such work partitioning, we guarantee that there
# won't be write collision (from different subcores) to any output row X
# column.
#
# Each column partition should be multiple of 128 (number of lanes) due to
# DMA requirements. Unless requiring padding on the column dimension, larger
# column partitions (thus smaller row partitions given fixed num_cores) is
# more preferable because large row partition may lead to imbalanced load
# (valid_rows_mask may have more rows in some partitions than others).
# Most LLM's hidden size is multiple of 1024, `num_column_partitions=8` should
# work well in practice without requiring padding on the column size.
# Most LLM's hidden size is multiple of 1024, `num_column_partitions=8`
# should work well in practice without requiring padding on the column size.
num_column_partitions = 8
assert num_cores % num_column_partitions == 0
num_rows_partitions = num_cores // num_column_partitions
Expand Down Expand Up @@ -471,7 +474,6 @@ def ragged_gather_reduce(
dst_indices,
topk_weights,
num_src_rows_per_row_partition,
mask,
) = _preprocess(
indices,
topk_weights,
Expand All @@ -487,8 +489,8 @@ def ragged_gather_reduce(
core_axis_name="core",
subcore_axis_name="subcore",
)
# Each output row from `main_kernel` will be of type float32, and then casted
# to the input dtype when doing the filter operation.
# Each output row from `main_kernel` will be of type float32, and then
# casted to the input dtype when doing the filter operation.
out = pl.kernel( # pytype: disable=wrong-keyword-args
functools.partial(
main_kernel,
Expand Down Expand Up @@ -519,10 +521,4 @@ def ragged_gather_reduce(
},
)(num_src_rows_per_row_partition, x, src_indices, dst_indices, topk_weights)

# If there is no valid source row in a reduce group, set that group's output
# to zero.
return jnp.where(
mask[:, None],
out.astype(x.dtype),
jnp.zeros_like(out, dtype=x.dtype),
)[: (input_size // reduce_group_size), :hidden_size]
return out.astype(x.dtype)
20 changes: 18 additions & 2 deletions src/maxtext/kernels/ragged/ragged_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def ring_ragged_sort(
ep_name,
ep_size,
buffer_size=None,
enforce_gather_fallback=False,
enforce_gather_reduce_fallback=False,
):
"""Ragged-gather variant for AG-RS Expert Parallelism token routing.

Expand Down Expand Up @@ -102,6 +104,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
token_indices_sorted,
shard_output_start[None],
shard_output_end[None],
enforce_fallback=enforce_gather_fallback,
)
else:
local_buffer_size = buffer_size
Expand All @@ -122,6 +125,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
sliced_indices,
jnp.int32(0)[None],
gather_end[None],
enforce_fallback=enforce_gather_fallback,
)

out = (x, group_sizes_local, topk_argsort_revert_indices)
Expand Down Expand Up @@ -173,6 +177,7 @@ def _ring_ragged_sort_bwd(res, g_out):
topk_weights=jnp.ones((n,), dtype=jnp.float32),
valid_rows_mask=valid_rows_mask,
reduce_group_size=topk,
enforce_fallback=enforce_gather_reduce_fallback,
)
else:
# Buffering: g_x has size `local_buffer_size` (packed).
Expand All @@ -195,6 +200,7 @@ def _ring_ragged_sort_bwd(res, g_out):
topk_weights=jnp.ones((n,), dtype=jnp.float32),
valid_rows_mask=valid_rows_mask,
reduce_group_size=topk,
enforce_fallback=enforce_gather_reduce_fallback,
)
return grad_hidden_states, None

Expand All @@ -211,6 +217,8 @@ def ring_ragged_unsort(
local_num_experts,
ep_name,
topk_weights,
enforce_gather_fallback=False,
enforce_gather_reduce_fallback=False,
):
"""Dual of :func:`ring_ragged_sort`.

Expand Down Expand Up @@ -282,6 +290,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort
topk_weights=topk_weights_flat,
valid_rows_mask=valid_rows_mask,
reduce_group_size=topk,
enforce_fallback=enforce_gather_reduce_fallback,
)
else:
# Shift indices so they map to the packed local buffer [0, local_num_tokens).
Expand All @@ -297,6 +306,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort
topk_weights=topk_weights_flat,
valid_rows_mask=valid_rows_mask,
reduce_group_size=topk,
enforce_fallback=enforce_gather_reduce_fallback,
)

res = (
Expand Down Expand Up @@ -352,6 +362,7 @@ def _ring_ragged_unsort_bwd(res, g_out):
shard_output_end[None],
weights=weight_for_sorted,
has_weights=True,
enforce_fallback=enforce_gather_fallback,
)
else:
# Slice the inverse permutation to match the packed local buffer.
Expand All @@ -368,6 +379,7 @@ def _ring_ragged_unsort_bwd(res, g_out):
gather_end[None],
weights=sliced_weights,
has_weights=True,
enforce_fallback=enforce_gather_fallback,
)
return grad_sorted_tokens, None, None, None

Expand All @@ -379,7 +391,7 @@ def _ring_ragged_unsort_bwd(res, g_out):
return _ring_ragged_unsort(sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat)


def a2a_ragged_sort(inputs, sort_indices, valid_end):
def a2a_ragged_sort(inputs, sort_indices, valid_end, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False):
"""Ragged-gather variant for ``local_permute``.

Unlike :func:`ring_ragged_sort`, the rows valid for this shard live in
Expand Down Expand Up @@ -442,6 +454,7 @@ def _a2a_ragged_sort_bwd(res, g_out):
topk_weights=jnp.ones((n,), dtype=jnp.float32),
valid_rows_mask=valid_rows_mask[idx_inv],
reduce_group_size=1,
enforce_fallback=enforce_gather_reduce_fallback,
)
# custom_vjp must return one gradient per primal arg; valid_end is integer
# and non-differentiable, so we return None for it.
Expand All @@ -451,7 +464,9 @@ def _a2a_ragged_sort_bwd(res, g_out):
return _a2a_ragged_sort(inputs, sort_indices, valid_end)


def a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end):
def a2a_ragged_unsort(
sorted_tokens, revert_indices, valid_end, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False
):
"""Dual of :func:`a2a_ragged_sort`.

Forward:
Expand Down Expand Up @@ -492,6 +507,7 @@ def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end):
topk_weights=jnp.ones((n,), dtype=jnp.float32),
valid_rows_mask=valid_rows_mask,
reduce_group_size=1,
enforce_fallback=enforce_gather_reduce_fallback,
)
res = (revert_indices, end, sorted_tokens.shape, start)
return out, res
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,8 @@ def permute(
self._expert_parallelism_name,
num_expert_parallelism,
buffer_size=buffer_size,
enforce_gather_fallback=self.config.ragged_gather_fallback,
enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback,
)
else:
flatten_selected_experts = jnp.ravel(selected_experts)
Expand Down Expand Up @@ -928,6 +930,8 @@ def unpermute(
local_num_experts,
self._expert_parallelism_name,
topk_weights=flat_weights,
enforce_gather_fallback=self.config.ragged_gather_fallback,
enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback,
)
else:
unsort_intermediate = _sort_activations(
Expand Down
Loading
Loading