Refactor moe.p: gmm and a2a unsort#4170
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| num_ep = self.get_expert_parallelism_size() | ||
| num_experts_per_shard = self.config.num_experts // num_ep | ||
| use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] | ||
| if use_truncated_buffer: |
There was a problem hiding this comment.
maybe we can merge line 1647 and 1648?
| group_offset=experts_start, | ||
| ) | ||
|
|
||
| def unsort_output_with_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert): |
There was a problem hiding this comment.
This function includes both unsort and ra2a.. Maybe we should name it unsort_output_and_ra2a?
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) |
There was a problem hiding this comment.
emmm why ragged_all_to_all show up twice, one in a function and one outside the function?
There was a problem hiding this comment.
Where is the second call? Do you mean line 1693? This ra2a is still within the unsort_output_with_ra2a function and is only invoked when is_batch_sharded_by_expert is not true
| check_vma=self.config.check_vma, | ||
| ) | ||
| def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs): | ||
| def sparse_matmul_route_and_compute( |
There was a problem hiding this comment.
yayay a descriptive name lets go
| expert_assignments=routing.selected_experts, | ||
| group_offset=experts_start, | ||
| ) | ||
| gmm_fn = get_gmm_for_local_experts(x, routing, route_metadata) |
There was a problem hiding this comment.
This looks great! Ideally this main function (sparse_matmul_route_and_compute) is as small and easy to read as possible - all annoying details hidden in small functions like this one!
gobbleturk
left a comment
There was a problem hiding this comment.
This looks great! Just need another 5 or 6 refactors like this
Description
This is the second PR refactoring sparse_matmul to make chunking activations and future features easier to implement.
Major changes:
get_gmm_for_local_expertsTests
Verified loss and perplexity is identical on main vs refactor branch after 20 train steps: loss: 12.259, perplexity: 210794.859 for both.
commands to reproduce:
Full table of correctness test results:
| Mode | Ring of Experts | EP Size | FSDP Size | Loss (
main) | Loss (refactor) | Perp (main) | Perp (refactor) | Tok/s/Dev(
main) | Tok/s/Dev (refactor) | Command Differences ||---|---|---|---|---|---|---|---|---|---|---|
| Sparse | True | 8 | 1 | 12.259 | 12.259 | 210794.859 | 210794.859 | 134,312 | 135,845 |
sparse_matmul=True use_ring_of_experts=True ici_expert_parallelism=-1|| Sparse | False | 8 | 1 | 12.259 | 12.259 | 210800.438 | 210800.438 | 165,695 | 166,883 |
sparse_matmul=True use_ring_of_experts=False ici_expert_parallelism=-1|| Sparse | False | 1 | 8 | 12.259 | 12.259 | 210867.609 | 210867.609 | 166,802 | 179,712 |
sparse_matmul=True use_ring_of_experts=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1|| Dense | False | 1 | 8 | 12.259 | 12.259 | 210843.078 | 210843.078 | 178,025 | 161,234 |
sparse_matmul=False use_ring_of_experts=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1|| Dense | False | 8 | 1 | 12.259 | 12.259 | 210809.906 | 210809.906 | 177,469 | 164,076 |
sparse_matmul=False use_ring_of_experts=False ici_expert_parallelism=-1 ici_fsdp_parallelism=1|Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.