Skip to content
18 changes: 16 additions & 2 deletions physicsnemo/domain_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,26 @@
# In minumum versions are met, we can import the shard tensor and spec.

from ._shard_tensor_spec import ShardTensorSpec
from .shard_tensor import ShardTensor, scatter_tensor
from .shard_tensor import (
FSDPOutputTensorAdapter,
ShardTensor,
distribute_over_domain_for_fsdp,
scatter_tensor,
wrap_for_fsdp,
)

def register_custom_ops():
"""Register all custom ShardTensor ops and shard-aware wrappers.

Imports are deferred to this function to avoid an import cycle between
``shard_tensor`` and the individual op modules.
"""
# These imports will register the custom ops with the ShardTensor class.
# It's done here to avoid an import cycle.
from .custom_ops import (
mean_wrapper,
sum_wrapper,
unbind_rules,
unbind_wrapper,
)
from .shard_utils import register_shard_wrappers

Expand All @@ -69,3 +80,6 @@ def register_custom_ops():
ShardTensor = None
ShardTensorSpec = None
scatter_tensor = None
distribute_over_domain_for_fsdp = None
FSDPOutputTensorAdapter = None
wrap_for_fsdp = None
2 changes: 1 addition & 1 deletion physicsnemo/domain_parallel/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@

if ST_AVAILABLE:
from ._reductions import mean_wrapper, sum_wrapper
from ._tensor_ops import unbind_rules
from ._tensor_ops import unbind_wrapper
110 changes: 78 additions & 32 deletions physicsnemo/domain_parallel/custom_ops/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@
)

import torch
from torch.distributed.tensor._dtensor_spec import TensorMeta
from torch.distributed.tensor.placement_types import (
Partial,
Shard,
)

# noqa: E402
from physicsnemo.domain_parallel._shard_tensor_spec import (
ShardTensorSpec,
_stride_from_contiguous_shape_C_style,
)
from physicsnemo.domain_parallel.shard_tensor import ShardTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -248,6 +253,62 @@ def compute_result_sharding_shapes(
return result_sharding_shapes


def build_reduction_result(
local_result: torch.Tensor,
input_tensor: ShardTensor,
placements: list[Partial | Shard],
sharding_shapes: dict[int, list[torch.Size]],
) -> ShardTensor:
r"""Construct a ShardTensor result from a local reduction output.

Builds the ``ShardTensorSpec`` directly from the already-computed placements
and sharding shapes, avoiding the overhead and autograd side-effects of
``ShardTensor.from_local``.

Parameters
----------
local_result : torch.Tensor
The locally-computed reduction result.
input_tensor : ShardTensor
The original input ShardTensor (used for device mesh).
placements : List[Union[Partial, Shard]]
Result placements from :func:`compute_result_placements`.
sharding_shapes : Dict[int, List[torch.Size]]
Result sharding shapes from :func:`compute_result_sharding_shapes`.

Returns
-------
ShardTensor
Wrapped result with correct sharding metadata.
"""
global_shape = list(local_result.shape)
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
tensor_dim = placement.dim
global_shape[tensor_dim] = sum(
s[tensor_dim] for s in sharding_shapes[mesh_dim]
)

stride = _stride_from_contiguous_shape_C_style(global_shape)
spec = ShardTensorSpec(
mesh=input_tensor.device_mesh,
placements=tuple(placements),
tensor_meta=TensorMeta(
shape=tuple(global_shape),
stride=stride,
dtype=local_result.dtype,
),
_local_shape=local_result.shape,
_sharding_shapes={dim: tuple(s) for dim, s in sharding_shapes.items()},
)
return ShardTensor.__new__(
ShardTensor,
local_tensor=local_result,
spec=spec,
requires_grad=input_tensor.requires_grad,
)


def create_sharded_grad_input(
local_grad_input: torch.Tensor, original_spec: Any
) -> ShardTensor:
Expand All @@ -265,11 +326,15 @@ def create_sharded_grad_input(
ShardTensor
A distributed tensor with the same sharding as the original input.
"""
return ShardTensor.from_local(
local_grad_input,
device_mesh=original_spec.mesh,
placements=original_spec.placements,
sharding_shapes=original_spec.sharding_shapes(),
# In custom autograd backward, return the input gradient directly as a
# ShardTensor value. Avoid ``from_local`` here (which routes through a
# separate autograd Function) so the gradient is attached unambiguously to
# the original ShardTensor input.
return ShardTensor.__new__(
ShardTensor,
local_tensor=local_grad_input,
spec=original_spec,
requires_grad=False,
)


Expand Down Expand Up @@ -361,24 +426,14 @@ def forward(
"""
dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim)

# Get local tensor
local_tensor = tensor._local_tensor
# Perform local sum
local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype)
local_result = aten.sum(
tensor._local_tensor, dim=dim, keepdim=keepdim, dtype=dtype
)

# Compute placements for the result
placements = compute_result_placements(tensor, dim, "sum")
output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

# Create result ShardTensor
result = ShardTensor.from_local(
local_result,
tensor.device_mesh,
placements,
sharding_shapes=output_sharding_shapes,
)
sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

return result
return build_reduction_result(local_result, tensor, placements, sharding_shapes)

@staticmethod
def backward(
Expand Down Expand Up @@ -495,23 +550,14 @@ def forward(
for d in reduction_dims:
weight *= local_shape[d] / global_shape[d]

# Perform local mean
# Perform local mean and apply weighting for uneven shards
local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype)
# Apply weighting
local_result = local_result * weight

placements = compute_result_placements(tensor, dim, "sum")
output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

# Create result ShardTensor
result = ShardTensor.from_local(
local_result,
tensor.device_mesh,
placements,
sharding_shapes=output_sharding_shapes,
)
sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

return result
return build_reduction_result(local_result, tensor, placements, sharding_shapes)

@staticmethod
def backward(
Expand Down
Loading
Loading