diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 13a6193b25..57b5fb5958 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -153,6 +153,29 @@ the ``Dataset`` is responsible for the threaded execution of ``Reader``s and :members: :show-inheritance: +MultiDataset +^^^^^^^^^^^^ + +The ``MultiDataset`` composes two or more ``Dataset`` instances behind a single +index space (concatenation). Each sub-dataset can have its own Reader and +transforms. Global indices are mapped to the owning sub-dataset and local index; +metadata is enriched with ``dataset_index`` so batches can identify the source. +Use ``MultiDataset`` when you want to train on multiple datasets with the same +DataLoader, optionally enforcing that all outputs share the same TensorDict keys +for collation. See :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY` +for the metadata key added to each sample. + +Note that to properly collate and stack outputs from different datasets, you +can set ``output_strict=True`` in the constructor of a ``MultiDataset``. Upon +construction, it will load the first batch from every passed dataset and test +that the tensordict produced by the ``Reader`` and ``Transform`` pipeline has +consistent keys. Because the exact collation details differ by dataset, the +``MultiDataset`` does not check more aggressively than output key consistency. + +.. autoclass:: physicsnemo.datapipes.multi_dataset.MultiDataset + :members: + :show-inheritance: + Readers ^^^^^^^ diff --git a/examples/cfd/darcy-multidataset/.gitignore b/examples/cfd/darcy-multidataset/.gitignore new file mode 100644 index 0000000000..a1df66cf2e --- /dev/null +++ b/examples/cfd/darcy-multidataset/.gitignore @@ -0,0 +1,3 @@ +runs/ +outputs/ +output/ diff --git a/examples/cfd/darcy-multidataset/README.md b/examples/cfd/darcy-multidataset/README.md new file mode 100644 index 0000000000..0881290c11 --- /dev/null +++ b/examples/cfd/darcy-multidataset/README.md @@ -0,0 +1,5 @@ +# Darcy Flow with multiple datasets + +This readme is a work in progress and will be updated. + +Don't approve the PR until it's updated! diff --git a/examples/cfd/darcy-multidataset/benchmark_datasets.py b/examples/cfd/darcy-multidataset/benchmark_datasets.py new file mode 100644 index 0000000000..de6b5801d8 --- /dev/null +++ b/examples/cfd/darcy-multidataset/benchmark_datasets.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Load and benchmark each dataset (numpy and hdf5) separately. +# Usage: python benchmark_datasets.py data.numpy_path=/path/to/npz data.hdf5_path=/path/to/h5 + +import time +from collections import defaultdict + +import torch +import hydra +from omegaconf import DictConfig, OmegaConf + +from physicsnemo import datapipes + + +def _bench_dataset(name: str, dataset, n_iters: int = 2, n_samples=None) -> None: + """Run n_iters passes over the dataset (or first n_samples per pass) and report throughput.""" + n = len(dataset) + if n == 0: + print(f" {name}: empty, skip") + return + count = n if n_samples is None else min(n_samples, n) + + # Warmup + for i in range(min(3, count)): + data_dict, meta = dataset[i] + + for key, val in data_dict.items(): + print(f" Key {key} has shape {val.shape}") + + # Accumulate per-key running stats over the full dataset + sums = defaultdict(lambda: 0.0) + sq_sums = defaultdict(lambda: 0.0) + counts = defaultdict(lambda: 0) + + start = time.perf_counter() + for _ in range(n_iters): + for i in range(count): + data_dict, meta = dataset[i] + for key, val in data_dict.items(): + val_f = val.float() + sums[key] += val_f.sum().item() + sq_sums[key] += (val_f**2).sum().item() + counts[key] += val_f.numel() + elapsed = time.perf_counter() - start + + total = n_iters * count + rate = total / elapsed if elapsed > 0 else 0 + print( + f" {name}: {total} loads in {elapsed:.3f}s -> {rate:.1f} samples/s (len={n})" + ) + + for key in sums: + mean = sums[key] / counts[key] + std = ((sq_sums[key] / counts[key]) - mean**2) ** 0.5 + print( + f" {name}/{key}: mean={mean:.6g}, std={std:.6g} (over {counts[key]} elements)" + ) + + +def _path_ok(path) -> bool: + """True if path looks set (not OmegaConf missing placeholder).""" + if path is None: + return False + s = str(path).strip() + return s != "" and s != "???" + + +@hydra.main(version_base=None, config_path="./conf", config_name="config") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + print("Config (full):") + print(OmegaConf.to_yaml(cfg)) + print() + n_iters = 1 + n_samples = getattr(cfg, "bench_n_samples", 300) # optional: cap samples per pass + + print("Benchmarking individual datasets:\n") + + for i, ds_cfg in enumerate(cfg.multi_dataset.datasets): + print(f"Benchmark Dataset {i}") + ds = hydra.utils.instantiate(ds_cfg) + _bench_dataset("name", ds, n_iters=n_iters, n_samples=n_samples) + ds.close() + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/darcy-multidataset/conf/config.yaml b/examples/cfd/darcy-multidataset/conf/config.yaml new file mode 100644 index 0000000000..61a3934a7f --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/config.yaml @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Darcy GeoTransolver multi-dataset example — main config. +# Compose dataloader (numpy + hdf5). Model and training configs added later. +# +# Run: python load_and_visualize_data.py (or train.py when implemented) +# Override paths: data.numpy_path=... data.pde_bench_darcy_flow_dir=... +# +# Dataset configs live in conf/datasets/*.yaml. PDEBench Darcy betas can be +# mixed by commenting in/out lines in multi_dataset.datasets below. + +defaults: + - _self_ + - dataloader/config + - model/transolver + # - model/geotransolver + - training/default + - datasets/numpy@darcy_fno + - datasets/hdf5_beta0.01@dataset_pdebench_0.01 + - datasets/hdf5_beta0.1@dataset_pdebench_0.1 + - datasets/hdf5_beta1.0@dataset_pdebench_1.0 + - datasets/hdf5_beta10.0@dataset_pdebench_10.0 + - datasets/hdf5_beta100.0@dataset_pdebench_100.0 + +target_size: [128, 128] + + +# Transolver spatial shape must match data target_size (e.g. [256, 256]) +model: + structured_shape: ${target_size} + +# Shared data paths (override from CLI) +data: + numpy_path: /lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz # path to .npz file or directory (Darcy: permeability, darcy) + # Directory containing 2D_DarcyFlow_beta*_Train.hdf5 (PDEBench). Used by datasets/hdf5_beta*.yaml. + pde_bench_darcy_flow_dir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/pde_bench/2D/DarcyFlow + + +# MultiDataset: enable/disable PDEBench Darcy betas by commenting in or out lines below. +multi_dataset: + _target_: physicsnemo.datapipes.MultiDataset + datasets: + - ${darcy_fno} + - ${dataset_pdebench_0.01} + - ${dataset_pdebench_0.1} + - ${dataset_pdebench_1.0} + - ${dataset_pdebench_10.0} + - ${dataset_pdebench_100.0} + output_strict: true diff --git a/examples/cfd/darcy-multidataset/conf/dataloader/config.yaml b/examples/cfd/darcy-multidataset/conf/dataloader/config.yaml new file mode 100644 index 0000000000..899664d09c --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/dataloader/config.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# DataLoader: references root-level multi_dataset only. +# Dataset definitions live in conf/datasets/*.yaml; multi_dataset composes them in config.yaml. + +_target_: physicsnemo.datapipes.DataLoader +batch_size: 36 +shuffle: true +drop_last: false +prefetch_factor: 2 +num_streams: 4 +use_streams: false +collate_metadata: true + +dataset: ${multi_dataset} diff --git a/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.01.yaml b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.01.yaml new file mode 100644 index 0000000000..7ac35b6a50 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.01.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# PDEBench Darcy HDF5, beta=0.01. Path: data.pde_bench_darcy_flow_dir. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.HDF5Reader + path: ${data.pde_bench_darcy_flow_dir}/2D_DarcyFlow_beta0.01_Train.hdf5 + fields: + - "nu" + - "tensor" + file_pattern: "*.h5" + index_key: null + pin_memory: false + include_index_in_metadata: true +transforms: + - _target_: ${dp:Rename} + mapping: + nu: x + tensor: y + - _target_: ${dp:Reshape} + keys: + - y + shape: ${target_size} + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 0.536271 + y: 0.0106918 + stds: + x: 0.449791 + y: 0.0353361 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.1.yaml b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.1.yaml new file mode 100644 index 0000000000..6f9902b8b5 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta0.1.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# PDEBench Darcy HDF5, beta=0.1. Path: data.pde_bench_darcy_flow_dir. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.HDF5Reader + path: ${data.pde_bench_darcy_flow_dir}/2D_DarcyFlow_beta0.1_Train.hdf5 + fields: + - "nu" + - "tensor" + file_pattern: "*.h5" + index_key: null + pin_memory: false + include_index_in_metadata: true +transforms: + - _target_: ${dp:Rename} + mapping: + nu: x + tensor: y + - _target_: ${dp:Reshape} + keys: + - y + shape: ${target_size} + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 0.536271 + y: 0.0244916 + stds: + x: 0.449791 + y: 0.0427034 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta1.0.yaml b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta1.0.yaml new file mode 100644 index 0000000000..648df31cc9 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta1.0.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# PDEBench Darcy HDF5, beta=1.0. Path: data.pde_bench_darcy_flow_dir. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.HDF5Reader + path: ${data.pde_bench_darcy_flow_dir}/2D_DarcyFlow_beta1.0_Train.hdf5 + fields: + - "nu" + - "tensor" + file_pattern: "*.h5" + index_key: null + pin_memory: false + include_index_in_metadata: true +transforms: + - _target_: ${dp:Rename} + mapping: + nu: x + tensor: y + - _target_: ${dp:Reshape} + keys: + - y + shape: ${target_size} + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 0.536271 + y: 0.162347 + stds: + x: 0.449791 + y: 0.139494 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta10.0.yaml b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta10.0.yaml new file mode 100644 index 0000000000..e7055db854 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta10.0.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# PDEBench Darcy HDF5, beta=10.0. Path: data.pde_bench_darcy_flow_dir. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.HDF5Reader + path: ${data.pde_bench_darcy_flow_dir}/2D_DarcyFlow_beta10.0_Train.hdf5 + fields: + - "nu" + - "tensor" + file_pattern: "*.h5" + index_key: null + pin_memory: false + include_index_in_metadata: true +transforms: + - _target_: ${dp:Rename} + mapping: + nu: x + tensor: y + - _target_: ${dp:Reshape} + keys: + - y + shape: ${target_size} + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 0.536271 + y: 1.5409 + stds: + x: 0.449791 + y: 1.17804 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta100.0.yaml b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta100.0.yaml new file mode 100644 index 0000000000..e89fdd5f5f --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/hdf5_beta100.0.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# PDEBench Darcy HDF5, beta=100.0. Path: data.pde_bench_darcy_flow_dir. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.HDF5Reader + path: ${data.pde_bench_darcy_flow_dir}/2D_DarcyFlow_beta100.0_Train.hdf5 + fields: + - "nu" + - "tensor" + file_pattern: "*.h5" + index_key: null + pin_memory: false + include_index_in_metadata: true +transforms: + - _target_: ${dp:Rename} + mapping: + nu: x + tensor: y + - _target_: ${dp:Reshape} + keys: + - y + shape: ${target_size} + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 0.536271 + y: 15.3158 + stds: + x: 0.449791 + y: 11.5822 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/datasets/numpy.yaml b/examples/cfd/darcy-multidataset/conf/datasets/numpy.yaml new file mode 100644 index 0000000000..a4cf2956b0 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/datasets/numpy.yaml @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Darcy NPZ dataset: permeability (coeff) -> pressure (sol). +# Path comes from config.yaml data.numpy_path. + +_target_: physicsnemo.datapipes.Dataset +reader: + _target_: physicsnemo.datapipes.readers.NumpyReader + path: ${data.numpy_path} + fields: + - "coeff" + - "sol" + file_pattern: "*.npz" + index_key: null + pin_memory: true + include_index_in_metadata: true + preload_to_cpu: true +transforms: + - _target_: ${dp:Resize} + input_keys: + - "coeff" + - "sol" + size: ${target_size} + mode: bilinear + - _target_: ${dp:Rename} + mapping: + coeff: x + sol: y + - _target_: ${dp:Normalize} + input_keys: + - "x" + - "y" + method: mean_std + means: + x: 7.5348 + y: 0.00567372 + stds: + x: 4.49987 + y: 0.00377293 +device: auto +num_workers: 2 diff --git a/examples/cfd/darcy-multidataset/conf/model/geotransolver.yaml b/examples/cfd/darcy-multidataset/conf/model/geotransolver.yaml new file mode 100644 index 0000000000..8ff6309e7d --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/model/geotransolver.yaml @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# GeoTransolver model config for 2D Darcy (permeability -> pressure). +# Geometry: normalized 2D grid coordinates (x, y) per token. +# structured_shape: set to [H, W] to match the training grid so GALE uses Conv2d +# slice projection (same idea as Transolver). Omit for flattened mesh-style runs. + +# Coming soon! + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + +functional_dim: 2 +out_dim: 1 +geometry_dim: 1 # 2D grid coords (x, y) as geometry +global_dim: null +# Must match root config target_size / grid (same as Transolver structured_shape) +structured_shape: ${target_size} + +n_layers: 6 +n_hidden: 128 +dropout: 0.0 +n_head: 8 +act: gelu +mlp_ratio: 4 +slice_num: 32 +use_te: false +time_input: false +plus: false +include_local_features: false diff --git a/examples/cfd/darcy-multidataset/conf/model/transolver.yaml b/examples/cfd/darcy-multidataset/conf/model/transolver.yaml new file mode 100644 index 0000000000..b327f56a65 --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/model/transolver.yaml @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Transolver model config for 2D Darcy (permeability -> pressure). +# structured_shape is set in config.yaml from target_size so it matches data spatial shape. + +_target_: physicsnemo.models.transolver.Transolver + +functional_dim: 1 +out_dim: 1 +embedding_dim: 2 +# structured_shape set in config.yaml as ${target_size} to match data +unified_pos: true +ref: 8 + +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +use_te: false +time_input: false +plus: false + + diff --git a/examples/cfd/darcy-multidataset/conf/training/default.yaml b/examples/cfd/darcy-multidataset/conf/training/default.yaml new file mode 100644 index 0000000000..a4b2e9858d --- /dev/null +++ b/examples/cfd/darcy-multidataset/conf/training/default.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Training config for Darcy Transolver multi-dataset. +# Optimizer and scheduler are built in train.py (need model.parameters()). +# Muon (pytorch>=2.9) is used for 2D weight params; Adam for the rest when use_muon=true. + +batch_size: 16 +max_epochs: 100 +validation_every_epochs: 1 +checkpoint_every_epochs: 1 +# Train/val split via sampler (same dataset, different indices) +val_fraction: 0.2 +split_seed: 42 + +optimizer: + name: Adam + lr: 3.0e-3 + weight_decay: 1.0e-4 +# Muon for matrix (2D) parameters; requires pytorch>=2.9 +use_muon: true + +scheduler: + name: CosineAnnealingLR + T_max: 100 + +# torch.compile; set compile=true to enable (modes: default | reduce-overhead | max-autotune) +compile: true +compile_mode: default diff --git a/examples/cfd/darcy-multidataset/load_and_visualize_data.py b/examples/cfd/darcy-multidataset/load_and_visualize_data.py new file mode 100644 index 0000000000..70713ae3d1 --- /dev/null +++ b/examples/cfd/darcy-multidataset/load_and_visualize_data.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Load datapipes from Hydra config, iterate batches, and visualize input (x) and output (y). +# Usage: python load_and_visualize_data.py data.numpy_path=/path/to/npz data.hdf5_path=/path/to/h5 + +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +from physicsnemo import datapipes + + +def _squeeze_2d(tensor): + """Return (H, W) for (H,W), (1,H,W), (C,H,W) with C==1, or a single sample from (B,C,H,W).""" + import torch + + t = tensor + if t.dim() == 2: + return t + if t.dim() == 3 and t.shape[0] == 1: + return t[0] + if t.dim() == 3: + return t[0] + if t.dim() == 4: + return t[0].squeeze(0) if t.shape[1] == 1 else t[0] + return t.squeeze() + + +@hydra.main(version_base=None, config_path="./conf", config_name="config") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + dataloader = hydra.utils.instantiate(cfg.dataloader) + + out_dir = Path(HydraConfig.get().runtime.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # Input/output keys after config transforms (Rename: coeff/nu -> x, sol/tensor -> y) + in_key = "x" + out_key = "y" + + n_batches_show = min(3, len(dataloader)) + for batch_idx, batch_out in enumerate(dataloader): + if batch_idx >= n_batches_show: + break + + if isinstance(batch_out, tuple): + data, meta_list = batch_out + else: + data = batch_out + batch_size = next( + (data[k].shape[0] for k in data.keys() if data[k].dim() >= 2), + 0, + ) + meta_list = [{}] * batch_size + + if in_key not in data or out_key not in data: + keys = [k for k in data.keys() if data[k].dim() >= 2] + in_key_use = keys[0] if len(keys) > 0 else None + out_key_use = keys[1] if len(keys) > 1 else keys[0] + else: + in_key_use = in_key + out_key_use = out_key + + if in_key_use is None: + continue + + B = data[in_key_use].shape[0] + if B == 0: + continue + fig, axes = plt.subplots(B, 2, figsize=(6, 3 * B)) + if B == 1: + axes = axes.reshape(1, -1) + for b in range(B): + in_grid = _squeeze_2d(data[in_key_use][b]) + out_grid = _squeeze_2d(data[out_key_use][b]) if out_key_use else None + if hasattr(in_grid, "numpy"): + in_grid = in_grid.detach().cpu().numpy() + meta = meta_list[b] if b < len(meta_list) else {} + ds_idx = meta.get("dataset_index", -1) + + ax_in = axes[b, 0] + ax_in.imshow(in_grid) + ax_in.set_title( + f"Input ({in_key_use}) [b={b}, batch={batch_idx}, ds={ds_idx}]" + ) + ax_in.set_axis_off() + + ax_out = axes[b, 1] + if out_grid is not None: + if hasattr(out_grid, "numpy"): + out_grid = out_grid.detach().cpu().numpy() + ax_out.imshow(out_grid) + ax_out.set_title(f"Output ({out_key_use}) [b={b}]") + ax_out.set_axis_off() + + fig.tight_layout() + fig.savefig(out_dir / f"batch_{batch_idx:02d}.png", dpi=100) + plt.close(fig) + + print(f"Saved {n_batches_show} batch figures to {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/darcy-multidataset/train.py b/examples/cfd/darcy-multidataset/train.py new file mode 100644 index 0000000000..980db8eece --- /dev/null +++ b/examples/cfd/darcy-multidataset/train.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Training entrypoint for Darcy Transolver multi-dataset. +# Uses MultiDataset dataloader and Transolver with spatial shape matching +# target_size for permeability -> pressure. Muon used for 2D params when available. + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import random_split, SubsetRandomSampler +from torch.utils.tensorboard import SummaryWriter + +# This is needed for datapipe registry via hydra instantiation +from physicsnemo import datapipes +from physicsnemo.datapipes import DataLoader + +from physicsnemo.distributed import DistributedManager +from physicsnemo.optim import CombinedOptimizer +from physicsnemo.utils import load_checkpoint, save_checkpoint +from physicsnemo.utils.logging import PythonLogger, LaunchLogger + + +class RelativeL2Loss: + """Scale-invariant relative L2 loss: mean( ||pred - y||_2 / ||y||_2 ).""" + + def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + B = pred.shape[0] + diff = torch.norm(pred.reshape(B, -1) - target.reshape(B, -1), dim=1) + ref = torch.norm(target.reshape(B, -1), dim=1) + return torch.mean(diff / ref) + + +def make_spatial_positions( + h: int, + w: int, + *, + batch_size: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + r"""Build a fixed 2D grid of normalized positions in :math:`[0,1]^2`, tiled over batch. + + Call once per run (same ``h, w`` as ``target_size`` / model ``structured_shape``). + Use :meth:`torch.Tensor.expand` in the training loop if minibatch size can differ + (e.g. last batch when ``drop_last`` is false). + + Parameters + ---------- + h, w : int + Grid height and width (row, col count), matching data spatial shape. + batch_size : int + Leading batch dimension (e.g. dataloader ``batch_size``). + device : torch.device + Where to place the tensor (e.g. training device). + dtype : torch.dtype + Floating dtype (e.g. ``torch.get_default_dtype()``). + + Returns + ------- + torch.Tensor + Shape :math:`(B, H, W, 2)` with last dimension ``(x, y)`` in index order + ``ij`` (``y`` increases with row, ``x`` with column). + """ + yy, xx = torch.meshgrid( + torch.linspace(0, 1, h, device=device, dtype=dtype), + torch.linspace(0, 1, w, device=device, dtype=dtype), + indexing="ij", + ) + # (H, W, 2) then (1, H, W, 2) -> expand to (B, H, W, 2) without extra storage + grid_hw2 = torch.stack([xx, yy], dim=-1) + return grid_hw2.unsqueeze(0).expand(batch_size, -1, -1, -1) + + +def _normalize_for_image(t: torch.Tensor) -> torch.Tensor: + """Min-max normalize a tensor to [0, 1] for TensorBoard image logging.""" + t_min, t_max = t.min(), t.max() + if t_max - t_min < 1e-8: + return torch.zeros_like(t) + return (t - t_min) / (t_max - t_min) + + +def _log_sample_images( + writer: SummaryWriter, + tag_prefix: str, + x: torch.Tensor, + y: torch.Tensor, + pred: torch.Tensor, + epoch: int, +) -> None: + """Log the first sample's x, y, and pred as grayscale images.""" + # x, y, pred arrive as (B, H, W, 1); take first sample -> (H, W, 1) -> (1, H, W) + for name, tensor in [("x", x), ("y_true", y), ("y_pred", pred)]: + img = _normalize_for_image(tensor[0].detach().squeeze(-1)) # (H, W) + writer.add_image(f"{tag_prefix}/{name}", img.unsqueeze(0), epoch) + + +@hydra.main(version_base=None, config_path="./conf", config_name="config") +def main(cfg: DictConfig) -> None: + OmegaConf.resolve(cfg) + DistributedManager.initialize() + dist = DistributedManager() + + log = PythonLogger(name="darcy_transolver_multidataset") + log.file_logging() + + # Multi-dataset Data Loader Instantiated from hydra: + dataloader = hydra.utils.instantiate(cfg.dataloader) + dataset = dataloader.dataset + n = len(dataset) + val_fraction = getattr(cfg.training, "val_fraction", 0.2) + split_seed = getattr(cfg.training, "split_seed", 42) + n_val = max(1, int(n * val_fraction)) + n_train = n - n_val + train_subset, val_subset = random_split( + dataset, [n_train, n_val], generator=torch.Generator().manual_seed(split_seed) + ) + train_sampler = SubsetRandomSampler(train_subset.indices) + val_sampler = SubsetRandomSampler(val_subset.indices) + loader_kw = { + "batch_size": dataloader.batch_size, + "collate_fn": dataloader.collate_fn, + "prefetch_factor": dataloader.prefetch_factor, + "num_streams": dataloader.num_streams, + "use_streams": dataloader.use_streams, + } + train_dataloader = DataLoader( + dataset, + shuffle=False, + sampler=train_sampler, + drop_last=True, + **loader_kw, + ) + val_dataloader = DataLoader( + dataset, + shuffle=False, + sampler=val_sampler, + drop_last=True, + **loader_kw, + ) + log.info( + f"Train/val split: {n_train} train, {n_val} val (val_fraction={val_fraction}, seed={split_seed})" + ) + + h, w = int(cfg.target_size[0]), int(cfg.target_size[1]) + batch_size = int(cfg.dataloader.batch_size) + spatial_positions = make_spatial_positions( + h, w, batch_size=batch_size, device=dist.device, dtype=torch.get_default_dtype() + ) + + # spatial_positions: (B, H, W, 2). Expand/slice per step if batch size varies. + log.info( + f"Spatial positions grid shape {tuple(spatial_positions.shape)} on {dist.device}" + ) + + # Model (structured_shape from config matches target_size) + model_cfg = OmegaConf.to_container(cfg.model, resolve=True) + model = hydra.utils.instantiate(model_cfg, _convert_="all").to(dist.device) + + def _compute_metrics(pred, y): + loss = loss_fn(pred, y) + with torch.no_grad(): + l2_err_sq = (pred - y).pow(2).sum() + l2_ref_sq = y.pow(2).sum() + return {"loss": loss, "l2_err_sq": l2_err_sq, "l2_ref_sq": l2_ref_sq} + + # Resolve forward call from config so the training loop never branches on + # model type (isinstance); required for torch.compile. + _model_target = OmegaConf.select(cfg, "model._target_", default="") + model_name = _model_target.rsplit(".", 1)[-1] if _model_target else "unknown" + if "geotransolver" in _model_target.lower(): + # Coming soon! + def _forward(model, batch, positions): + x = batch["x"].unsqueeze(-1) + y = batch["y"].unsqueeze(-1) + pred = model(local_embedding=positions, geometry=x) + return _compute_metrics(pred, y), pred, x, y + elif "transolver" in _model_target.lower(): + + def _forward(model, batch, positions): + x = batch["x"].unsqueeze(-1) + y = batch["y"].unsqueeze(-1) + pred = model(fx=x, embedding=positions) + return _compute_metrics(pred, y), pred, x, y + else: + raise ValueError( + f"Unsupported model _target_ {_model_target!r}; " + "expected a class path containing 'Transolver' or 'GeoTransolver'." + ) + + tb_log_dir = f"runs/{model_name}_{h}x{w}" + tb_tag = f"{model_name}_{h}x{w}" + train_writer = SummaryWriter(log_dir=tb_log_dir + f"/{tb_tag}/train/") + val_writer = SummaryWriter(log_dir=tb_log_dir + f"/{tb_tag}/val/") + log.info(f"TensorBoard logging to {tb_log_dir}") + + if getattr(cfg.training, "compile", False): + compile_mode = getattr(cfg.training, "compile_mode", "default") + model = torch.compile(model, mode=compile_mode) + log.info(f"Model compiled with mode={compile_mode}.") + + # Optimizer and scheduler + opt_cfg = cfg.training.optimizer + use_muon = getattr(cfg.training, "use_muon", False) + if use_muon: + muon_params = [p for p in model.parameters() if p.ndim == 2] + other_params = [p for p in model.parameters() if p.ndim != 2] + weight_decay = getattr(opt_cfg, "weight_decay", 0.0) + optimizer = CombinedOptimizer( + optimizers=[ + torch.optim.Muon( + muon_params, + lr=opt_cfg.lr, + weight_decay=weight_decay, + ), + torch.optim.Adam( + other_params, + lr=opt_cfg.lr, + weight_decay=weight_decay, + ), + ], + ) + log.info("Using Muon for 2D params, Adam for rest.") + elif opt_cfg.name == "Adam": + weight_decay = getattr(opt_cfg, "weight_decay", 0.0) + optimizer = torch.optim.Adam( + model.parameters(), lr=opt_cfg.lr, weight_decay=weight_decay + ) + else: + raise ValueError(f"Unsupported optimizer: {opt_cfg.name}") + + sch_cfg = cfg.training.scheduler + if sch_cfg.name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=cfg.training.max_epochs + ) + else: + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda _: 1.0 + ) + + loss_fn = RelativeL2Loss() + + ckpt_args = { + "path": f"./checkpoints/{model_name}", + "optimizer": optimizer, + "scheduler": scheduler, + "models": [model], + } + loaded_epoch = load_checkpoint(device=dist.device, **ckpt_args) + start_epoch = max(1, loaded_epoch + 1) if loaded_epoch else 1 + + n_batches = len(train_dataloader) + val_every = cfg.training.validation_every_epochs + ckpt_every = cfg.training.checkpoint_every_epochs + + if start_epoch == 1: + log.success("Training started...") + else: + log.warning(f"Resuming from epoch {start_epoch}.") + + for epoch in range(start_epoch, cfg.training.max_epochs + 1): + model.train() + _zero = torch.tensor(0.0, device=dist.device) + train_loss_sum = _zero.clone() + train_l2_err_sq = _zero.clone() + train_l2_ref_sq = _zero.clone() + train_n = 0 + train_sample = None + with LaunchLogger( + "train", num_mini_batch=n_batches, epoch_alert_freq=1, epoch=epoch + ) as logger: + for batch, meta in train_dataloader: + metrics, pred, x, y = _forward(model, batch, spatial_positions) + + optimizer.zero_grad() + metrics["loss"].backward() + optimizer.step() + + b = x.shape[0] + train_loss_sum += metrics["loss"].detach() * b + train_l2_err_sq += metrics["l2_err_sq"] + train_l2_ref_sq += metrics["l2_ref_sq"] + train_n += b + train_sample = (x, y, pred) + + logger.log_minibatch({"loss": metrics["loss"].detach()}) + + logger.log_epoch({"lr": optimizer.param_groups[0]["lr"]}) + scheduler.step() + + if train_n > 0: + avg_train_loss = (train_loss_sum / train_n).item() + train_rel_l2 = (train_l2_err_sq.sqrt() / train_l2_ref_sq.sqrt()).item() + train_writer.add_scalar(f"train/loss", avg_train_loss, epoch) + train_writer.add_scalar(f"train/rel_l2", train_rel_l2, epoch) + if train_sample is not None: + _log_sample_images(train_writer, f"train", *train_sample, epoch) + + if epoch % val_every == 0: + model.eval() + val_loss_sum = _zero.clone() + val_l2_err_sq = _zero.clone() + val_l2_ref_sq = _zero.clone() + val_n = 0 + val_sample = None + with torch.no_grad(): + for batch, meta in val_dataloader: + metrics, pred, x, y = _forward(model, batch, spatial_positions) + b = x.shape[0] + val_loss_sum += metrics["loss"] * b + val_l2_err_sq += metrics["l2_err_sq"] + val_l2_ref_sq += metrics["l2_ref_sq"] + val_n += b + val_sample = (x, y, pred) + if val_n > 0: + avg_val_loss = (val_loss_sum / val_n).item() + val_rel_l2 = (val_l2_err_sq.sqrt() / val_l2_ref_sq.sqrt()).item() + val_writer.add_scalar(f"val/loss", avg_val_loss, epoch) + val_writer.add_scalar(f"val/rel_l2", val_rel_l2, epoch) + log.info( + f"Epoch {epoch} val_loss={avg_val_loss:.6f} val_rel_l2={val_rel_l2:.6f}" + ) + if val_sample is not None: + _log_sample_images(val_writer, f"val", *val_sample, epoch) + model.train() + + if epoch % ckpt_every == 0: + save_checkpoint(**ckpt_args, epoch=epoch) + + train_writer.close() + val_writer.close() + save_checkpoint(**ckpt_args, epoch=cfg.training.max_epochs) + log.success("Training completed.") + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index e693a61df6..3d762c054a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -40,6 +40,7 @@ ) from physicsnemo.datapipes.dataloader import DataLoader from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.multi_dataset import MultiDataset from physicsnemo.datapipes.readers import ( HDF5Reader, NumpyReader, @@ -70,6 +71,7 @@ NormalizeVectors, Purge, Rename, + Resize, Scale, SubsamplePoints, Transform, @@ -84,6 +86,7 @@ "TensorDict", # Re-export from tensordict "Dataset", "DataLoader", + "MultiDataset", # Transforms - Base "Transform", "Compose", @@ -107,6 +110,7 @@ "CreateGrid", "KNearestNeighbors", "CenterOfMass", + "Resize", # Transforms - Utility "Rename", "Purge", diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index c6f465d425..cf3b160cd2 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -168,7 +168,9 @@ def __len__(self) -> int: int Number of batches in the dataloader. """ - n_samples = len(self.dataset) + n_samples = ( + len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) + ) if self.drop_last: return n_samples // self.batch_size return (n_samples + self.batch_size - 1) // self.batch_size diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py new file mode 100644 index 0000000000..8b6152a10f --- /dev/null +++ b/physicsnemo/datapipes/multi_dataset.py @@ -0,0 +1,378 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MultiDataset - Compose multiple Dataset instances behind a single dataset-like interface. + +MultiDataset presents a single index space (concatenation of all constituent datasets) +and delegates __getitem__, prefetch, and close to the appropriate sub-dataset. +Each sub-dataset can have its own Reader and transforms. Optional output strictness +validates that all sub-datasets produce the same TensorDict keys (outputs) so default +collation works. +""" + +from __future__ import annotations + +from typing import Any, Iterator, Optional, Sequence + +from tensordict import TensorDict + +from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.registry import register + +# Metadata key added by MultiDataset to identify which sub-dataset produced the sample. +DATASET_INDEX_METADATA_KEY = "dataset_index" + + +def _validate_strict_outputs(datasets: Sequence[Dataset]) -> list[str]: + """ + Check that all non-empty datasets produce the same TensorDict keys; return them. + + Loads one sample from each non-empty dataset and compares output keys (after + transforms). This validates output schema, not reader field_names. + + Parameters + ---------- + datasets : Sequence[Dataset] + Datasets to validate. + + Returns + ------- + list[str] + Common output keys (sorted, from first non-empty dataset). + + Raises + ------ + ValueError + If any non-empty dataset has different output keys. + """ + if not datasets: + return [] + ref_keys: Optional[list[str]] = None + ref_index: Optional[int] = None + for i, ds in enumerate(datasets): + if len(ds) == 0: + continue + data, _ = ds[0] + keys = sorted(data.keys()) + if ref_keys is None: + ref_keys = keys + ref_index = i + elif keys != ref_keys: + raise ValueError( + "output_strict=True requires identical output keys (TensorDict keys) " + f"across datasets: dataset {ref_index} has {ref_keys}, dataset {i} has {keys}" + ) + return list(ref_keys) if ref_keys is not None else list(datasets[0].field_names) + + +@register() +class MultiDataset: + r""" + A dataset that composes multiple :class:`Dataset` instances behind one index space. + + Global indices are mapped to (dataset_index, local_index) by concatenation: + indices 0..len0-1 come from the first dataset, len0..len0+len1-1 from the second, + and so on. Each constituent can have its own Reader and transforms. Metadata + is enriched with ``dataset_index`` so batches can identify the source. + + Parameters + ---------- + datasets : Sequence[Dataset] + One or more Dataset instances (Reader + transforms each). Order defines + index mapping: first dataset occupies 0..len(ds0)-1, etc. + output_strict : bool, default=True + If True, require all datasets to produce the same TensorDict keys (output + keys after transforms) so :class:`DefaultCollator` can stack batches. If + False, no check is done; use a custom collator when keys or shapes differ. + Note that `output_strict=True` will load the first instance of all datasets + upon construction. Think of it as a debugging parameter: if you are sure + that your datasets are working properly, and want to defer loading, + you can safely disable this. + + Raises + ------ + ValueError + If ``len(datasets) < 1`` or if ``output_strict=True`` and output keys differ. + + Notes + ----- + MultiDataset implements the same interface as :class:`Dataset` (``__len__``, + ``__getitem__``, ``prefetch``, ``prefetch_batch``, ``prefetch_count``, + ``cancel_prefetch``, ``close``, ``field_names``) and can be passed to + :class:`DataLoader` in place of a single dataset. Prefetch and close are + delegated to the sub-dataset that owns the index. When ``output_strict=True``, + validation checks that each dataset's *output* TensorDict (after transforms) + has the same keys, not the reader's field_names. When ``output_strict=False``, + :attr:`field_names` returns the first dataset's field names; with heterogeneous + datasets, prefer a custom collator and use metadata ``dataset_index`` to + group or pad by source. + + Shuffling and sampling + --------------------- + The DataLoader sees one linear index space of size :math:`\\sum_k \\text{len}(\\text{datasets}[k])`. + With ``shuffle=True``, the default :class:`RandomSampler` shuffles these global + indices, so each batch is a random mix of samples from all sub-datasets. There + is no per-dataset balancing: if one dataset is much larger, its samples will + appear more often. For balanced or stratified sampling, use a custom + :class:`torch.utils.data.Sampler` (e.g. weighted or one sample per dataset per + batch) and pass it to the DataLoader. + + Metadata + -------- + Every sample returned by :meth:`__getitem__` has its metadata dict extended + with the key :const:`DATASET_INDEX_METADATA_KEY` (``"dataset_index"``), the + integer index of the sub-dataset that produced the sample (0 for the first + dataset, 1 for the second, etc.). Sub-dataset–specific metadata (e.g. file + path, index within that dataset) is unchanged. When using the DataLoader with + ``collate_metadata=True``, each batch yields a list of metadata dicts aligned + with the batch dimension; each dict includes ``dataset_index`` so you can + filter, weight, or aggregate by source in the training loop. + + Examples + -------- + >>> from physicsnemo.datapipes import Dataset, MultiDataset, HDF5Reader, Normalize + >>> ds_a = Dataset(HDF5Reader("a.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP + >>> ds_b = Dataset(HDF5Reader("b.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP + >>> multi = MultiDataset([ds_a, ds_b], output_strict=True) # doctest: +SKIP + >>> len(multi) == len(ds_a) + len(ds_b) # doctest: +SKIP + True + >>> data, meta = multi[0] # from ds_a # doctest: +SKIP + >>> meta["dataset_index"] # 0 # doctest: +SKIP + """ + + def __init__( + self, + datasets: Sequence[Dataset], + *, + output_strict: bool = True, + ) -> None: + if len(datasets) < 1: + raise ValueError( + f"MultiDataset requires at least one dataset, got {len(datasets)}" + ) + for i, ds in enumerate(datasets): + if not isinstance(ds, Dataset): + raise TypeError( + f"datasets[{i}] must be a Dataset instance, got {type(ds).__name__}" + ) + + self._datasets = list(datasets) + self._output_strict = output_strict + + # Cumulative lengths: cumul[k] = sum(len(datasets[j]) for j in range(k)) + # So index i is in dataset k when cumul[k] <= i < cumul[k+1], local = i - cumul[k] + cumul = [0] + for ds in self._datasets: + cumul.append(cumul[-1] + len(ds)) + self._cumul = cumul + + if output_strict: + self._field_names = _validate_strict_outputs(self._datasets) + else: + self._field_names = list(self._datasets[0].field_names) + + def _index_to_dataset_and_local(self, index: int) -> tuple[int, int]: + """ + Map global index to (dataset_index, local_index). + + Parameters + ---------- + index : int + Global index in [0, len(self)). + + Returns + ------- + tuple[int, int] + (dataset_index, local_index). + + Raises + ------ + IndexError + If index is out of range. + """ + n = len(self) + if index < 0: + index = n + index + if index < 0 or index >= n: + raise IndexError( + f"Index {index} out of range for MultiDataset with {n} samples" + ) + # Find k such that cumul[k] <= index < cumul[k+1] + for k in range(len(self._cumul) - 1): + if self._cumul[k] <= index < self._cumul[k + 1]: + return k, index - self._cumul[k] + # Fallback (should not be reached) + return len(self._datasets) - 1, index - self._cumul[-2] + + def _index_to_dataset_and_local_optional( + self, index: int + ) -> Optional[tuple[int, int]]: + """ + Map global index to (dataset_index, local_index), or None if out of range. + + Used by cancel_prefetch to match Dataset behavior (no-op for invalid index). + """ + n = len(self) + if index < 0: + index = n + index + if index < 0 or index >= n: + return None + for k in range(len(self._cumul) - 1): + if self._cumul[k] <= index < self._cumul[k + 1]: + return k, index - self._cumul[k] + return len(self._datasets) - 1, index - self._cumul[-2] + + def __len__(self) -> int: + """Return the total number of samples (sum of all sub-dataset lengths).""" + return self._cumul[-1] + + def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: + """ + Return the transformed sample and metadata for the given global index. + + Metadata is enriched with ``dataset_index`` (key :const:`DATASET_INDEX_METADATA_KEY`). + + Parameters + ---------- + index : int + Global sample index. Supports negative indexing. + + Returns + ------- + tuple[TensorDict, dict[str, Any]] + (TensorDict, metadata dict) from the owning sub-dataset. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + data, metadata = self._datasets[ds_id][local_i] + metadata = dict(metadata) + metadata[DATASET_INDEX_METADATA_KEY] = ds_id + return data, metadata + + def prefetch( + self, + index: int, + stream: Optional[Any] = None, + ) -> None: + """ + Start prefetching the sample at the given global index. + + Delegates to the sub-dataset that owns that index. + + Parameters + ---------- + index : int + Global sample index to prefetch. + stream : object, optional + Optional CUDA stream for the sub-dataset prefetch. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + self._datasets[ds_id].prefetch(local_i, stream=stream) + + def prefetch_batch( + self, + indices: Sequence[int], + streams: Optional[Sequence[Any]] = None, + ) -> None: + """ + Start prefetching multiple samples by global index. + + Delegates to the sub-dataset that owns each index. Streams are cycled + if shorter than indices. + + Parameters + ---------- + indices : Sequence[int] + Global sample indices to prefetch. + streams : Sequence[Any], optional + Optional CUDA streams, one per index. If shorter than indices, + streams are cycled. If None, no streams used. + """ + for i, idx in enumerate(indices): + stream = None + if streams: + stream = streams[i % len(streams)] + self.prefetch(idx, stream=stream) + + def cancel_prefetch(self, index: Optional[int] = None) -> None: + """ + Cancel prefetch for the given index or all sub-datasets. + + When index is provided, only cancels if it is in range; out-of-range + indices are ignored to match :class:`Dataset` behavior. + + Parameters + ---------- + index : int, optional + Global index to cancel, or None to cancel all. + """ + if index is None: + for ds in self._datasets: + ds.cancel_prefetch(None) + else: + mapped = self._index_to_dataset_and_local_optional(index) + if mapped is not None: + ds_id, local_i = mapped + self._datasets[ds_id].cancel_prefetch(local_i) + + @property + def prefetch_count(self) -> int: + """ + Number of items currently being prefetched across all sub-datasets. + + Returns + ------- + int + Sum of in-flight prefetch counts from each sub-dataset. + """ + return sum(ds.prefetch_count for ds in self._datasets) + + def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: + """Iterate over all samples in global index order.""" + for i in range(len(self)): + yield self[i] + + @property + def field_names(self) -> list[str]: + """ + Field names in samples. + + With ``output_strict=True``, returns the common output keys (TensorDict + keys after transforms). With ``output_strict=False``, returns the first + dataset's field names. + """ + return list(self._field_names) + + def close(self) -> None: + """Close all sub-datasets and release resources.""" + for ds in self._datasets: + ds.close() + + def __enter__(self) -> "MultiDataset": + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit.""" + self.close() + + def __repr__(self) -> str: + parts = [f" ({i}): {ds}" for i, ds in enumerate(self._datasets)] + return ( + f"MultiDataset(\n output_strict={self._output_strict},\n datasets=[\n" + + ",\n".join(parts) + + "\n ]\n)" + ) diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index aba2221b43..b1ee5f3fa3 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -18,6 +18,8 @@ NumpyReader - Read data from NumPy .npz files. Supports reading from single .npz files or directories of .npz files. +In single-file mode, optional ``preload_to_cpu=True`` loads the entire +dataset into RAM at init for faster iteration with no per-sample I/O. """ from __future__ import annotations @@ -38,8 +40,12 @@ class NumpyReader(Reader): Read samples from NumPy .npz files. Supports two modes: - 1. Single .npz file: samples indexed along first dimension of each array - 2. Directory of .npz files: one sample per file + + 1. **Single .npz file**: Samples are indexed along the first dimension + of each array. Optionally, ``preload_to_cpu=True`` loads all arrays + into RAM at init so iteration does no disk I/O. + 2. **Directory of .npz files**: One sample per file; each file is opened + on demand. Example (single .npz): >>> # data.npz with arrays "positions" (N, 100, 3), "features" (N, 100) @@ -52,6 +58,10 @@ class NumpyReader(Reader): >>> # Directory with sample_0.npz, sample_1.npz, ... >>> reader = NumpyReader("data_dir/", file_pattern="sample_*.npz") # doctest: +SKIP >>> data, metadata = reader[0] # Returns (TensorDict, dict) tuple # doctest: +SKIP + + Example (single .npz with preload): + >>> reader = NumpyReader("data.npz", preload_to_cpu=True) # doctest: +SKIP + >>> # All arrays loaded into RAM at init; no disk I/O during iteration """ def __init__( @@ -65,6 +75,7 @@ def __init__( pin_memory: bool = False, include_index_in_metadata: bool = True, coordinated_subsampling: Optional[dict[str, Any]] = None, + preload_to_cpu: bool = False, ) -> None: """ Initialize the NumPy reader. @@ -93,6 +104,11 @@ def __init__( Optional dict to configure coordinated subsampling (directory mode only). If provided, must contain ``n_points`` (int) and ``target_keys`` (list of str). + preload_to_cpu : bool, default=False + If True, in single-file mode the reader loads all requested + arrays into RAM at init, closes the file, and serves samples + from memory. Use when the dataset fits in RAM and you want + to avoid disk I/O during training. Ignored in directory mode. Raises ------ @@ -100,6 +116,9 @@ def __init__( If path doesn't exist. ValueError If no files found in directory or unsupported file type. + KeyError + If preload_to_cpu is True and a required field is missing + from the file (and not in default_values). """ super().__init__( pin_memory=pin_memory, @@ -112,14 +131,17 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key + self.preload_to_cpu = preload_to_cpu if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") - # Determine mode based on path - self._mode: str # "single" or "directory" + # Mode: "single" (one .npz, samples along first dim) or "directory" + self._mode: str self._files: Optional[list[Path]] = None self._data: Optional[np.lib.npyio.NpzFile] = None + # When preload_to_cpu: in-memory arrays keyed by field name (single-file only) + self._preloaded: Optional[dict[str, np.ndarray]] = None self._available_fields: list[str] = [] if self.path.is_dir(): @@ -147,12 +169,12 @@ def _setup_directory_mode(self) -> None: self._available_fields = list(npz.files) def _setup_single_file_mode(self) -> None: - """Set up reader for single .npz file.""" + """Set up reader for a single .npz file; optionally preload all arrays to RAM.""" self._mode = "single" self._data = np.load(self.path) self._available_fields = list(self._data.files) - # Determine length from index_key or first field + # Sample count is the first dimension of index_key or of the first array if self.index_key is not None: self._length = self._data[self.index_key].shape[0] elif self._available_fields: @@ -160,6 +182,24 @@ def _setup_single_file_mode(self) -> None: else: self._length = 0 + # Optional: load entire dataset into RAM and close the file + if self.preload_to_cpu: + required = set(self.fields) - set(self.default_values.keys()) + missing = required - set(self._data.files) + if missing: + raise KeyError( + f"Required fields {missing} not found in {self.path}. " + f"Available: {list(self._data.files)}" + ) + self._preloaded = {} + for field in self.fields: + if field in self._data.files: + # .copy() forces a real array; np.array() ensures contiguous + self._preloaded[field] = np.array(self._data[field].copy()) + if hasattr(self._data, "close"): + self._data.close() + self._data = None + @property def fields(self) -> list[str]: """Fields that will be loaded (user-specified or all available).""" @@ -275,20 +315,66 @@ def _load_from_npz( # Directory mode: load full array arr = arr[:] - data[field] = torch.from_numpy(np.array(arr)) + data[field] = torch.from_numpy(np.asarray(arr, dtype=np.float32)) elif field in self.default_values: - data[field] = self.default_values[field].clone() + data[field] = self.default_values[field].clone().float() + + return data + + def _load_sample_from_preloaded(self, index: int) -> dict[str, torch.Tensor]: + """ + Load a single sample by indexing into preloaded in-memory arrays. + + Used only when ``preload_to_cpu=True`` in single-file mode. Applies + coordinated subsampling (random contiguous slice) when configured. + + Parameters + ---------- + index : int + Sample index along the first dimension of each preloaded array. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary mapping field names to CPU tensors for this sample. + """ + data = {} + fields_to_load = self.fields + target_keys_set = set() + subsample_slice = None + + # If subsampling is enabled, pick one random contiguous slice for this sample + if self._coordinated_subsampling_config is not None: + n_points = self._coordinated_subsampling_config["n_points"] + target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) + for field in target_keys_set: + if field in self._preloaded: + arr = self._preloaded[field][index] + subsample_slice = self._select_random_sections_from_slice( + 0, arr.shape[0], n_points + ) + break + for field in fields_to_load: + if field in self._preloaded: + arr = np.array(self._preloaded[field][index], copy=False) + if subsample_slice is not None and field in target_keys_set: + arr = arr[subsample_slice] + data[field] = torch.from_numpy(np.asarray(arr, dtype=np.float32)) + elif field in self.default_values: + data[field] = self.default_values[field].clone().float() return data def _load_sample(self, index: int) -> dict[str, torch.Tensor]: - """Load a single sample.""" + """Load a single sample from disk or from preloaded RAM.""" if self._mode == "directory": file_path = self._files[index] with np.load(file_path) as npz: return self._load_from_npz(npz, index=None, file_path=file_path) - else: # single + elif self._preloaded is not None: + return self._load_sample_from_preloaded(index) + else: return self._load_from_npz(self._data, index=index) def __len__(self) -> int: @@ -318,12 +404,13 @@ def _supports_coordinated_subsampling(self) -> bool: return self._mode == "directory" def close(self) -> None: - """Close file handles.""" + """Close file handles and release preloaded in-memory arrays (if any).""" super().close() if self._data is not None: if hasattr(self._data, "close"): self._data.close() self._data = None + self._preloaded = None def __repr__(self) -> str: subsample_info = "" @@ -331,11 +418,13 @@ def __repr__(self) -> str: cfg = self._coordinated_subsampling_config subsample_info = f", subsampling={cfg['n_points']} points" + preload_info = ", preload_to_cpu=True" if self._preloaded is not None else "" return ( f"NumpyReader(" f"path={self.path}, " f"mode={self._mode}, " f"len={len(self)}, " f"fields={self.fields}" - f"{subsample_info})" + f"{subsample_info}" + f"{preload_info})" ) diff --git a/physicsnemo/datapipes/transforms/__init__.py b/physicsnemo/datapipes/transforms/__init__.py index 963b4b0985..8814ff5e1b 100644 --- a/physicsnemo/datapipes/transforms/__init__.py +++ b/physicsnemo/datapipes/transforms/__init__.py @@ -45,6 +45,7 @@ CenterOfMass, CreateGrid, KNearestNeighbors, + Resize, ) from physicsnemo.datapipes.transforms.subsample import ( SubsamplePoints, @@ -55,6 +56,7 @@ ConstantField, Purge, Rename, + Reshape, ) __all__ = [ @@ -83,8 +85,10 @@ "CreateGrid", "KNearestNeighbors", "CenterOfMass", + "Resize", # Utility "Rename", "Purge", "ConstantField", + "Reshape", ] diff --git a/physicsnemo/datapipes/transforms/normalize.py b/physicsnemo/datapipes/transforms/normalize.py index 97413a9688..21ea71ce45 100644 --- a/physicsnemo/datapipes/transforms/normalize.py +++ b/physicsnemo/datapipes/transforms/normalize.py @@ -20,6 +20,7 @@ from __future__ import annotations +import collections.abc import warnings from pathlib import Path from typing import Any, Literal, Optional @@ -196,7 +197,7 @@ def _process_stats_dict( """Process statistics into dict of tensors for each field.""" result: dict[str, torch.Tensor] = {} - if isinstance(stats, dict): + if isinstance(stats, collections.abc.Mapping): for key in self.input_keys: if key not in stats: raise ValueError( diff --git a/physicsnemo/datapipes/transforms/spatial.py b/physicsnemo/datapipes/transforms/spatial.py index c791d63333..57114f76e9 100644 --- a/physicsnemo/datapipes/transforms/spatial.py +++ b/physicsnemo/datapipes/transforms/spatial.py @@ -23,9 +23,10 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Tuple import torch +import torch.nn.functional as F from tensordict import TensorDict from physicsnemo.datapipes.registry import register @@ -523,3 +524,140 @@ def __repr__(self) -> str: return ( f"CenterOfMass(coords_key={self.coords_key}, output_key={self.output_key})" ) + + +@register() +class Resize(Transform): + r""" + Resize a set of grid tensors via interpolation. + + Applies spatial resizing to tensors identified by ``input_keys`` using + :func:`torch.nn.functional.interpolate`. Transforms operate on + single-sample data (no batch dimension). Supports 2D tensors + :math:`(C, H, W)` and 3D tensors :math:`(C, D, H, W)`. + + Parameters + ---------- + input_keys : list[str] + Keys of tensors to resize. Each tensor must have shape + :math:`(C, H, W)` for 2D or :math:`(C, D, H, W)` for 3D. + size : tuple[int, ...] + Target spatial size. For 2D use :math:`(H, W)`, for 3D use + :math:`(D, H, W)`. + mode : str, optional + Interpolation mode. One of ``"nearest"``, ``"bilinear"``, + ``"bicubic"`` (2D only), ``"trilinear"`` (3D only), ``"area"``. + Default is ``"bilinear"`` for 2D and ``"trilinear"`` for 3D. + align_corners : bool, optional + Used for ``"bilinear"``, ``"bicubic"``, ``"trilinear"``. + See :func:`torch.nn.functional.interpolate`. Default is ``False``. + + Examples + -------- + >>> transform = Resize( + ... input_keys=["pressure", "velocity"], + ... size=(64, 64), + ... mode="bilinear", + ... ) + >>> sample = TensorDict({ + ... "pressure": torch.randn(1, 128, 128), + ... "velocity": torch.randn(2, 128, 128), + ... }) + >>> result = transform(sample) + >>> result["pressure"].shape + torch.Size([1, 64, 64]) + >>> result["velocity"].shape + torch.Size([2, 64, 64]) + """ + + def __init__( + self, + input_keys: list[str], + size: Tuple[int, ...], + *, + mode: Optional[str] = None, + align_corners: bool = False, + ) -> None: + """ + Initialize the resize transform. + + Parameters + ---------- + input_keys : list[str] + Keys of tensors to resize. + size : tuple[int, ...] + Target spatial size, e.g. :math:`(H, W)` or :math:`(D, H, W)`. + mode : str, optional + Interpolation mode. Defaults by spatial dims: ``"bilinear"`` (2D), + ``"trilinear"`` (3D). + align_corners : bool, optional + Passed to :func:`torch.nn.functional.interpolate`. Default ``False``. + """ + super().__init__() + self.input_keys = input_keys + self.size = tuple(size) + ndim = len(self.size) + if ndim == 2: + self._default_mode: str = "bilinear" + elif ndim == 3: + self._default_mode = "trilinear" + else: + raise ValueError(f"size must have 2 or 3 spatial dimensions, got {ndim}") + self.mode = mode if mode is not None else self._default_mode + self.align_corners = align_corners + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Resize each tensor in ``input_keys`` to the target spatial size. + + Parameters + ---------- + data : TensorDict + Input TensorDict containing grid tensors to resize. + + Returns + ------- + TensorDict + TensorDict with resized tensors in place of originals. + """ + n_spatial = len(self.size) + # Single-sample only: (C, H, W) or (C, D, H, W); also accept (H, W) / (D, H, W) as single-channel + expected_ndim_with_channel = n_spatial + 1 # channel + spatial + + interp_kw: dict = {"size": self.size, "mode": self.mode} + if self.mode not in ("nearest", "area"): + interp_kw["align_corners"] = self.align_corners + + updates = {} + for key in self.input_keys: + if key not in data: + continue + t = data[key] + if not isinstance(t, torch.Tensor) or not t.is_floating_point(): + continue + ndim = t.ndim + if ndim == n_spatial: + # (H, W) or (D, H, W): treat as single-channel for interpolate + t = t.unsqueeze(0) + elif ndim != expected_ndim_with_channel: + continue + # Add batch dim for F.interpolate, then remove; restore to original ndim if we added channel + out = F.interpolate(t.unsqueeze(0), **interp_kw).squeeze(0) + if ndim == n_spatial: + out = out.squeeze(0) + updates[key] = out + return data.update(updates) + + def __repr__(self) -> str: + """ + Return string representation. + + Returns + ------- + str + String representation of the transform. + """ + return ( + f"Resize(input_keys={self.input_keys}, size={self.size}, " + f"mode={self.mode!r})" + ) diff --git a/physicsnemo/datapipes/transforms/utility.py b/physicsnemo/datapipes/transforms/utility.py index 962c1de158..9dd03da424 100644 --- a/physicsnemo/datapipes/transforms/utility.py +++ b/physicsnemo/datapipes/transforms/utility.py @@ -23,7 +23,7 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Sequence import torch from tensordict import TensorDict @@ -487,3 +487,93 @@ def extra_repr(self) -> str: f"fill_value={self.fill_value}, " f"output_dim={self.output_dim}" ) + + +@register() +class Reshape(Transform): + r""" + Reshape specified TensorDict fields to target shapes. + + Applies :func:`torch.reshape` so each specified field gets the given shape. + At most one dimension in the shape may be ``-1``, which is inferred from + the tensor's element count. Useful to unify layouts across datasets (e.g. + :math:`(1, H, W)` to :math:`(H, W)`) or to flatten/spread dimensions. + + Parameters + ---------- + keys : list[str] + TensorDict keys to reshape. Only these keys are modified; others are + left unchanged. + shape : tuple[int, ...] or list[int] + Target shape for all specified keys. Use ``-1`` for at most one + dimension to infer from the tensor size. + + Examples + -------- + Drop a leading singleton dimension (e.g. single-channel image): + + >>> transform = Reshape(keys=["y"], shape=(256, 256)) + >>> data = TensorDict({"x": torch.randn(256, 256), "y": torch.randn(1, 256, 256)}) + >>> result = transform(data) + >>> result["y"].shape + torch.Size([256, 256]) + + Flatten spatial dimensions: + + >>> transform = Reshape(keys=["features"], shape=(-1,)) + >>> data = TensorDict({"features": torch.randn(4, 8, 8)}) + >>> transform(data)["features"].shape + torch.Size([256]) + """ + + def __init__( + self, + keys: list[str], + shape: Sequence[int], + ) -> None: + """ + Initialize the reshape transform. + + Parameters + ---------- + keys : list[str] + TensorDict keys to reshape. + shape : tuple or list of int + Target shape. At most one entry may be -1 (inferred). + """ + super().__init__() + self.keys = list(keys) + self.shape = tuple(int(s) for s in shape) + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Reshape specified fields in the TensorDict. + + Parameters + ---------- + data : TensorDict + Input TensorDict. + + Returns + ------- + TensorDict + TensorDict with reshaped tensors for the specified keys. + Keys not present in the data are skipped. + """ + out = data.clone() + for key in self.keys: + if key not in out.keys(): + continue + out[key] = out[key].reshape(self.shape) + return out + + def extra_repr(self) -> str: + """ + Return extra information for repr. + + Returns + ------- + str + String with transform parameters. + """ + return f"keys={self.keys}, shape={self.shape}" diff --git a/test/datapipes/core/test_multi_dataset.py b/test/datapipes/core/test_multi_dataset.py new file mode 100644 index 0000000000..4fde358658 --- /dev/null +++ b/test/datapipes/core/test_multi_dataset.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MultiDataset class.""" + +import pytest + +import physicsnemo.datapipes as dp +from physicsnemo.datapipes.multi_dataset import DATASET_INDEX_METADATA_KEY + + +class TestMultiDatasetBasic: + """Basic MultiDataset functionality.""" + + def test_create_multi_dataset(self, numpy_data_dir): + """MultiDataset with two datasets has combined length.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert len(multi) == len(ds_a) + len(ds_b) + + def test_create_multi_dataset_three_or_more(self, numpy_data_dir): + """MultiDataset with three+ datasets has combined length and correct index mapping.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_c = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b, ds_c], output_strict=True) + + assert len(multi) == len(ds_a) + len(ds_b) + len(ds_c) + assert multi[0][1][DATASET_INDEX_METADATA_KEY] == 0 + assert multi[10][1][DATASET_INDEX_METADATA_KEY] == 1 + assert multi[20][1][DATASET_INDEX_METADATA_KEY] == 2 + + def test_getitem_maps_to_correct_dataset(self, numpy_data_dir): + """Indices 0..len0-1 from first dataset, then second.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + # First 10 from ds_a (dataset_index 0) + data0, meta0 = multi[0] + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + assert "positions" in data0 + + data9, meta9 = multi[9] + assert meta9[DATASET_INDEX_METADATA_KEY] == 0 + + # Next 10 from ds_b (dataset_index 1) + data10, meta10 = multi[10] + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + data19, meta19 = multi[19] + assert meta19[DATASET_INDEX_METADATA_KEY] == 1 + + def test_getitem_preserves_sub_dataset_metadata(self, numpy_data_dir): + """Metadata from sub-dataset (e.g. index) is preserved alongside dataset_index.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + data0, meta0 = multi[0] + assert meta0["index"] == 0 + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + + data10, meta10 = multi[10] + assert meta10["index"] == 0 # first sample of second dataset + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + def test_negative_indexing(self, numpy_data_dir): + """Negative indices work as expected.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + data_last, meta_last = multi[-1] + assert meta_last[DATASET_INDEX_METADATA_KEY] == 1 + data_first, meta_first = multi[-20] + assert meta_first[DATASET_INDEX_METADATA_KEY] == 0 + + def test_field_names_strict(self, numpy_data_dir): + """output_strict=True returns common output keys (TensorDict keys).""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert "positions" in multi.field_names + assert "features" in multi.field_names + + def test_iteration(self, numpy_data_dir): + """Iteration yields all samples in order with dataset_index in metadata.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + seen_indices = [] + seen_dataset_indices = [] + for data, meta in multi: + seen_indices.append(meta["index"]) + seen_dataset_indices.append(meta[DATASET_INDEX_METADATA_KEY]) + + assert len(seen_indices) == 20 + assert seen_dataset_indices[:10] == [0] * 10 + assert seen_dataset_indices[10:] == [1] * 10 + + def test_context_manager(self, numpy_data_dir): + """MultiDataset as context manager closes sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + with dp.MultiDataset([ds_a, ds_b], output_strict=True) as multi: + data, meta = multi[0] + assert meta[DATASET_INDEX_METADATA_KEY] == 0 + + +class TestMultiDatasetStrictValidation: + """Output strictness validation.""" + + def test_strict_raises_when_output_keys_differ(self, numpy_data_dir): + """output_strict=True raises if datasets produce different output keys.""" + ds_full = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) + + with pytest.raises(ValueError, match="output keys"): + dp.MultiDataset([ds_full, ds_pos_only], output_strict=True) + + def test_non_strict_accepts_different_fields(self, numpy_data_dir): + """output_strict=False does not validate output keys; field_names is first dataset's.""" + ds_full = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) + + multi = dp.MultiDataset([ds_full, ds_pos_only], output_strict=False) + assert len(multi) == 20 + assert set(multi.field_names) == set(ds_full.field_names) + + def test_strict_validates_output_keys_not_reader_fields(self, numpy_data_dir): + """output_strict compares TensorDict keys after transforms, not reader field_names.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + assert len(multi.field_names) >= 2 # positions, features + + +class TestMultiDatasetPrefetchAndClose: + """Prefetch and close delegation.""" + + def test_prefetch_delegates(self, numpy_data_dir): + """Prefetch delegates to correct sub-dataset.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.prefetch(10) + data0, meta0 = multi[0] + data10, meta10 = multi[10] + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + def test_cancel_prefetch_all(self, numpy_data_dir): + """cancel_prefetch(None) clears all sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.prefetch(10) + multi.cancel_prefetch() + # Should still be able to get data synchronously + data0, _ = multi[0] + assert "positions" in data0 + + def test_cancel_prefetch_invalid_index_no_op(self, numpy_data_dir): + """cancel_prefetch(out-of-range index) does not raise (matches Dataset).""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.cancel_prefetch(999) # out of range, should no-op + multi.cancel_prefetch(-1) # also out of range + data0, _ = multi[0] + assert "positions" in data0 + + def test_prefetch_batch(self, numpy_data_dir): + """prefetch_batch delegates to sub-datasets by global index.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch_batch([0, 1, 10, 11]) + for idx in [0, 1, 10, 11]: + data, meta = multi[idx] + assert meta[DATASET_INDEX_METADATA_KEY] == (0 if idx < 10 else 1) + assert "positions" in data + + def test_prefetch_count(self, numpy_data_dir): + """prefetch_count is sum of sub-dataset prefetch counts.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert multi.prefetch_count == 0 + multi.prefetch_batch([0, 1, 2, 3]) + assert multi.prefetch_count >= 0 # may complete quickly + multi.cancel_prefetch() + assert multi.prefetch_count == 0 + + def test_close_closes_all(self, numpy_data_dir): + """close() closes all sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi.close() + # After close, sub-datasets are closed (no-op to call again) + + +class TestMultiDatasetErrors: + """Error cases.""" + + def test_requires_at_least_one_datasets(self, numpy_data_dir): + """MultiDataset requires at least one datasets.""" + with pytest.raises(ValueError, match="at least one"): + dp.MultiDataset([], output_strict=True) + + def test_requires_dataset_instances(self, numpy_data_dir): + """All elements must be Dataset instances.""" + ds = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + with pytest.raises(TypeError, match="must be a Dataset"): + dp.MultiDataset([ds, dp.NumpyReader(numpy_data_dir)], output_strict=False) + + def test_index_out_of_range(self, numpy_data_dir): + """Index out of range raises IndexError.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + with pytest.raises(IndexError, match="out of range"): + _ = multi[20] + with pytest.raises(IndexError, match="out of range"): + _ = multi[-21] + + def test_prefetch_out_of_range_raises(self, numpy_data_dir): + """prefetch with out-of-range index raises IndexError.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + with pytest.raises(IndexError, match="out of range"): + multi.prefetch(20) + + +class TestMultiDatasetWithDataLoader: + """DataLoader accepts MultiDataset (same interface as Dataset).""" + + def test_dataloader_with_multi_dataset(self, numpy_data_dir): + """DataLoader iterates over MultiDataset and collates batches.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader(multi, batch_size=4, shuffle=False) + assert len(loader) == 5 # 20 / 4 + + batches = list(loader) + assert len(batches) == 5 + assert batches[0]["positions"].shape[0] == 4 + assert batches[-1]["positions"].shape[0] == 4 + + def test_dataloader_with_multi_dataset_and_metadata(self, numpy_data_dir): + """Collate metadata includes dataset_index from MultiDataset.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader( + multi, batch_size=5, shuffle=False, collate_metadata=True + ) + batch_data, metadata_list = next(iter(loader)) + + assert len(metadata_list) == 5 + assert [m[DATASET_INDEX_METADATA_KEY] for m in metadata_list] == [ + 0, + 0, + 0, + 0, + 0, + ] + + batches = list(loader) + _, meta_batch_2 = batches[2] # indices 10-14, all from dataset 1 + assert all(m[DATASET_INDEX_METADATA_KEY] == 1 for m in meta_batch_2) + + def test_dataloader_shuffle_with_multi_dataset(self, numpy_data_dir): + """Shuffled DataLoader over MultiDataset yields all indices.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader(multi, batch_size=4, shuffle=True, collate_metadata=True) + all_dataset_indices = [] + for batch_data, metadata_list in loader: + all_dataset_indices.extend( + m[DATASET_INDEX_METADATA_KEY] for m in metadata_list + ) + assert set(all_dataset_indices) == {0, 1} + assert len(all_dataset_indices) == 20 + + +class TestMultiDatasetRepr: + """String representation.""" + + def test_repr(self, numpy_data_dir): + """Repr includes output_strict and datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + r = repr(multi) + assert "MultiDataset" in r + assert "output_strict=True" in r diff --git a/test/datapipes/core/test_transforms.py b/test/datapipes/core/test_transforms.py index 6048ecf5c9..d3dbcbecf6 100644 --- a/test/datapipes/core/test_transforms.py +++ b/test/datapipes/core/test_transforms.py @@ -807,6 +807,23 @@ def test_normalize_repr(): # assert "100" in repr(ds) +def test_normalize_accepts_ordered_dict_stats(): + """Test that Normalize accepts collections.abc.Mapping (e.g. OrderedDict) for stats.""" + from collections import OrderedDict + + sample = TensorDict({"x": torch.tensor([10.0, 20.0, 30.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="mean_std", + means=OrderedDict([("x", 20.0)]), + stds=OrderedDict([("x", 10.0)]), + ) + + result = norm(sample) + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + def test_compose_repr(): pipeline = dp.Compose( [ diff --git a/test/datapipes/readers/test_numpy_consolidated.py b/test/datapipes/readers/test_numpy_consolidated.py index c564ef5534..52358bdaf1 100644 --- a/test/datapipes/readers/test_numpy_consolidated.py +++ b/test/datapipes/readers/test_numpy_consolidated.py @@ -283,5 +283,232 @@ def test_close_handles(self): reader2.close() +class TestNumpyReaderPreload: + """Tests for preload_to_cpu functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_preload_basic(self): + """Test that preload_to_cpu loads data into RAM and closes the file.""" + coords = np.random.randn(15, 3).astype(np.float32) + features = np.random.randn(15, 4).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=True + ) + + assert reader._data is None + assert reader._preloaded is not None + assert "coords" in reader._preloaded + assert "features" in reader._preloaded + assert len(reader) == 15 + + data, metadata = reader[0] + assert data["coords"].shape == (3,) + assert data["features"].shape == (4,) + torch.testing.assert_close( + data["coords"], torch.from_numpy(coords[0]), atol=1e-6, rtol=1e-6 + ) + + def test_preload_matches_non_preloaded(self): + """Test that preloaded data matches non-preloaded data.""" + np.random.seed(42) + coords = np.random.randn(10, 3).astype(np.float32) + features = np.random.randn(10, 5).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader_disk = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=False + ) + reader_ram = NumpyReader( + npz_path, fields=["coords", "features"], preload_to_cpu=True + ) + + for i in range(len(reader_disk)): + data_disk, _ = reader_disk[i] + data_ram, _ = reader_ram[i] + torch.testing.assert_close(data_disk["coords"], data_ram["coords"]) + torch.testing.assert_close(data_disk["features"], data_ram["features"]) + + reader_disk.close() + reader_ram.close() + + def test_preload_with_default_values(self): + """Test preload with default values for missing fields.""" + coords = np.random.randn(10, 100, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + default_normals = torch.ones(100, 3, dtype=torch.float64) + reader = NumpyReader( + npz_path, + fields=["coords", "normals"], + default_values={"normals": default_normals}, + preload_to_cpu=True, + ) + + data, _ = reader[0] + assert "normals" in data + assert data["normals"].dtype == torch.float32 + reader.close() + + def test_preload_missing_required_field_raises(self): + """Test that preload raises KeyError for missing required fields.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + with pytest.raises(KeyError, match="Required fields"): + NumpyReader( + npz_path, + fields=["coords", "missing_field"], + preload_to_cpu=True, + ) + + def test_preload_ignored_in_directory_mode(self): + """Test that preload_to_cpu is ignored in directory mode.""" + for i in range(3): + coords = np.random.randn(50, 3).astype(np.float32) + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords"], + preload_to_cpu=True, + ) + + assert reader._preloaded is None + assert len(reader) == 3 + + data, _ = reader[0] + assert data["coords"].shape == (50, 3) + reader.close() + + def test_preload_with_coordinated_subsampling(self): + """Test preloaded reader with coordinated subsampling.""" + n_samples = 5 + n_points = 1000 + subsample_points = 100 + + coords = np.random.randn(n_samples, n_points, 3).astype(np.float32) + features = np.random.randn(n_samples, n_points, 4).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + reader = NumpyReader( + npz_path, + fields=["coords", "features"], + preload_to_cpu=True, + coordinated_subsampling={ + "n_points": subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + assert reader._preloaded is not None + data, _ = reader[0] + assert data["coords"].shape == (subsample_points, 3) + assert data["features"].shape == (subsample_points, 4) + reader.close() + + def test_preload_close_releases_memory(self): + """Test that close() releases preloaded arrays.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"], preload_to_cpu=True) + assert reader._preloaded is not None + + reader.close() + assert reader._preloaded is None + + def test_preload_repr(self): + """Test that repr includes preload_to_cpu info.""" + coords = np.random.randn(10, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"], preload_to_cpu=True) + assert "preload_to_cpu=True" in repr(reader) + reader.close() + + +class TestNumpyReaderFloat32: + """Tests for float32 conversion behavior.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_float64_converted_to_float32(self): + """Test that float64 numpy arrays are returned as float32 tensors.""" + coords = np.random.randn(10, 3).astype(np.float64) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path, fields=["coords"]) + data, _ = reader[0] + assert data["coords"].dtype == torch.float32 + reader.close() + + def test_float64_converted_to_float32_directory_mode(self): + """Test float64 conversion in directory mode.""" + for i in range(3): + coords = np.random.randn(50, 3).astype(np.float64) + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader( + self.temp_path, file_pattern="sample_*.npz", fields=["coords"] + ) + data, _ = reader[0] + assert data["coords"].dtype == torch.float32 + reader.close() + + def test_default_values_converted_to_float32(self): + """Test that default values are returned as float32.""" + coords = np.random.randn(10, 100, 3).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + default_normals = torch.zeros(100, 3, dtype=torch.float64) + reader = NumpyReader( + npz_path, + fields=["coords", "normals"], + default_values={"normals": default_normals}, + ) + + data, _ = reader[0] + assert data["normals"].dtype == torch.float32 + reader.close() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/datapipes/transforms/test_spatial.py b/test/datapipes/transforms/test_spatial.py index f33d33d206..11ad4aa774 100644 --- a/test/datapipes/transforms/test_spatial.py +++ b/test/datapipes/transforms/test_spatial.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for spatial transforms (BoundingBoxFilter, CreateGrid, KNearestNeighbors, CenterOfMass).""" +"""Tests for spatial transforms (BoundingBoxFilter, CreateGrid, KNearestNeighbors, CenterOfMass, Resize).""" import pytest import torch @@ -25,6 +25,7 @@ CenterOfMass, CreateGrid, KNearestNeighbors, + Resize, ) # ============================================================================ @@ -604,6 +605,125 @@ def test_repr(self): assert "center_of_mass" in repr_str +# ============================================================================ +# Resize Tests +# ============================================================================ + + +class TestResize: + """Tests for Resize transform.""" + + def test_resize_2d_basic(self): + """Test basic 2D resizing of (C, H, W) tensor.""" + transform = Resize(input_keys=["pressure"], size=(32, 32)) + data = TensorDict({"pressure": torch.randn(1, 128, 128)}) + + result = transform(data) + assert result["pressure"].shape == (1, 32, 32) + + def test_resize_3d_basic(self): + """Test basic 3D resizing of (C, D, H, W) tensor.""" + transform = Resize(input_keys=["field"], size=(8, 16, 16)) + data = TensorDict({"field": torch.randn(3, 32, 64, 64)}) + + result = transform(data) + assert result["field"].shape == (3, 8, 16, 16) + + def test_resize_no_channel_dim(self): + """Test resizing an (H, W) tensor without channel dimension.""" + transform = Resize(input_keys=["image"], size=(16, 16)) + data = TensorDict({"image": torch.randn(64, 64)}) + + result = transform(data) + assert result["image"].shape == (16, 16) + + def test_resize_default_mode_2d(self): + """Test that default mode for 2D size is bilinear.""" + transform = Resize(input_keys=["x"], size=(32, 32)) + assert transform.mode == "bilinear" + + def test_resize_default_mode_3d(self): + """Test that default mode for 3D size is trilinear.""" + transform = Resize(input_keys=["x"], size=(8, 8, 8)) + assert transform.mode == "trilinear" + + def test_resize_nearest_mode(self): + """Test explicit nearest mode (no align_corners needed).""" + transform = Resize(input_keys=["x"], size=(16, 16), mode="nearest") + data = TensorDict({"x": torch.randn(2, 64, 64)}) + + result = transform(data) + assert result["x"].shape == (2, 16, 16) + + def test_resize_align_corners(self): + """Test align_corners=True with bilinear mode.""" + transform = Resize( + input_keys=["x"], size=(16, 16), mode="bilinear", align_corners=True + ) + data = TensorDict({"x": torch.randn(1, 64, 64)}) + + result = transform(data) + assert result["x"].shape == (1, 16, 16) + + def test_resize_missing_key_skipped(self): + """Test that missing keys are silently skipped.""" + transform = Resize(input_keys=["missing"], size=(16, 16)) + original = torch.randn(1, 64, 64) + data = TensorDict({"present": original.clone()}) + + result = transform(data) + assert "present" in result + torch.testing.assert_close(result["present"], original) + + def test_resize_non_float_skipped(self): + """Test that integer tensors are skipped.""" + transform = Resize(input_keys=["mask"], size=(16, 16)) + int_tensor = torch.randint(0, 2, (1, 64, 64)) + data = TensorDict({"mask": int_tensor}) + + result = transform(data) + assert result["mask"].shape == (1, 64, 64) + + def test_resize_invalid_size_dims(self): + """Test that invalid size dimensions raise ValueError.""" + with pytest.raises(ValueError, match="2 or 3 spatial dimensions"): + Resize(input_keys=["x"], size=(16,)) + + with pytest.raises(ValueError, match="2 or 3 spatial dimensions"): + Resize(input_keys=["x"], size=(4, 8, 16, 32)) + + def test_resize_preserves_other_fields(self): + """Test that non-input fields are untouched.""" + transform = Resize(input_keys=["field"], size=(16, 16)) + other = torch.randn(50, 3) + data = TensorDict({"field": torch.randn(1, 64, 64), "other": other.clone()}) + + result = transform(data) + assert result["field"].shape == (1, 16, 16) + torch.testing.assert_close(result["other"], other) + + def test_resize_multiple_keys(self): + """Test resizing multiple input keys.""" + transform = Resize(input_keys=["pressure", "velocity"], size=(32, 32)) + data = TensorDict( + { + "pressure": torch.randn(1, 128, 128), + "velocity": torch.randn(2, 128, 128), + } + ) + + result = transform(data) + assert result["pressure"].shape == (1, 32, 32) + assert result["velocity"].shape == (2, 32, 32) + + def test_resize_repr(self): + """Test string representation.""" + transform = Resize(input_keys=["pressure"], size=(64, 64), mode="bilinear") + repr_str = repr(transform) + assert "Resize" in repr_str + assert "bilinear" in repr_str + + # ============================================================================ # Integration Tests # ============================================================================ diff --git a/test/datapipes/transforms/test_utility.py b/test/datapipes/transforms/test_utility.py index 5661168257..115b09aaeb 100644 --- a/test/datapipes/transforms/test_utility.py +++ b/test/datapipes/transforms/test_utility.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for utility transforms: Rename, Purge, ConstantField, and ZeroLike.""" +"""Tests for utility transforms: Rename, Purge, ConstantField, Reshape, and ZeroLike.""" import pytest import torch from tensordict import TensorDict -from physicsnemo.datapipes.transforms import ConstantField, Purge, Rename +from physicsnemo.datapipes.transforms import ConstantField, Purge, Rename, Reshape class TestRename: @@ -842,3 +842,64 @@ def test_extra_repr(self): assert "output_key" in repr_str assert "fill_value" in repr_str assert "output_dim" in repr_str + + +class TestReshape: + """Tests for the Reshape transform.""" + + def test_reshape_basic(self): + """Test basic reshape from (1, H, W) to (H, W).""" + data = TensorDict({"y": torch.randn(1, 256, 256)}) + transform = Reshape(keys=["y"], shape=(256, 256)) + + result = transform(data) + assert result["y"].shape == torch.Size([256, 256]) + + def test_reshape_with_inferred_dim(self): + """Test reshape with -1 for inferred dimension.""" + data = TensorDict({"features": torch.randn(4, 8, 8)}) + transform = Reshape(keys=["features"], shape=(-1,)) + + result = transform(data) + assert result["features"].shape == torch.Size([256]) + + def test_reshape_missing_key_skipped(self): + """Test that missing keys are silently skipped.""" + data = TensorDict({"x": torch.randn(10, 3)}) + transform = Reshape(keys=["missing"], shape=(30,)) + + result = transform(data) + assert "x" in result + assert result["x"].shape == (10, 3) + + def test_reshape_preserves_other_fields(self): + """Test that non-target fields are untouched.""" + original = torch.randn(50, 3) + data = TensorDict({"target": torch.randn(1, 50, 3), "other": original.clone()}) + transform = Reshape(keys=["target"], shape=(50, 3)) + + result = transform(data) + assert result["target"].shape == torch.Size([50, 3]) + torch.testing.assert_close(result["other"], original) + + def test_reshape_multiple_keys(self): + """Test reshaping multiple keys.""" + data = TensorDict( + { + "a": torch.randn(1, 64, 64), + "b": torch.randn(1, 64, 64), + } + ) + transform = Reshape(keys=["a", "b"], shape=(64, 64)) + + result = transform(data) + assert result["a"].shape == torch.Size([64, 64]) + assert result["b"].shape == torch.Size([64, 64]) + + def test_reshape_extra_repr(self): + """Test extra_repr output.""" + transform = Reshape(keys=["y"], shape=(256, 256)) + repr_str = transform.extra_repr() + + assert "keys" in repr_str + assert "shape" in repr_str