diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index 0e3706475a..3d762c054a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -71,6 +71,7 @@ NormalizeVectors, Purge, Rename, + Resize, Scale, SubsamplePoints, Transform, @@ -109,6 +110,7 @@ "CreateGrid", "KNearestNeighbors", "CenterOfMass", + "Resize", # Transforms - Utility "Rename", "Purge", diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index aba2221b43..b1ee5f3fa3 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -18,6 +18,8 @@ NumpyReader - Read data from NumPy .npz files. Supports reading from single .npz files or directories of .npz files. +In single-file mode, optional ``preload_to_cpu=True`` loads the entire +dataset into RAM at init for faster iteration with no per-sample I/O. """ from __future__ import annotations @@ -38,8 +40,12 @@ class NumpyReader(Reader): Read samples from NumPy .npz files. Supports two modes: - 1. Single .npz file: samples indexed along first dimension of each array - 2. Directory of .npz files: one sample per file + + 1. **Single .npz file**: Samples are indexed along the first dimension + of each array. Optionally, ``preload_to_cpu=True`` loads all arrays + into RAM at init so iteration does no disk I/O. + 2. **Directory of .npz files**: One sample per file; each file is opened + on demand. Example (single .npz): >>> # data.npz with arrays "positions" (N, 100, 3), "features" (N, 100) @@ -52,6 +58,10 @@ class NumpyReader(Reader): >>> # Directory with sample_0.npz, sample_1.npz, ... >>> reader = NumpyReader("data_dir/", file_pattern="sample_*.npz") # doctest: +SKIP >>> data, metadata = reader[0] # Returns (TensorDict, dict) tuple # doctest: +SKIP + + Example (single .npz with preload): + >>> reader = NumpyReader("data.npz", preload_to_cpu=True) # doctest: +SKIP + >>> # All arrays loaded into RAM at init; no disk I/O during iteration """ def __init__( @@ -65,6 +75,7 @@ def __init__( pin_memory: bool = False, include_index_in_metadata: bool = True, coordinated_subsampling: Optional[dict[str, Any]] = None, + preload_to_cpu: bool = False, ) -> None: """ Initialize the NumPy reader. @@ -93,6 +104,11 @@ def __init__( Optional dict to configure coordinated subsampling (directory mode only). If provided, must contain ``n_points`` (int) and ``target_keys`` (list of str). + preload_to_cpu : bool, default=False + If True, in single-file mode the reader loads all requested + arrays into RAM at init, closes the file, and serves samples + from memory. Use when the dataset fits in RAM and you want + to avoid disk I/O during training. Ignored in directory mode. Raises ------ @@ -100,6 +116,9 @@ def __init__( If path doesn't exist. ValueError If no files found in directory or unsupported file type. + KeyError + If preload_to_cpu is True and a required field is missing + from the file (and not in default_values). """ super().__init__( pin_memory=pin_memory, @@ -112,14 +131,17 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key + self.preload_to_cpu = preload_to_cpu if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") - # Determine mode based on path - self._mode: str # "single" or "directory" + # Mode: "single" (one .npz, samples along first dim) or "directory" + self._mode: str self._files: Optional[list[Path]] = None self._data: Optional[np.lib.npyio.NpzFile] = None + # When preload_to_cpu: in-memory arrays keyed by field name (single-file only) + self._preloaded: Optional[dict[str, np.ndarray]] = None self._available_fields: list[str] = [] if self.path.is_dir(): @@ -147,12 +169,12 @@ def _setup_directory_mode(self) -> None: self._available_fields = list(npz.files) def _setup_single_file_mode(self) -> None: - """Set up reader for single .npz file.""" + """Set up reader for a single .npz file; optionally preload all arrays to RAM.""" self._mode = "single" self._data = np.load(self.path) self._available_fields = list(self._data.files) - # Determine length from index_key or first field + # Sample count is the first dimension of index_key or of the first array if self.index_key is not None: self._length = self._data[self.index_key].shape[0] elif self._available_fields: @@ -160,6 +182,24 @@ def _setup_single_file_mode(self) -> None: else: self._length = 0 + # Optional: load entire dataset into RAM and close the file + if self.preload_to_cpu: + required = set(self.fields) - set(self.default_values.keys()) + missing = required - set(self._data.files) + if missing: + raise KeyError( + f"Required fields {missing} not found in {self.path}. " + f"Available: {list(self._data.files)}" + ) + self._preloaded = {} + for field in self.fields: + if field in self._data.files: + # .copy() forces a real array; np.array() ensures contiguous + self._preloaded[field] = np.array(self._data[field].copy()) + if hasattr(self._data, "close"): + self._data.close() + self._data = None + @property def fields(self) -> list[str]: """Fields that will be loaded (user-specified or all available).""" @@ -275,20 +315,66 @@ def _load_from_npz( # Directory mode: load full array arr = arr[:] - data[field] = torch.from_numpy(np.array(arr)) + data[field] = torch.from_numpy(np.asarray(arr, dtype=np.float32)) elif field in self.default_values: - data[field] = self.default_values[field].clone() + data[field] = self.default_values[field].clone().float() + + return data + + def _load_sample_from_preloaded(self, index: int) -> dict[str, torch.Tensor]: + """ + Load a single sample by indexing into preloaded in-memory arrays. + + Used only when ``preload_to_cpu=True`` in single-file mode. Applies + coordinated subsampling (random contiguous slice) when configured. + + Parameters + ---------- + index : int + Sample index along the first dimension of each preloaded array. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary mapping field names to CPU tensors for this sample. + """ + data = {} + fields_to_load = self.fields + target_keys_set = set() + subsample_slice = None + + # If subsampling is enabled, pick one random contiguous slice for this sample + if self._coordinated_subsampling_config is not None: + n_points = self._coordinated_subsampling_config["n_points"] + target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) + for field in target_keys_set: + if field in self._preloaded: + arr = self._preloaded[field][index] + subsample_slice = self._select_random_sections_from_slice( + 0, arr.shape[0], n_points + ) + break + for field in fields_to_load: + if field in self._preloaded: + arr = np.array(self._preloaded[field][index], copy=False) + if subsample_slice is not None and field in target_keys_set: + arr = arr[subsample_slice] + data[field] = torch.from_numpy(np.asarray(arr, dtype=np.float32)) + elif field in self.default_values: + data[field] = self.default_values[field].clone().float() return data def _load_sample(self, index: int) -> dict[str, torch.Tensor]: - """Load a single sample.""" + """Load a single sample from disk or from preloaded RAM.""" if self._mode == "directory": file_path = self._files[index] with np.load(file_path) as npz: return self._load_from_npz(npz, index=None, file_path=file_path) - else: # single + elif self._preloaded is not None: + return self._load_sample_from_preloaded(index) + else: return self._load_from_npz(self._data, index=index) def __len__(self) -> int: @@ -318,12 +404,13 @@ def _supports_coordinated_subsampling(self) -> bool: return self._mode == "directory" def close(self) -> None: - """Close file handles.""" + """Close file handles and release preloaded in-memory arrays (if any).""" super().close() if self._data is not None: if hasattr(self._data, "close"): self._data.close() self._data = None + self._preloaded = None def __repr__(self) -> str: subsample_info = "" @@ -331,11 +418,13 @@ def __repr__(self) -> str: cfg = self._coordinated_subsampling_config subsample_info = f", subsampling={cfg['n_points']} points" + preload_info = ", preload_to_cpu=True" if self._preloaded is not None else "" return ( f"NumpyReader(" f"path={self.path}, " f"mode={self._mode}, " f"len={len(self)}, " f"fields={self.fields}" - f"{subsample_info})" + f"{subsample_info}" + f"{preload_info})" ) diff --git a/physicsnemo/datapipes/transforms/__init__.py b/physicsnemo/datapipes/transforms/__init__.py index 963b4b0985..8814ff5e1b 100644 --- a/physicsnemo/datapipes/transforms/__init__.py +++ b/physicsnemo/datapipes/transforms/__init__.py @@ -45,6 +45,7 @@ CenterOfMass, CreateGrid, KNearestNeighbors, + Resize, ) from physicsnemo.datapipes.transforms.subsample import ( SubsamplePoints, @@ -55,6 +56,7 @@ ConstantField, Purge, Rename, + Reshape, ) __all__ = [ @@ -83,8 +85,10 @@ "CreateGrid", "KNearestNeighbors", "CenterOfMass", + "Resize", # Utility "Rename", "Purge", "ConstantField", + "Reshape", ] diff --git a/physicsnemo/datapipes/transforms/normalize.py b/physicsnemo/datapipes/transforms/normalize.py index 97413a9688..21ea71ce45 100644 --- a/physicsnemo/datapipes/transforms/normalize.py +++ b/physicsnemo/datapipes/transforms/normalize.py @@ -20,6 +20,7 @@ from __future__ import annotations +import collections.abc import warnings from pathlib import Path from typing import Any, Literal, Optional @@ -196,7 +197,7 @@ def _process_stats_dict( """Process statistics into dict of tensors for each field.""" result: dict[str, torch.Tensor] = {} - if isinstance(stats, dict): + if isinstance(stats, collections.abc.Mapping): for key in self.input_keys: if key not in stats: raise ValueError( diff --git a/physicsnemo/datapipes/transforms/spatial.py b/physicsnemo/datapipes/transforms/spatial.py index c791d63333..57114f76e9 100644 --- a/physicsnemo/datapipes/transforms/spatial.py +++ b/physicsnemo/datapipes/transforms/spatial.py @@ -23,9 +23,10 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Tuple import torch +import torch.nn.functional as F from tensordict import TensorDict from physicsnemo.datapipes.registry import register @@ -523,3 +524,140 @@ def __repr__(self) -> str: return ( f"CenterOfMass(coords_key={self.coords_key}, output_key={self.output_key})" ) + + +@register() +class Resize(Transform): + r""" + Resize a set of grid tensors via interpolation. + + Applies spatial resizing to tensors identified by ``input_keys`` using + :func:`torch.nn.functional.interpolate`. Transforms operate on + single-sample data (no batch dimension). Supports 2D tensors + :math:`(C, H, W)` and 3D tensors :math:`(C, D, H, W)`. + + Parameters + ---------- + input_keys : list[str] + Keys of tensors to resize. Each tensor must have shape + :math:`(C, H, W)` for 2D or :math:`(C, D, H, W)` for 3D. + size : tuple[int, ...] + Target spatial size. For 2D use :math:`(H, W)`, for 3D use + :math:`(D, H, W)`. + mode : str, optional + Interpolation mode. One of ``"nearest"``, ``"bilinear"``, + ``"bicubic"`` (2D only), ``"trilinear"`` (3D only), ``"area"``. + Default is ``"bilinear"`` for 2D and ``"trilinear"`` for 3D. + align_corners : bool, optional + Used for ``"bilinear"``, ``"bicubic"``, ``"trilinear"``. + See :func:`torch.nn.functional.interpolate`. Default is ``False``. + + Examples + -------- + >>> transform = Resize( + ... input_keys=["pressure", "velocity"], + ... size=(64, 64), + ... mode="bilinear", + ... ) + >>> sample = TensorDict({ + ... "pressure": torch.randn(1, 128, 128), + ... "velocity": torch.randn(2, 128, 128), + ... }) + >>> result = transform(sample) + >>> result["pressure"].shape + torch.Size([1, 64, 64]) + >>> result["velocity"].shape + torch.Size([2, 64, 64]) + """ + + def __init__( + self, + input_keys: list[str], + size: Tuple[int, ...], + *, + mode: Optional[str] = None, + align_corners: bool = False, + ) -> None: + """ + Initialize the resize transform. + + Parameters + ---------- + input_keys : list[str] + Keys of tensors to resize. + size : tuple[int, ...] + Target spatial size, e.g. :math:`(H, W)` or :math:`(D, H, W)`. + mode : str, optional + Interpolation mode. Defaults by spatial dims: ``"bilinear"`` (2D), + ``"trilinear"`` (3D). + align_corners : bool, optional + Passed to :func:`torch.nn.functional.interpolate`. Default ``False``. + """ + super().__init__() + self.input_keys = input_keys + self.size = tuple(size) + ndim = len(self.size) + if ndim == 2: + self._default_mode: str = "bilinear" + elif ndim == 3: + self._default_mode = "trilinear" + else: + raise ValueError(f"size must have 2 or 3 spatial dimensions, got {ndim}") + self.mode = mode if mode is not None else self._default_mode + self.align_corners = align_corners + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Resize each tensor in ``input_keys`` to the target spatial size. + + Parameters + ---------- + data : TensorDict + Input TensorDict containing grid tensors to resize. + + Returns + ------- + TensorDict + TensorDict with resized tensors in place of originals. + """ + n_spatial = len(self.size) + # Single-sample only: (C, H, W) or (C, D, H, W); also accept (H, W) / (D, H, W) as single-channel + expected_ndim_with_channel = n_spatial + 1 # channel + spatial + + interp_kw: dict = {"size": self.size, "mode": self.mode} + if self.mode not in ("nearest", "area"): + interp_kw["align_corners"] = self.align_corners + + updates = {} + for key in self.input_keys: + if key not in data: + continue + t = data[key] + if not isinstance(t, torch.Tensor) or not t.is_floating_point(): + continue + ndim = t.ndim + if ndim == n_spatial: + # (H, W) or (D, H, W): treat as single-channel for interpolate + t = t.unsqueeze(0) + elif ndim != expected_ndim_with_channel: + continue + # Add batch dim for F.interpolate, then remove; restore to original ndim if we added channel + out = F.interpolate(t.unsqueeze(0), **interp_kw).squeeze(0) + if ndim == n_spatial: + out = out.squeeze(0) + updates[key] = out + return data.update(updates) + + def __repr__(self) -> str: + """ + Return string representation. + + Returns + ------- + str + String representation of the transform. + """ + return ( + f"Resize(input_keys={self.input_keys}, size={self.size}, " + f"mode={self.mode!r})" + ) diff --git a/physicsnemo/datapipes/transforms/utility.py b/physicsnemo/datapipes/transforms/utility.py index 962c1de158..9dd03da424 100644 --- a/physicsnemo/datapipes/transforms/utility.py +++ b/physicsnemo/datapipes/transforms/utility.py @@ -23,7 +23,7 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Sequence import torch from tensordict import TensorDict @@ -487,3 +487,93 @@ def extra_repr(self) -> str: f"fill_value={self.fill_value}, " f"output_dim={self.output_dim}" ) + + +@register() +class Reshape(Transform): + r""" + Reshape specified TensorDict fields to target shapes. + + Applies :func:`torch.reshape` so each specified field gets the given shape. + At most one dimension in the shape may be ``-1``, which is inferred from + the tensor's element count. Useful to unify layouts across datasets (e.g. + :math:`(1, H, W)` to :math:`(H, W)`) or to flatten/spread dimensions. + + Parameters + ---------- + keys : list[str] + TensorDict keys to reshape. Only these keys are modified; others are + left unchanged. + shape : tuple[int, ...] or list[int] + Target shape for all specified keys. Use ``-1`` for at most one + dimension to infer from the tensor size. + + Examples + -------- + Drop a leading singleton dimension (e.g. single-channel image): + + >>> transform = Reshape(keys=["y"], shape=(256, 256)) + >>> data = TensorDict({"x": torch.randn(256, 256), "y": torch.randn(1, 256, 256)}) + >>> result = transform(data) + >>> result["y"].shape + torch.Size([256, 256]) + + Flatten spatial dimensions: + + >>> transform = Reshape(keys=["features"], shape=(-1,)) + >>> data = TensorDict({"features": torch.randn(4, 8, 8)}) + >>> transform(data)["features"].shape + torch.Size([256]) + """ + + def __init__( + self, + keys: list[str], + shape: Sequence[int], + ) -> None: + """ + Initialize the reshape transform. + + Parameters + ---------- + keys : list[str] + TensorDict keys to reshape. + shape : tuple or list of int + Target shape. At most one entry may be -1 (inferred). + """ + super().__init__() + self.keys = list(keys) + self.shape = tuple(int(s) for s in shape) + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Reshape specified fields in the TensorDict. + + Parameters + ---------- + data : TensorDict + Input TensorDict. + + Returns + ------- + TensorDict + TensorDict with reshaped tensors for the specified keys. + Keys not present in the data are skipped. + """ + out = data.clone() + for key in self.keys: + if key not in out.keys(): + continue + out[key] = out[key].reshape(self.shape) + return out + + def extra_repr(self) -> str: + """ + Return extra information for repr. + + Returns + ------- + str + String with transform parameters. + """ + return f"keys={self.keys}, shape={self.shape}" diff --git a/test/datapipes/core/test_transforms.py b/test/datapipes/core/test_transforms.py index 6048ecf5c9..d3dbcbecf6 100644 --- a/test/datapipes/core/test_transforms.py +++ b/test/datapipes/core/test_transforms.py @@ -807,6 +807,23 @@ def test_normalize_repr(): # assert "100" in repr(ds) +def test_normalize_accepts_ordered_dict_stats(): + """Test that Normalize accepts collections.abc.Mapping (e.g. OrderedDict) for stats.""" + from collections import OrderedDict + + sample = TensorDict({"x": torch.tensor([10.0, 20.0, 30.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="mean_std", + means=OrderedDict([("x", 20.0)]), + stds=OrderedDict([("x", 10.0)]), + ) + + result = norm(sample) + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + def test_compose_repr(): pipeline = dp.Compose( [ diff --git a/test/datapipes/readers/test_numpy_consolidated.py b/test/datapipes/readers/test_numpy_consolidated.py index c564ef5534..52358bdaf1 100644 --- a/test/datapipes/readers/test_numpy_consolidated.py +++ b/test/datapipes/readers/test_numpy_consolidated.py @@ -283,5 +283,232 @@ def test_close_handles(self): reader2.close() +class TestNumpyReaderPreload: + """Tests for preload_to_cpu functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_preload_basic(self): + """Test that preload_to_cpu loads data into RAM and closes the file.""" + coords = np.random.randn(15, 3).astype(np.float32) + features = np.random.randn(15, 4).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=True + ) + + assert reader._data is None + assert reader._preloaded is not None + assert "coords" in reader._preloaded + assert "features" in reader._preloaded + assert len(reader) == 15 + + data, metadata = reader[0] + assert data["coords"].shape == (3,) + assert data["features"].shape == (4,) + torch.testing.assert_close( + data["coords"], torch.from_numpy(coords[0]), atol=1e-6, rtol=1e-6 + ) + + def test_preload_matches_non_preloaded(self): + """Test that preloaded data matches non-preloaded data.""" + np.random.seed(42) + coords = np.random.randn(10, 3).astype(np.float32) + features = np.random.randn(10, 5).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader_disk = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=False + ) + reader_ram = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=True + ) + + for i in range(len(reader_disk)): + data_disk, _ = reader_disk[i] + data_ram, _ = reader_ram[i] + torch.testing.assert_close(data_disk["coords"], data_ram["coords"]) + torch.testing.assert_close(data_disk["features"], data_ram["features"]) + + reader_disk.close() + reader_ram.close() + + def test_preload_with_default_values(self): + """Test preload with default values for missing fields.""" + coords = np.random.randn(10, 100, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + default_normals = torch.ones(100, 3, dtype=torch.float64) + reader = NumpyReader( + npz_path, + fields=["coords", "normals"], + default_values={"normals": default_normals}, + preload_to_cpu=True, + ) + + data, _ = reader[0] + assert "normals" in data + assert data["normals"].dtype == torch.float32 + reader.close() + + def test_preload_missing_required_field_raises(self): + """Test that preload raises KeyError for missing required fields.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + with pytest.raises(KeyError, match="Required fields"): + NumpyReader( + npz_path, + fields=["coords", "missing_field"], + preload_to_cpu=True, + ) + + def test_preload_ignored_in_directory_mode(self): + """Test that preload_to_cpu is ignored in directory mode.""" + for i in range(3): + coords = np.random.randn(50, 3).astype(np.float32) + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords"], + preload_to_cpu=True, + ) + + assert reader._preloaded is None + assert len(reader) == 3 + + data, _ = reader[0] + assert data["coords"].shape == (50, 3) + reader.close() + + def test_preload_with_coordinated_subsampling(self): + """Test preloaded reader with coordinated subsampling.""" + n_samples = 5 + n_points = 1000 + subsample_points = 100 + + coords = np.random.randn(n_samples, n_points, 3).astype(np.float32) + features = np.random.randn(n_samples, n_points, 4).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader = NumpyReader( + npz_path, + fields=["coords", "features"], + preload_to_cpu=True, + coordinated_subsampling={ + "n_points": subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + assert reader._preloaded is not None + data, _ = reader[0] + assert data["coords"].shape == (subsample_points, 3) + assert data["features"].shape == (subsample_points, 4) + reader.close() + + def test_preload_close_releases_memory(self): + """Test that close() releases preloaded arrays.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"], preload_to_cpu=True) + assert reader._preloaded is not None + + reader.close() + assert reader._preloaded is None + + def test_preload_repr(self): + """Test that repr includes preload_to_cpu info.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"], preload_to_cpu=True) + assert "preload_to_cpu=True" in repr(reader) + reader.close() + + +class TestNumpyReaderFloat32: + """Tests for float32 conversion behavior.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_float64_converted_to_float32(self): + """Test that float64 numpy arrays are returned as float32 tensors.""" + coords = np.random.randn(10, 3).astype(np.float64) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"]) + data, _ = reader[0] + assert data["coords"].dtype == torch.float32 + reader.close() + + def test_float64_converted_to_float32_directory_mode(self): + """Test float64 conversion in directory mode.""" + for i in range(3): + coords = np.random.randn(50, 3).astype(np.float64) + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader( + self.temp_path, file_pattern="sample_*.npz", fields=["coords"] + ) + data, _ = reader[0] + assert data["coords"].dtype == torch.float32 + reader.close() + + def test_default_values_converted_to_float32(self): + """Test that default values are returned as float32.""" + coords = np.random.randn(10, 100, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + default_normals = torch.zeros(100, 3, dtype=torch.float64) + reader = NumpyReader( + npz_path, + fields=["coords", "normals"], + default_values={"normals": default_normals}, + ) + + data, _ = reader[0] + assert data["normals"].dtype == torch.float32 + reader.close() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/datapipes/transforms/test_spatial.py b/test/datapipes/transforms/test_spatial.py index f33d33d206..11ad4aa774 100644 --- a/test/datapipes/transforms/test_spatial.py +++ b/test/datapipes/transforms/test_spatial.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for spatial transforms (BoundingBoxFilter, CreateGrid, KNearestNeighbors, CenterOfMass).""" +"""Tests for spatial transforms (BoundingBoxFilter, CreateGrid, KNearestNeighbors, CenterOfMass, Resize).""" import pytest import torch @@ -25,6 +25,7 @@ CenterOfMass, CreateGrid, KNearestNeighbors, + Resize, ) # ============================================================================ @@ -604,6 +605,125 @@ def test_repr(self): assert "center_of_mass" in repr_str +# ============================================================================ +# Resize Tests +# ============================================================================ + + +class TestResize: + """Tests for Resize transform.""" + + def test_resize_2d_basic(self): + """Test basic 2D resizing of (C, H, W) tensor.""" + transform = Resize(input_keys=["pressure"], size=(32, 32)) + data = TensorDict({"pressure": torch.randn(1, 128, 128)}) + + result = transform(data) + assert result["pressure"].shape == (1, 32, 32) + + def test_resize_3d_basic(self): + """Test basic 3D resizing of (C, D, H, W) tensor.""" + transform = Resize(input_keys=["field"], size=(8, 16, 16)) + data = TensorDict({"field": torch.randn(3, 32, 64, 64)}) + + result = transform(data) + assert result["field"].shape == (3, 8, 16, 16) + + def test_resize_no_channel_dim(self): + """Test resizing an (H, W) tensor without channel dimension.""" + transform = Resize(input_keys=["image"], size=(16, 16)) + data = TensorDict({"image": torch.randn(64, 64)}) + + result = transform(data) + assert result["image"].shape == (16, 16) + + def test_resize_default_mode_2d(self): + """Test that default mode for 2D size is bilinear.""" + transform = Resize(input_keys=["x"], size=(32, 32)) + assert transform.mode == "bilinear" + + def test_resize_default_mode_3d(self): + """Test that default mode for 3D size is trilinear.""" + transform = Resize(input_keys=["x"], size=(8, 8, 8)) + assert transform.mode == "trilinear" + + def test_resize_nearest_mode(self): + """Test explicit nearest mode (no align_corners needed).""" + transform = Resize(input_keys=["x"], size=(16, 16), mode="nearest") + data = TensorDict({"x": torch.randn(2, 64, 64)}) + + result = transform(data) + assert result["x"].shape == (2, 16, 16) + + def test_resize_align_corners(self): + """Test align_corners=True with bilinear mode.""" + transform = Resize( + input_keys=["x"], size=(16, 16), mode="bilinear", align_corners=True + ) + data = TensorDict({"x": torch.randn(1, 64, 64)}) + + result = transform(data) + assert result["x"].shape == (1, 16, 16) + + def test_resize_missing_key_skipped(self): + """Test that missing keys are silently skipped.""" + transform = Resize(input_keys=["missing"], size=(16, 16)) + original = torch.randn(1, 64, 64) + data = TensorDict({"present": original.clone()}) + + result = transform(data) + assert "present" in result + torch.testing.assert_close(result["present"], original) + + def test_resize_non_float_skipped(self): + """Test that integer tensors are skipped.""" + transform = Resize(input_keys=["mask"], size=(16, 16)) + int_tensor = torch.randint(0, 2, (1, 64, 64)) + data = TensorDict({"mask": int_tensor}) + + result = transform(data) + assert result["mask"].shape == (1, 64, 64) + + def test_resize_invalid_size_dims(self): + """Test that invalid size dimensions raise ValueError.""" + with pytest.raises(ValueError, match="2 or 3 spatial dimensions"): + Resize(input_keys=["x"], size=(16,)) + + with pytest.raises(ValueError, match="2 or 3 spatial dimensions"): + Resize(input_keys=["x"], size=(4, 8, 16, 32)) + + def test_resize_preserves_other_fields(self): + """Test that non-input fields are untouched.""" + transform = Resize(input_keys=["field"], size=(16, 16)) + other = torch.randn(50, 3) + data = TensorDict({"field": torch.randn(1, 64, 64), "other": other.clone()}) + + result = transform(data) + assert result["field"].shape == (1, 16, 16) + torch.testing.assert_close(result["other"], other) + + def test_resize_multiple_keys(self): + """Test resizing multiple input keys.""" + transform = Resize(input_keys=["pressure", "velocity"], size=(32, 32)) + data = TensorDict( + { + "pressure": torch.randn(1, 128, 128), + "velocity": torch.randn(2, 128, 128), + } + ) + + result = transform(data) + assert result["pressure"].shape == (1, 32, 32) + assert result["velocity"].shape == (2, 32, 32) + + def test_resize_repr(self): + """Test string representation.""" + transform = Resize(input_keys=["pressure"], size=(64, 64), mode="bilinear") + repr_str = repr(transform) + assert "Resize" in repr_str + assert "bilinear" in repr_str + + # ============================================================================ # Integration Tests # ============================================================================ diff --git a/test/datapipes/transforms/test_utility.py b/test/datapipes/transforms/test_utility.py index f542ebdb96..91f7b8a342 100644 --- a/test/datapipes/transforms/test_utility.py +++ b/test/datapipes/transforms/test_utility.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for utility transforms: Rename, Purge, ConstantField, and ZeroLike.""" +"""Tests for utility transforms: Rename, Purge, ConstantField, Reshape, and ZeroLike.""" import pytest import torch from tensordict import TensorDict -from physicsnemo.datapipes.transforms import ConstantField, Purge, Rename +from physicsnemo.datapipes.transforms import ConstantField, Purge, Rename, Reshape class TestRename: @@ -809,3 +809,64 @@ def test_extra_repr(self): assert "output_key" in repr_str assert "fill_value" in repr_str assert "output_dim" in repr_str + + +class TestReshape: + """Tests for the Reshape transform.""" + + def test_reshape_basic(self): + """Test basic reshape from (1, H, W) to (H, W).""" + data = TensorDict({"y": torch.randn(1, 256, 256)}) + transform = Reshape(keys=["y"], shape=(256, 256)) + + result = transform(data) + assert result["y"].shape == torch.Size([256, 256]) + + def test_reshape_with_inferred_dim(self): + """Test reshape with -1 for inferred dimension.""" + data = TensorDict({"features": torch.randn(4, 8, 8)}) + transform = Reshape(keys=["features"], shape=(-1,)) + + result = transform(data) + assert result["features"].shape == torch.Size([256]) + + def test_reshape_missing_key_skipped(self): + """Test that missing keys are silently skipped.""" + data = TensorDict({"x": torch.randn(10, 3)}) + transform = Reshape(keys=["missing"], shape=(30,)) + + result = transform(data) + assert "x" in result + assert result["x"].shape == (10, 3) + + def test_reshape_preserves_other_fields(self): + """Test that non-target fields are untouched.""" + original = torch.randn(50, 3) + data = TensorDict({"target": torch.randn(1, 50, 3), "other": original.clone()}) + transform = Reshape(keys=["target"], shape=(50, 3)) + + result = transform(data) + assert result["target"].shape == torch.Size([50, 3]) + torch.testing.assert_close(result["other"], original) + + def test_reshape_multiple_keys(self): + """Test reshaping multiple keys.""" + data = TensorDict( + { + "a": torch.randn(1, 64, 64), + "b": torch.randn(1, 64, 64), + } + ) + transform = Reshape(keys=["a", "b"], shape=(64, 64)) + + result = transform(data) + assert result["a"].shape == torch.Size([64, 64]) + assert result["b"].shape == torch.Size([64, 64]) + + def test_reshape_extra_repr(self): + """Test extra_repr output.""" + transform = Reshape(keys=["y"], shape=(256, 256)) + repr_str = transform.extra_repr() + + assert "keys" in repr_str + assert "shape" in repr_str