Skip to content

Pnm mesh datapipes#1512

Open
coreyjadams wants to merge 35 commits intoNVIDIA:mainfrom
coreyjadams:pnm_mesh_datapipes
Open

Pnm mesh datapipes#1512
coreyjadams wants to merge 35 commits intoNVIDIA:mainfrom
coreyjadams:pnm_mesh_datapipes

Conversation

@coreyjadams
Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams commented Mar 17, 2026

PhysicsNeMo Pull Request

This brings in a new training recipe for external aerodynamics. We will build on this in the coming weeks, but right now it contains the following functionality:

  • Mesh-based datapipes. We can read, transform, and augment datapipes directly from physicsnemo.mesh objects. They are quite fast, too. This is working for drivaerML and shift_suv datasets.
  • Multi-dataset datapipes. Using the mesh datapipes, we're mixing and interleaving drivaerml and shift_suv at train time and plan to add more.
  • Extending datapipes in the example. The example built here injects non-dimensionalization information into each mesh and applies it to the data, which makes merging the datasets feasible. We also bring in a custom collation function.

Note that the dataloader built here is a hybrid of hydra and python initialization: the easiest way to make a flexible multi-dataset loader, with all of this, was to instantiate each with hydra but then package the whole business up in python.

There aren't significant changes to the GeoTransolver recipe other than these, and eventually I'll get transolver in here too. Next I want to intercept DomainMesh for volumetric learning. I've got a prototype datapipe reader for it, and the mesh-based data pipes should handle it, but it's not tested yet. I'm willing to prune this for the PR, if desired, until we stand up an example.

One thing I want to add to this example is the ability to set and log a random seed for the training runs. I think it will be necessary for the experiments.

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Comment thread AGENT.md Outdated
Comment thread examples/cfd/ahmed_pnm_mesh/benchmark_datapipe.py Outdated
…face and volumetric

meshes directly from physicsnemo mesh files.
Comment thread examples/cfd/ahmed_pnm_mesh/benchmark_datapipe.py Outdated
Comment thread physicsnemo/datapipes/readers/mesh.py Outdated
Comment thread physicsnemo/datapipes/readers/mesh.py Outdated
@coreyjadams coreyjadams marked this pull request as ready for review March 25, 2026 21:38
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 25, 2026

Greptile Summary

This PR introduces a mesh-centric datapipe infrastructure (MeshReader, DomainMeshReader, MeshDataset, MultiDataset, a suite of MeshTransform subclasses, and a shared DatasetBase ABC) along with a reference external aerodynamics training recipe. The design — separating mesh I/O from transforms and providing a unified index space via MultiDataset — is clean and well-documented, but there are a few actionable issues before merge:

  • NormalizeMeshFields is missing from the public API: the class is defined, decorated with @register(), and listed in transforms/mesh/__init__.py, but is not imported in physicsnemo/datapipes/transforms/__init__.py or physicsnemo/datapipes/__init__.py. Users cannot do from physicsnemo.datapipes import NormalizeMeshFields; only the Hydra registry path ${dp:NormalizeMeshFields} works.
  • torch.load(..., weights_only=False) in NormalizeMeshFields: allows arbitrary Python pickle deserialization from the stats_file path, which is a security concern for user-supplied paths.
  • _validate_strict_outputs fails ungracefully when a MeshDataset returns raw Mesh objects: calling data.keys() on a Mesh tensorclass returns internal structure keys (['cells', 'cell_data', ...]), not user-defined field names, producing a confusing error rather than a useful message when output_strict=True is combined with a MeshDataset that lacks a terminal MeshToTensorDict transform.
  • Docstring inaccuracy in MeshReader and DomainMeshReader: the pattern parameter is documented as defaulting to **/*.pt but the actual default is **/*.pmsh (DEFAULT_MESH_EXTENSION).

Important Files Changed

