-
Notifications
You must be signed in to change notification settings - Fork 540
Refactor moe.p: gmm and a2a unsort #4170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1640,6 +1640,84 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): | |
| layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") | ||
| return self.apply_ffn_activation(layer_w0, layer_w1) | ||
|
|
||
| def get_gmm_for_local_experts(x, routing, route_metadata): | ||
| """Return a partial GMM function with preconfigured routing params.""" | ||
| num_ep = self.get_expert_parallelism_size() | ||
| num_experts_per_shard = self.config.num_experts // num_ep | ||
| if self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0]: | ||
| local_group_sizes = routing.local_group_sizes | ||
| return functools.partial( | ||
| gmm, | ||
| group_sizes=local_group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=0, | ||
| ) | ||
| if self.config.use_ragged_sort and self.config.use_ring_of_experts: | ||
| experts_start = route_metadata.expert_shard_id * num_experts_per_shard | ||
| else: | ||
| experts_start = 0 | ||
| return functools.partial( | ||
| gmm, | ||
| group_sizes=routing.group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=experts_start, | ||
| ) | ||
|
|
||
| def unsort_output_and_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert): | ||
| """Unsort tokens and return them to original shards using ragged all-to-all.""" | ||
| if is_batch_sharded_by_expert: | ||
| # locally unpermute back to the original order | ||
| if self.config.use_ragged_sort: | ||
| # Mirror the ragged-prefix gather used in `local_permute`. The | ||
| # un-permute can use the same valid-prefix length because the | ||
| # routed token count is identical for forward and backward. | ||
| valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) | ||
| local_output = a2a_ragged_unsort( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable | ||
| valid_end, | ||
| ) | ||
| else: | ||
| local_output = _sort_activations( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), | ||
| self.config.use_custom_sort_vjp, | ||
| ) | ||
|
|
||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| jnp.transpose(route_metadata.all_shards_group_sizes), | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| ) | ||
| return jax.lax.ragged_all_to_all( | ||
| local_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
|
||
| # If batch is replicated across EP shards then each shard should send | ||
| # 0..local_shard_size data to the other shards and receive the | ||
| # local_shard data from all of the other shards using ragged_all_to_all. | ||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| route_metadata.reshaped_group_sizes, | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| is_batch_sharded=False, | ||
| ) | ||
| return jax.lax.ragged_all_to_all( | ||
| intermediate_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
|
||
| @functools.partial( | ||
| jax.shard_map, | ||
| mesh=self.mesh, | ||
|
|
@@ -1663,36 +1741,16 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): | |
| ), | ||
| check_vma=self.config.check_vma, | ||
| ) | ||
| def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs): | ||
| def sparse_matmul_route_and_compute( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yayay a descriptive name lets go |
||
| x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs | ||
| ): | ||
| batch_size, sequence_length, _ = x.shape | ||
| x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs, input_ids=sharded_input_ids) | ||
|
|
||
| if self.config.mlp_bias: | ||
| w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias) | ||
|
|
||
| num_ep = self.get_expert_parallelism_size() | ||
| num_experts_per_shard = self.config.num_experts // num_ep | ||
|
|
||
| use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] | ||
| if use_truncated_buffer: | ||
| local_group_sizes = routing.local_group_sizes | ||
| gmm_fn = functools.partial( | ||
| gmm, | ||
| group_sizes=local_group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=0, | ||
| ) | ||
| else: | ||
| if self.config.use_ragged_sort and self.config.use_ring_of_experts: | ||
| experts_start = route_metadata.expert_shard_id * num_experts_per_shard | ||
| else: | ||
| experts_start = 0 | ||
| gmm_fn = functools.partial( | ||
| gmm, | ||
| group_sizes=routing.group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=experts_start, | ||
| ) | ||
| gmm_fn = get_gmm_for_local_experts(x, routing, route_metadata) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks great! Ideally this main function (sparse_matmul_route_and_compute) is as small and easy to read as possible - all annoying details hidden in small functions like this one! |
||
| intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather) | ||
|
|
||
| wo_gather_axes, wo_tile_size = get_wo_gmm_params() | ||
|
|
@@ -1727,83 +1785,38 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s | |
| output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size()) | ||
| ) | ||
| output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) | ||
| return output, routing.lb_loss, routing.bias_updates | ||
|
|
||
| if self.get_expert_parallelism_size() > 1: | ||
| original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok | ||
| if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: | ||
| raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") | ||
| output_shape = jax.lax.empty( | ||
| ( | ||
| original_inputs_first_dim, | ||
| self.moe_expert_input_dim // self.get_tensor_parallelism_size(), | ||
| ), | ||
| dtype=intermediate_output.dtype, | ||
| ) | ||
|
|
||
| else: | ||
| if self.get_expert_parallelism_size() > 1: | ||
| original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok | ||
| if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: | ||
| raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") | ||
| output_shape = jax.lax.empty( | ||
| ( | ||
| original_inputs_first_dim, | ||
| self.moe_expert_input_dim // self.get_tensor_parallelism_size(), | ||
| ), | ||
| dtype=intermediate_output.dtype, | ||
| ) | ||
|
|
||
| if is_batch_sharded_by_expert: | ||
| # locally unpermute back to the original order | ||
| if self.config.use_ragged_sort: | ||
| # Mirror the ragged-prefix gather used in `local_permute`. The | ||
| # un-permute can use the same valid-prefix length because the | ||
| # routed token count is identical for forward and backward. | ||
| valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) | ||
| local_output = a2a_ragged_unsort( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable | ||
| valid_end, | ||
| ) | ||
| else: | ||
| local_output = _sort_activations( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), | ||
| self.config.use_custom_sort_vjp, | ||
| ) | ||
|
|
||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| jnp.transpose(route_metadata.all_shards_group_sizes), | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| ) | ||
| intermediate_output = jax.lax.ragged_all_to_all( | ||
| local_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
| else: | ||
| # If batch is replicated across EP shards then each shard should send | ||
| # 0..local_shard_size data to the other shards and receive the | ||
| # local_shard data from all of the other shards using ragged_all_to_all. | ||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| route_metadata.reshaped_group_sizes, | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| is_batch_sharded=False, | ||
| ) | ||
| intermediate_output = jax.lax.ragged_all_to_all( | ||
| intermediate_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
|
||
| output = self.unpermute( | ||
| intermediate_output = unsort_output_and_ra2a( | ||
| intermediate_output, | ||
| routing.sorted_selected_experts, | ||
| routing.weights, | ||
| batch_size=batch_size, | ||
| sequence_length=sequence_length, | ||
| use_custom_sort_vjp=self.config.use_custom_sort_vjp, | ||
| group_sizes=routing.group_sizes, | ||
| routing, | ||
| route_metadata, | ||
| output_shape, | ||
| is_batch_sharded_by_expert, | ||
| ) | ||
|
|
||
| output = self.unpermute( | ||
| intermediate_output, | ||
| routing.sorted_selected_experts, | ||
| routing.weights, | ||
| batch_size=batch_size, | ||
| sequence_length=sequence_length, | ||
| use_custom_sort_vjp=self.config.use_custom_sort_vjp, | ||
| group_sizes=routing.group_sizes, | ||
| ) | ||
|
|
||
| return output, routing.lb_loss, routing.bias_updates | ||
|
|
||
| if self.config.moe_fsdp_use_two_stage_all_gather: | ||
|
|
@@ -1851,7 +1864,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s | |
| if wo_bias is not None: | ||
| wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) | ||
|
|
||
| return wrapper( | ||
| return sparse_matmul_route_and_compute( | ||
| inputs, | ||
| gate_logits, | ||
| pre_bias_logits, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm why ragged_all_to_all show up twice, one in a function and one outside the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the second call? Do you mean line 1693? This ra2a is still within the
unsort_output_with_ra2afunction and is only invoked when is_batch_sharded_by_expert is not true