Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions corelib/dynamicemb/dynamicemb/output_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Dict, List, Optional, Union, cast

import torch
from torch import distributed as dist
from torchrec.distributed.dist_data import (
PooledEmbeddingsReduceScatter,
SequenceEmbeddingsAllToAll,
VariableBatchPooledEmbeddingsReduceScatter,
)
from torchrec.distributed.embedding_sharding import (
BaseEmbeddingDist,
EmbeddingShardingContext,
)
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs


class RwSequenceEmbeddingDist(
BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]
):
"""
Redistributes sequence embedding tensor in RW fashion with an AlltoAll operation.

Args:
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
num_features (int): total number of features.
device (Optional[torch.device]): device on which buffers will be allocated.
"""

def __init__(
self,
pg: dist.ProcessGroup,
num_features: int,
device: Optional[torch.device] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__()
self._dist = SequenceEmbeddingsAllToAll(
pg,
[num_features] * pg.size(),
device,
codecs=(
qcomm_codecs_registry.get(
CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None
)
if qcomm_codecs_registry
else None
),
)

def forward(
self,
local_embs: torch.Tensor,
sharding_ctx: Optional[SequenceShardingContext] = None,
) -> Awaitable[torch.Tensor]:
"""
Performs AlltoAll operation on sequence embeddings tensor.

Args:
local_embs (torch.Tensor): tensor of values to distribute.
sharding_ctx (SequenceShardingContext): shared context from KJTAllToAll
operation.

Returns:
Awaitable[torch.Tensor]: awaitable of sequence embeddings.
"""
assert sharding_ctx is not None
return self._dist(
local_embs,
lengths=sharding_ctx.lengths_after_input_dist,
input_splits=sharding_ctx.input_splits,
output_splits=sharding_ctx.output_splits,
batch_size_per_rank=sharding_ctx.batch_size_per_rank,
sparse_features_recat=sharding_ctx.sparse_features_recat,
unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor,
)


class RwPooledEmbeddingDist(
BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]
):
"""
Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter
operation.

Args:
pg (dist.ProcessGroup): ProcessGroup for reduce-scatter communication.
"""

def __init__(
self,
pg: dist.ProcessGroup,
embedding_dims: List[int],
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__()

self._dist: Optional[
Union[
PooledEmbeddingsReduceScatter,
VariableBatchPooledEmbeddingsReduceScatter,
]
] = None
self._pg = pg
self._qcomm_codecs_registry = qcomm_codecs_registry
self._codecs: Optional[QuantizedCommCodecs] = (
qcomm_codecs_registry.get(
CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None
)
if qcomm_codecs_registry
else None
)
self._embedding_dims = embedding_dims

def forward(
self,
local_embs: torch.Tensor,
sharding_ctx: Optional[EmbeddingShardingContext] = None,
) -> Awaitable[torch.Tensor]:
"""
Performs reduce-scatter pooled operation on pooled embeddings tensor.

Args:
local_embs (torch.Tensor): pooled embeddings tensor to distribute.
sharding_ctx (Optional[EmbeddingShardingContext]): shared context from
KJTAllToAll operation.

Returns:
Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor.
"""
if self._dist is None:
self._create_output_dist_module(sharding_ctx)

if sharding_ctx is None:
return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs)
elif sharding_ctx.variable_batch_per_feature:
return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)(
local_embs,
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
embedding_dims=self._embedding_dims,
)
else:
return cast(PooledEmbeddingsReduceScatter, self._dist)(
local_embs,
input_splits=sharding_ctx.batch_size_per_rank,
)
Comment on lines +130 to +146
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing context for init

When sharding_ctx is None, forward() returns PooledEmbeddingsReduceScatter(local_embs) (line 134-135) but _dist may have been created with a variable-batch module based on the first call’s sharding_ctx (line 131-133 / 151-155). If the first invocation had variable_batch_per_feature=True and a later call passes sharding_ctx=None, this will call a VariableBatchPooledEmbeddingsReduceScatter without its required args, causing a runtime error. Either require sharding_ctx always be provided, or make _dist selection independent of the first call.


def _create_output_dist_module(
self, sharding_ctx: Optional[EmbeddingShardingContext] = None
) -> None:
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
self._dist = VariableBatchPooledEmbeddingsReduceScatter(
pg=self._pg,
codecs=self._codecs,
)
else:
self._dist = PooledEmbeddingsReduceScatter(
pg=self._pg,
codecs=self._codecs,
)
30 changes: 30 additions & 0 deletions corelib/dynamicemb/dynamicemb/planner/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
GroupedPooledEmbeddingsLookup as _GroupedPooledEmbeddingsLookup,
)
from torchrec.distributed.embedding_sharding import (
BaseEmbeddingDist,
BaseSparseFeaturesDist,
EmbeddingShardingContext,
EmbeddingShardingInfo,
)
from torchrec.distributed.embedding_types import (
Expand All @@ -42,6 +44,7 @@
RwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv, ShardingType
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand All @@ -50,6 +53,7 @@
BatchedDynamicEmbeddingBag,
)
from ..input_dist import RwSparseFeaturesDist
from ..output_dist import RwPooledEmbeddingDist, RwSequenceEmbeddingDist


class GroupedEmbeddingsLookup(_GroupedEmbeddingsLookup):
Expand Down Expand Up @@ -157,6 +161,19 @@ def create_lookup(
device=device if device is not None else self._device,
)

def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]:
return RwSequenceEmbeddingDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
self._pg,
self._get_num_features(),
device if device is not None else self._device,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
Comment on lines +164 to +175
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passes Optional ProcessGroup

create_output_dist() passes self._pg through to RwSequenceEmbeddingDist, but self._pg is typed as Optional[ProcessGroup] (and you’ve added a pyre-fixme to silence it). If self._pg is actually None at runtime (e.g., in non-distributed / single-rank setups), RwSequenceEmbeddingDist.__init__ will call pg.size() and crash. This needs a real guard or to ensure _pg is always non-None before constructing the dist module (same pattern in pooled sharding too).

Also appears in: corelib/dynamicemb/dynamicemb/planner/rw_sharding.py:264-274.



class GroupedPooledEmbeddingsLookup(_GroupedPooledEmbeddingsLookup):
def _create_embedding_kernel(
Expand Down Expand Up @@ -259,3 +276,16 @@ def create_lookup(
feature_processor=feature_processor,
sharding_type=ShardingType.ROW_WISE,
)

def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]:
return RwPooledEmbeddingDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
self._pg,
qcomm_codecs_registry=self.qcomm_codecs_registry,
embedding_dims=self.embedding_dims(),
)
# TODO: confirm what is qcomm_codecs_registry