Filename Overview
physicsnemo/datapipes/transforms/mesh/transforms.py New deterministic mesh transforms (ScaleMesh, TranslateMesh, RotateMesh, CenterMesh, SubsampleMesh, DropMeshFields, RenameMeshFields, SetGlobalField, NormalizeMeshFields, ComputeSurfaceNormals, MeshToTensorDict, ComputeCellCentroids, RestructureTensorDict). Logic is mostly sound, but NormalizeMeshFields uses weights_only=False in torch.load (security risk), and NormalizeMeshFields is not exported from the public API (transforms/init.py and datapipes/init.py).
physicsnemo/datapipes/readers/mesh.py New MeshReader and DomainMeshReader classes. Both have docstring inaccuracies: the pattern parameter is documented as matching **/.pt but the actual default is **/.pmsh (DEFAULT_MESH_EXTENSION). Pin_memory, metadata, and iterator implementation are correct.
physicsnemo/datapipes/mesh_dataset.py New MeshDataset class extending DatasetBase with CUDA stream-aware prefetching and DomainMesh-aware transform dispatch. The _load and _load_and_transform logic correctly handles both Mesh and DomainMesh. No significant issues.
physicsnemo/datapipes/multi_dataset.py New MultiDataset composing multiple DatasetBase instances. The index mapping, prefetch delegation, and close propagation are correct. Two issues: _validate_strict_outputs may produce confusing errors when a MeshDataset returns raw Mesh objects (not TensorDict), and getitem's return type annotation is too narrow (TensorDict instead of Any).
physicsnemo/datapipes/protocols.py New DatasetBase ABC providing shared prefetch infrastructure (_PrefetchResult, ThreadPoolExecutor management, cancel_prefetch, close). Clean design and well-documented.
physicsnemo/datapipes/transforms/init.py Updated to export new mesh transforms, but NormalizeMeshFields is missing from both the import block and the all list, leaving it inaccessible via standard Python imports (only accessible via Hydra registry or deep import path).
physicsnemo/datapipes/init.py Updated public API exports — adds MeshDataset, DatasetBase, new readers and mesh transforms. NormalizeMeshFields is absent from both the import and all list, making it unreachable via from physicsnemo.datapipes import NormalizeMeshFields.
test/datapipes/readers/test_mesh_readers.py Good coverage of MeshReader, DomainMeshReader, MeshDataset, and apply_to_tensordict_mesh. Tests use real Mesh primitives. However, tests use pattern="*.pt" explicitly, which works but may mislead readers about the default extension (.pmsh).
test/datapipes/core/test_multi_dataset.py Thorough MultiDataset tests covering index mapping, negative indexing, strict validation, prefetch delegation, and DataLoader integration. No test coverage for mixing Dataset and MeshDataset (which is the primary motivation for MultiDataset per the PR description).

Comments Outside Diff (2)

  1. physicsnemo/datapipes/multi_dataset.py, line 66-82 (link)

    P1 _validate_strict_outputs fails silently for MeshDataset returning raw Mesh

    data.keys() is called on whatever ds[0] returns. For a MeshDataset whose last transform is not MeshToTensorDict, data will be a Mesh or DomainMesh tensorclass. Those do have .keys() (inherited from tensorclass), but they return internal structure keys — ['cells', 'cell_data', 'global_data', 'point_data', 'points'] — rather than user-defined field names.

    Consequence: mixing a Dataset (which returns a TensorDict with user keys like ['pressure', 'velocity']) and a MeshDataset (returning raw Mesh) under output_strict=True always raises a ValueError with a confusing message listing structural keys, even when the user intends output_strict=False. At minimum, a guard or a descriptive error should be added:

    from physicsnemo.mesh import Mesh, DomainMesh
    
    data, _ = ds[0]
    if isinstance(data, (Mesh, DomainMesh)):
        raise ValueError(
            f"output_strict=True requires all datasets to return TensorDict, but "
            f"dataset {i} returned {type(data).__name__}. "
            f"Add MeshToTensorDict as the last transform or use output_strict=False."
        )
    keys = sorted(data.keys())
  2. physicsnemo/datapipes/multi_dataset.py, line 253 (link)

    P2 Return type annotation is too narrow

    The return type tuple[TensorDict, dict[str, Any]] is correct only when sub-datasets are Dataset instances (which return TensorDict). When sub-datasets are MeshDataset instances (which may return Mesh or DomainMesh), the annotation is wrong.

    Consider widening the return type to match DatasetBase.__getitem__:

Reviews (1): Last reviewed commit: "Merge branch 'main' into pnm_mesh_datapi..." | Re-trigger Greptile

Comment thread physicsnemo/datapipes/transforms/mesh/transforms.py Outdated
Comment thread physicsnemo/datapipes/readers/mesh.py Outdated
Comment thread physicsnemo/datapipes/readers/mesh.py
Comment thread physicsnemo/datapipes/transforms/__init__.py
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding another round of comments!

# Unified External Aerodynamics Recipe

NOTE THIS README IS AI GENERATED AND YOU SHOULD USE IT WITH EXTREME CAUTION.
I WILL GO THROUGH IT CAREFULLY FOR ACCURACY BEFORE FINAL REVIEW OR MERGE.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping a TODO here, just so we don't forget

