diff --git a/physicsnemo/domain_parallel/__init__.py b/physicsnemo/domain_parallel/__init__.py index c3c16d14fa..270ce3f6df 100644 --- a/physicsnemo/domain_parallel/__init__.py +++ b/physicsnemo/domain_parallel/__init__.py @@ -47,15 +47,26 @@ # In minumum versions are met, we can import the shard tensor and spec. from ._shard_tensor_spec import ShardTensorSpec - from .shard_tensor import ShardTensor, scatter_tensor + from .shard_tensor import ( + FSDPOutputTensorAdapter, + ShardTensor, + distribute_over_domain_for_fsdp, + scatter_tensor, + wrap_for_fsdp, + ) def register_custom_ops(): + """Register all custom ShardTensor ops and shard-aware wrappers. + + Imports are deferred to this function to avoid an import cycle between + ``shard_tensor`` and the individual op modules. + """ # These imports will register the custom ops with the ShardTensor class. # It's done here to avoid an import cycle. from .custom_ops import ( mean_wrapper, sum_wrapper, - unbind_rules, + unbind_wrapper, ) from .shard_utils import register_shard_wrappers @@ -69,3 +80,6 @@ def register_custom_ops(): ShardTensor = None ShardTensorSpec = None scatter_tensor = None + distribute_over_domain_for_fsdp = None + FSDPOutputTensorAdapter = None + wrap_for_fsdp = None diff --git a/physicsnemo/domain_parallel/custom_ops/__init__.py b/physicsnemo/domain_parallel/custom_ops/__init__.py index 2dc160d6de..bfd7e1c641 100644 --- a/physicsnemo/domain_parallel/custom_ops/__init__.py +++ b/physicsnemo/domain_parallel/custom_ops/__init__.py @@ -21,4 +21,4 @@ if ST_AVAILABLE: from ._reductions import mean_wrapper, sum_wrapper - from ._tensor_ops import unbind_rules + from ._tensor_ops import unbind_wrapper diff --git a/physicsnemo/domain_parallel/custom_ops/_reductions.py b/physicsnemo/domain_parallel/custom_ops/_reductions.py index b673e6bec4..7c509c0eb3 100644 --- a/physicsnemo/domain_parallel/custom_ops/_reductions.py +++ b/physicsnemo/domain_parallel/custom_ops/_reductions.py @@ -44,12 +44,17 @@ ) import torch +from torch.distributed.tensor._dtensor_spec import TensorMeta from torch.distributed.tensor.placement_types import ( Partial, Shard, ) # noqa: E402 +from physicsnemo.domain_parallel._shard_tensor_spec import ( + ShardTensorSpec, + _stride_from_contiguous_shape_C_style, +) from physicsnemo.domain_parallel.shard_tensor import ShardTensor aten = torch.ops.aten @@ -248,6 +253,62 @@ def compute_result_sharding_shapes( return result_sharding_shapes +def build_reduction_result( + local_result: torch.Tensor, + input_tensor: ShardTensor, + placements: list[Partial | Shard], + sharding_shapes: dict[int, list[torch.Size]], +) -> ShardTensor: + r"""Construct a ShardTensor result from a local reduction output. + + Builds the ``ShardTensorSpec`` directly from the already-computed placements + and sharding shapes, avoiding the overhead and autograd side-effects of + ``ShardTensor.from_local``. + + Parameters + ---------- + local_result : torch.Tensor + The locally-computed reduction result. + input_tensor : ShardTensor + The original input ShardTensor (used for device mesh). + placements : List[Union[Partial, Shard]] + Result placements from :func:`compute_result_placements`. + sharding_shapes : Dict[int, List[torch.Size]] + Result sharding shapes from :func:`compute_result_sharding_shapes`. + + Returns + ------- + ShardTensor + Wrapped result with correct sharding metadata. + """ + global_shape = list(local_result.shape) + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + tensor_dim = placement.dim + global_shape[tensor_dim] = sum( + s[tensor_dim] for s in sharding_shapes[mesh_dim] + ) + + stride = _stride_from_contiguous_shape_C_style(global_shape) + spec = ShardTensorSpec( + mesh=input_tensor.device_mesh, + placements=tuple(placements), + tensor_meta=TensorMeta( + shape=tuple(global_shape), + stride=stride, + dtype=local_result.dtype, + ), + _local_shape=local_result.shape, + _sharding_shapes={dim: tuple(s) for dim, s in sharding_shapes.items()}, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_result, + spec=spec, + requires_grad=input_tensor.requires_grad, + ) + + def create_sharded_grad_input( local_grad_input: torch.Tensor, original_spec: Any ) -> ShardTensor: @@ -265,11 +326,15 @@ def create_sharded_grad_input( ShardTensor A distributed tensor with the same sharding as the original input. """ - return ShardTensor.from_local( - local_grad_input, - device_mesh=original_spec.mesh, - placements=original_spec.placements, - sharding_shapes=original_spec.sharding_shapes(), + # In custom autograd backward, return the input gradient directly as a + # ShardTensor value. Avoid ``from_local`` here (which routes through a + # separate autograd Function) so the gradient is attached unambiguously to + # the original ShardTensor input. + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_grad_input, + spec=original_spec, + requires_grad=False, ) @@ -361,24 +426,14 @@ def forward( """ dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) - # Get local tensor - local_tensor = tensor._local_tensor - # Perform local sum - local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) + local_result = aten.sum( + tensor._local_tensor, dim=dim, keepdim=keepdim, dtype=dtype + ) - # Compute placements for the result placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - return result + return build_reduction_result(local_result, tensor, placements, sharding_shapes) @staticmethod def backward( @@ -495,23 +550,14 @@ def forward( for d in reduction_dims: weight *= local_shape[d] / global_shape[d] - # Perform local mean + # Perform local mean and apply weighting for uneven shards local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) - # Apply weighting local_result = local_result * weight placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - return result + return build_reduction_result(local_result, tensor, placements, sharding_shapes) @staticmethod def backward( diff --git a/physicsnemo/domain_parallel/shard_tensor.py b/physicsnemo/domain_parallel/shard_tensor.py index 0a31ca7253..8ab22f67e1 100644 --- a/physicsnemo/domain_parallel/shard_tensor.py +++ b/physicsnemo/domain_parallel/shard_tensor.py @@ -16,14 +16,16 @@ from __future__ import annotations +import threading from collections.abc import Iterable, Mapping +from contextlib import contextmanager from typing import Callable, Sequence, cast -from warnings import warn import torch import torch.distributed as dist +from torch import nn from torch.distributed.device_mesh import DeviceMesh, _mesh_resources -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, distribute_module from torch.distributed.tensor._dtensor_spec import ( TensorMeta, ) @@ -42,66 +44,279 @@ _infer_shard_tensor_spec_from_local_chunks, _stride_from_contiguous_shape_C_style, ) -from physicsnemo.utils.profiling import annotate, profile aten = torch.ops.aten -def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: - r"""Convert a ShardTensor to a plain DTensor for dispatch. +# ====================================================================== - Creates a DTensor with the same internal state as the ShardTensor, - which allows DTensor's dispatch to handle it correctly. +# ============================================================================ +# Layer 1 -- Semi-private conversions (no autograd, no spec inference) +# ============================================================================ - Parameters - ---------- - st : ShardTensor - The ShardTensor to convert. - Returns - ------- - DTensor - A DTensor sharing the same ``_local_tensor`` and ``_spec``. +def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: + r"""Convert a ShardTensor to a plain DTensor (no autograd). + + Creates a DTensor sharing the same ``_local_tensor`` and ``_spec``. + Use for dispatch or inside backward when building a DTensor gradient. """ - dtensor = torch.Tensor._make_wrapper_subclass( - DTensor, - st._spec.tensor_meta.shape, - strides=st._spec.tensor_meta.stride, - dtype=st.dtype, - device=st.device, - layout=st.layout, - requires_grad=st.requires_grad, - ) + if hasattr(torch.Tensor, "_dtensor__new__"): + dtensor = torch.Tensor._dtensor__new__( + DTensor, st._local_tensor, st._spec, requires_grad=st.requires_grad + ) + else: + dtensor = torch.Tensor._make_wrapper_subclass( + DTensor, + st._spec.tensor_meta.shape, + strides=st._spec.tensor_meta.stride, + dtype=st.dtype, + device=st.device, + layout=st.layout, + requires_grad=st.requires_grad, + ) dtensor._local_tensor = st._local_tensor dtensor._spec = st._spec return dtensor -def _convert_args_to_dtensor(arg: object) -> object: - r"""Recursively convert ShardTensors in args to DTensors. +def _dtensor_to_shard_tensor(dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + r"""Promote a DTensor to a ShardTensor (no autograd). - Parameters - ---------- - arg : object - A single argument that may be a ShardTensor, an iterable of - arguments (e.g. list, tuple), a mapping (e.g. dict) whose - values are converted, or any other value. + Callers must supply a resolved ``spec``. Use inside backward (with spec + from ctx) or after resolving a spec via :func:`_resolve_spec_for_dtensor`. + """ + if isinstance(dtensor, ShardTensor): + # Shortcut if we're already a ShardTensor: + return dtensor + st = ShardTensor.__new__( + ShardTensor, + local_tensor=dtensor._local_tensor, + spec=spec, + requires_grad=dtensor.requires_grad, + ) + return st - Returns - ------- - object - The argument with any ShardTensors replaced by DTensors. + +# ============================================================================ +# Layer 2 -- Autograd Functions (use Layer 1 inside fwd / bwd) +# ============================================================================ + + +class _DTensorToShardTensor(torch.autograd.Function): + r"""Differentiable promotion: DTensor -> ShardTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_dtensor_to_shard_tensor`. + Backward: :func:`_shard_tensor_to_dtensor`. """ - # ShardTensor is defined later in this module; the isinstance check - # is safe because this function is only called at runtime. - if isinstance(arg, ShardTensor): - return _shard_tensor_to_dtensor(arg) - elif isinstance(arg, Mapping): - return type(arg)({k: _convert_args_to_dtensor(v) for k, v in arg.items()}) - elif isinstance(arg, Iterable) and not isinstance(arg, (str, bytes)): - converted = [_convert_args_to_dtensor(a) for a in arg] - return type(arg)(converted) - return arg + + @staticmethod + def forward(ctx, dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + return _dtensor_to_shard_tensor(dtensor, spec) + + @staticmethod + def backward(ctx, grad_output: "ShardTensor"): + return _shard_tensor_to_dtensor(grad_output), None + + +class _ShardTensorToDTensor(torch.autograd.Function): + r"""Differentiable conversion: ShardTensor -> DTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_shard_tensor_to_dtensor` (caches spec). + Backward: :func:`_dtensor_to_shard_tensor` (reuses cached spec). + """ + + @staticmethod + def forward(ctx, st: "ShardTensor") -> DTensor: + ctx.shard_tensor_spec = st._spec + return _shard_tensor_to_dtensor(st) + + @staticmethod + def backward(ctx, grad_output: DTensor): + return (_dtensor_to_shard_tensor(grad_output, ctx.shard_tensor_spec),) + + +# ============================================================================ +# Layer 3 -- Smart single-tensor converters (auto-diff when grad_fn present) +# ============================================================================ + + +def _resolve_spec_for_dtensor( + dtensor: DTensor, input_args: tuple = () +) -> ShardTensorSpec: + r"""Resolve a ShardTensorSpec for *dtensor*. + + Tries to reuse a spec from a ShardTensor in *input_args* whose + ``tensor_meta`` and ``placements`` match. Falls back to chunk-based + inference (no communication). + """ + for arg in input_args: + if ( + isinstance(arg, ShardTensor) + and dtensor._spec.tensor_meta == arg._spec.tensor_meta + and dtensor._spec.placements == arg._spec.placements + ): + return arg._spec + return _infer_shard_tensor_spec_from_local_chunks( + dtensor._local_tensor, + dtensor._spec.mesh, + dtensor._spec.placements, + sharding_shapes="chunk", + global_shape=dtensor.shape, + ) + + +# This is a thread-safe reentry guard. +# Goal is to prevent recursion into the fallback conversion paths. +_conversion_guard = threading.local() + + +def _conversion_active() -> bool: + r"""Return whether ShardTensor<->DTensor conversion is currently active.""" + return getattr(_conversion_guard, "depth", 0) > 0 + + +@contextmanager +def _conversion_scope(): + r"""Re-entrant conversion guard for cast-down/cast-up paths.""" + previous_depth = getattr(_conversion_guard, "depth", 0) + _conversion_guard.depth = previous_depth + 1 + try: + yield + finally: + if previous_depth == 0: + delattr(_conversion_guard, "depth") + else: + _conversion_guard.depth = previous_depth + + +def _dispatch_fallback_via_dtensor( + func: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute an ATen op through DTensor fallback using PURE data conversion. + + Native Autograd wraps this hook, so we must NOT build an internal graph + using .apply(). We just do the math and let PyTorch track the outer graph. + """ + with _conversion_scope(): + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=False) for arg in args + ) + converted_kwargs = { + k: _convert_args_to_dtensor(v, use_autograd=False) + for k, v in (kwargs or {}).items() + } + + dispatch_res = func(*converted_args, **(converted_kwargs or {})) + + with _conversion_scope(): + return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False) + + +def _torch_function_fallback_via_dtensor( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute a __torch_function__ fallback through DTensor safely. + + Because this executes at the Python API level (above Autograd), we MUST + use autograd functions (.apply) to bridge the tracking manually. + """ + with _conversion_scope(): + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=True) for arg in args + ) + converted_kwargs = { + k: _convert_args_to_dtensor(v, use_autograd=True) + for k, v in (kwargs or {}).items() + } + + with torch._C.DisableTorchFunctionSubclass(): + result = func(*converted_args, **converted_kwargs) + + with _conversion_scope(): + return _convert_results_to_shard_tensor(result, args, use_autograd=True) + + +# ============================================================================ +# Layer 4 -- Recurse utilities (walk args / kwargs / results) +# ============================================================================ + + +def _convert_args_to_dtensor(arg: object, use_autograd: bool = False) -> object: + r"""Recursively replace ShardTensors with DTensors. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. + """ + match arg: + case ShardTensor(): + if use_autograd and arg.requires_grad and torch.is_grad_enabled(): + return _ShardTensorToDTensor.apply(arg) + return _shard_tensor_to_dtensor(arg) + case DTensor(): + # DTensor can be iterable; exit early deliberately + return arg + case Mapping(): + return type(arg)( + {k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()} + ) + case tuple(): + return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg) + case list(): + return [_convert_args_to_dtensor(a, use_autograd) for a in arg] + case _: + return arg + + +def _convert_results_to_shard_tensor( + result: object, input_args: tuple, use_autograd: bool = False +) -> object: + r"""Recursively replace DTensors with ShardTensors in an op result. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. + Handles None returns gracefully for inplace ATen operations. + """ + if result is None: + return None + + if isinstance(result, DTensor): + spec = _resolve_spec_for_dtensor(result, input_args) + + # If autograd graph connection is requested AND the DTensor actually + # requires tracking (it has a grad_fn or requires_grad is active) + if ( + use_autograd + and torch.is_grad_enabled() + and (result.grad_fn is not None or result.requires_grad) + ): + return _DTensorToShardTensor.apply(result, spec) + + return _dtensor_to_shard_tensor(result, spec) + + if isinstance(result, Mapping): + return type(result)( + { + k: _convert_results_to_shard_tensor(v, input_args, use_autograd) + for k, v in result.items() + } + ) + + if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): + return type(result)( + _convert_results_to_shard_tensor(d, input_args, use_autograd) + for d in result + ) + + return result class _ToTorchTensor(torch.autograd.Function): @@ -136,13 +351,17 @@ def forward( """ ctx.shard_tensor_spec = input._spec ctx.grad_placements = grad_placements - local_tensor = input._local_tensor + # # JUST LIKE DTENSOR: + # # We need to return a fresh Tensor object there as autograd metadata + # # will be inplaced into it. So we don't want to pollute the Tensor + # # object stored in the _local_tensor of this ShardTensor. + # return local_tensor.view_as(local_tensor) - # JUST LIKE DTENSOR: - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this ShardTensor. - return local_tensor.view_as(local_tensor) + # Force the local view to inherit the requires_grad state of the ShardTensor + local_tensor = input._local_tensor + res = local_tensor.view_as(local_tensor) + res.requires_grad_(input.requires_grad) + return res @staticmethod def backward( @@ -296,74 +515,7 @@ def backward( return grad_output.to_local(), None, None, None -class _PromoteDTensorToShardTensor(torch.autograd.Function): - r"""Autograd function to promote a DTensor to a ShardTensor while preserving ``grad_fn``. - - When DTensor's ``__torch_function__`` returns a non-leaf DTensor (one that - has a ``grad_fn``), creating a new ShardTensor via ``_make_wrapper_subclass`` - always produces a leaf — disconnecting it from the autograd graph. - - This function bridges that gap: the forward creates the ShardTensor wrapper, - and ``apply`` attaches a ``grad_fn`` that connects it back to the original - DTensor's graph. The backward simply passes gradients through unchanged. - - This is only used at the ``__torch_function__`` level where the DTensor - result already carries autograd state. At the ``__torch_dispatch__`` level, - promotion is safe without this because autograd wraps the result afterwards. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - dtensor: DTensor, - spec: "ShardTensorSpec", - ) -> "ShardTensor": - r"""Create a ShardTensor from a DTensor, preserving autograd via ``apply``. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused — no state needed for backward). - dtensor : DTensor - The DTensor to promote. - spec : ShardTensorSpec - The ShardTensorSpec to use for the new ShardTensor. - - Returns - ------- - ShardTensor - A new ShardTensor wrapping the same local data. - """ - return ShardTensor.__new__( - ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=False, # autograd.Function.apply handles this - ) - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - grad_output: "ShardTensor", - ) -> tuple[DTensor, None]: - r"""Pass gradient through unchanged. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused). - grad_output : ShardTensor - Gradient with respect to the ShardTensor output. - - Returns - ------- - Tuple[DTensor, None] - The gradient for the DTensor input, and ``None`` for the spec. - """ - return grad_output, None - - -class ShardTensor(DTensor): +class ShardTensor(torch.Tensor): r"""A distributed tensor class with support for uneven data sharding. Similar to PyTorch's native ``DTensor`` but with more flexibility for @@ -496,41 +648,6 @@ def __new__( *, requires_grad: bool, ) -> "ShardTensor": - r"""Construct a new ShardTensor from a local tensor and specification. - - Note that unlike ``DTensor``, ShardTensor will automatically collect - the shard size information from all participating devices. This enables - uneven and dynamic sharding. - - Parameters - ---------- - local_tensor : torch.Tensor - Local tensor to use as the data. - spec : ShardTensorSpec - ShardTensorSpec defining the sharding scheme. - requires_grad : bool - Whether the tensor requires gradients. - - Returns - ------- - ShardTensor - A new ShardTensor instance. - - Note - ---- - This implementation is heavily derived from ``torch.distributed.tensor.DTensor``. - """ - if local_tensor.requires_grad and not requires_grad: - warn( - "To construct a new ShardTensor from torch.Tensor, " - "it's recommended to use local_tensor.detach() and " - "make requires_grad consistent." - ) - - if spec.tensor_meta is None: - raise ValueError("TensorMeta should not be None!") - - # Check the sharding information is known: ret = torch.Tensor._make_wrapper_subclass( cls, spec.tensor_meta.shape, @@ -538,178 +655,211 @@ def __new__( dtype=local_tensor.dtype, device=local_tensor.device, layout=local_tensor.layout, - requires_grad=requires_grad, + requires_grad=False, ) ret._spec = spec ret._local_tensor = local_tensor - cls._enable_shard_patches = True + # Set requires_grad AFTER _spec/_local_tensor are assigned, using + # the C-level setter directly (bypassing __torch_function__ which + # would convert to DTensor and set on a temporary). + if requires_grad: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(ret, True) + cls._enable_shard_patches = True return ret def __repr__(self) -> str: - return f"ShardTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + return ( + "ShardTensor(" + f"local_tensor={repr(self._local_tensor)}, " + f"device_mesh={repr(self._spec.mesh)}, " + f"placements={repr(self._spec.placements)}" + ")" + ) - @classmethod - def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": - r"""Convert a DTensor to a ShardTensor. + def __str__(self) -> str: + # Avoid Tensor/DTensor string formatting paths that can re-enter dispatch. + return self.__repr__() - Assumes the DTensor is properly constructed. Since DTensor is locked - to sharding a tensor according to chunk format, the sharding sizes - can be inferred with no communication. + def __format__(self, format_spec: str) -> str: + # Format as plain Python string to bypass tensor formatting internals. + return format(str(self), format_spec) - If the DTensor is a non-leaf (has a ``grad_fn``), the autograd graph - is preserved via :class:`_PromoteDTensorToShardTensor`. + @property + def device_mesh(self) -> DeviceMesh: + """Return the :class:`DeviceMesh` that this tensor is distributed over.""" + return self._spec.mesh - Parameters - ---------- - dtensor : DTensor - DTensor to convert. + @property + def placements(self) -> tuple[Placement, ...]: + """Return the placement strategy for each mesh dimension.""" + return self._spec.placements - Returns - ------- - ShardTensor - Equivalent ShardTensor with the same local tensor and inferred spec. - """ - return cls._maybe_promote_dtensor(dtensor, ()) + def __tensor_flatten__(self): + return ["_local_tensor"], (self._spec, self.requires_grad) @staticmethod - def _maybe_promote_dtensor(dtensor, input_args): - r"""Promote a single DTensor back to ShardTensor if it matches input criteria. + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + spec, requires_grad = flatten_spec + local_tensor = inner_tensors["_local_tensor"] + unflatten_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = ShardTensorSpec( + mesh=spec.mesh, + placements=spec.placements, + tensor_meta=unflatten_meta, + _local_shape=local_tensor.shape, + _sharding_shapes=spec._sharding_shapes, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_tensor.requires_grad_(requires_grad), + spec=unflatten_spec, + requires_grad=requires_grad, + ) + + # -- Autograd property overrides ------------------------------------------- + # The C-level requires_grad is authoritative for autograd engine + # decisions; we read it first and fall back to _local_tensor for the + # case where _make_wrapper_subclass didn't propagate it correctly. + # For grad, the autograd engine accumulates at the C level, so we + # check there first then fall back to _local_tensor.grad. - If ``dtensor`` is already a ShardTensor, it is returned as-is. Otherwise, - determines a ``ShardTensorSpec`` (reusing an input's spec when possible, - otherwise inferring one) and creates a new ShardTensor. + @property # type: ignore[override] + def requires_grad(self) -> bool: # type: ignore[override] + """Whether this tensor requires gradient computation. - When the DTensor is a non-leaf (has a ``grad_fn``), the promotion goes - through :class:`_PromoteDTensorToShardTensor` so that the autograd graph - is preserved. For leaf DTensors, direct construction is used since there - is no graph to preserve. + Returns ``True`` if either the wrapper tensor or the underlying local + tensor has ``requires_grad`` set. + """ + with torch._C.DisableTorchFunctionSubclass(): + if torch.Tensor.requires_grad.__get__(self): + return True + return self._local_tensor.requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + """Set ``requires_grad`` on both the wrapper and the local tensor.""" + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, value) + self._local_tensor.requires_grad = value + + def requires_grad_(self, requires_grad: bool = True) -> "ShardTensor": + """Set ``requires_grad`` in-place on both the wrapper and local tensor. Parameters ---------- - dtensor : DTensor - The DTensor result to promote. - input_args : tuple - Original input arguments to search for matching ShardTensors. + requires_grad : bool, optional + Whether to enable gradient tracking. Default is ``True``. Returns ------- ShardTensor - Promoted ShardTensor (or the original if already a ShardTensor). + ``self``, for method chaining. """ - if isinstance(dtensor, ShardTensor): - return dtensor - - # Determine the ShardTensorSpec — reuse an input's spec when the - # tensor_meta and placements match (avoids communication). - spec = None - for arg in input_args: - if ( - isinstance(arg, ShardTensor) - and dtensor._spec.tensor_meta == arg._spec.tensor_meta - and dtensor._spec.placements == arg._spec.placements - ): - spec = arg._spec - break - - if spec is None: - # Infer from DTensor (no communication for chunk-based sharding). - spec = _infer_shard_tensor_spec_from_local_chunks( - dtensor._local_tensor, - dtensor._spec.mesh, - dtensor._spec.placements, - sharding_shapes="chunk", - global_shape=dtensor.shape, + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, requires_grad) + self._local_tensor.requires_grad_(requires_grad) + return self + + @property # type: ignore[override] + def is_leaf(self) -> bool: # type: ignore[override] + """Whether this tensor is a leaf in the autograd graph.""" + with torch._C.DisableTorchFunctionSubclass(): + return torch.Tensor.is_leaf.__get__(self) + + @property # type: ignore[override] + def grad(self) -> "ShardTensor | None": # type: ignore[override] + """Return the accumulated gradient, wrapped as a :class:`ShardTensor`. + + If no gradient has been accumulated yet, returns ``None``. + """ + with torch._C.DisableTorchFunctionSubclass(): + c_grad = torch.Tensor.grad.__get__(self) + if c_grad is not None: + if isinstance(c_grad, ShardTensor): + return c_grad + return ShardTensor.__new__( + ShardTensor, + local_tensor=c_grad._local_tensor + if isinstance(c_grad, DTensor) + else c_grad, + spec=self._spec, + requires_grad=False, ) - - # Non-leaf DTensors carry a grad_fn from the operation that produced - # them. Creating a new ShardTensor via _make_wrapper_subclass would - # discard that grad_fn (producing a leaf). Go through the autograd - # function so that apply() connects the new ShardTensor back to the - # original graph. - if dtensor.grad_fn is not None: - return _PromoteDTensorToShardTensor.apply(dtensor, spec) - - # Leaf DTensors (parameters, buffers, detached tensors) can be - # constructed directly — there is no autograd graph to preserve. + local_grad = self._local_tensor.grad + if local_grad is None: + return None return ShardTensor.__new__( ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=dtensor.requires_grad, + local_tensor=local_grad, + spec=self._spec, + requires_grad=False, ) - @staticmethod - def _promote_dtensor_results(result, input_args): - r"""Promote DTensor(s) in a dispatch/function result back to ShardTensor. + @grad.setter + def grad(self, value: "ShardTensor | torch.Tensor | None") -> None: + """Set or clear the gradient on both the wrapper and local tensor.""" + if value is None: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, None) + self._local_tensor.grad = None + elif isinstance(value, ShardTensor): + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value._local_tensor + else: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value - Handles four cases: + @classmethod + def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": + r"""Convert a DTensor to a ShardTensor. - 1. Single DTensor — promoted via :meth:`_maybe_promote_dtensor`. - 2. Mapping (e.g. dict) — each value is promoted if it is a DTensor. - 3. Iterable of results — each DTensor element is promoted individually. - 4. Anything else — returned as-is. + Differentiable when *dtensor* is non-leaf (has a ``grad_fn``). + Spec is inferred from the DTensor (chunk-based, no communication). Parameters ---------- - result : object - The result returned by DTensor dispatch or ``__torch_function__``. - input_args : tuple - Original input arguments used for matching specs. + dtensor : DTensor + DTensor to convert. Returns ------- - object - The result with any DTensors promoted to ShardTensors. + ShardTensor + Equivalent ShardTensor with the same local tensor and inferred spec. """ - if isinstance(result, DTensor): - return ShardTensor._maybe_promote_dtensor(result, input_args) - - if isinstance(result, Mapping): - return type(result)( - { - k: ShardTensor._maybe_promote_dtensor(v, input_args) - if isinstance(v, DTensor) - else v - for k, v in result.items() - } - ) - - # Exclude str/bytes so we don't iterate over characters. - if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): - return type(result)( - ShardTensor._maybe_promote_dtensor(d, input_args) - if isinstance(d, DTensor) - else d - for d in result - ) - - return result + if isinstance(dtensor, ShardTensor): + return dtensor + spec = _resolve_spec_for_dtensor(dtensor) + if dtensor.grad_fn is not None: + return _DTensorToShardTensor.apply(dtensor, spec) + return _dtensor_to_shard_tensor(dtensor, spec) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - with annotate(f"__torch_function___{func.__name__}"): - # Check for overrides: - if func in cls._function_registry and cls._enable_shard_patches: - res = cls._function_registry[func](func, types, args, kwargs) - return res - elif ( - str(func) in cls._named_function_registry and cls._enable_shard_patches - ): - res = cls._named_function_registry[str(func)](func, types, args, kwargs) - return res - # Fall back to the default behavior, but promote any DTensor - # results back to ShardTensor (matching dispatch behavior): - result = super().__torch_function__(func, types, args, kwargs) - return cls._promote_dtensor_results(result, args) + if _conversion_active(): + # When converting shard tensor to dtensor, or dtensor to shard tensor, + # we just run the function without ShardTensor dispatch. + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + if func in cls._function_registry and cls._enable_shard_patches: + return cls._function_registry[func](func, types, args, kwargs) + if str(func) in cls._named_function_registry and cls._enable_shard_patches: + return cls._named_function_registry[str(func)](func, types, args, kwargs) + res = _torch_function_fallback_via_dtensor(func, args, kwargs) + return res @classmethod - @torch._disable_dynamo - @profile def __torch_dispatch__( cls, func: torch._ops.OpOverload, @@ -717,33 +867,14 @@ def __torch_dispatch__( args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, ) -> "ShardTensor" | Iterable["ShardTensor"] | object: - with annotate(f"__torch_dispatch___{func.__name__}"): - # Leverage DTensor Dispatch as much as possible, but, enable - # the ability to operate on this output in the future: - handler = cls._dispatch_registry.get(func) - if handler is None: - handler = cls._dispatch_registry_by_name.get(str(func)) - if handler is not None: - res = handler(*args, **kwargs) - return res - - # We assume that if we reach this point, the operator has not been - # intercepted by a wrapper or in the registry. So the DTensor - # default behavior is likely to be correct. - - # Convert ShardTensors to DTensors so DTensor's dispatcher - # receives the types it expects. - converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) - converted_kwargs = { - k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() - } - - dispatch_res = DTensor._op_dispatcher.dispatch( - func, converted_args, converted_kwargs - ) - - # Promote any DTensor results back to ShardTensor. - return cls._promote_dtensor_results(dispatch_res, args) + # Use a handler, if we have one: + handler = cls._dispatch_registry.get(func) + if handler is None: + handler = cls._dispatch_registry_by_name.get(str(func)) + if handler is not None: + return handler(*args, **kwargs) + # Otherwise, try the dtensor route: + return _dispatch_fallback_via_dtensor(func, args, kwargs) @staticmethod def from_local( @@ -962,9 +1093,43 @@ def backward(self, *args, **kwargs): if needs_redistribute: self = self.redistribute(placements=new_placements) + if self.grad_fn is not None: + return torch.Tensor.backward(self, *args, **kwargs) + return self.to_local().backward(*args, **kwargs) +class FSDPOutputTensorAdapter(nn.Module): + """Wrap a module and convert ShardTensor outputs to torch.Tensor.""" + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + out = self.module(*args, **kwargs) + return out.to_local() if isinstance(out, ShardTensor) else out + + +def wrap_for_fsdp(module: nn.Module) -> nn.Module: + """Return a module wrapper that exposes tensor outputs for FSDP hooks.""" + return FSDPOutputTensorAdapter(module) + + +def distribute_over_domain_for_fsdp( + module: nn.Module, + device_mesh: DeviceMesh, + partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None, +) -> nn.Module: + """Distribute a module over a domain mesh and adapt outputs for FSDP.""" + distributed_module = distribute_module( + module, + device_mesh=device_mesh, + partition_fn=partition_fn, + ) + return wrap_for_fsdp(distributed_module) + + def scatter_tensor( tensor: torch.Tensor, global_src: int, @@ -1044,12 +1209,15 @@ def scatter_tensor( # scatter along Shard dimensions. BUT, the focus is on performance of full applications # and this is a once-per-iteration cost. - # Broadcast the tensor to all ranks + # Broadcast the tensor to all ranks. + # scatter_tensor is an input-boundary utility; keep internal collectives/layout + # transforms out of autograd and construct the requested leaf explicitly. if tensor is None and not is_src: # Tensor is allowed to be none if not on the root rank tensor = torch.empty(local_meta.shape, dtype=local_meta.dtype, device=dm.device) - dist.broadcast(tensor, src=global_src, group=mesh_group) + with torch.no_grad(): + dist.broadcast(tensor, src=global_src, group=mesh_group) # Create a fully-replicated spec: spec = ShardTensorSpec( @@ -1059,18 +1227,30 @@ def scatter_tensor( _sharding_shapes={}, ) - # Make a "fully-replicated" tensor on all ranks: - st = ShardTensor.__new__( - ShardTensor, - local_tensor=tensor, - spec=spec, - requires_grad=requires_grad, - ) + with torch.no_grad(): + # Build a replicated ShardTensor and redistribute to the requested + # placements without recording autograd history. + st = ShardTensor.__new__( + ShardTensor, + local_tensor=tensor, + spec=spec, + requires_grad=False, + ) + st = st.redistribute(mesh, placements, async_op=False) - # Redistribute the tensor to the desired placements: - st = st.redistribute(mesh, placements, async_op=False) - # This is an unoptimal step but is functional: if requires_grad: - st = st.detach() - st.requires_grad = True + # 1. Ensure the local data is a clean leaf + local_leaf = st._local_tensor.detach().requires_grad_(True) + + # 2. Create the ShardTensor wrapper + st = ShardTensor.__new__( + ShardTensor, + local_tensor=local_leaf, + spec=st._spec, + requires_grad=True, + ) + + # 3. CRITICAL: Force the wrapper itself to be a leaf in the autograd graph + st = st.detach().requires_grad_(True) + return st diff --git a/physicsnemo/domain_parallel/shard_utils/__init__.py b/physicsnemo/domain_parallel/shard_utils/__init__.py index 4ce1bfc714..69b7370cb2 100644 --- a/physicsnemo/domain_parallel/shard_utils/__init__.py +++ b/physicsnemo/domain_parallel/shard_utils/__init__.py @@ -25,6 +25,11 @@ from physicsnemo.domain_parallel.shard_tensor import ShardTensor def register_shard_wrappers(): + """Import and register all shard-aware operation wrappers with ShardTensor. + + Each imported module registers its wrapper via + :meth:`ShardTensor.register_op` at import time. + """ from .attention_patches import sdpa_wrapper from .conv_patches import generic_conv_nd_wrapper from .index_ops import ( diff --git a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py index 4685a9a49f..5593a3d1ad 100644 --- a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py @@ -282,13 +282,13 @@ def backward( grad_weight = None grad_bias = None - if weight is not None and weight.requires_grad: + if weight is not None and ctx.needs_input_grad[3]: # grad_weight_c = sum_{n, spatial} grad_output * y (per-channel) y_c = y.view(N, C, HxW_local) grad_out_c = local_grad_output.view(N, C, HxW_local) grad_weight = (grad_out_c * y_c).sum(dim=(0, 2)) # (C,) - if bias is not None and bias.requires_grad: + if bias is not None and ctx.needs_input_grad[4]: grad_out_c = local_grad_output.view(N, C, HxW_local) grad_bias = grad_out_c.sum(dim=(0, 2)) # (C,) diff --git a/physicsnemo/domain_parallel/shard_utils/view_ops.py b/physicsnemo/domain_parallel/shard_utils/view_ops.py index 42120272ab..8d4fb50a5d 100644 --- a/physicsnemo/domain_parallel/shard_utils/view_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/view_ops.py @@ -625,7 +625,8 @@ def forward( Viewed ShardTensor. """ ctx.input_global_shape = tuple(tensor.shape) - return _sharded_view_forward(tensor, target_shape) + out = _sharded_view_forward(tensor, target_shape) + return out @staticmethod def backward( @@ -646,6 +647,7 @@ def backward( tuple[ShardTensor, None] Gradient for the input tensor, and ``None`` for ``target_shape``. """ + return ( _sharded_view_forward(grad_output, ctx.input_global_shape), None, diff --git a/test/domain_parallel/ops/test_convolution.py b/test/domain_parallel/ops/test_convolution.py index a20ea6efe6..410e6b9134 100644 --- a/test/domain_parallel/ops/test_convolution.py +++ b/test/domain_parallel/ops/test_convolution.py @@ -167,7 +167,7 @@ def test_conv_transpose_1d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -265,7 +265,7 @@ def test_conv_transpose_2d_1dmesh( 2, C_in, ( - H, + 2 * H, H, ), device=dm.device, @@ -293,7 +293,7 @@ def test_conv_transpose_2d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -405,8 +405,8 @@ def test_conv_transpose_2d_2dmesh( 2, C_in, ( - H, - H, + 2 * H, + 2 * H, ), device=dm.device, ) diff --git a/test/domain_parallel/ops/test_view_ops.py b/test/domain_parallel/ops/test_view_ops.py index 7094074f2c..dd1a2a7d8d 100644 --- a/test/domain_parallel/ops/test_view_ops.py +++ b/test/domain_parallel/ops/test_view_ops.py @@ -534,20 +534,23 @@ def test_view_trailing_dims_1d_to_3d( distributed_mesh, backward, ): - """Test view (6,) -> (2, 3, 1) with Shard(0): trailing dim must stay in group. + """Test view (48,) -> (8, 6, 1) with Shard(0): trailing singleton in target. - With the shard on dim 0, each rank has a contiguous chunk of the 1D tensor. - The target shape has a trailing singleton (2, 3, 1). The trailing dimension - must be included in the same dimension group so that the local element - count is correct (product of local shape equals chunk_size). Without that, - the old code produced wrong local shapes (e.g. product 4 instead of 2 or 3). + The 1D tensor is sharded on dim 0. The target shape has a trailing + singleton ``(8, 6, 1)`` that falls outside the dimension group matched + by ``_match_view_dim_groups`` (which pairs ``(48,)`` with ``(8, 6)``). + The trailing ``1`` must be carried through unchanged in the local shape + so that ``product(local_shape) == chunk_size``. + + We use a tensor size (48) that divides cleanly across 2-, 4-, and 8-GPU + meshes so that every rank's chunk aligns to a row boundary in ``(8, 6)``. """ if not torch.cuda.is_available(): pytest.skip("CUDA is not available") dm = DistributedManager() - shape = (6,) - target_shape = (2, 3, 1) + shape = (48,) + target_shape = (8, 6, 1) original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) diff --git a/test/domain_parallel/ops/utils.py b/test/domain_parallel/ops/utils.py index de8052c93c..fcf04f9c92 100644 --- a/test/domain_parallel/ops/utils.py +++ b/test/domain_parallel/ops/utils.py @@ -169,7 +169,9 @@ def unparallelize_module(module): This function is for testing purposes only. Do not use in production code. """ for name, param in list(module._parameters.items()): - if isinstance(param, torch.nn.Parameter) and isinstance(param.data, DTensor): + if isinstance(param, torch.nn.Parameter) and isinstance( + param.data, (ShardTensor, DTensor) + ): # gather to replicated then unwrap local_tensor = param.data.full_tensor() # replace with a normal Parameter diff --git a/test/domain_parallel/test_grad_sharding.py b/test/domain_parallel/test_grad_sharding.py index 871782459b..167044159d 100644 --- a/test/domain_parallel/test_grad_sharding.py +++ b/test/domain_parallel/test_grad_sharding.py @@ -275,7 +275,7 @@ def run_dtensor_to_shard_tensor_non_leaf_gradient(mesh): loss_ref.backward() assert dt.grad is not None - assert isinstance(dt.grad, DTensor) + assert isinstance(dt.grad, (ShardTensor, DTensor)) assert torch.allclose(dt.grad.full_tensor(), ref.grad) diff --git a/test/domain_parallel/test_initialization.py b/test/domain_parallel/test_initialization.py index 5c5c8fbf02..d6cd7d054b 100644 --- a/test/domain_parallel/test_initialization.py +++ b/test/domain_parallel/test_initialization.py @@ -121,6 +121,101 @@ def init_from_data_rank_worker(mesh): assert dim == local_data.shape[i] +def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool): + r"""Validate scatter_tensor construction contract for requires_grad modes.""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=requires_grad, + ) + + assert st.requires_grad is requires_grad + if requires_grad: + assert st.is_leaf + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_1d(distributed_mesh, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh, requires_grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_2d(distributed_mesh_2d, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh_2d, requires_grad) + + +def scatter_tensor_grad_population_worker(mesh): + r"""Validate that gradients populate for scatter_tensor(..., requires_grad=True).""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=True, + ) + assert st.is_leaf + assert st.requires_grad + + reference = st.full_tensor().detach().requires_grad_(True) + reference_loss = (reference**2).sum() + reference_loss.backward() + + st2 = st**2 + sharded_loss = st2.sum() + sharded_loss.backward() + + assert st.grad is not None + assert st.grad._spec.placements == st._spec.placements + assert st.grad._spec.sharding_shapes() == st._spec.sharding_shapes() + assert torch.allclose(st.grad.full_tensor(), reference.grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_1d(distributed_mesh): + scatter_tensor_grad_population_worker(distributed_mesh) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_2d(distributed_mesh_2d): + scatter_tensor_grad_population_worker(distributed_mesh_2d) + + @pytest.mark.timeout(10) @pytest.mark.multigpu_static def test_shard_tensor_initialization_from_data_rank_1d(distributed_mesh, verbose=False): @@ -162,8 +257,6 @@ def shard_tensor_initialization_from_all_dtensor_worker(mesh): st = ShardTensor.from_dtensor(dt) - print(f"Rank {dm.rank} made shard tensors.") - dt_full = dt.full_tensor() st_full = st.full_tensor() diff --git a/test/domain_parallel/test_reductions.py b/test/domain_parallel/test_reductions.py index 8cb8931e45..2145f48af1 100644 --- a/test/domain_parallel/test_reductions.py +++ b/test/domain_parallel/test_reductions.py @@ -118,6 +118,10 @@ def test_shard_tensor_reduction( requires_grad=backward, ) + # if backward: + # assert shard_tensor.is_leaf + # assert shard_tensor.requires_grad + if verbose: print( f"Shard tensor global shape: {shard_tensor.shape} and local shape: {shard_tensor._local_tensor.shape}"