diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..244730ba8d 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -222,6 +222,10 @@ use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the Mo # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops +ragged_gather_fallback: false # when true, unconditionally use the JAX reference implementation instead of the + # ragged gather SparseCore kernel. When false (default), use the SparseCore kernel. +ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the + # ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel. # tunable tiling dimensions used for mlp gmm # megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`) # tokamax ragged dot - supports all 18 configs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..7f96e3784e 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -743,6 +743,16 @@ class MoEGeneral(BaseModel): False, description="Whether to use a custom mosaic kernel for token gather ops.", ) + ragged_gather_fallback: bool = Field( + False, + description="When true, unconditionally use the JAX reference implementation instead of the ragged gather " + "SparseCore kernel. When false (default), use the SparseCore kernel.", + ) + ragged_gather_reduce_fallback: bool = Field( + False, + description="When true, unconditionally use the JAX reference implementation instead of the ragged gather " + "reduce SparseCore kernel. When false (default), use the SparseCore kernel.", + ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") moe_fsdp_use_two_stage_all_gather: bool = Field( diff --git a/src/maxtext/kernels/ragged/ragged_gather.py b/src/maxtext/kernels/ragged/ragged_gather.py index 5cb18dba83..174507e2d0 100644 --- a/src/maxtext/kernels/ragged/ragged_gather.py +++ b/src/maxtext/kernels/ragged/ragged_gather.py @@ -16,7 +16,6 @@ # Source from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/sparse_core/ragged_gather.py import functools - import jax import jax.numpy as jnp from jax.experimental import pallas as pl @@ -24,15 +23,23 @@ from jax.experimental.pallas import tpu_sc as plsc from packaging.version import Version - # JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later # versions renamed them to `out_type`/`scratch_types`. if Version(jax.__version__) <= Version("0.10.0"): _OUT_KW = "out_shape" _SCRATCH_KW = "scratch_shapes" + _COMPILER_PARAMS = { + "use_tc_tiling_on_sc": True, + "disable_bounds_checks": True, + } else: _OUT_KW = "out_type" _SCRATCH_KW = "scratch_types" + _COMPILER_PARAMS = { + "use_tc_tiling_on_sc": True, + "disable_bounds_checks": True, + "needs_layout_passes": False, + } def main_kernel( @@ -264,6 +271,19 @@ def dma_write_loop(col_vmem_start): inner_kernel() +def _fallback_implementation( + x: jax.Array, + indices: jax.Array, + weights: jax.Array | None = None, + has_weights: bool = False, +) -> jax.Array: + """Fallback to (non-ragged) JAX implementation for ragged gather.""" + out = x[indices] + if has_weights: + out = out * weights[:, None] + return out + + def calculate_col_size(hidden_size: int) -> int: """Calculate col size for ragged gather kernel.""" tpu_info = pltpu.get_tpu_info() @@ -288,7 +308,7 @@ def calculate_col_size(hidden_size: int) -> int: return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes -@functools.partial(jax.jit, static_argnames=("has_weights",)) +@functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback")) def ragged_gather( x: jax.Array, indices: jax.Array, @@ -296,6 +316,7 @@ def ragged_gather( end: jax.Array, weights: jax.Array | None = None, has_weights: bool = False, + enforce_fallback: bool = False, ) -> jax.Array: """Perform gather on indices within dynamic array start and end. @@ -309,6 +330,9 @@ def ragged_gather( kernel, avoiding an extra HBM read-write pass. has_weights: Static bool flag indicating whether ``weights`` should be applied. Must be ``True`` when ``weights`` is not ``None``. + enforce_fallback: Static bool flag. When ``True``, unconditionally use the + JAX reference implementation instead of the SparseCore kernel. + When ``False`` (default), use the SparseCore kernel and raise any error. Returns: Gathered output of shape ``(indices_size, hidden_size)``. @@ -331,12 +355,9 @@ def ragged_gather( dtype = x.dtype sc_info = pltpu.get_tpu_info().sparse_core - if sc_info is None: - # Sparse core is not available. Fallback to regular gather. - out = x[indices] - if has_weights: - out = out * weights[:, None] - return out + if sc_info is None or enforce_fallback: + # Sparse core is not available or fallback is enforced. Use JAX reference. + return _fallback_implementation(x, indices, weights, has_weights) hidden_size = x.shape[-1] out_size = indices.size @@ -371,9 +392,8 @@ def ragged_gather( subcore_axis_name=vector_mesh.subcore_axis_name, has_weights=has_weights, ), - compiler_params=pltpu.CompilerParams( - use_tc_tiling_on_sc=True, - disable_bounds_checks=True, + compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args + **_COMPILER_PARAMS, ), mesh=vector_mesh, name="sc_ragged_gather", diff --git a/src/maxtext/kernels/ragged/ragged_gather_reduce.py b/src/maxtext/kernels/ragged/ragged_gather_reduce.py index 28f59996a8..8c0dca9cbb 100644 --- a/src/maxtext/kernels/ragged/ragged_gather_reduce.py +++ b/src/maxtext/kernels/ragged/ragged_gather_reduce.py @@ -24,6 +24,7 @@ import jax.numpy as jnp from packaging.version import Version + # JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later # versions renamed them to `out_type`/`scratch_types`. if Version(jax.__version__) <= Version("0.10.0"): @@ -360,26 +361,23 @@ def _preprocess( num_src_rows_per_row_partition.astype(jnp.int32), (0, num_simd_lanes - num_row_partitions), ) - # If there is no valid source row in a reduce group, we set the mask to - # False, so that the output for that group is set to zero. - mask = jnp.any(valid_rows_mask.reshape(-1, reduce_group_size), axis=-1) return ( src_indices, dst_indices, topk_weights, num_src_rows_per_row_partition, - mask, ) -@functools.partial(jax.jit, static_argnames=("reduce_group_size",)) +@functools.partial(jax.jit, static_argnames=("reduce_group_size", "enforce_fallback")) def ragged_gather_reduce( x: jax.Array, indices: jax.Array, topk_weights: jax.Array, valid_rows_mask: jax.Array, reduce_group_size: int, + enforce_fallback: bool = False, ) -> jax.Array: """Gathers `x` according to `indices`, applies weights and masks, and reduces. @@ -402,6 +400,9 @@ def ragged_gather_reduce( valid, with shape `(input_size,)`. reduce_group_size: An integer representing the number of consecutive rows to reduce (sum) together. + enforce_fallback: Static bool flag. When ``True``, unconditionally use the + JAX reference implementation instead of the SparseCore kernel. + When ``False`` (default), use the SparseCore kernel and raise any error. Returns: A 2D JAX array of reduced data with shape @@ -414,7 +415,8 @@ def ragged_gather_reduce( assert valid_rows_mask.ndim == 1, "ragged_gather_reduce only supports 1d valid_rows_mask." sc_info = pltpu.get_tpu_info().sparse_core - if sc_info is None: + if sc_info is None or enforce_fallback: + # Sparse core is not available or fallback is enforced. Use JAX reference. return _fallback_implementation(x, indices, topk_weights, valid_rows_mask, reduce_group_size) # Heuristic threshold on whether to fallback for small inputs. @@ -431,19 +433,20 @@ def ragged_gather_reduce( num_simd_lanes = sc_info.num_lanes num_cores = sc_info.num_cores * sc_info.num_subcores - # This kernel partitions the output's columns into `num_column_partitions` and - # partition the output's rows into `num_row_partitions` and run each + # This kernel partitions the output's columns into `num_column_partitions` + # and partition the output's rows into `num_row_partitions` and run each # {row_partition} x {column_partition} combination on a separate SC subcore - # for parallelism. With such work partitioning, we guarantee that there won't - # be write collision (from different subcores) to the any output row X column. + # for parallelism. With such work partitioning, we guarantee that there + # won't be write collision (from different subcores) to any output row X + # column. # # Each column partition should be multiple of 128 (number of lanes) due to # DMA requirements. Unless requiring padding on the column dimension, larger # column partitions (thus smaller row partitions given fixed num_cores) is # more preferable because large row partition may lead to imbalanced load # (valid_rows_mask may have more rows in some partitions than others). - # Most LLM's hidden size is multiple of 1024, `num_column_partitions=8` should - # work well in practice without requiring padding on the column size. + # Most LLM's hidden size is multiple of 1024, `num_column_partitions=8` + # should work well in practice without requiring padding on the column size. num_column_partitions = 8 assert num_cores % num_column_partitions == 0 num_rows_partitions = num_cores // num_column_partitions @@ -471,7 +474,6 @@ def ragged_gather_reduce( dst_indices, topk_weights, num_src_rows_per_row_partition, - mask, ) = _preprocess( indices, topk_weights, @@ -487,8 +489,8 @@ def ragged_gather_reduce( core_axis_name="core", subcore_axis_name="subcore", ) - # Each output row from `main_kernel` will be of type float32, and then casted - # to the input dtype when doing the filter operation. + # Each output row from `main_kernel` will be of type float32, and then + # casted to the input dtype when doing the filter operation. out = pl.kernel( # pytype: disable=wrong-keyword-args functools.partial( main_kernel, @@ -519,10 +521,4 @@ def ragged_gather_reduce( }, )(num_src_rows_per_row_partition, x, src_indices, dst_indices, topk_weights) - # If there is no valid source row in a reduce group, set that group's output - # to zero. - return jnp.where( - mask[:, None], - out.astype(x.dtype), - jnp.zeros_like(out, dtype=x.dtype), - )[: (input_size // reduce_group_size), :hidden_size] + return out.astype(x.dtype) diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index 6f1ccc64f6..67fc12da94 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -28,6 +28,8 @@ def ring_ragged_sort( ep_name, ep_size, buffer_size=None, + enforce_gather_fallback=False, + enforce_gather_reduce_fallback=False, ): """Ragged-gather variant for AG-RS Expert Parallelism token routing. @@ -102,6 +104,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): token_indices_sorted, shard_output_start[None], shard_output_end[None], + enforce_fallback=enforce_gather_fallback, ) else: local_buffer_size = buffer_size @@ -122,6 +125,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): sliced_indices, jnp.int32(0)[None], gather_end[None], + enforce_fallback=enforce_gather_fallback, ) out = (x, group_sizes_local, topk_argsort_revert_indices) @@ -173,6 +177,7 @@ def _ring_ragged_sort_bwd(res, g_out): topk_weights=jnp.ones((n,), dtype=jnp.float32), valid_rows_mask=valid_rows_mask, reduce_group_size=topk, + enforce_fallback=enforce_gather_reduce_fallback, ) else: # Buffering: g_x has size `local_buffer_size` (packed). @@ -195,6 +200,7 @@ def _ring_ragged_sort_bwd(res, g_out): topk_weights=jnp.ones((n,), dtype=jnp.float32), valid_rows_mask=valid_rows_mask, reduce_group_size=topk, + enforce_fallback=enforce_gather_reduce_fallback, ) return grad_hidden_states, None @@ -211,6 +217,8 @@ def ring_ragged_unsort( local_num_experts, ep_name, topk_weights, + enforce_gather_fallback=False, + enforce_gather_reduce_fallback=False, ): """Dual of :func:`ring_ragged_sort`. @@ -282,6 +290,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort topk_weights=topk_weights_flat, valid_rows_mask=valid_rows_mask, reduce_group_size=topk, + enforce_fallback=enforce_gather_reduce_fallback, ) else: # Shift indices so they map to the packed local buffer [0, local_num_tokens). @@ -297,6 +306,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort topk_weights=topk_weights_flat, valid_rows_mask=valid_rows_mask, reduce_group_size=topk, + enforce_fallback=enforce_gather_reduce_fallback, ) res = ( @@ -352,6 +362,7 @@ def _ring_ragged_unsort_bwd(res, g_out): shard_output_end[None], weights=weight_for_sorted, has_weights=True, + enforce_fallback=enforce_gather_fallback, ) else: # Slice the inverse permutation to match the packed local buffer. @@ -368,6 +379,7 @@ def _ring_ragged_unsort_bwd(res, g_out): gather_end[None], weights=sliced_weights, has_weights=True, + enforce_fallback=enforce_gather_fallback, ) return grad_sorted_tokens, None, None, None @@ -379,7 +391,7 @@ def _ring_ragged_unsort_bwd(res, g_out): return _ring_ragged_unsort(sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat) -def a2a_ragged_sort(inputs, sort_indices, valid_end): +def a2a_ragged_sort(inputs, sort_indices, valid_end, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False): """Ragged-gather variant for ``local_permute``. Unlike :func:`ring_ragged_sort`, the rows valid for this shard live in @@ -442,6 +454,7 @@ def _a2a_ragged_sort_bwd(res, g_out): topk_weights=jnp.ones((n,), dtype=jnp.float32), valid_rows_mask=valid_rows_mask[idx_inv], reduce_group_size=1, + enforce_fallback=enforce_gather_reduce_fallback, ) # custom_vjp must return one gradient per primal arg; valid_end is integer # and non-differentiable, so we return None for it. @@ -451,7 +464,9 @@ def _a2a_ragged_sort_bwd(res, g_out): return _a2a_ragged_sort(inputs, sort_indices, valid_end) -def a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end): +def a2a_ragged_unsort( + sorted_tokens, revert_indices, valid_end, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False +): """Dual of :func:`a2a_ragged_sort`. Forward: @@ -492,6 +507,7 @@ def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end): topk_weights=jnp.ones((n,), dtype=jnp.float32), valid_rows_mask=valid_rows_mask, reduce_group_size=1, + enforce_fallback=enforce_gather_reduce_fallback, ) res = (revert_indices, end, sorted_tokens.shape, start) return out, res diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..1905a55679 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -849,6 +849,8 @@ def permute( self._expert_parallelism_name, num_expert_parallelism, buffer_size=buffer_size, + enforce_gather_fallback=self.config.ragged_gather_fallback, + enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback, ) else: flatten_selected_experts = jnp.ravel(selected_experts) @@ -928,6 +930,8 @@ def unpermute( local_num_experts, self._expert_parallelism_name, topk_weights=flat_weights, + enforce_gather_fallback=self.config.ragged_gather_fallback, + enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback, ) else: unsort_intermediate = _sort_activations( diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 18c675948d..0a37bc36b6 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -600,7 +600,13 @@ def test_ring_of_expert_and_tensor_parallelism(self): actual_output, _, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) - def _run_ragged_sort_loss_and_grad(self, use_ring_of_experts: bool, ragged_buffer_factor: float = -1.0): + def _run_ragged_sort_loss_and_grad( + self, + use_ring_of_experts: bool, + ragged_buffer_factor: float = -1.0, + ragged_gather_fallback: bool = False, + ragged_gather_reduce_fallback: bool = False, + ): """Loss and gradient correctness for the use_ragged_sort flag. Compares an EP run with use_ragged_sort=True against the same @@ -630,6 +636,8 @@ def _build_cfg(use_ragged_sort: bool): max_target_length=128, use_ragged_sort=use_ragged_sort, ragged_buffer_factor=effective_buffer_factor, + ragged_gather_fallback=ragged_gather_fallback, + ragged_gather_reduce_fallback=ragged_gather_reduce_fallback, ) def _build_model(cfg, mesh): @@ -723,6 +731,13 @@ def test_ragged_sort_loss_and_grad_ring_of_experts(self): def test_ragged_sort_loss_and_grad_ring_of_experts_ragged_buffer(self): self._run_ragged_sort_loss_and_grad(use_ring_of_experts=True, ragged_buffer_factor=1.5) + @pytest.mark.tpu_only + @pytest.mark.skip_on_tpu7x + def test_ragged_sort_loss_and_grad_ring_of_experts_fallback(self): + self._run_ragged_sort_loss_and_grad( + use_ring_of_experts=True, ragged_gather_fallback=True, ragged_gather_reduce_fallback=True + ) + @pytest.mark.tpu_only @pytest.mark.skip_on_tpu7x def test_ragged_sort_loss_and_grad_no_ring_of_experts(self): @@ -733,6 +748,13 @@ def test_ragged_sort_loss_and_grad_no_ring_of_experts(self): def test_ragged_sort_loss_and_grad_no_ring_of_experts_ragged_buffer(self): self._run_ragged_sort_loss_and_grad(use_ring_of_experts=False, ragged_buffer_factor=1.5) + @pytest.mark.tpu_only + @pytest.mark.skip_on_tpu7x + def test_ragged_sort_loss_and_grad_no_ring_of_experts_fallback(self): + self._run_ragged_sort_loss_and_grad( + use_ring_of_experts=False, ragged_gather_fallback=True, ragged_gather_reduce_fallback=True + ) + @pytest.mark.tpu_only def test_moe_fsdp_two_stage_parallelism_tpu_only(self): # Use an imperative skip inside the test method instead of a static decorator.