Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/api/datapipes/physicsnemo.datapipes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ the ``Dataset`` is responsible for the threaded execution of ``Reader``s and
:members:
:show-inheritance:

MultiDataset
^^^^^^^^^^^^

The ``MultiDataset`` composes two or more ``Dataset`` instances behind a single
index space (concatenation). Each sub-dataset can have its own Reader and
transforms. Global indices are mapped to the owning sub-dataset and local index;
metadata is enriched with ``dataset_index`` so batches can identify the source.
Use ``MultiDataset`` when you want to train on multiple datasets with the same
DataLoader, optionally enforcing that all outputs share the same TensorDict keys
for collation. See :const:`physicsnemo.datapipes.multi_dataset.DATASET_INDEX_METADATA_KEY`
for the metadata key added to each sample.

Note that to properly collate and stack outputs from different datasets, you
can set ``output_strict=True`` in the constructor of a ``MultiDataset``. Upon
construction, it will load the first batch from every passed dataset and test
that the tensordict produced by the ``Reader`` and ``Transform`` pipeline has
consistent keys. Because the exact collation details differ by dataset, the
``MultiDataset`` does not check more aggressively than output key consistency.

.. autoclass:: physicsnemo.datapipes.multi_dataset.MultiDataset
:members:
:show-inheritance:


Readers
^^^^^^^
Expand Down
2 changes: 2 additions & 0 deletions physicsnemo/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from physicsnemo.datapipes.dataloader import DataLoader
from physicsnemo.datapipes.dataset import Dataset
from physicsnemo.datapipes.multi_dataset import MultiDataset
from physicsnemo.datapipes.readers import (
HDF5Reader,
NumpyReader,
Expand Down Expand Up @@ -84,6 +85,7 @@
"TensorDict", # Re-export from tensordict
"Dataset",
"DataLoader",
"MultiDataset",
# Transforms - Base
"Transform",
"Compose",
Expand Down
4 changes: 3 additions & 1 deletion physicsnemo/datapipes/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def __len__(self) -> int:
int
Number of batches in the dataloader.
"""
n_samples = len(self.dataset)
n_samples = (
len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset)
)
if self.drop_last:
return n_samples // self.batch_size
return (n_samples + self.batch_size - 1) // self.batch_size
Expand Down
Loading
Loading