diff --git a/physicsnemo/domain_parallel/shard_utils/padding.py b/physicsnemo/domain_parallel/shard_utils/padding.py index a5aa1d215a..f61e3e4dac 100644 --- a/physicsnemo/domain_parallel/shard_utils/padding.py +++ b/physicsnemo/domain_parallel/shard_utils/padding.py @@ -209,7 +209,7 @@ def generic_pad_nd_wrapper( mesh_coords = list(self_mesh_coords) mesh_coords[mesh_dim] = i output_shape, local_padding = compute_local_padding_and_output_shape( - local_input.shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map + local_shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map ) # Catch and cache the one that applies to this rank: diff --git a/test/domain_parallel/ops/test_padding.py b/test/domain_parallel/ops/test_padding.py index 09ed328516..659a851adf 100644 --- a/test/domain_parallel/ops/test_padding.py +++ b/test/domain_parallel/ops/test_padding.py @@ -24,10 +24,11 @@ import pytest import torch +from torch.distributed.tensor import distribute_module from torch.distributed.tensor.placement_types import Shard from physicsnemo.distributed import DistributedManager -from physicsnemo.domain_parallel import scatter_tensor +from physicsnemo.domain_parallel import ShardTensor, scatter_tensor from .utils import generate_image_like_data, numerical_shard_tensor_check @@ -161,3 +162,63 @@ def test_padded_convolution_2d_1dmesh(distributed_mesh, padding_mode, backward): numerical_shard_tensor_check( distributed_mesh, module, [sharded_image], {}, check_grads=backward ) + + +class test_uneven_pad_functionals(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.nn.functional.pad(x, (0, 0, 2, 0), mode="replicate") + x = torch.nn.functional.pad(x, (0, 0, 0, 2), mode="constant", value=1.0) + return x + + +@pytest.mark.multigpu_static +@pytest.mark.parametrize("backward", [False, True]) +def test_uneven_pad_2d_1dmesh(distributed_mesh, backward): + H = 256 + C_in = 8 + + dm = DistributedManager() + n_gpus = dm.world_size + + image = generate_image_like_data(2, C_in, (H, H)).to(dm.device) + + placements = (Shard(2),) + + sharded_image = scatter_tensor( + image, 0, distributed_mesh, placements, requires_grad=backward + ) + # Local reference tensors that should be the same shape as the shards + if dm.rank in [0, n_gpus - 1]: + local_ref = torch.ones(2, C_in, H // n_gpus + 2, H).to(dm.device) + else: + local_ref = torch.ones(2, C_in, H // n_gpus, H).to(dm.device) + + local_ref_sharded = ShardTensor.from_local( + local_ref, distributed_mesh, (Shard(dim=2),), sharding_shapes="infer" + ) + + dist_test_pad = distribute_module( + test_uneven_pad_functionals().to(dm.device), device_mesh=distributed_mesh + ) + + sharded_image_pad = dist_test_pad(sharded_image) + + _ = local_ref_sharded + sharded_image_pad + + numerical_shard_tensor_check( + distributed_mesh, + test_uneven_pad_functionals(), + [sharded_image], + {}, + check_grads=backward, + ) + + full_image_pad = torch.nn.functional.pad(image, (0, 0, 2, 0), mode="replicate") + full_image_pad = torch.nn.functional.pad( + full_image_pad, (0, 0, 0, 2), mode="constant", value=1.0 + ) + + assert torch.allclose(sharded_image_pad.full_tensor(), full_image_pad)