-
Notifications
You must be signed in to change notification settings - Fork 66
Port output distribution classes to DynamicEmb #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| from typing import Dict, List, Optional, Union, cast | ||
|
|
||
| import torch | ||
| from torch import distributed as dist | ||
| from torchrec.distributed.dist_data import ( | ||
| PooledEmbeddingsReduceScatter, | ||
| SequenceEmbeddingsAllToAll, | ||
| VariableBatchPooledEmbeddingsReduceScatter, | ||
| ) | ||
| from torchrec.distributed.embedding_sharding import ( | ||
| BaseEmbeddingDist, | ||
| EmbeddingShardingContext, | ||
| ) | ||
| from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext | ||
| from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs | ||
|
|
||
|
|
||
| class RwSequenceEmbeddingDist( | ||
| BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] | ||
| ): | ||
| """ | ||
| Redistributes sequence embedding tensor in RW fashion with an AlltoAll operation. | ||
|
|
||
| Args: | ||
| pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. | ||
| num_features (int): total number of features. | ||
| device (Optional[torch.device]): device on which buffers will be allocated. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pg: dist.ProcessGroup, | ||
| num_features: int, | ||
| device: Optional[torch.device] = None, | ||
| qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._dist = SequenceEmbeddingsAllToAll( | ||
| pg, | ||
| [num_features] * pg.size(), | ||
| device, | ||
| codecs=( | ||
| qcomm_codecs_registry.get( | ||
| CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None | ||
| ) | ||
| if qcomm_codecs_registry | ||
| else None | ||
| ), | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| local_embs: torch.Tensor, | ||
| sharding_ctx: Optional[SequenceShardingContext] = None, | ||
| ) -> Awaitable[torch.Tensor]: | ||
| """ | ||
| Performs AlltoAll operation on sequence embeddings tensor. | ||
|
|
||
| Args: | ||
| local_embs (torch.Tensor): tensor of values to distribute. | ||
| sharding_ctx (SequenceShardingContext): shared context from KJTAllToAll | ||
| operation. | ||
|
|
||
| Returns: | ||
| Awaitable[torch.Tensor]: awaitable of sequence embeddings. | ||
| """ | ||
| assert sharding_ctx is not None | ||
| return self._dist( | ||
| local_embs, | ||
| lengths=sharding_ctx.lengths_after_input_dist, | ||
| input_splits=sharding_ctx.input_splits, | ||
| output_splits=sharding_ctx.output_splits, | ||
| batch_size_per_rank=sharding_ctx.batch_size_per_rank, | ||
| sparse_features_recat=sharding_ctx.sparse_features_recat, | ||
| unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor, | ||
| ) | ||
|
|
||
|
|
||
| class RwPooledEmbeddingDist( | ||
| BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] | ||
| ): | ||
| """ | ||
| Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter | ||
| operation. | ||
|
|
||
| Args: | ||
| pg (dist.ProcessGroup): ProcessGroup for reduce-scatter communication. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pg: dist.ProcessGroup, | ||
| embedding_dims: List[int], | ||
| qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self._dist: Optional[ | ||
| Union[ | ||
| PooledEmbeddingsReduceScatter, | ||
| VariableBatchPooledEmbeddingsReduceScatter, | ||
| ] | ||
| ] = None | ||
| self._pg = pg | ||
| self._qcomm_codecs_registry = qcomm_codecs_registry | ||
| self._codecs: Optional[QuantizedCommCodecs] = ( | ||
| qcomm_codecs_registry.get( | ||
| CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None | ||
| ) | ||
| if qcomm_codecs_registry | ||
| else None | ||
| ) | ||
| self._embedding_dims = embedding_dims | ||
|
|
||
| def forward( | ||
| self, | ||
| local_embs: torch.Tensor, | ||
| sharding_ctx: Optional[EmbeddingShardingContext] = None, | ||
| ) -> Awaitable[torch.Tensor]: | ||
| """ | ||
| Performs reduce-scatter pooled operation on pooled embeddings tensor. | ||
|
|
||
| Args: | ||
| local_embs (torch.Tensor): pooled embeddings tensor to distribute. | ||
| sharding_ctx (Optional[EmbeddingShardingContext]): shared context from | ||
| KJTAllToAll operation. | ||
|
|
||
| Returns: | ||
| Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. | ||
| """ | ||
| if self._dist is None: | ||
| self._create_output_dist_module(sharding_ctx) | ||
|
|
||
| if sharding_ctx is None: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) | ||
| elif sharding_ctx.variable_batch_per_feature: | ||
| return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, | ||
| embedding_dims=self._embedding_dims, | ||
| ) | ||
| else: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| input_splits=sharding_ctx.batch_size_per_rank, | ||
| ) | ||
|
|
||
| def _create_output_dist_module( | ||
| self, sharding_ctx: Optional[EmbeddingShardingContext] = None | ||
| ) -> None: | ||
| if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: | ||
| self._dist = VariableBatchPooledEmbeddingsReduceScatter( | ||
| pg=self._pg, | ||
| codecs=self._codecs, | ||
| ) | ||
| else: | ||
| self._dist = PooledEmbeddingsReduceScatter( | ||
| pg=self._pg, | ||
| codecs=self._codecs, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,9 @@ | |
| GroupedPooledEmbeddingsLookup as _GroupedPooledEmbeddingsLookup, | ||
| ) | ||
| from torchrec.distributed.embedding_sharding import ( | ||
| BaseEmbeddingDist, | ||
| BaseSparseFeaturesDist, | ||
| EmbeddingShardingContext, | ||
| EmbeddingShardingInfo, | ||
| ) | ||
| from torchrec.distributed.embedding_types import ( | ||
|
|
@@ -42,6 +44,7 @@ | |
| RwSequenceEmbeddingSharding, | ||
| ) | ||
| from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding | ||
| from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext | ||
| from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv, ShardingType | ||
| from torchrec.sparse.jagged_tensor import KeyedJaggedTensor | ||
|
|
||
|
|
@@ -50,6 +53,7 @@ | |
| BatchedDynamicEmbeddingBag, | ||
| ) | ||
| from ..input_dist import RwSparseFeaturesDist | ||
| from ..output_dist import RwPooledEmbeddingDist, RwSequenceEmbeddingDist | ||
|
|
||
|
|
||
| class GroupedEmbeddingsLookup(_GroupedEmbeddingsLookup): | ||
|
|
@@ -157,6 +161,19 @@ def create_lookup( | |
| device=device if device is not None else self._device, | ||
| ) | ||
|
|
||
| def create_output_dist( | ||
| self, | ||
| device: Optional[torch.device] = None, | ||
| ) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]: | ||
| return RwSequenceEmbeddingDist( | ||
| # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got | ||
| # `Optional[ProcessGroup]`. | ||
| self._pg, | ||
| self._get_num_features(), | ||
| device if device is not None else self._device, | ||
| qcomm_codecs_registry=self.qcomm_codecs_registry, | ||
| ) | ||
|
Comment on lines
+164
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passes Optional ProcessGroup
Also appears in: |
||
|
|
||
|
|
||
| class GroupedPooledEmbeddingsLookup(_GroupedPooledEmbeddingsLookup): | ||
| def _create_embedding_kernel( | ||
|
|
@@ -259,3 +276,16 @@ def create_lookup( | |
| feature_processor=feature_processor, | ||
| sharding_type=ShardingType.ROW_WISE, | ||
| ) | ||
|
|
||
| def create_output_dist( | ||
| self, | ||
| device: Optional[torch.device] = None, | ||
| ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: | ||
| return RwPooledEmbeddingDist( | ||
| # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got | ||
| # `Optional[ProcessGroup]`. | ||
| self._pg, | ||
| qcomm_codecs_registry=self.qcomm_codecs_registry, | ||
| embedding_dims=self.embedding_dims(), | ||
| ) | ||
| # TODO: confirm what is qcomm_codecs_registry | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing context for init
When
sharding_ctx is None,forward()returnsPooledEmbeddingsReduceScatter(local_embs)(line 134-135) but_distmay have been created with a variable-batch module based on the first call’ssharding_ctx(line 131-133 / 151-155). If the first invocation hadvariable_batch_per_feature=Trueand a later call passessharding_ctx=None, this will call aVariableBatchPooledEmbeddingsReduceScatterwithout its required args, causing a runtime error. Either requiresharding_ctxalways be provided, or make_distselection independent of the first call.