diff --git a/corelib/dynamicemb/dynamicemb/output_dist.py b/corelib/dynamicemb/dynamicemb/output_dist.py new file mode 100644 index 00000000..3fd22621 --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/output_dist.py @@ -0,0 +1,160 @@ +from typing import Dict, List, Optional, Union, cast + +import torch +from torch import distributed as dist +from torchrec.distributed.dist_data import ( + PooledEmbeddingsReduceScatter, + SequenceEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, + EmbeddingShardingContext, +) +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs + + +class RwSequenceEmbeddingDist( + BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] +): + """ + Redistributes sequence embedding tensor in RW fashion with an AlltoAll operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + num_features (int): total number of features. + device (Optional[torch.device]): device on which buffers will be allocated. + """ + + def __init__( + self, + pg: dist.ProcessGroup, + num_features: int, + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + self._dist = SequenceEmbeddingsAllToAll( + pg, + [num_features] * pg.size(), + device, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), + ) + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[SequenceShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs AlltoAll operation on sequence embeddings tensor. + + Args: + local_embs (torch.Tensor): tensor of values to distribute. + sharding_ctx (SequenceShardingContext): shared context from KJTAllToAll + operation. + + Returns: + Awaitable[torch.Tensor]: awaitable of sequence embeddings. + """ + assert sharding_ctx is not None + return self._dist( + local_embs, + lengths=sharding_ctx.lengths_after_input_dist, + input_splits=sharding_ctx.input_splits, + output_splits=sharding_ctx.output_splits, + batch_size_per_rank=sharding_ctx.batch_size_per_rank, + sparse_features_recat=sharding_ctx.sparse_features_recat, + unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor, + ) + + +class RwPooledEmbeddingDist( + BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] +): + """ + Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter + operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for reduce-scatter communication. + """ + + def __init__( + self, + pg: dist.ProcessGroup, + embedding_dims: List[int], + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + + self._dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._pg = pg + self._qcomm_codecs_registry = qcomm_codecs_registry + self._codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None + ) + if qcomm_codecs_registry + else None + ) + self._embedding_dims = embedding_dims + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[EmbeddingShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs reduce-scatter pooled operation on pooled embeddings tensor. + + Args: + local_embs (torch.Tensor): pooled embeddings tensor to distribute. + sharding_ctx (Optional[EmbeddingShardingContext]): shared context from + KJTAllToAll operation. + + Returns: + Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. + """ + if self._dist is None: + self._create_output_dist_module(sharding_ctx) + + if sharding_ctx is None: + return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) + elif sharding_ctx.variable_batch_per_feature: + return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( + local_embs, + batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, + embedding_dims=self._embedding_dims, + ) + else: + return cast(PooledEmbeddingsReduceScatter, self._dist)( + local_embs, + input_splits=sharding_ctx.batch_size_per_rank, + ) + + def _create_output_dist_module( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + self._dist = VariableBatchPooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) + else: + self._dist = PooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) diff --git a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py index 69b6f361..cbd0b5da 100644 --- a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py +++ b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py @@ -29,7 +29,9 @@ GroupedPooledEmbeddingsLookup as _GroupedPooledEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, BaseSparseFeaturesDist, + EmbeddingShardingContext, EmbeddingShardingInfo, ) from torchrec.distributed.embedding_types import ( @@ -42,6 +44,7 @@ RwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv, ShardingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -50,6 +53,7 @@ BatchedDynamicEmbeddingBag, ) from ..input_dist import RwSparseFeaturesDist +from ..output_dist import RwPooledEmbeddingDist, RwSequenceEmbeddingDist class GroupedEmbeddingsLookup(_GroupedEmbeddingsLookup): @@ -157,6 +161,19 @@ def create_lookup( device=device if device is not None else self._device, ) + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]: + return RwSequenceEmbeddingDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._pg, + self._get_num_features(), + device if device is not None else self._device, + qcomm_codecs_registry=self.qcomm_codecs_registry, + ) + class GroupedPooledEmbeddingsLookup(_GroupedPooledEmbeddingsLookup): def _create_embedding_kernel( @@ -259,3 +276,16 @@ def create_lookup( feature_processor=feature_processor, sharding_type=ShardingType.ROW_WISE, ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: + return RwPooledEmbeddingDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._pg, + qcomm_codecs_registry=self.qcomm_codecs_registry, + embedding_dims=self.embedding_dims(), + ) + # TODO: confirm what is qcomm_codecs_registry