-
Notifications
You must be signed in to change notification settings - Fork 66
[PR1] Port output distribution classes to DynamicEmb #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+274
−0
Closed
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
df835b9
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 7f04661
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 0906186
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 1aa415d
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 5b1e9ba
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 18196f2
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH 44da9db
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH c5c82d5
[PR1] Port output distribution classes to DynamicEmb
ShaobinChen-AH File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
|
|
||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # pyre-strict | ||
|
|
||
| """ | ||
| Output distribution classes for DynamicEmb row-wise sharding. | ||
|
|
||
| This module provides optimized output distribution implementations for: | ||
| - RwSequenceEmbeddingDist: for sequence (unpooled) embeddings | ||
| - RwPooledEmbeddingDist: for pooled embeddings | ||
|
|
||
| The key optimization is in the unbucketize_permute operation, which is slow | ||
| in the original TorchRec implementation, especially for non-contiguous | ||
| distribution patterns (e.g., round-robin). | ||
| """ | ||
|
|
||
| from typing import Awaitable, Dict, List, Optional, cast | ||
|
ShaobinChen-AH marked this conversation as resolved.
Outdated
ShaobinChen-AH marked this conversation as resolved.
Outdated
|
||
| import torch | ||
| from torch import distributed as dist | ||
| from torchrec.distributed.types import CommOp | ||
| from torchrec.distributed.dist_data import ( | ||
| PooledEmbeddingsReduceScatter, | ||
| SequenceEmbeddingsAllToAll, | ||
| VariableBatchPooledEmbeddingsReduceScatter, | ||
| ) | ||
|
|
||
| from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext | ||
| from torchrec.distributed.embedding_sharding import BaseEmbeddingDist, EmbeddingShardingContext | ||
| from torchrec.distributed.types import QuantizedCommCodecs | ||
|
|
||
|
|
||
| class RwSequenceEmbeddingDist( | ||
| BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] | ||
| ): | ||
| """ | ||
| Redistributes sequence embedding tensor in RW fashion with an AlltoAll operation. | ||
|
|
||
| This is a customized version for DynamicEmb that can be optimized for | ||
| non-contiguous distribution patterns (e.g., round-robin). | ||
|
|
||
| Args: | ||
| pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. | ||
| num_features (int): total number of features. | ||
| device (Optional[torch.device]): device on which buffers will be allocated. | ||
| qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]): | ||
| quantized communication codecs registry. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pg: dist.ProcessGroup, | ||
| num_features: int, | ||
| device: Optional[torch.device] = None, | ||
| qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._pg = pg | ||
| self._num_features = num_features | ||
| self._device = device | ||
|
|
||
| self._dist = SequenceEmbeddingsAllToAll( | ||
| pg, | ||
| [num_features] * pg.size(), | ||
| device, | ||
| codecs=( | ||
| qcomm_codecs_registry.get( | ||
| CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None | ||
| ) | ||
| if qcomm_codecs_registry | ||
| else None | ||
| ), | ||
| ) | ||
|
|
||
| # Store unbucketize_permute_tensor for potential optimization | ||
| self._unbucketize_permute_tensor: Optional[torch.Tensor] = None | ||
|
ShaobinChen-AH marked this conversation as resolved.
Outdated
|
||
|
|
||
| def forward( | ||
| self, | ||
| local_embs: torch.Tensor, | ||
| sharding_ctx: Optional[SequenceShardingContext] = None, | ||
| ) -> torch.Tensor: | ||
|
ShaobinChen-AH marked this conversation as resolved.
|
||
| """ | ||
| Performs AlltoAll operation on sequence embeddings tensor. | ||
|
|
||
| Args: | ||
| local_embs (torch.Tensor): tensor of values to distribute. | ||
| sharding_ctx (SequenceShardingContext): shared context from KJTAllToAll | ||
| operation. | ||
|
|
||
| Returns: | ||
| torch.Tensor: sequence embeddings after distribution. | ||
| """ | ||
| assert sharding_ctx is not None | ||
|
ShaobinChen-AH marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Store unbucketize_permute_tensor for potential optimization | ||
| self._unbucketize_permute_tensor = sharding_ctx.unbucketize_permute_tensor | ||
|
|
||
| # TODO: Optimize unbucketize_permute operation here | ||
| # The unbucketize_permute_tensor is used in SequenceEmbeddingsAwaitable | ||
| # to reorder the output. For non-contiguous distribution (round-robin), | ||
| # this operation is slow and can be optimized with custom CUDA kernels. | ||
|
|
||
| result = self._dist( | ||
| local_embs, | ||
| lengths=sharding_ctx.lengths_after_input_dist, | ||
| input_splits=sharding_ctx.input_splits, | ||
| output_splits=sharding_ctx.output_splits, | ||
| batch_size_per_rank=sharding_ctx.batch_size_per_rank, | ||
| sparse_features_recat=sharding_ctx.sparse_features_recat, | ||
| unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor, | ||
| ) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| class RwPooledEmbeddingDist( | ||
| BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] | ||
| ): | ||
| """ | ||
| Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter | ||
| operation. | ||
|
|
||
| Args: | ||
| pg (dist.ProcessGroup): ProcessGroup for reduce-scatter communication. | ||
| embedding_dims (List[int]): embedding dimensions per feature. | ||
| qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]): | ||
| quantized communication codecs registry. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pg: dist.ProcessGroup, | ||
| embedding_dims: List[int], | ||
| qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._pg = pg | ||
| self._embedding_dims = embedding_dims | ||
| self._qcomm_codecs_registry = qcomm_codecs_registry | ||
|
ShaobinChen-AH marked this conversation as resolved.
Outdated
|
||
|
|
||
| self._dist: Optional[ | ||
| Union[ | ||
| PooledEmbeddingsReduceScatter, | ||
| VariableBatchPooledEmbeddingsReduceScatter, | ||
| ] | ||
| ] = None | ||
|
|
||
| self._codecs: Optional[QuantizedCommCodecs] = ( | ||
| qcomm_codecs_registry.get( | ||
| CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None | ||
| ) | ||
| if qcomm_codecs_registry | ||
| else None | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| local_embs: torch.Tensor, | ||
| sharding_ctx: Optional[EmbeddingShardingContext] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Performs reduce-scatter pooled operation on pooled embeddings tensor. | ||
|
|
||
| Args: | ||
| local_embs (torch.Tensor): pooled embeddings tensor to distribute. | ||
| sharding_ctx (Optional[EmbeddingShardingContext]): shared context from | ||
| KJTAllToAll operation. | ||
|
|
||
| Returns: | ||
| torch.Tensor: pooled embeddings tensor after distribution. | ||
| """ | ||
| if self._dist is None: | ||
| self._create_output_dist_module(sharding_ctx) | ||
|
|
||
| if sharding_ctx is None: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) | ||
| elif sharding_ctx.variable_batch_per_feature: | ||
| return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, | ||
| embedding_dims=self._embedding_dims, | ||
| ) | ||
| else: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| input_splits=sharding_ctx.batch_size_per_rank, | ||
| ) | ||
|
ShaobinChen-AH marked this conversation as resolved.
|
||
|
|
||
| def _create_output_dist_module( | ||
| self, sharding_ctx: Optional[EmbeddingShardingContext] = None | ||
| ) -> None: | ||
| """Create the appropriate output distribution module based on context.""" | ||
| if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: | ||
| self._dist = VariableBatchPooledEmbeddingsReduceScatter( | ||
| pg=self._pg, | ||
| codecs=self._codecs, | ||
| ) | ||
| else: | ||
| self._dist = PooledEmbeddingsReduceScatter( | ||
| pg=self._pg, | ||
| codecs=self._codecs, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.