Comment thread physicsnemo/datapipes/transforms/mesh/augmentations.py Outdated
def __init__(
self,
axes: list[Literal["x", "y", "z"]] | None = None,
angle_range: tuple[float, float] = (-math.pi, math.pi),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than parameterizing the rotation via:

  • Random axis from ["x", "y", "z"]
  • Random angle

It might be nicer to have a truly random rotation? With the current setup, the diversity of rotations is quite limited.

One possible algorithm:

  1. Generate a random vector -> normalize it -> that's your new x'.
  2. Generate another random vector -> normalize it -> project it into the plane normal to x'; that's your y'
  3. z' is the cross product of x' and y'

This generates a random rotation, while preserving angles + right-handed coordinate system (i.e., it's orthonormal).

And then you arrange all three of these (x', y', z') as columns of a 3x3 matrix, and do mesh.transform(matrix).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with this alg, basically: https://kya8.github.io/p/uniform-sampling-of-so3-rotations/

But, there are definitely cases where we may want to restrict to physical rotations: a car left to right is fine but they don't drive upside down, for example. So I put in a mode switch too.

Comment thread physicsnemo/datapipes/transforms/mesh/augmentations.py Outdated
Comment thread physicsnemo/datapipes/transforms/mesh/augmentations.py Outdated

points = mesh.points
if "L_ref" in gd:
points = points / gd["L_ref"].float()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing here:

  • Nondimensionalization of the points is sometimes done in aerospace, but it's not super common. (By contrast, something like $C_p$ nondimensionalization or $C_f$ nondimensionalization is pretty ubiquitous.) This is a good idea to do, but we should be sure to call this out in downstream use cases. This PR already does a good job of noting this in the docstring; just noting this for any downstream consumers of this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can note it in the readme, yeah. And, I think it's ok to just set L_ref to 1.0 and "turn it off" so to speak, too.

Comment thread examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py Outdated
target = torch.cat([pressure, wss], dim=-1) # (N, 4)

points_list.append(pts)
embedding_list.append(torch.cat([pts, normals], dim=-1)) # (N, 6)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This collate function seems to be Transolver/GeoTransolver-specific, and in particular this part with the embeddings breaks equivariance. Do we want to have these controlled by a kwarg? Or, more broadly, is there a way to make this collate function less Transolver/GeoTransolver-specific?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is absolutely specific to geotransolver, and its something we could improve. Most datapipes use a generic collation and a challenge with the Mesh datapipes is it doesn't totally make sense to "collate" a few mesh objects ... IMO we might want to leave mesh datapipes to be typically batch_size=1 locally.

Then we're viewing this more as "how do I massage the output keys from the datapipe as input keys to the network" which is what's really going on here, honestly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl;dr we should just call this map_data_to_model, not collate, and have a dictionary of mappings for the models we want to support?

idx = 0
for name, ftype in field_types.items():
if ftype == "pressure":
out[..., idx] = out[..., idx] * q_inf + p_inf
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is technically duplicated w/ line 171-ish. It feels like there should be a way to deduplicate this by having an abstraction here, which would help long-term maintainaibility. This is just a suggestion though!

Comment thread physicsnemo/datapipes/protocols.py
coreyjadams and others added 5 commits April 9, 2026 15:26
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
name: drivaer_ml_surface

train_datadir: /path/to/your/PhysicsNeMo-DrivaerML/
train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental paths?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I had these purged, and then I broke all my training scripts. Arg.

name: drivaer_ml_volume

train_datadir: /path/to/your/PhysicsNeMo-DrivaerML/
train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental paths?


name: highlift_surface

train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-HighLiftAeroML/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental paths?

Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work overall! This is a thoughtful and large contribution and really cleans up the datapipes stack. I also think the unified training recipe will be well-received.

I'm approving to not hold this up further, and we've discussed most of this together. Some of the comments are suggestions, others are things that we should fix before merging.

Tests

Separately, if we can improve test coverage of some of these new features, that would be great. Here are some parts missing coverage - don't feel like you need to cover it all, but if there are any high-risk items that jump out, let's cover those:

  • No tests for **MultiDataset.set_generator / set_epoch** in [test/datapipes/core/test_multi_dataset.py](test/datapipes/core/test_multi_dataset.py).
  • No tests for **Dataset.set_generator / set_epoch** in [test/datapipes/core/test_dataset.py](test/datapipes/core/test_dataset.py).
  • No tests for **MeshDataset + CUDA streams** path (analogue of the stream tests in test_dataset.py).
  • No tests for **MultiDataset mixing MeshDataset sub-datasets** (the docstring says it is supported).
  • No tests for **MeshReader subsampling + RNG reproducibility end-to-end**, or for **DomainMeshReader.extra_boundaries**.
  • No tests for **state_mixing_mode="concat_project"**.
  • test_close_closes_all in test_multi_dataset.py:197-203 has no assertions after multi.close() (passes trivially).
  • test_scale_each in test_mesh_readers.py:291-299 does not assert scaling actually changed coordinates (would pass if ScaleMesh were a no-op).
  • Several tests depend on private attributes (dataset._prefetch_futures, _ensure_executor()).

self,
axes: list[Literal["x", "y", "z"]] | None = None,
distribution: torch.distributions.Distribution | None = None,
mode: Literal["axis_aligned", "uniform"] = "axis_aligned",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were defaulting to uniform?

"""Apply this transform to a DomainMesh.

Default: broadcasts ``__call__`` to interior and all boundaries
via :meth:`DomainMesh._map_meshes`, leaving domain-level
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed from the private-method DomainMesh._map_meshes to the public-method DomainMesh.apply() as part of the now-merged #1558: https://github.com/NVIDIA/physicsnemo/pull/1558/changes#diff-11f619a0c0840c0bb17aec96535e01fce9c5ea67831bfe2e6d70f9345f09d6f4

For this PR, we probably want to: 1. merge from main (which will contain this), 2. update to use DomainMesh.apply instead, and 3. re-run tests just to double-check.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out, I reviewed your pr so I knew your api changed but hadn't adapted yet. I'll absolutely update to align to the API before merge.

# Derived from the standard unit-quaternion rotation formula using
# w²+x²+y²+z² = 1 to rewrite 1-2(…) terms as sums of squared components.
# ww wx wy wz xw xx xy xz yw yx yy yz zw zx zy zz
self._q2r_map = torch.tensor(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soft-recommendation (not a requirement) to reduce the "magic" of this hardcoded mapping, and instead use the Rodrigues formula (which exists in physicsnemo/mesh/transformations/geometric.py:202-216). The current 16x16 tensor is a little hard to verify correctness for by inspection.

Here, that might look like a) deleting this mapping, and then b) later on using something like:

@staticmethod
def _quaternion_to_rotation_matrix(
    q: Float[torch.Tensor, "4"],
) -> Float[torch.Tensor, "3 3"]:
    r"""Unit quaternion :math:`(w, \vec v)` to rotation matrix via Rodrigues' formula.

    :math:`R = (2w^2 - 1)\,I + 2\,\vec v\vec v^\top + 2w\,[\vec v]_\times`,
    where :math:`[\vec v]_\times` is the skew-symmetric cross-product matrix of
    :math:`\vec v`.

    Parameters
    ----------
    q : torch.Tensor
        Unit quaternion :math:`(w, x, y, z)`, shape :math:`(4,)`.

    Returns
    -------
    torch.Tensor
        Rotation matrix, shape :math:`(3, 3)`.
    """
    w, x, y, z = q.unbind()
    zero = torch.zeros_like(w)
    v_cross = torch.stack(
        [
            torch.stack([zero, -z, y]),
            torch.stack([z, zero, -x]),
            torch.stack([-y, x, zero]),
        ]
    )
    return (
        (2 * w * w - 1) * torch.eye(3, dtype=q.dtype, device=q.device)
        + 2 * torch.outer(q[1:], q[1:])
        + 2 * w * v_cross
    )

transform_global_data=self.transform_global_data,
assume_invertible=True,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif self.mode == "axis_aligned"

and

else: raise...

or a match-case block would be safer here - if the user accidentally typos "uniform" we are silently falling back to axis_aligned rather than loudly warning.

transform_global_data=self.transform_global_data,
assume_invertible=True,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same elif else here as comment near L453

_NONDIM_TYPE_MAP = {"scalar": "pressure", "vector": "stress"}


def _to_physical(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code?

return torch.mean((pred - target) ** 2.0)


def compute_rmse(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically RMSE is used for root mean squared error; could we name this relative_mse to reduce confusion perhaps?

Mapping of field names to types. Order determines channel indices.
Example: {"pressure": "scalar", "velocity": "vector", "turbulence": "scalar"}
loss_type : Literal["huber", "mse", "rmse"], optional
Type of loss to compute. Default is "huber".
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says we default to Huber, below in the code we actually default to MSE

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok sure. Pretty sure all my yamls are set to huber so it's not too wrong but I'll fix the default below.

# limitations under the License.

"""
Mesh readers - Load physicsnemo Mesh / DomainMesh from physicsnemo mesh format (.pt).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throughout this file (~5 places) we refer to Mesh format as using .pt extensions rather than .pmsh

Comment on lines 122 to 123
def to(self, device: torch.device | str) -> Transform:
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transform.to does not migrate Generators, but MeshTransform.to does - seems like a potential footgun?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants