From 916ea036123c17f63eaf01252e2b0cd4d582c204 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Sat, 14 Mar 2026 09:57:48 -0500 Subject: [PATCH 1/4] Adding most of the multi-dataset cleanly --- docs/api/datapipes/physicsnemo.datapipes.rst | 23 ++ physicsnemo/datapipes/__init__.py | 2 + physicsnemo/datapipes/dataloader.py | 4 +- physicsnemo/datapipes/multi_dataset.py | 378 +++++++++++++++++++ test/datapipes/core/test_multi_dataset.py | 333 ++++++++++++++++ 5 files changed, 739 insertions(+), 1 deletion(-) create mode 100644 physicsnemo/datapipes/multi_dataset.py create mode 100644 test/datapipes/core/test_multi_dataset.py diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 13a6193b25..57b5fb5958 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -153,6 +153,29 @@ the ``Dataset`` is responsible for the threaded execution of ``Reader``s and :members: :show-inheritance: +MultiDataset +^^^^^^^^^^^^ + +The ``MultiDataset`` composes two or more ``Dataset`` instances behind a single +index space (concatenation). Each sub-dataset can have its own Reader and +transforms. Global indices are mapped to the owning sub-dataset and local index; +metadata is enriched with ``dataset_index`` so batches can identify the source. +Use ``MultiDataset`` when you want to train on multiple datasets with the same +DataLoader, optionally enforcing that all outputs share the same TensorDict keys +for collation. See :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY` +for the metadata key added to each sample. + +Note that to properly collate and stack outputs from different datasets, you +can set ``output_strict=True`` in the constructor of a ``MultiDataset``. Upon +construction, it will load the first batch from every passed dataset and test +that the tensordict produced by the ``Reader`` and ``Transform`` pipeline has +consistent keys. Because the exact collation details differ by dataset, the +``MultiDataset`` does not check more aggressively than output key consistency. + +.. autoclass:: physicsnemo.datapipes.multi_dataset.MultiDataset + :members: + :show-inheritance: + Readers ^^^^^^^ diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index e693a61df6..0e3706475a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -40,6 +40,7 @@ ) from physicsnemo.datapipes.dataloader import DataLoader from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.multi_dataset import MultiDataset from physicsnemo.datapipes.readers import ( HDF5Reader, NumpyReader, @@ -84,6 +85,7 @@ "TensorDict", # Re-export from tensordict "Dataset", "DataLoader", + "MultiDataset", # Transforms - Base "Transform", "Compose", diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index c6f465d425..cf3b160cd2 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -168,7 +168,9 @@ def __len__(self) -> int: int Number of batches in the dataloader. """ - n_samples = len(self.dataset) + n_samples = ( + len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) + ) if self.drop_last: return n_samples // self.batch_size return (n_samples + self.batch_size - 1) // self.batch_size diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py new file mode 100644 index 0000000000..8b6152a10f --- /dev/null +++ b/physicsnemo/datapipes/multi_dataset.py @@ -0,0 +1,378 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MultiDataset - Compose multiple Dataset instances behind a single dataset-like interface. + +MultiDataset presents a single index space (concatenation of all constituent datasets) +and delegates __getitem__, prefetch, and close to the appropriate sub-dataset. +Each sub-dataset can have its own Reader and transforms. Optional output strictness +validates that all sub-datasets produce the same TensorDict keys (outputs) so default +collation works. +""" + +from __future__ import annotations + +from typing import Any, Iterator, Optional, Sequence + +from tensordict import TensorDict + +from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.registry import register + +# Metadata key added by MultiDataset to identify which sub-dataset produced the sample. +DATASET_INDEX_METADATA_KEY = "dataset_index" + + +def _validate_strict_outputs(datasets: Sequence[Dataset]) -> list[str]: + """ + Check that all non-empty datasets produce the same TensorDict keys; return them. + + Loads one sample from each non-empty dataset and compares output keys (after + transforms). This validates output schema, not reader field_names. + + Parameters + ---------- + datasets : Sequence[Dataset] + Datasets to validate. + + Returns + ------- + list[str] + Common output keys (sorted, from first non-empty dataset). + + Raises + ------ + ValueError + If any non-empty dataset has different output keys. + """ + if not datasets: + return [] + ref_keys: Optional[list[str]] = None + ref_index: Optional[int] = None + for i, ds in enumerate(datasets): + if len(ds) == 0: + continue + data, _ = ds[0] + keys = sorted(data.keys()) + if ref_keys is None: + ref_keys = keys + ref_index = i + elif keys != ref_keys: + raise ValueError( + "output_strict=True requires identical output keys (TensorDict keys) " + f"across datasets: dataset {ref_index} has {ref_keys}, dataset {i} has {keys}" + ) + return list(ref_keys) if ref_keys is not None else list(datasets[0].field_names) + + +@register() +class MultiDataset: + r""" + A dataset that composes multiple :class:`Dataset` instances behind one index space. + + Global indices are mapped to (dataset_index, local_index) by concatenation: + indices 0..len0-1 come from the first dataset, len0..len0+len1-1 from the second, + and so on. Each constituent can have its own Reader and transforms. Metadata + is enriched with ``dataset_index`` so batches can identify the source. + + Parameters + ---------- + datasets : Sequence[Dataset] + One or more Dataset instances (Reader + transforms each). Order defines + index mapping: first dataset occupies 0..len(ds0)-1, etc. + output_strict : bool, default=True + If True, require all datasets to produce the same TensorDict keys (output + keys after transforms) so :class:`DefaultCollator` can stack batches. If + False, no check is done; use a custom collator when keys or shapes differ. + Note that `output_strict=True` will load the first instance of all datasets + upon construction. Think of it as a debugging parameter: if you are sure + that your datasets are working properly, and want to defer loading, + you can safely disable this. + + Raises + ------ + ValueError + If ``len(datasets) < 1`` or if ``output_strict=True`` and output keys differ. + + Notes + ----- + MultiDataset implements the same interface as :class:`Dataset` (``__len__``, + ``__getitem__``, ``prefetch``, ``prefetch_batch``, ``prefetch_count``, + ``cancel_prefetch``, ``close``, ``field_names``) and can be passed to + :class:`DataLoader` in place of a single dataset. Prefetch and close are + delegated to the sub-dataset that owns the index. When ``output_strict=True``, + validation checks that each dataset's *output* TensorDict (after transforms) + has the same keys, not the reader's field_names. When ``output_strict=False``, + :attr:`field_names` returns the first dataset's field names; with heterogeneous + datasets, prefer a custom collator and use metadata ``dataset_index`` to + group or pad by source. + + Shuffling and sampling + --------------------- + The DataLoader sees one linear index space of size :math:`\\sum_k \\text{len}(\\text{datasets}[k])`. + With ``shuffle=True``, the default :class:`RandomSampler` shuffles these global + indices, so each batch is a random mix of samples from all sub-datasets. There + is no per-dataset balancing: if one dataset is much larger, its samples will + appear more often. For balanced or stratified sampling, use a custom + :class:`torch.utils.data.Sampler` (e.g. weighted or one sample per dataset per + batch) and pass it to the DataLoader. + + Metadata + -------- + Every sample returned by :meth:`__getitem__` has its metadata dict extended + with the key :const:`DATASET_INDEX_METADATA_KEY` (``"dataset_index"``), the + integer index of the sub-dataset that produced the sample (0 for the first + dataset, 1 for the second, etc.). Sub-dataset–specific metadata (e.g. file + path, index within that dataset) is unchanged. When using the DataLoader with + ``collate_metadata=True``, each batch yields a list of metadata dicts aligned + with the batch dimension; each dict includes ``dataset_index`` so you can + filter, weight, or aggregate by source in the training loop. + + Examples + -------- + >>> from physicsnemo.datapipes import Dataset, MultiDataset, HDF5Reader, Normalize + >>> ds_a = Dataset(HDF5Reader("a.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP + >>> ds_b = Dataset(HDF5Reader("b.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP + >>> multi = MultiDataset([ds_a, ds_b], output_strict=True) # doctest: +SKIP + >>> len(multi) == len(ds_a) + len(ds_b) # doctest: +SKIP + True + >>> data, meta = multi[0] # from ds_a # doctest: +SKIP + >>> meta["dataset_index"] # 0 # doctest: +SKIP + """ + + def __init__( + self, + datasets: Sequence[Dataset], + *, + output_strict: bool = True, + ) -> None: + if len(datasets) < 1: + raise ValueError( + f"MultiDataset requires at least one dataset, got {len(datasets)}" + ) + for i, ds in enumerate(datasets): + if not isinstance(ds, Dataset): + raise TypeError( + f"datasets[{i}] must be a Dataset instance, got {type(ds).__name__}" + ) + + self._datasets = list(datasets) + self._output_strict = output_strict + + # Cumulative lengths: cumul[k] = sum(len(datasets[j]) for j in range(k)) + # So index i is in dataset k when cumul[k] <= i < cumul[k+1], local = i - cumul[k] + cumul = [0] + for ds in self._datasets: + cumul.append(cumul[-1] + len(ds)) + self._cumul = cumul + + if output_strict: + self._field_names = _validate_strict_outputs(self._datasets) + else: + self._field_names = list(self._datasets[0].field_names) + + def _index_to_dataset_and_local(self, index: int) -> tuple[int, int]: + """ + Map global index to (dataset_index, local_index). + + Parameters + ---------- + index : int + Global index in [0, len(self)). + + Returns + ------- + tuple[int, int] + (dataset_index, local_index). + + Raises + ------ + IndexError + If index is out of range. + """ + n = len(self) + if index < 0: + index = n + index + if index < 0 or index >= n: + raise IndexError( + f"Index {index} out of range for MultiDataset with {n} samples" + ) + # Find k such that cumul[k] <= index < cumul[k+1] + for k in range(len(self._cumul) - 1): + if self._cumul[k] <= index < self._cumul[k + 1]: + return k, index - self._cumul[k] + # Fallback (should not be reached) + return len(self._datasets) - 1, index - self._cumul[-2] + + def _index_to_dataset_and_local_optional( + self, index: int + ) -> Optional[tuple[int, int]]: + """ + Map global index to (dataset_index, local_index), or None if out of range. + + Used by cancel_prefetch to match Dataset behavior (no-op for invalid index). + """ + n = len(self) + if index < 0: + index = n + index + if index < 0 or index >= n: + return None + for k in range(len(self._cumul) - 1): + if self._cumul[k] <= index < self._cumul[k + 1]: + return k, index - self._cumul[k] + return len(self._datasets) - 1, index - self._cumul[-2] + + def __len__(self) -> int: + """Return the total number of samples (sum of all sub-dataset lengths).""" + return self._cumul[-1] + + def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: + """ + Return the transformed sample and metadata for the given global index. + + Metadata is enriched with ``dataset_index`` (key :const:`DATASET_INDEX_METADATA_KEY`). + + Parameters + ---------- + index : int + Global sample index. Supports negative indexing. + + Returns + ------- + tuple[TensorDict, dict[str, Any]] + (TensorDict, metadata dict) from the owning sub-dataset. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + data, metadata = self._datasets[ds_id][local_i] + metadata = dict(metadata) + metadata[DATASET_INDEX_METADATA_KEY] = ds_id + return data, metadata + + def prefetch( + self, + index: int, + stream: Optional[Any] = None, + ) -> None: + """ + Start prefetching the sample at the given global index. + + Delegates to the sub-dataset that owns that index. + + Parameters + ---------- + index : int + Global sample index to prefetch. + stream : object, optional + Optional CUDA stream for the sub-dataset prefetch. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + self._datasets[ds_id].prefetch(local_i, stream=stream) + + def prefetch_batch( + self, + indices: Sequence[int], + streams: Optional[Sequence[Any]] = None, + ) -> None: + """ + Start prefetching multiple samples by global index. + + Delegates to the sub-dataset that owns each index. Streams are cycled + if shorter than indices. + + Parameters + ---------- + indices : Sequence[int] + Global sample indices to prefetch. + streams : Sequence[Any], optional + Optional CUDA streams, one per index. If shorter than indices, + streams are cycled. If None, no streams used. + """ + for i, idx in enumerate(indices): + stream = None + if streams: + stream = streams[i % len(streams)] + self.prefetch(idx, stream=stream) + + def cancel_prefetch(self, index: Optional[int] = None) -> None: + """ + Cancel prefetch for the given index or all sub-datasets. + + When index is provided, only cancels if it is in range; out-of-range + indices are ignored to match :class:`Dataset` behavior. + + Parameters + ---------- + index : int, optional + Global index to cancel, or None to cancel all. + """ + if index is None: + for ds in self._datasets: + ds.cancel_prefetch(None) + else: + mapped = self._index_to_dataset_and_local_optional(index) + if mapped is not None: + ds_id, local_i = mapped + self._datasets[ds_id].cancel_prefetch(local_i) + + @property + def prefetch_count(self) -> int: + """ + Number of items currently being prefetched across all sub-datasets. + + Returns + ------- + int + Sum of in-flight prefetch counts from each sub-dataset. + """ + return sum(ds.prefetch_count for ds in self._datasets) + + def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: + """Iterate over all samples in global index order.""" + for i in range(len(self)): + yield self[i] + + @property + def field_names(self) -> list[str]: + """ + Field names in samples. + + With ``output_strict=True``, returns the common output keys (TensorDict + keys after transforms). With ``output_strict=False``, returns the first + dataset's field names. + """ + return list(self._field_names) + + def close(self) -> None: + """Close all sub-datasets and release resources.""" + for ds in self._datasets: + ds.close() + + def __enter__(self) -> "MultiDataset": + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit.""" + self.close() + + def __repr__(self) -> str: + parts = [f" ({i}): {ds}" for i, ds in enumerate(self._datasets)] + return ( + f"MultiDataset(\n output_strict={self._output_strict},\n datasets=[\n" + + ",\n".join(parts) + + "\n ]\n)" + ) diff --git a/test/datapipes/core/test_multi_dataset.py b/test/datapipes/core/test_multi_dataset.py new file mode 100644 index 0000000000..4fde358658 --- /dev/null +++ b/test/datapipes/core/test_multi_dataset.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MultiDataset class.""" + +import pytest + +import physicsnemo.datapipes as dp +from physicsnemo.datapipes.multi_dataset import DATASET_INDEX_METADATA_KEY + + +class TestMultiDatasetBasic: + """Basic MultiDataset functionality.""" + + def test_create_multi_dataset(self, numpy_data_dir): + """MultiDataset with two datasets has combined length.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert len(multi) == len(ds_a) + len(ds_b) + + def test_create_multi_dataset_three_or_more(self, numpy_data_dir): + """MultiDataset with three+ datasets has combined length and correct index mapping.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_c = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b, ds_c], output_strict=True) + + assert len(multi) == len(ds_a) + len(ds_b) + len(ds_c) + assert multi[0][1][DATASET_INDEX_METADATA_KEY] == 0 + assert multi[10][1][DATASET_INDEX_METADATA_KEY] == 1 + assert multi[20][1][DATASET_INDEX_METADATA_KEY] == 2 + + def test_getitem_maps_to_correct_dataset(self, numpy_data_dir): + """Indices 0..len0-1 from first dataset, then second.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + # First 10 from ds_a (dataset_index 0) + data0, meta0 = multi[0] + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + assert "positions" in data0 + + data9, meta9 = multi[9] + assert meta9[DATASET_INDEX_METADATA_KEY] == 0 + + # Next 10 from ds_b (dataset_index 1) + data10, meta10 = multi[10] + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + data19, meta19 = multi[19] + assert meta19[DATASET_INDEX_METADATA_KEY] == 1 + + def test_getitem_preserves_sub_dataset_metadata(self, numpy_data_dir): + """Metadata from sub-dataset (e.g. index) is preserved alongside dataset_index.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + data0, meta0 = multi[0] + assert meta0["index"] == 0 + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + + data10, meta10 = multi[10] + assert meta10["index"] == 0 # first sample of second dataset + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + def test_negative_indexing(self, numpy_data_dir): + """Negative indices work as expected.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + data_last, meta_last = multi[-1] + assert meta_last[DATASET_INDEX_METADATA_KEY] == 1 + data_first, meta_first = multi[-20] + assert meta_first[DATASET_INDEX_METADATA_KEY] == 0 + + def test_field_names_strict(self, numpy_data_dir): + """output_strict=True returns common output keys (TensorDict keys).""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert "positions" in multi.field_names + assert "features" in multi.field_names + + def test_iteration(self, numpy_data_dir): + """Iteration yields all samples in order with dataset_index in metadata.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + seen_indices = [] + seen_dataset_indices = [] + for data, meta in multi: + seen_indices.append(meta["index"]) + seen_dataset_indices.append(meta[DATASET_INDEX_METADATA_KEY]) + + assert len(seen_indices) == 20 + assert seen_dataset_indices[:10] == [0] * 10 + assert seen_dataset_indices[10:] == [1] * 10 + + def test_context_manager(self, numpy_data_dir): + """MultiDataset as context manager closes sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + with dp.MultiDataset([ds_a, ds_b], output_strict=True) as multi: + data, meta = multi[0] + assert meta[DATASET_INDEX_METADATA_KEY] == 0 + + +class TestMultiDatasetStrictValidation: + """Output strictness validation.""" + + def test_strict_raises_when_output_keys_differ(self, numpy_data_dir): + """output_strict=True raises if datasets produce different output keys.""" + ds_full = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) + + with pytest.raises(ValueError, match="output keys"): + dp.MultiDataset([ds_full, ds_pos_only], output_strict=True) + + def test_non_strict_accepts_different_fields(self, numpy_data_dir): + """output_strict=False does not validate output keys; field_names is first dataset's.""" + ds_full = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) + + multi = dp.MultiDataset([ds_full, ds_pos_only], output_strict=False) + assert len(multi) == 20 + assert set(multi.field_names) == set(ds_full.field_names) + + def test_strict_validates_output_keys_not_reader_fields(self, numpy_data_dir): + """output_strict compares TensorDict keys after transforms, not reader field_names.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + assert len(multi.field_names) >= 2 # positions, features + + +class TestMultiDatasetPrefetchAndClose: + """Prefetch and close delegation.""" + + def test_prefetch_delegates(self, numpy_data_dir): + """Prefetch delegates to correct sub-dataset.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.prefetch(10) + data0, meta0 = multi[0] + data10, meta10 = multi[10] + assert meta0[DATASET_INDEX_METADATA_KEY] == 0 + assert meta10[DATASET_INDEX_METADATA_KEY] == 1 + + def test_cancel_prefetch_all(self, numpy_data_dir): + """cancel_prefetch(None) clears all sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.prefetch(10) + multi.cancel_prefetch() + # Should still be able to get data synchronously + data0, _ = multi[0] + assert "positions" in data0 + + def test_cancel_prefetch_invalid_index_no_op(self, numpy_data_dir): + """cancel_prefetch(out-of-range index) does not raise (matches Dataset).""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch(0) + multi.cancel_prefetch(999) # out of range, should no-op + multi.cancel_prefetch(-1) # also out of range + data0, _ = multi[0] + assert "positions" in data0 + + def test_prefetch_batch(self, numpy_data_dir): + """prefetch_batch delegates to sub-datasets by global index.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + multi.prefetch_batch([0, 1, 10, 11]) + for idx in [0, 1, 10, 11]: + data, meta = multi[idx] + assert meta[DATASET_INDEX_METADATA_KEY] == (0 if idx < 10 else 1) + assert "positions" in data + + def test_prefetch_count(self, numpy_data_dir): + """prefetch_count is sum of sub-dataset prefetch counts.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + assert multi.prefetch_count == 0 + multi.prefetch_batch([0, 1, 2, 3]) + assert multi.prefetch_count >= 0 # may complete quickly + multi.cancel_prefetch() + assert multi.prefetch_count == 0 + + def test_close_closes_all(self, numpy_data_dir): + """close() closes all sub-datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi.close() + # After close, sub-datasets are closed (no-op to call again) + + +class TestMultiDatasetErrors: + """Error cases.""" + + def test_requires_at_least_one_datasets(self, numpy_data_dir): + """MultiDataset requires at least one datasets.""" + with pytest.raises(ValueError, match="at least one"): + dp.MultiDataset([], output_strict=True) + + def test_requires_dataset_instances(self, numpy_data_dir): + """All elements must be Dataset instances.""" + ds = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + with pytest.raises(TypeError, match="must be a Dataset"): + dp.MultiDataset([ds, dp.NumpyReader(numpy_data_dir)], output_strict=False) + + def test_index_out_of_range(self, numpy_data_dir): + """Index out of range raises IndexError.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + with pytest.raises(IndexError, match="out of range"): + _ = multi[20] + with pytest.raises(IndexError, match="out of range"): + _ = multi[-21] + + def test_prefetch_out_of_range_raises(self, numpy_data_dir): + """prefetch with out-of-range index raises IndexError.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + with pytest.raises(IndexError, match="out of range"): + multi.prefetch(20) + + +class TestMultiDatasetWithDataLoader: + """DataLoader accepts MultiDataset (same interface as Dataset).""" + + def test_dataloader_with_multi_dataset(self, numpy_data_dir): + """DataLoader iterates over MultiDataset and collates batches.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader(multi, batch_size=4, shuffle=False) + assert len(loader) == 5 # 20 / 4 + + batches = list(loader) + assert len(batches) == 5 + assert batches[0]["positions"].shape[0] == 4 + assert batches[-1]["positions"].shape[0] == 4 + + def test_dataloader_with_multi_dataset_and_metadata(self, numpy_data_dir): + """Collate metadata includes dataset_index from MultiDataset.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader( + multi, batch_size=5, shuffle=False, collate_metadata=True + ) + batch_data, metadata_list = next(iter(loader)) + + assert len(metadata_list) == 5 + assert [m[DATASET_INDEX_METADATA_KEY] for m in metadata_list] == [ + 0, + 0, + 0, + 0, + 0, + ] + + batches = list(loader) + _, meta_batch_2 = batches[2] # indices 10-14, all from dataset 1 + assert all(m[DATASET_INDEX_METADATA_KEY] == 1 for m in meta_batch_2) + + def test_dataloader_shuffle_with_multi_dataset(self, numpy_data_dir): + """Shuffled DataLoader over MultiDataset yields all indices.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + loader = dp.DataLoader(multi, batch_size=4, shuffle=True, collate_metadata=True) + all_dataset_indices = [] + for batch_data, metadata_list in loader: + all_dataset_indices.extend( + m[DATASET_INDEX_METADATA_KEY] for m in metadata_list + ) + assert set(all_dataset_indices) == {0, 1} + assert len(all_dataset_indices) == 20 + + +class TestMultiDatasetRepr: + """String representation.""" + + def test_repr(self, numpy_data_dir): + """Repr includes output_strict and datasets.""" + ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) + multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + + r = repr(multi) + assert "MultiDataset" in r + assert "output_strict=True" in r From 1b98bec9fe7079c35d57cb016142491562bbbce9 Mon Sep 17 00:00:00 2001 From: megnvidia Date: Tue, 17 Mar 2026 14:24:03 -0700 Subject: [PATCH 2/4] Refine documentation for PhysicsNeMo Datapipes Updated the documentation for PhysicsNeMo Datapipes to improve clarity and consistency. Adjusted wording and structure for better readability. --- docs/api/datapipes/physicsnemo.datapipes.rst | 60 +++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 57b5fb5958..2d8cbf3738 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -4,8 +4,9 @@ PhysicsNeMo Datapipes .. automodule:: physicsnemo.datapipes .. currentmodule:: physicsnemo.datapipes -The PhysicsNeMo Datapipes consists largely of two separate components, both -described here. Prior to version 2.0 of PhysicsNeMo, each datapipe was largely +The PhysicsNeMo Datapipes consists largely of two separate components. + +Prior to version 2.0 of PhysicsNeMo, each datapipe was largely independent from all others, targeted for very specific datasets and applications, and broadly not extensible. Those datapipes, preserved in v2.0 for compatibility, are described in the climate, cae, gnn, and @@ -13,12 +14,13 @@ benchmark subsections. In PhysicsNeMo v2.0, the datapipes API has been redesigned from scratch to focus on key factors to enable scientific machine learning training and inference. -These documentation pages describe the architecture and design philosophy, while -in the examples of PhysicsNeMo there are runnable datapipe tutorials for -getting started. +This document describes the architecture and design philosophy + +Refer to the examples of PhysicsNeMo for runnable datapipe tutorials to +get started. -Datapipes philosophy +Datapipes Philosophy -------------------- The PhysicsNeMo datapipe structure is built on several key design decisions @@ -27,18 +29,18 @@ that are specifically made to enable diverse scientific machine learning dataset - GPU First: data preprocessing is done on the GPU, not the CPU. - Isolation of roles: reading data is separate from transforming data, which is separate from pipelining data for training, which is separate from threading - and stream management, etc. Changing data sources, or preprocessing pipelines, + and stream management. Changing data sources, or preprocessing pipelines, should require no intervention in other areas. - Composability and Extensibility: We aim to provide a tool kit and examples that lets you build what you need yourself, easily, if it's not here. - Datapipes as configuration: Changing a pipeline shouldn't require source code modification; the registry system in PhysicsNeMo datapipes enables hydra instantiation of datapipes at runtime for version-controlled, runtime-configured datapipes. - You can register and instantiate custom components, of course. + You can also register and instantiate custom components. Data flows through a PhysicsNeMo datapipe in a consistent path: -1. A ``reader`` will bring the data from storage to CPU memory +1. A ``reader`` will bring the data from storage to CPU memory. 2. An optional series of one or more transformations will apply on-the-fly manipulations of that data, per instance of data. 3. Several instances of data will be collated into a batch (customizable, @@ -46,8 +48,8 @@ Data flows through a PhysicsNeMo datapipe in a consistent path: 4. The batched data is ready for use in a model. At the highest level, ``physicsnemo.datapipes.DataLoader`` has a similar API and -model as ``pytorch.utils.data.DataLoader``, enabling a drop-in replacement in many -cases. Under the hood, physicsnemo follows a very different computation orchestration. +model as ``pytorch.utils.data.DataLoader``, which enables a drop-in replacement for many +cases. However, PhysicsNeMo has a very different computation orchestration. Quick Start ----------- @@ -92,8 +94,8 @@ Quick Start predictions = model(batch["pressure"], batch["coordinates"]) -The best place to see the PhysicsNeMo datapipes in action, and get a sense of -how they work and use them, is to start with the examples located in the +The best place to see the PhysicsNeMo datapipes in action, get a sense of +how they work, and use them, is to start with the examples located in the `examples directory `_. @@ -156,19 +158,19 @@ the ``Dataset`` is responsible for the threaded execution of ``Reader``s and MultiDataset ^^^^^^^^^^^^ -The ``MultiDataset`` composes two or more ``Dataset`` instances behind a single +The ``MultiDataset`` includes two or more ``Dataset`` instances behind a single index space (concatenation). Each sub-dataset can have its own Reader and transforms. Global indices are mapped to the owning sub-dataset and local index; -metadata is enriched with ``dataset_index`` so batches can identify the source. +metadata is enriched with ``dataset_index`` so that batches can identify the source. Use ``MultiDataset`` when you want to train on multiple datasets with the same -DataLoader, optionally enforcing that all outputs share the same TensorDict keys -for collation. See :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY` +DataLoader, and, optionally, enforce all outputs to share the same TensorDict keys +for collation. Refer to :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY` for the metadata key added to each sample. -Note that to properly collate and stack outputs from different datasets, you -can set ``output_strict=True`` in the constructor of a ``MultiDataset``. Upon +To properly collate and stack outputs from different datasets, you +can set ``output_strict=True`` in the constructor of a ``MultiDataset``. After construction, it will load the first batch from every passed dataset and test -that the tensordict produced by the ``Reader`` and ``Transform`` pipeline has +that the TensorDict produced by the ``Reader`` and ``Transform`` pipeline has consistent keys. Because the exact collation details differ by dataset, the ``MultiDataset`` does not check more aggressively than output key consistency. @@ -180,9 +182,9 @@ consistent keys. Because the exact collation details differ by dataset, the Readers ^^^^^^^ -Readers are the data-ingestion layer: each one loads individual samples from a -specific storage format (HDF5, Zarr, NumPy, VTK, etc.) and returns CPU tensors -in a uniform dict interface. See :doc:`physicsnemo.datapipes.readers` for the +Readers are the data-ingestion layer. Each one loads individual samples from a +specific storage format (HDF5, Zarr, NumPy, VTK) and returns CPU tensors +in a uniform dict interface. Refer to :doc:`physicsnemo.datapipes.readers` for the base class API and all built-in readers. Transforms @@ -190,21 +192,23 @@ Transforms Transforms are composable, device-agnostic operations applied to each sample after it is loaded and transferred to the target device. The ``Compose`` -container chains multiple transforms into a single callable. See +container chains multiple transforms into a single callable. Refer to :doc:`physicsnemo.datapipes.transforms` for the base class API, ``Compose``, and all built-in transforms. Collation ^^^^^^^^^ -Combining a set of tensordict objects into a batch of data can, at times, +Combining a set of TensorDict objects into a batch of data can, at times, +result in difficulties or errors. + Some datasets, like graph datasets, require special care. For this reason, PhysicsNeMo datapipes offers custom collation functions as well as an interface to write your own collator. If the dataset you are -trying to collate can not be accommodated here, please open an issue on github. +trying to collate can not be accommodated here, open an issue on github. -For an example of a custom collation function to produce a batch of PyG graph data, -see the examples on github for the datapipes. +For an example of a custom collation function that produces a batch of PyG graph data, +refer to the examples on github for the datapipes. .. autoclass:: physicsnemo.datapipes.collate.Collator :members: From 56593a30df602201919470deea5623ee0f4d0555 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:51:48 -0500 Subject: [PATCH 3/4] update api docs. --- docs/api/datapipes/physicsnemo.datapipes.rst | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 2d8cbf3738..82b438f25b 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -199,11 +199,10 @@ and all built-in transforms. Collation ^^^^^^^^^ -Combining a set of TensorDict objects into a batch of data can, at times, -result in difficulties or errors. - -Some datasets, like graph datasets, require special care. For -this reason, PhysicsNeMo datapipes offers custom collation functions +Combining a set of TensorDict objects into a batch of data can, at times, +require special care. For example, collating graph datasets for Graph Neural +Networks requires different merging of batches than concatenation along a batch +dimension. For this reason, PhysicsNeMo datapipes offers custom collation functions as well as an interface to write your own collator. If the dataset you are trying to collate can not be accommodated here, open an issue on github. From f91db31fba9ff1c9632f37a362c02f7e7568458f Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:34:47 -0500 Subject: [PATCH 4/4] Update multidata set interface to accept an unpacked tuple instead of a list, etc. --- physicsnemo/datapipes/multi_dataset.py | 14 +++---- test/datapipes/core/test_multi_dataset.py | 50 +++++++++++------------ 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py index 8b6152a10f..e09c5855dc 100644 --- a/physicsnemo/datapipes/multi_dataset.py +++ b/physicsnemo/datapipes/multi_dataset.py @@ -91,9 +91,10 @@ class MultiDataset: Parameters ---------- - datasets : Sequence[Dataset] - One or more Dataset instances (Reader + transforms each). Order defines - index mapping: first dataset occupies 0..len(ds0)-1, etc. + *datasets : Dataset + One or more Dataset instances passed as positional arguments + (Reader + transforms each). Order defines index mapping: first + dataset occupies 0..len(ds0)-1, etc. output_strict : bool, default=True If True, require all datasets to produce the same TensorDict keys (output keys after transforms) so :class:`DefaultCollator` can stack batches. If @@ -106,7 +107,7 @@ class MultiDataset: Raises ------ ValueError - If ``len(datasets) < 1`` or if ``output_strict=True`` and output keys differ. + If no datasets are provided or if ``output_strict=True`` and output keys differ. Notes ----- @@ -147,7 +148,7 @@ class MultiDataset: >>> from physicsnemo.datapipes import Dataset, MultiDataset, HDF5Reader, Normalize >>> ds_a = Dataset(HDF5Reader("a.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP >>> ds_b = Dataset(HDF5Reader("b.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP - >>> multi = MultiDataset([ds_a, ds_b], output_strict=True) # doctest: +SKIP + >>> multi = MultiDataset(ds_a, ds_b, output_strict=True) # doctest: +SKIP >>> len(multi) == len(ds_a) + len(ds_b) # doctest: +SKIP True >>> data, meta = multi[0] # from ds_a # doctest: +SKIP @@ -156,8 +157,7 @@ class MultiDataset: def __init__( self, - datasets: Sequence[Dataset], - *, + *datasets: Dataset, output_strict: bool = True, ) -> None: if len(datasets) < 1: diff --git a/test/datapipes/core/test_multi_dataset.py b/test/datapipes/core/test_multi_dataset.py index 4fde358658..1c25c39b9f 100644 --- a/test/datapipes/core/test_multi_dataset.py +++ b/test/datapipes/core/test_multi_dataset.py @@ -29,7 +29,7 @@ def test_create_multi_dataset(self, numpy_data_dir): """MultiDataset with two datasets has combined length.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) assert len(multi) == len(ds_a) + len(ds_b) @@ -38,7 +38,7 @@ def test_create_multi_dataset_three_or_more(self, numpy_data_dir): ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_c = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b, ds_c], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, ds_c, output_strict=True) assert len(multi) == len(ds_a) + len(ds_b) + len(ds_c) assert multi[0][1][DATASET_INDEX_METADATA_KEY] == 0 @@ -49,7 +49,7 @@ def test_getitem_maps_to_correct_dataset(self, numpy_data_dir): """Indices 0..len0-1 from first dataset, then second.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) # First 10 from ds_a (dataset_index 0) data0, meta0 = multi[0] @@ -70,7 +70,7 @@ def test_getitem_preserves_sub_dataset_metadata(self, numpy_data_dir): """Metadata from sub-dataset (e.g. index) is preserved alongside dataset_index.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) data0, meta0 = multi[0] assert meta0["index"] == 0 @@ -84,7 +84,7 @@ def test_negative_indexing(self, numpy_data_dir): """Negative indices work as expected.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) data_last, meta_last = multi[-1] assert meta_last[DATASET_INDEX_METADATA_KEY] == 1 @@ -95,7 +95,7 @@ def test_field_names_strict(self, numpy_data_dir): """output_strict=True returns common output keys (TensorDict keys).""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) assert "positions" in multi.field_names assert "features" in multi.field_names @@ -104,7 +104,7 @@ def test_iteration(self, numpy_data_dir): """Iteration yields all samples in order with dataset_index in metadata.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) seen_indices = [] seen_dataset_indices = [] @@ -120,7 +120,7 @@ def test_context_manager(self, numpy_data_dir): """MultiDataset as context manager closes sub-datasets.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - with dp.MultiDataset([ds_a, ds_b], output_strict=True) as multi: + with dp.MultiDataset(ds_a, ds_b, output_strict=True) as multi: data, meta = multi[0] assert meta[DATASET_INDEX_METADATA_KEY] == 0 @@ -134,14 +134,14 @@ def test_strict_raises_when_output_keys_differ(self, numpy_data_dir): ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) with pytest.raises(ValueError, match="output keys"): - dp.MultiDataset([ds_full, ds_pos_only], output_strict=True) + dp.MultiDataset(ds_full, ds_pos_only, output_strict=True) def test_non_strict_accepts_different_fields(self, numpy_data_dir): """output_strict=False does not validate output keys; field_names is first dataset's.""" ds_full = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_pos_only = dp.Dataset(dp.NumpyReader(numpy_data_dir, fields=["positions"])) - multi = dp.MultiDataset([ds_full, ds_pos_only], output_strict=False) + multi = dp.MultiDataset(ds_full, ds_pos_only, output_strict=False) assert len(multi) == 20 assert set(multi.field_names) == set(ds_full.field_names) @@ -149,7 +149,7 @@ def test_strict_validates_output_keys_not_reader_fields(self, numpy_data_dir): """output_strict compares TensorDict keys after transforms, not reader field_names.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) assert len(multi.field_names) >= 2 # positions, features @@ -160,7 +160,7 @@ def test_prefetch_delegates(self, numpy_data_dir): """Prefetch delegates to correct sub-dataset.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) multi.prefetch(0) multi.prefetch(10) @@ -173,7 +173,7 @@ def test_cancel_prefetch_all(self, numpy_data_dir): """cancel_prefetch(None) clears all sub-datasets.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) multi.prefetch(0) multi.prefetch(10) @@ -186,7 +186,7 @@ def test_cancel_prefetch_invalid_index_no_op(self, numpy_data_dir): """cancel_prefetch(out-of-range index) does not raise (matches Dataset).""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) multi.prefetch(0) multi.cancel_prefetch(999) # out of range, should no-op @@ -198,7 +198,7 @@ def test_prefetch_batch(self, numpy_data_dir): """prefetch_batch delegates to sub-datasets by global index.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) multi.prefetch_batch([0, 1, 10, 11]) for idx in [0, 1, 10, 11]: @@ -210,7 +210,7 @@ def test_prefetch_count(self, numpy_data_dir): """prefetch_count is sum of sub-dataset prefetch counts.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) assert multi.prefetch_count == 0 multi.prefetch_batch([0, 1, 2, 3]) @@ -222,7 +222,7 @@ def test_close_closes_all(self, numpy_data_dir): """close() closes all sub-datasets.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) multi.close() # After close, sub-datasets are closed (no-op to call again) @@ -233,19 +233,19 @@ class TestMultiDatasetErrors: def test_requires_at_least_one_datasets(self, numpy_data_dir): """MultiDataset requires at least one datasets.""" with pytest.raises(ValueError, match="at least one"): - dp.MultiDataset([], output_strict=True) + dp.MultiDataset(output_strict=True) def test_requires_dataset_instances(self, numpy_data_dir): """All elements must be Dataset instances.""" ds = dp.Dataset(dp.NumpyReader(numpy_data_dir)) with pytest.raises(TypeError, match="must be a Dataset"): - dp.MultiDataset([ds, dp.NumpyReader(numpy_data_dir)], output_strict=False) + dp.MultiDataset(ds, dp.NumpyReader(numpy_data_dir), output_strict=False) def test_index_out_of_range(self, numpy_data_dir): """Index out of range raises IndexError.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) with pytest.raises(IndexError, match="out of range"): _ = multi[20] @@ -256,7 +256,7 @@ def test_prefetch_out_of_range_raises(self, numpy_data_dir): """prefetch with out-of-range index raises IndexError.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) with pytest.raises(IndexError, match="out of range"): multi.prefetch(20) @@ -269,7 +269,7 @@ def test_dataloader_with_multi_dataset(self, numpy_data_dir): """DataLoader iterates over MultiDataset and collates batches.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) loader = dp.DataLoader(multi, batch_size=4, shuffle=False) assert len(loader) == 5 # 20 / 4 @@ -283,7 +283,7 @@ def test_dataloader_with_multi_dataset_and_metadata(self, numpy_data_dir): """Collate metadata includes dataset_index from MultiDataset.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) loader = dp.DataLoader( multi, batch_size=5, shuffle=False, collate_metadata=True @@ -307,7 +307,7 @@ def test_dataloader_shuffle_with_multi_dataset(self, numpy_data_dir): """Shuffled DataLoader over MultiDataset yields all indices.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) loader = dp.DataLoader(multi, batch_size=4, shuffle=True, collate_metadata=True) all_dataset_indices = [] @@ -326,7 +326,7 @@ def test_repr(self, numpy_data_dir): """Repr includes output_strict and datasets.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset([ds_a, ds_b], output_strict=True) + multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) r = repr(multi) assert "MultiDataset" in r