Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/api/datapipes/physicsnemo.datapipes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ the ``Dataset`` is responsible for the threaded execution of ``Reader``s and
:members:
:show-inheritance:

MultiDataset
^^^^^^^^^^^^

The ``MultiDataset`` composes two or more ``Dataset`` instances behind a single
index space (concatenation). Each sub-dataset can have its own Reader and
transforms. Global indices are mapped to the owning sub-dataset and local index;
metadata is enriched with ``dataset_index`` so batches can identify the source.
Use ``MultiDataset`` when you want to train on multiple datasets with the same
DataLoader, optionally enforcing that all outputs share the same TensorDict keys
for collation. See :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY`
for the metadata key added to each sample.

Note that to properly collate and stack outputs from different datasets, you
can set ``output_strict=True`` in the constructor of a ``MultiDataset``. Upon
construction, it will load the first batch from every passed dataset and test
that the tensordict produced by the ``Reader`` and ``Transform`` pipeline has
consistent keys. Because the exact collation details differ by dataset, the
``MultiDataset`` does not check more aggressively than output key consistency.

.. autoclass:: physicsnemo.datapipes.multi_dataset.MultiDataset
:members:
:show-inheritance:


Readers
^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions physicsnemo/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from physicsnemo.datapipes.dataloader import DataLoader
from physicsnemo.datapipes.dataset import Dataset
from physicsnemo.datapipes.multi_dataset import MultiDataset
from physicsnemo.datapipes.readers import (
HDF5Reader,
NumpyReader,
Expand Down Expand Up @@ -70,6 +71,7 @@
NormalizeVectors,
Purge,
Rename,
Resize,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reshape missing from top-level exports

Reshape is exported from physicsnemo.datapipes.transforms (added to both the import and __all__ in transforms/__init__.py), but it is not imported or listed in this top-level datapipes/__init__.py. This means users cannot do from physicsnemo.datapipes import Reshape or dp.Reshape(...), unlike Resize which was properly added. This appears to be an oversight.

from physicsnemo.datapipes.transforms import (
    ...
    Rename,
    Resize,
    Reshape,   # <-- add this
    Scale,
    ...
)

__all__ = [
    ...
    "ConstantField",
    "Reshape",   # <-- add this
    ...
]

Scale,
SubsamplePoints,
Transform,
Expand All @@ -84,6 +86,7 @@
"TensorDict", # Re-export from tensordict
"Dataset",
"DataLoader",
"MultiDataset",
# Transforms - Base
"Transform",
"Compose",
Expand All @@ -107,6 +110,7 @@
"CreateGrid",
"KNearestNeighbors",
"CenterOfMass",
"Resize",
# Transforms - Utility
"Rename",
"Purge",
Expand Down
4 changes: 3 additions & 1 deletion physicsnemo/datapipes/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def __len__(self) -> int:
int
Number of batches in the dataloader.
"""
n_samples = len(self.dataset)
n_samples = (
len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset)
)
if self.drop_last:
return n_samples // self.batch_size
return (n_samples + self.batch_size - 1) // self.batch_size
Expand Down
Loading
Loading