From bf6bd347ebac862f8d69e91d9da200edc5e10047 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Fri, 6 Feb 2026 08:25:28 +0000 Subject: [PATCH 1/3] Add torchrec output dist. --- corelib/dynamicemb/dynamicemb/output_dist.py | 163 ++++++++++++++++++ .../dynamicemb/planner/rw_sharding.py | 29 ++++ 2 files changed, 192 insertions(+) create mode 100644 corelib/dynamicemb/dynamicemb/output_dist.py diff --git a/corelib/dynamicemb/dynamicemb/output_dist.py b/corelib/dynamicemb/dynamicemb/output_dist.py new file mode 100644 index 000000000..a5a0c3a0d --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/output_dist.py @@ -0,0 +1,163 @@ + +from typing import cast, Dict, List, Optional, Union + +import torch +from torch import distributed as dist + +from torchrec.distributed.dist_data import ( + PooledEmbeddingsReduceScatter, + SequenceEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, + EmbeddingShardingContext, +) +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +class RwSequenceEmbeddingDist( + BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] +): + """ + Redistributes sequence embedding tensor in RW fashion with an AlltoAll operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + num_features (int): total number of features. + device (Optional[torch.device]): device on which buffers will be allocated. + """ + + def __init__( + self, + pg: dist.ProcessGroup, + num_features: int, + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + self._dist = SequenceEmbeddingsAllToAll( + pg, + [num_features] * pg.size(), + device, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), + ) + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[SequenceShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs AlltoAll operation on sequence embeddings tensor. + + Args: + local_embs (torch.Tensor): tensor of values to distribute. + sharding_ctx (SequenceShardingContext): shared context from KJTAllToAll + operation. + + Returns: + Awaitable[torch.Tensor]: awaitable of sequence embeddings. + """ + print("forward RwSequenceEmbeddingDist--------------------------------") + assert sharding_ctx is not None + return self._dist( + local_embs, + lengths=sharding_ctx.lengths_after_input_dist, + input_splits=sharding_ctx.input_splits, + output_splits=sharding_ctx.output_splits, + batch_size_per_rank=sharding_ctx.batch_size_per_rank, + sparse_features_recat=sharding_ctx.sparse_features_recat, + unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor, + ) + + + +class RwPooledEmbeddingDist( + BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] +): + """ + Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter + operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for reduce-scatter communication. + """ + + def __init__( + self, + pg: dist.ProcessGroup, + embedding_dims: List[int], + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + + self._dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._pg = pg + self._qcomm_codecs_registry = qcomm_codecs_registry + self._codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None + ) + if qcomm_codecs_registry + else None + ) + self._embedding_dims = embedding_dims + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[EmbeddingShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs reduce-scatter pooled operation on pooled embeddings tensor. + + Args: + local_embs (torch.Tensor): pooled embeddings tensor to distribute. + sharding_ctx (Optional[EmbeddingShardingContext]): shared context from + KJTAllToAll operation. + + Returns: + Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. + """ + print("forward RwPooledEmbeddingDist--------------------------------") + if self._dist is None: + self._create_output_dist_module(sharding_ctx) + + if sharding_ctx is None: + return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) + elif sharding_ctx.variable_batch_per_feature: + return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( + local_embs, + batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, + embedding_dims=self._embedding_dims, + ) + else: + return cast(PooledEmbeddingsReduceScatter, self._dist)( + local_embs, + input_splits=sharding_ctx.batch_size_per_rank, + ) + + def _create_output_dist_module( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + self._dist = VariableBatchPooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) + else: + self._dist = PooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) \ No newline at end of file diff --git a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py index 69b6f361b..b0a7f0b6c 100644 --- a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py +++ b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py @@ -50,7 +50,10 @@ BatchedDynamicEmbeddingBag, ) from ..input_dist import RwSparseFeaturesDist +from torchrec.distributed.embedding_sharding import BaseEmbeddingDist, EmbeddingShardingContext +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from ..output_dist import RwPooledEmbeddingDist, RwSequenceEmbeddingDist class GroupedEmbeddingsLookup(_GroupedEmbeddingsLookup): def _create_embedding_kernel( @@ -158,6 +161,19 @@ def create_lookup( ) + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]: + return RwSequenceEmbeddingDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._pg, + self._get_num_features(), + device if device is not None else self._device, + qcomm_codecs_registry=self.qcomm_codecs_registry, + ) + class GroupedPooledEmbeddingsLookup(_GroupedPooledEmbeddingsLookup): def _create_embedding_kernel( self, @@ -259,3 +275,16 @@ def create_lookup( feature_processor=feature_processor, sharding_type=ShardingType.ROW_WISE, ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: + return RwPooledEmbeddingDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._pg, + qcomm_codecs_registry=self.qcomm_codecs_registry, + embedding_dims=self.embedding_dims(), + ) + # TODO: confirm what is qcomm_codecs_registry \ No newline at end of file From 3256224ee07b8580462cecec36308d4470d31bd8 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Fri, 6 Feb 2026 08:26:30 +0000 Subject: [PATCH 2/3] Pre commit --- corelib/dynamicemb/dynamicemb/output_dist.py | 11 +++++------ corelib/dynamicemb/dynamicemb/planner/rw_sharding.py | 11 ++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/output_dist.py b/corelib/dynamicemb/dynamicemb/output_dist.py index a5a0c3a0d..3c697fe07 100644 --- a/corelib/dynamicemb/dynamicemb/output_dist.py +++ b/corelib/dynamicemb/dynamicemb/output_dist.py @@ -1,9 +1,7 @@ - -from typing import cast, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast import torch from torch import distributed as dist - from torchrec.distributed.dist_data import ( PooledEmbeddingsReduceScatter, SequenceEmbeddingsAllToAll, @@ -15,6 +13,8 @@ ) from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs + + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] ): @@ -64,7 +64,7 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of sequence embeddings. """ - print("forward RwSequenceEmbeddingDist--------------------------------") + print("forward RwSequenceEmbeddingDist--------------------------------") assert sharding_ctx is not None return self._dist( local_embs, @@ -77,7 +77,6 @@ def forward( ) - class RwPooledEmbeddingDist( BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] ): @@ -160,4 +159,4 @@ def _create_output_dist_module( self._dist = PooledEmbeddingsReduceScatter( pg=self._pg, codecs=self._codecs, - ) \ No newline at end of file + ) diff --git a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py index b0a7f0b6c..cbd0b5dac 100644 --- a/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py +++ b/corelib/dynamicemb/dynamicemb/planner/rw_sharding.py @@ -29,7 +29,9 @@ GroupedPooledEmbeddingsLookup as _GroupedPooledEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, BaseSparseFeaturesDist, + EmbeddingShardingContext, EmbeddingShardingInfo, ) from torchrec.distributed.embedding_types import ( @@ -42,6 +44,7 @@ RwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv, ShardingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -50,11 +53,9 @@ BatchedDynamicEmbeddingBag, ) from ..input_dist import RwSparseFeaturesDist -from torchrec.distributed.embedding_sharding import BaseEmbeddingDist, EmbeddingShardingContext -from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext - from ..output_dist import RwPooledEmbeddingDist, RwSequenceEmbeddingDist + class GroupedEmbeddingsLookup(_GroupedEmbeddingsLookup): def _create_embedding_kernel( self, @@ -160,7 +161,6 @@ def create_lookup( device=device if device is not None else self._device, ) - def create_output_dist( self, device: Optional[torch.device] = None, @@ -174,6 +174,7 @@ def create_output_dist( qcomm_codecs_registry=self.qcomm_codecs_registry, ) + class GroupedPooledEmbeddingsLookup(_GroupedPooledEmbeddingsLookup): def _create_embedding_kernel( self, @@ -287,4 +288,4 @@ def create_output_dist( qcomm_codecs_registry=self.qcomm_codecs_registry, embedding_dims=self.embedding_dims(), ) - # TODO: confirm what is qcomm_codecs_registry \ No newline at end of file + # TODO: confirm what is qcomm_codecs_registry From 55c7bdc34dc72fe8b8516f24584bd8ff4e601a8a Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Fri, 6 Feb 2026 10:21:53 +0000 Subject: [PATCH 3/3] Remove some comments. --- corelib/dynamicemb/dynamicemb/output_dist.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/output_dist.py b/corelib/dynamicemb/dynamicemb/output_dist.py index 3c697fe07..3fd226210 100644 --- a/corelib/dynamicemb/dynamicemb/output_dist.py +++ b/corelib/dynamicemb/dynamicemb/output_dist.py @@ -64,7 +64,6 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of sequence embeddings. """ - print("forward RwSequenceEmbeddingDist--------------------------------") assert sharding_ctx is not None return self._dist( local_embs, @@ -129,7 +128,6 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. """ - print("forward RwPooledEmbeddingDist--------------------------------") if self._dist is None: self._create_output_dist_module(sharding_ctx)