diff --git a/.markdownlint.yaml b/.markdownlint.yaml index 29ed832dfa..45533117a6 100644 --- a/.markdownlint.yaml +++ b/.markdownlint.yaml @@ -91,7 +91,7 @@ MD013: # Include code blocks code_blocks: true # Include tables - tables: true + tables: false # Include headings headings: true # Include headings diff --git a/.markdownlintignore b/.markdownlintignore index fb9dc39e47..d327cb161f 100644 --- a/.markdownlintignore +++ b/.markdownlintignore @@ -1 +1,2 @@ CODE_OF_CONDUCT.md +physicsnemo/datapipes/transforms/mesh/DISTRIBUTIONS.md diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/.gitignore b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/.gitignore new file mode 100644 index 0000000000..b88f2d2ca8 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/.gitignore @@ -0,0 +1,11 @@ +*.sh +runs/ +mlruns/ +*.err +*.out +outputs/ +stats/ +checkpoints/ +*.mdlus +*.tfevents* +*.parquet diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/README.md b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/README.md new file mode 100644 index 0000000000..ee14d7c635 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/README.md @@ -0,0 +1,537 @@ + +# Unified External Aerodynamics Recipe + +> This unified recipe is still under some final polishing but nearly +> completed. Feel free to used it and experiment. In the meantime, +> be wary of sharp edges! + +## Introduction + +External Aerodynamic recipes in physicsnemo have proliferated: we have +a number of recipes, across a range of models, all working on different models +with unique data handling, pipelines, model architectures, metrics, training +paradigms, etc. While there is nothing wrong with that, it does make comparison +challenging and development of new models somewhat challenging. In this folder, +we have unified the external aerodynamic recipes for most of our best models (notably +missing is our newest model, still in development for large 3D use cases: GLOBE). + +Here, you're able to train the following models: +- [Transolver](https://arxiv.org/abs/2402.02366) +- [GeoTransolver](https://arxiv.org/abs/2512.20399) +- [Flare](https://arxiv.org/abs/2508.12594) +- GeoTransolver also supports using the FLARE attention mechanism backend +- DoMINO is coming shortly + +We currently support the following datasets: +- DrivaerML + +Support for these datasets is coming imminently, with pre-processing support from +physicsnemo curator: +- ShiftSUV Estate +- ShiftSUV Fastback +- ShiftWING +- HiftliftAeroML + +## Dataset Handling + +The data processing pipeline in this example explicitly performs non dimensionalization +of input data to unitless fields for model inputs. Check out the yaml configurations +in `conf/dataset/` to see examples: the metadata section describes the reference +parameters for each data. Because datasets are non-dimensionalized, and are loaded +with the physicsnemo datapipes which support a MultiDataset abstraction, it's +possible to merge datasets on-the-fly during training to perform multi-dataset +training. We at PhysicsNeMo haven't extensively explored all of the parameters +of this multi-dataset training yet, but the infrastructure can support it and +we welcome you to try it if you're interested in it. + +Dataset non dimensionalization is handled in the `nondim.py` transformation, which +is part of the data transformation pipeline. See `src/nondim.py` in this example +for the source code. + +## Quick start + +```bash +cd examples/cfd/external_aerodynamics/unified_external_aero_recipe + +# 1. Train (single GPU, default GeoTransolver surface config) +python src/train.py + +# 1b. Train with a specific config +python src/train.py --config-name train_transolver_automotive_surface + +# 1c. Train (multi-GPU) +torchrun --nproc_per_node=N src/train.py + +# 2. Override config values +python src/train.py precision=bfloat16 training.num_epochs=100 +``` + +## Pipeline architecture + +Each dataset gets its own `MeshDataset` or `DomainMeshDataset` with an ordered chain of +`MeshTransform` steps defined in YAML. Multiple datasets are then +merged via `MultiDataset`. + +``` + ┌─────────────────────────────────────────────────────────────┐ + │ Per-dataset pipeline (one per YAML config) │ + │ │ + │ MeshReader / DomainMeshReader │ + │ │ Load raw Mesh from .pdmsh/.pmsh files │ + │ │ │ + │ (metadata injection) Write U_inf, rho_inf, p_inf, nu │ + │ │ from YAML metadata into │ + │ │ global_data (done by builder) │ + │ │ │ + │ (DropMeshFields) Remove unwanted fields │ + │ │ (e.g. TimeValue; drivaer only) │ + │ │ │ + │ (CenterMesh) Translate center of mass │ + │ │ to origin │ + │ │ │ + │ (RandomRotateMesh) Random yaw around vertical axis │ + │ │ (inserted after CenterMesh when │ + │ │ augment=true) │ + │ │ │ + │ (RandomTranslateMesh) Random horizontal shift │ + │ │ (inserted after CenterMesh when │ + │ │ augment=true) │ + │ │ │ + │ (NonDimensionalizeByMeta) Convert to Cp/Cf/nondim velocity │ + │ │ using q_inf = ½ρ|U∞|² │ + │ │ │ + │ (ComputeSDFFromBoundary) Compute signed distance field │ + │ │ from STL geometry (volume only) │ + │ │ │ + │ (DropBoundary) Remove auxiliary STL boundary │ + │ │ after SDF (volume only) │ + │ │ │ + │ RenameMeshFields Map dataset-specific names to │ + │ │ canonical names (pressure, wss) │ + │ │ │ + │ (NormalizeMeshFields) z-score normalize using │ + │ │ inline stats from YAML │ + │ │ │ + │ (ComputeSurfaceNormals) Compute per-cell surface normals │ + │ │ (surface pipelines only) │ + │ │ │ + │ (SubsampleMesh) Downsample to fixed point/cell │ + │ │ count (surface pipelines only; │ + │ │ volume uses reader subsampling) │ + │ │ │ + │ MeshToTensorDict Convert Mesh → TensorDict │ + │ │ │ + │ (ComputeCellCentroids) Compute cell centers from │ + │ │ connectivity (cell-based only) │ + │ │ │ + │ RestructureTensorDict Remap flat TensorDict into │ + │ │ input/output groups for the │ + │ │ collate function │ + └───────┼────────────────────────────────────────────────────┘ + │ + ▼ + ┌──────────────┐ + │ MultiDataset │ Concatenates index spaces, + │ │ adds dataset_index to metadata + └──────────────┘ + │ + ▼ + ┌──────────────┐ + │ Collate │ Stacks samples into batched tensors + │ │ via model-specific mapping + └──────────────┘ +``` + +### Why each step exists + +- **Metadata injection** — The dataset builder writes freestream conditions + (`U_inf`, `rho_inf`, `p_inf`, `nu`) from the YAML config's `metadata:` + block into each mesh's `global_data`. This makes physical reference + quantities available to downstream transforms without hardcoding them + in Python. + +- **DropMeshFields** — Removes fields that are not needed for training + (e.g. `TimeValue` in DrivaerML) to reduce memory and avoid schema + mismatches when merging datasets. + +- **CenterMesh** — Centers each geometry at the origin so that + rotations happen around a sensible point. DrivaerML uses point-mean + centering (`use_area_weighting: false`); SHIFT SUV uses area-weighted + cell centroid centering (`use_area_weighting: true`). + +- **RandomRotateMesh / RandomTranslateMesh** — Data augmentation, + defined in the `augmentations:` block of each dataset config and + activated at runtime by setting `augment: true` (default `false`). + Augmentations are inserted after `CenterMesh` by the dataset builder. + Rotation is restricted to the vertical axis. Translation is restricted + to horizontal axes by setting the vertical component of the offset + distribution to zero. + +- **NonDimensionalizeByMetadata** — Converts raw physical fields into + non-dimensional coefficients using the injected freestream metadata: + - Pressure → Cp: `(p - p_inf) / q_inf` where `q_inf = 0.5 * rho_inf * |U_inf|²` + - Wall shear stress → Cf: `tau / q_inf` + - Velocity → `U / |U_inf|` + + Also supports temperature, density, and identity (pass-through) field + types. Provides an `inverse()` method for re-dimensionalizing + predictions. + +- **ComputeSDFFromBoundary** — Volume pipelines only. Computes a + signed distance field (and surface normals) from an auxiliary STL + boundary mesh loaded via the reader's `extra_boundaries` option. + The SDF and normals are stored into `point_data` and used as + geometry-aware input features for the model. + +- **DropBoundary** — Removes the auxiliary STL boundary mesh after + `ComputeSDFFromBoundary` has consumed it, keeping only the interior + volume and the original surface boundary. + +- **RenameMeshFields** — Maps dataset-specific field names to canonical + names (`pressure`, `wss`, `velocity`, etc.) so all downstream code + uses a single naming convention. + +- **NormalizeMeshFields** — Applies z-score normalization using + inline statistics declared in the YAML config or loaded from a `.pt` + file. Handles scalar and vector fields differently. The normalization + stats are saved alongside model checkpoints for use at inference time. + + Note that not all fields are normalized, in fact most are not. Only fields + that are particularly far from unit mean or standard deviation are normalized. + +- **ComputeSurfaceNormals** — Computes per-cell (or per-point) surface + normals from the mesh connectivity. Used in surface pipelines to + provide normal vectors as part of the model's local embedding. + +- **SubsampleMesh** — Randomly downsamples each mesh to a fixed size + (controlled by `sampling_resolution` in the training config) so that + samples can be batched. Different samples in the same dataset get + different random subsets each epoch. + +- **MeshToTensorDict** — Terminal transform that converts the `Mesh` + object into a flat `TensorDict`. After this step, further mesh + transforms are invalid. + +- **ComputeCellCentroids** — For cell-based datasets, computes the + centroid of each cell from the connectivity and vertex positions. + These centroids serve as the "point positions" for the model. + +- **RestructureTensorDict** — Reorganizes the flat TensorDict into + `input/` and `output/` groups expected by the collate function. Maps + point positions (or cell centroids), normals, and freestream velocity + into `input`, and target fields into `output`. + +## Non-dimensionalization and normalization + +The pipeline applies two layers of field conditioning: + +1. **Physics-based non-dimensionalization** (`NonDimensionalizeByMetadata`) + converts raw simulation outputs to standard aerodynamic coefficients + (Cp, Cf) or non-dimensional velocity. This is essential when + combining datasets that may use different freestream conditions, fluid + properties, or unit conventions. The freestream metadata (`U_inf`, + `rho_inf`, `p_inf`) is declared per-dataset in the YAML config. + +2. **Statistical normalization** (`NormalizeMeshFields`) applies z-score + scaling so that all field values fed to the model have roughly zero + mean and unit variance. Statistics are specified inline in the dataset + YAML config or loaded from a `.pt` file. + +## Model and training + +The default model is **GeoTransolver**, a transformer-based architecture +for point-cloud regression that uses multi-scale local attention with +geometric embeddings. + +### Default settings (GeoTransolver automotive surface) + +| Setting | Default | +|---|---| +| Model | `GeoTransolver` (12 layers, 256 hidden, 8 heads) | +| Attention type | `GALE` (also supports `GALE_FA` for FLARE-based self-attention) | +| State mixing | `weighted` (learnable sigmoid gate; also supports `concat_project`) | +| Input | Cell centroids (N×3) + surface normals (N×3) + freestream velocity (1×3) | +| Output | Pressure (1) + wall shear stress (3) = 4 channels | +| Loss | Huber (smooth L1), normalized by total channels | +| Optimizer | Muon (2D params) + AdamW (other params) | +| Scheduler | StepLR (step=100, gamma=0.1) | +| Precision | bfloat16 (float16/float32/float8 also supported) | +| Batch size | 1 | + +### Data-to-model mapping + +The **data-to-model mapping** (`src/collate.py`) converts datapipe +outputs into the model's forward signature. Mappings are registered by +name in `MODEL_MAPPINGS`; the active mapping is selected via the +`data_mapping` config key (default: `"geotransolver_automotive_surface"`): + +```python +# geotransolver_automotive_surface mapping produces: +{ + "geometry": (B, N, 3), # cell centroids / point positions + "local_embedding": (B, N, 6), # cat(points, normals) via ["input/points", "input/normals"] + "local_positions": (B, N, 3), # point positions (for local feature builder) + "global_embedding": (B, 1, 3), # freestream velocity + "fields": (B, N, 4), # cat(pressure, wss) = prediction target +} +``` + +All available mappings: + +| Mapping name | Model | Domain | +|---|---|---| +| `geotransolver_automotive_surface` | GeoTransolver | Automotive surface (Cp, Cf) | +| `geotransolver_automotive_volume` | GeoTransolver | Automotive volume (U, p, nut) | +| `geotransolver_highlift_surface` | GeoTransolver | HighLift surface (P, T, rho, U, tau_wall) | +| `geotransolver_highlift_volume` | GeoTransolver | HighLift volume (P, T, rho, U) | +| `transolver_automotive_surface` | Transolver | Automotive surface (Cp, Cf) | +| `transolver_automotive_volume` | Transolver | Automotive volume (U, p, nut) | +| `flare_automotive_surface` | FLARE | Automotive surface (Cp, Cf) | +| `flare_automotive_volume` | FLARE | Automotive volume (U, p, nut) | +| `domino_automotive_surface` | DoMINO | Automotive surface (Cp, Cf) | +| `domino_automotive_volume` | DoMINO | Automotive volume (U, p, nut) | + +To add a new model, register a mapping in `MODEL_MAPPINGS` and set +`data_mapping` in your config. + +The **loss calculator** (`src/loss.py`) and **metric calculator** +(`src/metrics.py`) are both driven by the same target config +(e.g. `pressure: scalar`, `wss: vector`), so adding a new field is a +config-only change. Supported loss types: Huber, MSE, relative MSE. +Supported metrics: relative L1, relative L2, MAE. + +## Scripts + +All scripts are run from the recipe root directory: + +```bash +cd examples/cfd/external_aerodynamics/unified_external_aero_recipe +``` + +### Train + +```bash +# Single GPU (default: GeoTransolver automotive surface) +python src/train.py + +# Explicit config selection +python src/train.py --config-name train_transolver_automotive_surface + +# Multi-GPU +torchrun --nproc_per_node=N src/train.py + +# Override config values +python src/train.py precision=float32 training.num_epochs=100 training.batch_size=1 +``` + +Supports checkpointing (auto-resume), MLflow logging, mixed precision +(float16/bfloat16/float8 via Transformer Engine), `torch.compile`, and +NVIDIA profiling. + +### Benchmark datapipe throughput + +```bash +python src/train.py benchmark_io=true +python src/train.py benchmark_io=true +training.benchmark_max_steps=20 +``` + +> NOTE: If you want to profile, we recommend you set the number of epochs to 2. + +Measures per-sample load time and throughput without running the model. + +## Configuration + +The recipe uses a two-level config structure: + +- **`conf/train_*.yaml`** — Top-level training configs. Each specifies + the model, optimizer, scheduler, precision, and which dataset configs + to load. Six are provided (see [Training configurations](#training-configurations)). +- **`conf/dataset/*.yaml`** — Per-dataset configs. Each declares the + reader, transform pipeline, freestream metadata, target field types, + and metrics. + +### Dataset config anatomy + +```yaml +name: drivaer_ml_surface + +train_datadir: /path/to/your/PhysicsNeMo-DrivaerML/ + +# Freestream conditions (injected into global_data by the dataset builder) +metadata: + U_inf: [30.0, 0.0, 0.0] + p_inf: 0.0 + rho_inf: 1.225 + nu: 1 + L_ref: 5.0 + +# Transform pipeline — each entry is Hydra-instantiated +pipeline: + reader: + _target_: ${dp:MeshReader} + path: ${train_datadir} + pattern: "**/*.pdmsh/_tensordict/boundaries/surface" + subsample_n_cells: ${sampling_resolution} + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_cell_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:DropMeshFields} + global_data: [TimeValue] + - _target_: ${dp:CenterMesh} + use_area_weighting: false + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + pMeanTrim: pressure + wallShearStressMeanTrim: stress + section: cell_data + - _target_: ${dp:RenameMeshFields} + cell_data: + pMeanTrim: pressure + wallShearStressMeanTrim: wss + - _target_: ${dp:NormalizeMeshFields} + section: cell_data + fields: + wss: {type: vector, mean: [0.0, 0.0, 0.0], std: 0.00313} + - _target_: ${dp:ComputeSurfaceNormals} + store_as: cell_data + field_name: normals + - _target_: ${dp:SubsampleMesh} + n_cells: ${sampling_resolution} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:ComputeCellCentroids} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: cell_centroids + normals: cell_data.normals + U_inf: global_data.U_inf + output: + pressure: cell_data.pressure + wss: cell_data.wss + +targets: + pressure: scalar + wss: vector + +metrics: [l1, l2, mae] +``` + +The `${dp:ComponentName}` syntax is an OmegaConf resolver registered by +PhysicsNeMo's datapipe registry. It maps short class names to fully +qualified import paths, so Hydra can instantiate them. Each transform +entry's keys are passed directly as constructor kwargs. + +The `${sampling_resolution}` interpolation is resolved from the +top-level training config's `dataset.sampling_resolution` value. + +### Manifest-based data splitting + +DrivaerML datasets use a `manifest.json` file to define train/val/test +splits. The manifest path and split names are declared in the top-level +training config: + +```yaml +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /path/to/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val +``` + +The `ManifestSampler` in `src/datasets.py` resolves manifest entries to +dataset indices and handles distributed sampling across ranks. + +For datasets without a manifest (e.g. SHIFT SUV), separate +`train_datadir` / `val_datadir` paths are specified in the dataset YAML. + +### Adding a new dataset + +1. Create a new YAML config in `conf/dataset/` following the pattern above. +2. Set `reader.path` and `reader.pattern` for your data files. + Use `MeshReader` for single-mesh files or `DomainMeshReader` for + domain meshes that contain both interior and boundary sub-meshes. +3. Declare the correct `metadata:` block with freestream conditions. +4. Choose the right `section:` (`point_data` or `cell_data`) in + `NonDimensionalizeByMetadata`, `RenameMeshFields`, and + `NormalizeMeshFields`. +5. For cell-based surface data, add `ComputeSurfaceNormals` and + `ComputeCellCentroids` and use `cell_centroids` as the point source + in `RestructureTensorDict`. +6. Add inline normalization stats to `NormalizeMeshFields` (or point + `stats_file` at a `.pt` file with precomputed statistics). +7. Add an entry in the appropriate `conf/train_*.yaml` under `data:` + pointing to your new config. + +No Python code changes are needed. + +### MLflow experiment tracking + +Training metrics are logged to MLflow. By default, experiments are +stored in a local `./mlruns` directory. To use a remote tracking +server, set `mlflow.tracking_uri` in the training config: + +```yaml +mlflow: + tracking_uri: "http://YOUR_MLFLOW_SERVER:5000" + experiment_name: "unified_external_aero" + log_every_n_steps: 10 +``` + +## Source modules + +| Module | Purpose | +|---|---| +| `src/datasets.py` | Factory functions: `build_dataset`, `build_multi_surface_dataset`, `load_dataset_config`. Hydra-instantiates readers and transforms from YAML; injects metadata into `global_data`. Also provides `load_manifest`, `resolve_manifest_indices`, and `ManifestSampler` for manifest-based splitting. | +| `src/nondim.py` | Recipe-local transform: `NonDimensionalizeByMetadata`. Registered into the global datapipe registry. Supports pressure, stress, velocity, temperature, density, and identity field types. | +| `src/collate.py` | Data-to-model mapping: converts datapipe `(TensorDict, metadata)` tuples into batched model inputs via a registry of 10 named mapping specs (see [Data-to-model mapping](#data-to-model-mapping)). | +| `src/loss.py` | `LossCalculator` — config-driven loss for mixed scalar/vector fields. Supports Huber, MSE, relative MSE. Normalizes total loss by number of output channels. | +| `src/metrics.py` | `MetricCalculator` — config-driven metrics (relative L1, relative L2, MAE) with optional distributed all-reduce. Reports per-field and per-component (x/y/z) metrics for vector fields. | +| `src/utils.py` | `build_muon_optimizer` (Muon+AdamW via `CombinedOptimizer`), `parse_target_config`, `FieldSpec` dataclass, `set_seed`. | +| `src/train.py` | Training loop with DDP, mixed precision, checkpointing, MLflow logging, I/O benchmarking (`benchmark_io=true`), and profiling. | + +## Design decisions + +**Why cell-based representation for surfaces?** +Both DrivaerML and SHIFT SUV surface data use triangulated meshes with +fields stored in `cell_data`. The pipeline computes cell centroids as +the model's point positions and cell-based surface normals for the local +embedding. For volume data, fields live in `point_data` and vertex +positions are used directly. + +**Why two-stage field conditioning (non-dim then normalize)?** +Non-dimensionalization is physics: it removes dependence on freestream +conditions and produces standard aerodynamic coefficients (Cp, Cf) that are +comparable across datasets. Statistical normalization is numerics: it +rescales those coefficients so the model sees inputs with zero mean and unit +variance, improving training stability. Separating them means you can +change normalization strategy without touching the physics, and vice versa. + +**Why inject metadata from YAML instead of storing it in the mesh files?** +The freestream conditions are not always stored in the converted mesh +files. Rather than modifying the data conversion pipeline, we inject +them at runtime from the config. This keeps the mesh files +format-agnostic and makes it trivial to change conditions without +reconverting data. The dataset builder reads the `metadata:` block and +prepends an injection step automatically. + +**Why Hydra instantiation for the pipeline?** +The entire pipeline is expressed in YAML with no conditional Python logic. +Adding a new dataset, changing augmentation parameters, or swapping +transform order is a YAML-only change. The factory code in `src/datasets.py` +is compact and generic. The configs are self-documenting: you can read a +single YAML file and see exactly what transforms run and in what order. + +**Why inline normalization stats?** +Specifying normalization statistics directly in the YAML config (or in a +`.pt` file) keeps the pipeline self-contained and avoids a separate +statistics collection step. The values are easy to inspect, update, and +version-control alongside the rest of the configuration. diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_surface.yaml new file mode 100644 index 0000000000..d224948be5 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_surface.yaml @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DrivaerML surface dataset +# Reads the boundary surface Mesh directly from DomainMesh .pdmsh files +# by navigating into the on-disk tensordict directory structure. +# Triangulated surface mesh. Fields live in cell_data. +# Splits are controlled by manifest.json + train_split/val_split in the training config. + +name: drivaer_ml_surface + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/ + +metadata: + U_inf: [30.0, 0.0, 0.0] # freestream velocity + p_inf: 0.0 # reference pressure, drivaerml is gauge pressure + rho_inf: 1.225 # freestream density + nu: 1 + L_ref: 5.0 # reference length [m] + +pipeline: + reader: + _target_: ${dp:MeshReader} + path: ${train_datadir} + pattern: "**/*.pdmsh/_tensordict/boundaries/surface" + subsample_n_cells: ${sampling_resolution} + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_cell_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:DropMeshFields} + global_data: [TimeValue] + - _target_: ${dp:CenterMesh} + use_area_weighting: false + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + pMeanTrim: pressure + wallShearStressMeanTrim: stress + section: cell_data + - _target_: ${dp:RenameMeshFields} + cell_data: + pMeanTrim: pressure + wallShearStressMeanTrim: wss + - _target_: ${dp:NormalizeMeshFields} + section: cell_data + fields: + wss: {type: vector, mean: [0.0, 0.0, 0.0], std: 0.00313} + - _target_: ${dp:ComputeSurfaceNormals} + store_as: cell_data + field_name: normals + - _target_: ${dp:SubsampleMesh} + n_cells: ${sampling_resolution} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:ComputeCellCentroids} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: cell_centroids + normals: cell_data.normals + U_inf: global_data.U_inf + output: + pressure: cell_data.pressure + wss: cell_data.wss + +targets: + pressure: scalar + wss: vector + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_volume.yaml new file mode 100644 index 0000000000..6003309946 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/drivaer_ml_volume.yaml @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DrivaerML volume dataset (DomainMesh) +# Loads domain_*.pmsh files containing interior volume mesh + surface boundary + global_data. +# Interior: tetrahedral volume mesh with fields in point_data. +# Boundary: triangulated surface mesh ("surface"). +# global_data: U_inf, rho_inf (baked into .pmsh); p_inf, nu, L_ref injected from metadata below. +# Splits are controlled by manifest.json + train_split/val_split in the training config. + +name: drivaer_ml_volume + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/ + +metadata: + U_inf: [30.0, 0.0, 0.0] # freestream velocity + p_inf: 0.0 # reference pressure, drivaerml is gauge pressure + rho_inf: 1.225 # freestream density + nu: 1 + L_ref: 5.0 # reference length [m] + +pipeline: + reader: + _target_: ${dp:DomainMeshReader} + path: ${train_datadir} + pattern: "**/domain_*.pdmsh" + subsample_n_points: ${sampling_resolution} + subsample_n_cells: ${sampling_resolution} + extra_boundaries: + stl_geometry: + pattern: "*_single_solid.stl.pmsh" + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_point_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:DropMeshFields} + global_data: [TimeValue] + - _target_: ${dp:CenterMesh} + use_area_weighting: false + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + UMeanTrim: velocity + pMeanTrim: pressure + nutMeanTrim: identity + section: point_data + - _target_: ${dp:ComputeSDFFromBoundary} + boundary_name: stl_geometry + sdf_field: sdf + normals_field: sdf_normals + use_winding_number: true + - _target_: ${dp:DropBoundary} + names: [stl_geometry] + - _target_: ${dp:RenameMeshFields} + point_data: + UMeanTrim: velocity + pMeanTrim: pressure + nutMeanTrim: nut + - _target_: ${dp:NormalizeMeshFields} + section: point_data + fields: + nut: {type: scalar, mean: 4.8e-4, std: 9.4e-4} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: interior.points + U_inf: global_data.U_inf + sdf: interior.point_data.sdf + sdf_normals: interior.point_data.sdf_normals + output: + velocity: interior.point_data.velocity + pressure: interior.point_data.pressure + nut: interior.point_data.nut + +targets: + velocity: vector + pressure: scalar + nut: scalar + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_surface.yaml new file mode 100644 index 0000000000..fbb0c80c06 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_surface.yaml @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# High-lift airplane surface dataset (compressible) +# Reads the boundary surface Mesh directly from DomainMesh .pdmsh files +# by navigating into the on-disk tensordict directory structure. +# Boundary mesh from HiLiftAeroML. Fields live in point_data. +# Splits are controlled by manifest.json + train_split/val_split in the training config. + +name: highlift_surface + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-HighLiftAeroML/ + +# Freestream conditions — slug-inch-second-Rankine unit system +metadata: + U_inf: [2672.95, 0.0, 186.92] # 2679.505 in/s at AoA=4° + p_inf: 176.352 # 14.696 psi × 12 → slug/(in·s²) + rho_inf: 1.3756e-6 # 0.002377 slug/ft³ → slug/in³ + T_inf: 518.67 # [°R] + L_ref: 1156.75 # [in] — span + +pipeline: + reader: + _target_: ${dp:MeshReader} + path: ${train_datadir} + pattern: "**/*.pdmsh/_tensordict/boundaries/boundary" + subsample_n_cells: ${sampling_resolution} + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_point_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:CenterMesh} + use_area_weighting: false + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + PROJ(AVG(P)): pressure + PROJ(AVG(T)): temperature + PROJ(AVG(RHO)): density + PROJ(AVG(U)): velocity + AVG(TAU_WALL): stress + section: point_data + - _target_: ${dp:RenameMeshFields} + point_data: + PROJ(AVG(P)): pressure + PROJ(AVG(T)): temperature + PROJ(AVG(RHO)): density + PROJ(AVG(U)): velocity + AVG(TAU_WALL): tau_wall + # z-score normalization (dataset-wide stats, 180 cases / 24B points) + # Stats are post-nondimensionalization values. + - _target_: ${dp:NormalizeMeshFields} + section: point_data + fields: + pressure: {type: scalar, mean: -0.26, std: 0.62} + temperature: {type: scalar, mean: 1.004, std: 0.00224} + density: {type: scalar, mean: 0.988, std: 0.0156} + velocity: {type: vector, mean: [0.413, -0.005, 0.029], std: 0.20} + tau_wall: {type: scalar, mean: 0.005068, std: 0.003735} + # Full subsampling before normals: subsample cells first, then points + # to the final count. Normals are computed on the surviving triangles. + - _target_: ${dp:SubsampleMesh} + n_cells: 500000 + n_points: ${sampling_resolution} + - _target_: ${dp:ComputeSurfaceNormals} + store_as: point_data + field_name: normals + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: points + normals: point_data.normals + U_inf: global_data.U_inf + output: + pressure: point_data.pressure + temperature: point_data.temperature + density: point_data.density + velocity: point_data.velocity + tau_wall: point_data.tau_wall + +targets: + pressure: scalar + temperature: scalar + density: scalar + velocity: vector + tau_wall: scalar + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_volume.yaml new file mode 100644 index 0000000000..9f2022622f --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/highlift_volume.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# High-lift airplane volume dataset (compressible, DomainMesh) +# Loads .pdmsh files from PhysicsNeMo-HighLiftAeroML. +# Each case directory (e.g. geo_LHC001_AoA_4/) contains a DomainMesh .pdmsh. +# Interior: volume point cloud with fields in point_data. +# global_data: freestream conditions injected from metadata below. +# Splits are controlled by a manifest.json + train_split/val_split in the +# training config (see train_geotransolver_automotive_volume.yaml). + +name: highlift_volume + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-HighLiftAeroML/ + +# Freestream conditions — slug-inch-second-Rankine unit system +metadata: + U_inf: [2672.95, 0.0, 186.92] # 2679.505 in/s at AoA=4° + p_inf: 176.352 # 14.696 psi × 12 → slug/(in·s²) + rho_inf: 1.3756e-6 # 0.002377 slug/ft³ → slug/in³ + T_inf: 518.67 # [°R] + L_ref: 1156.75 # [in] — span + +pipeline: + reader: + _target_: ${dp:DomainMeshReader} + path: ${train_datadir} + pattern: "**/*.pdmsh" + subsample_n_points: ${sampling_resolution} + extra_boundaries: + stl_geometry: + pattern: "*.stl.pmsh" + # augmentations: + # - _target_: ${dp:RandomRotateMesh} + # axes: ["z"] + # transform_point_data: true + # transform_global_data: true + # - _target_: ${dp:RandomTranslateMesh} + # distribution: + # _target_: torch.distributions.Uniform + # low: [-1.0, -1.0, 0.0] + # high: [1.0, 1.0, 0.0] + transforms: + # - _target_: ${dp:CenterMesh} + # use_area_weighting: false + # - _target_: ${dp:NonDimensionalizeByMetadata} + # fields: + # avg(P): pressure + # avg(T): temperature + # avg(rho): density + # avg(u): velocity + # section: point_data + - _target_: ${dp:ComputeSDFFromBoundary} + boundary_name: stl_geometry + sdf_field: sdf + normals_field: sdf_normals + use_winding_number: true + - _target_: ${dp:DropBoundary} + names: [stl_geometry] + - _target_: ${dp:RenameMeshFields} + point_data: + avg(P): pressure + avg(T): temperature + avg(rho): density + avg(u): velocity + # z-score normalization — leave empty until stats are computed + # - _target_: ${dp:NormalizeMeshFields} + # section: point_data + # fields: {} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: interior.points + U_inf: global_data.U_inf + sdf: interior.point_data.sdf + sdf_normals: interior.point_data.sdf_normals + output: + pressure: interior.point_data.pressure + temperature: interior.point_data.temperature + density: interior.point_data.density + velocity: interior.point_data.velocity + +targets: + pressure: scalar + temperature: scalar + density: scalar + velocity: vector + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_estate.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_estate.yaml new file mode 100644 index 0000000000..7541872d99 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_estate.yaml @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SHIFT SUV Estate surface dataset +# Triangulated surface mesh. Fields live in cell_data. + +name: shift_suv_estate + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/estate/train/ +val_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/estate/val/ +test_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/estate/test/ + +metadata: + U_inf: [30.0, 0.0, 0.0] # freestream velocity + p_inf: 101325 # reference pressure (gauge) + rho_inf: 1.225 # freestream density + nu: 1.5e-5 # kinematic viscosity + L_ref: 5.0 # reference length [m] + +pipeline: + reader: + _target_: ${dp:MeshReader} + path: ${train_datadir} + pattern: "**/merged_surfaces.vtp.pmsh" + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_cell_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:CenterMesh} + use_area_weighting: true + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + pressure_average: pressure + wall_shear_stress_average: stress + section: cell_data + - _target_: ${dp:RenameMeshFields} + cell_data: + pressure_average: pressure + wall_shear_stress_average: wss + - _target_: ${dp:NormalizeMeshFields} + section: cell_data + fields: + wss: {type: vector, mean: [0.0, 0.0, 0.0], std: 0.00183} + - _target_: ${dp:ComputeSurfaceNormals} + store_as: cell_data + field_name: normals + - _target_: ${dp:SubsampleMesh} + n_cells: ${sampling_resolution} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:ComputeCellCentroids} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: cell_centroids + normals: cell_data.normals + U_inf: global_data.U_inf + output: + pressure: cell_data.pressure + wss: cell_data.wss + +targets: + pressure: scalar + wss: vector + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_fastback.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_fastback.yaml new file mode 100644 index 0000000000..a2d7b31cd9 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/dataset/shift_suv_fastback.yaml @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SHIFT SUV Fastback surface dataset +# Triangulated surface mesh. Fields live in cell_data. + +name: shift_suv_fastback + +train_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/fastback/train/ +val_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/fastback/val/ +test_datadir: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/shift_suv_pnm_mesh/SUV/fastback/test/ + +metadata: + U_inf: [30.0, 0.0, 0.0] # freestream velocity + p_inf: 101325 # reference pressure (gauge) + rho_inf: 1.225 # freestream density + nu: 1.5e-5 # kinematic viscosity + L_ref: 5.0 # reference length [m] + +pipeline: + reader: + _target_: ${dp:MeshReader} + path: ${train_datadir} + pattern: "**/merged_surfaces.vtp.pmsh" + augmentations: + - _target_: ${dp:RandomRotateMesh} + axes: ["z"] + transform_cell_data: true + transform_global_data: true + - _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-1.0, -1.0, 0.0] + high: [1.0, 1.0, 0.0] + transforms: + - _target_: ${dp:CenterMesh} + use_area_weighting: true + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + pressure_average: pressure + wall_shear_stress_average: stress + section: cell_data + - _target_: ${dp:RenameMeshFields} + cell_data: + pressure_average: pressure + wall_shear_stress_average: wss + - _target_: ${dp:NormalizeMeshFields} + section: cell_data + fields: + wss: {type: vector, mean: [0.0, 0.0, 0.0], std: 0.00183} + - _target_: ${dp:ComputeSurfaceNormals} + store_as: cell_data + field_name: normals + - _target_: ${dp:SubsampleMesh} + n_cells: ${sampling_resolution} + - _target_: ${dp:MeshToTensorDict} + - _target_: ${dp:ComputeCellCentroids} + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: cell_centroids + normals: cell_data.normals + U_inf: global_data.U_inf + output: + pressure: cell_data.pressure + wss: cell_data.wss + +targets: + pressure: scalar + wss: vector + +metrics: [l1, l2, mae] diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_surface.yaml new file mode 100644 index 0000000000..381e8d24ab --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_surface.yaml @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Surface Training Configuration (DoMINO) +# --------------------------------------------------------------------------- +# Trains a DoMINO model in surface-only mode on surface pressure (Cp) and +# wall shear stress (Cf) from DrivaerML datasets. +# +# DoMINO uses a geometry-encoding UNet + basis-function architecture that +# natively handles SDF grids and surface neighbor information. Its forward +# method accepts a single data_dict; the collate layer wraps batch tensors +# into data_dict automatically (see _wrap_as in MODEL_MAPPINGS). + +output_dir: "runs" +checkpoint_dir: null +run_id: "surface_domino_drivaer_ml" + +precision: "bfloat16" +compile: false # DoMINO internals are not compile-friendly yet +profile: false +augment: false +benchmark_io: false + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "domino_automotive_surface" + +# -- Model (DoMINO, surface-only) ------------------------------------------- +model: + _target_: physicsnemo.models.domino.DoMINO + input_features: 3 # x, y, z coordinates + output_features_vol: null # surface-only mode + output_features_surf: 4 # pressure (1) + wss (3) + global_features: 3 # freestream velocity U_inf (3 components) + model_parameters: + interp_res: [128, 64, 64] + use_sdf_in_basis_func: true + surface_neighbors: true + num_neighbors_surface: 7 + use_surface_normals: true + use_surface_area: true + geometry_encoding_type: "both" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_surface + sampling_resolution: 200000 + targets: + pressure: scalar + wss: vector + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_volume.yaml new file mode 100644 index 0000000000..08eb5c4228 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_domino_automotive_volume.yaml @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Volume Training Configuration (DoMINO) +# --------------------------------------------------------------------------- +# Trains a DoMINO model in volume-only mode on volume fields (velocity, +# pressure, nut) from DrivaerML volume meshes. +# +# DoMINO uses a geometry-encoding UNet + basis-function architecture that +# natively handles SDF grids and volume neighbor information. Its forward +# method accepts a single data_dict; the collate layer wraps batch tensors +# into data_dict automatically (see _wrap_as in MODEL_MAPPINGS). + +output_dir: "runs" +checkpoint_dir: null +run_id: "volume_domino_drivaer_ml" + +precision: "bfloat16" +compile: false # DoMINO internals are not compile-friendly yet +profile: false +augment: false +benchmark_io: false + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "domino_automotive_volume" + +# -- Model (DoMINO, volume-only) -------------------------------------------- +model: + _target_: physicsnemo.models.domino.DoMINO + input_features: 3 # x, y, z coordinates + output_features_vol: 5 # velocity (3) + pressure (1) + nut (1) + output_features_surf: null # volume-only mode + global_features: 3 # freestream velocity U_inf (3 components) + model_parameters: + interp_res: [128, 64, 64] + use_sdf_in_basis_func: true + num_neighbors_volume: 10 + geometry_encoding_type: "both" + solution_calculation_mode: "two-loop" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_volume + sampling_resolution: 200000 + targets: + velocity: vector + pressure: scalar + nut: scalar + nondim_types: + velocity: velocity + nut: identity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_surface.yaml new file mode 100644 index 0000000000..56ef1c54e5 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_surface.yaml @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Surface Training Configuration (FLARE) +# --------------------------------------------------------------------------- +# Trains a FLARE model (Transolver with FLARE attention) on surface pressure +# (Cp) and wall shear stress (Cf) from DrivaerML and/or SHIFT SUV datasets. + +output_dir: "runs" +checkpoint_dir: null +run_id: "surface_flare_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: true # broadcast U_inf (B,1,3) -> (B,N,3) for FLARE/Transolver + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "flare_automotive_surface" + +# -- Model (FLARE) ----------------------------------------------------------- +# FLARE inherits from Transolver and replaces physics attention blocks with +# FLARE (Fast Low-rank Attention Routing Engine) blocks. TE is not supported. +model: + _target_: physicsnemo.experimental.models.flare.FLARE + functional_dim: 3 # fx = freestream velocity U_inf (3) + out_dim: 4 # pressure (1) + wss (3) + embedding_dim: 6 # embedding = points (3) + normals (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 # n_global_queries for FLARE attention + unified_pos: false + structured_shape: null # unstructured mesh + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_surface + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + pressure: scalar + wss: vector + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: true + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_volume.yaml new file mode 100644 index 0000000000..9e31a37159 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_flare_automotive_volume.yaml @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Volume Training Configuration (FLARE) +# --------------------------------------------------------------------------- +# Trains a FLARE model (Transolver with FLARE attention) on volume fields +# (velocity, pressure, nut) from DrivaerML volume meshes. + +output_dir: "runs" +checkpoint_dir: null +run_id: "volume_flare_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: true # broadcast U_inf (B,1,3) -> (B,N,3) for FLARE/Transolver + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "flare_automotive_volume" + +# -- Model (FLARE) ----------------------------------------------------------- +# FLARE inherits from Transolver and replaces physics attention blocks with +# FLARE (Fast Low-rank Attention Routing Engine) blocks. TE is not supported. +model: + _target_: physicsnemo.experimental.models.flare.FLARE + functional_dim: 3 # fx = freestream velocity U_inf (3) + out_dim: 5 # velocity (3) + pressure (1) + nut (1) + embedding_dim: 7 # embedding = points (3) + sdf (1) + sdf_normals (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 # n_global_queries for FLARE attention + unified_pos: false + structured_shape: null # unstructured mesh + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_volume + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + velocity: vector + pressure: scalar + nut: scalar + nondim_types: + velocity: velocity + nut: identity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_surface.yaml new file mode 100644 index 0000000000..da505caa06 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_surface.yaml @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Surface Training Configuration +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model on surface pressure (Cp) and wall shear +# stress (Cf) from DrivaerML (point cloud) and/or SHIFT SUV (triangulated). + +output_dir: "runs" +checkpoint_dir: null +run_id: "surface_geotransolver_shift_suv_estate" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "geotransolver_automotive_surface" + +# -- Model (GeoTransolver) -------------------------------------------------- +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 6 # local_embedding = geometry coords (3) + normals (3) + out_dim: 4 # pressure (1) + wss (3) + geometry_dim: 3 # point coordinates + global_dim: 3 # inlet velocity (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false + plus: false + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + state_mixing_mode: "weighted" # "weighted" (default) or "concat_project" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_surface + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + pressure: scalar + wss: vector + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: true + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + # Default to local file tracking. Override with a remote URI when the server is ready: + # tracking_uri: "http://YOUR_MLFLOW_SERVER:5000" + tracking_uri: null # null => local ./mlruns directory + experiment_name: "unified_external_aero" + run_name: null # auto-generated from run_id if null + log_every_n_steps: 10 # per-step metric logging frequency + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_volume.yaml new file mode 100644 index 0000000000..3b46f8f109 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_automotive_volume.yaml @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Volume Training Configuration (GeoTransolver) +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model on volume fields (velocity, pressure, nut) +# from DrivaerML volume meshes. + +output_dir: "runs" +checkpoint_dir: null +run_id: "volume_geotransolver_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: false + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "geotransolver_automotive_volume" + +# -- Model (GeoTransolver) -------------------------------------------------- +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 7 # local_embedding = coords (3) + SDF (1) + SDF normals (3) + out_dim: 5 # velocity (3) + pressure (1) + nut (1) + geometry_dim: 3 # point coordinates + global_dim: 3 # inlet velocity (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false + plus: false + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + state_mixing_mode: "weighted" # "weighted" (default) or "concat_project" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_volume + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + velocity: vector + pressure: scalar + nut: scalar + nondim_types: + velocity: velocity + nut: identity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_surface.yaml new file mode 100644 index 0000000000..2dd1ed49a6 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_surface.yaml @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Surface Training Configuration +# (GeoTransolver with FLARE Attention / GALE_FA) +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model using the GALE_FA attention backbone (FLARE- +# based self-attention + GALE cross-attention) on surface pressure (Cp) and +# wall shear stress (Cf) from DrivaerML and/or SHIFT SUV datasets. + +output_dir: "runs" +checkpoint_dir: null +run_id: "surface_geotransolver_fa_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +# Same forward interface as standard GeoTransolver. +data_mapping: "geotransolver_automotive_surface" + +# -- Model (GeoTransolver + GALE_FA) ---------------------------------------- +# Uses GALE_FA attention (FLARE self-attention with GALE cross-attention) +# instead of the default GALE physics attention. GALE_FA does not support +# Transformer Engine — use_te must be false. +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 6 # local_embedding = geometry coords (3) + normals (3) + out_dim: 4 # pressure (1) + wss (3) + geometry_dim: 3 # point coordinates + global_dim: 3 # inlet velocity (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false # GALE_FA does not support TE + plus: false + attention_type: "GALE_FA" + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + state_mixing_mode: "weighted" # "weighted" (default) or "concat_project" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_surface + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + pressure: scalar + wss: vector + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: true + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_volume.yaml new file mode 100644 index 0000000000..18f16bedf9 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_geotransolver_fa_automotive_volume.yaml @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Volume Training Configuration +# (GeoTransolver with FLARE Attention / GALE_FA) +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model using the GALE_FA attention backbone (FLARE- +# based self-attention + GALE cross-attention) on volume fields (velocity, +# pressure, nut) from DrivaerML volume meshes. + +output_dir: "runs" +checkpoint_dir: null +run_id: "volume_geotransolver_fa_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: false + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +# Same forward interface as standard GeoTransolver. +data_mapping: "geotransolver_automotive_volume" + +# -- Model (GeoTransolver + GALE_FA) ---------------------------------------- +# Uses GALE_FA attention (FLARE self-attention with GALE cross-attention) +# instead of the default GALE physics attention. GALE_FA does not support +# Transformer Engine — use_te must be false. +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 7 # local_embedding = geometry coords (3) + sdf (1) + sdf_normals (3) + out_dim: 5 # velocity (3) + pressure (1) + nut (1) + geometry_dim: 3 # point coordinates + global_dim: 3 # inlet velocity (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false # GALE_FA does not support TE + plus: false + attention_type: "GALE_FA" + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + state_mixing_mode: "weighted" # "weighted" (default) or "concat_project" + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_volume + sampling_resolution: 500000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + velocity: vector + pressure: scalar + nut: scalar + nondim_types: + velocity: velocity + nut: identity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_surface.yaml new file mode 100644 index 0000000000..ae8bb2a321 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_surface.yaml @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# High-Lift Airplane - Surface Training Configuration (Compressible) +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model on surface pressure, temperature, density, +# velocity, and wall shear stress from the HiLiftAeroML dataset (AoA 4°). + +output_dir: "runs" +checkpoint_dir: null +run_id: "highlift_surface_geotransolver" + +precision: "bfloat16" +compile: false +profile: false +augment: false +benchmark_io: false + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "geotransolver_highlift_surface" + +# -- Model (GeoTransolver) -------------------------------------------------- +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 6 # local_embedding = geometry coords (3) + normals (3) + out_dim: 7 # pressure (1) + temperature (1) + density (1) + velocity (3) + tau_wall (1) + geometry_dim: 3 # point coordinates + global_dim: 3 # freestream velocity U_inf (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false + plus: false + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: highlift_surface + sampling_resolution: 100000 + targets: + pressure: scalar + temperature: scalar + density: scalar + velocity: vector + tau_wall: scalar + # Explicit nondim type per output field for physical-space metric inversion. + nondim_types: + pressure: pressure + temperature: temperature + density: density + velocity: velocity + tau_wall: stress + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: true + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +data: + highlift: + config: conf/dataset/highlift_surface.yaml diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_volume.yaml new file mode 100644 index 0000000000..5e5988887f --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_highlift_volume.yaml @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# High-Lift Airplane - Volume Training Configuration (Compressible) +# --------------------------------------------------------------------------- +# Trains a GeoTransolver model on volume pressure, temperature, density, +# and velocity from the HiLiftAeroML dataset (AoA 4°). +# Volume data is a point cloud — no surface normals available. + +output_dir: "runs" +checkpoint_dir: null +run_id: "highlift_volume_geotransolver" + +precision: "bfloat16" +compile: false +profile: true +augment: false +benchmark_io: false + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "geotransolver_highlift_volume" + +# -- Model (GeoTransolver) -------------------------------------------------- +model: + _target_: physicsnemo.experimental.models.geotransolver.GeoTransolver + functional_dim: 7 # local_embedding = coords (3) + SDF (1) + SDF normals (3) + out_dim: 6 # pressure (1) + temperature (1) + density (1) + velocity (3) + geometry_dim: 3 # point coordinates + global_dim: 3 # freestream velocity U_inf (3) + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + use_te: false + plus: false + include_local_features: false + radii: [0.1, 0.5, 2.0] + neighbors_in_radius: [16, 32, 64] + n_hidden_local: 32 + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: highlift_volume + sampling_resolution: 100000 + targets: + pressure: scalar + temperature: scalar + density: scalar + velocity: vector + # Explicit nondim type per output field for physical-space metric inversion. + nondim_types: + pressure: pressure + temperature: temperature + density: density + velocity: velocity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 + num_epochs: 1 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 3.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.1 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments (e.g. geometry_train, full_train, aoa_train, ...). +data: + highlift: + config: conf/dataset/highlift_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-HighLiftAeroML/manifest.json + train_split: single_aoa_4_train + val_split: single_aoa_4_val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_surface.yaml new file mode 100644 index 0000000000..dd3585a35b --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_surface.yaml @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Surface Training Configuration (Transolver) +# --------------------------------------------------------------------------- +# Trains a Transolver model on surface pressure (Cp) and wall shear +# stress (Cf) from DrivaerML and/or SHIFT SUV datasets. + +output_dir: "runs" +checkpoint_dir: null +run_id: "surface_transolver_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: true # broadcast U_inf (B,1,3) -> (B,N,3) for Transolver + +model_type: "surface" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "transolver_automotive_surface" + +# -- Model (Transolver) ----------------------------------------------------- +model: + _target_: physicsnemo.models.transolver.Transolver + functional_dim: 3 # fx = freestream velocity U_inf (3) + out_dim: 4 # pressure (1) + wss (3) + embedding_dim: 6 # embedding = points (3) + normals (3) + n_layers: 8 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + unified_pos: false + structured_shape: null # unstructured mesh + use_te: false + plus: false + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_surface + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + pressure: scalar + wss: vector + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: true + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_surface.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val + shift_suv_estate: + config: conf/dataset/shift_suv_estate.yaml + shift_suv_fastback: + config: conf/dataset/shift_suv_fastback.yaml diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_volume.yaml new file mode 100644 index 0000000000..1dfb7c4b09 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/train_transolver_automotive_volume.yaml @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# Unified External Aerodynamics - Volume Training Configuration (Transolver) +# --------------------------------------------------------------------------- +# Trains a Transolver model on volume fields (velocity, pressure, nut) +# from DrivaerML volume meshes. + +output_dir: "runs" +checkpoint_dir: null +run_id: "volume_transolver_drivaer_ml" + +precision: "bfloat16" +compile: true +profile: false +augment: false +benchmark_io: false +broadcast_global: true # broadcast U_inf (B,1,3) -> (B,N,3) for Transolver + +model_type: "volume" + +# -- Data-to-model mapping (see src/collate.py MODEL_MAPPINGS) --------------- +data_mapping: "transolver_automotive_volume" + +# -- Model (Transolver) ----------------------------------------------------- +model: + _target_: physicsnemo.models.transolver.Transolver + functional_dim: 3 # fx = freestream velocity U_inf (3) + out_dim: 5 # velocity (3) + pressure (1) + nut (1) + embedding_dim: 3 # embedding = points (3), no normals for volume + n_layers: 12 + n_hidden: 256 + dropout: 0.0 + n_head: 8 + act: "gelu" + mlp_ratio: 4 + slice_num: 256 + unified_pos: false + structured_shape: null # unstructured mesh + use_te: false + plus: false + +# -- Dataset ----------------------------------------------------------------- +dataset: + name: drivaer_ml_volume + sampling_resolution: 200000 # override SubsampleMesh n_cells/n_points in all dataset configs + targets: + velocity: vector + pressure: scalar + nut: scalar + nondim_types: + velocity: velocity + nut: identity + metrics: [l2] + +# -- Training ---------------------------------------------------------------- +training: + seed: 42 # set to null to leave RNGs unseeded + num_epochs: 500 + save_interval: 25 + loss_type: "huber" + batch_size: 1 + scheduler_update_mode: "epoch" + + optimizer: + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + step_size: 100 + gamma: 0.5 + +# -- DataLoader / Dataset Performance Tuning -------------------------------- +dataloader: + prefetch_factor: 1 + num_streams: 1 + use_streams: false # SDF computation is not stream-safe + num_workers: 1 + pin_memory: true + +# -- MLflow Experiment Tracking ---------------------------------------------- +mlflow: + tracking_uri: null + experiment_name: "unified_external_aero" + run_name: null + log_every_n_steps: 10 + tags: {} + +# -- Data paths -------------------------------------------------------------- +# Splits are read from manifest.json; change train_split / val_split to +# switch experiments. +data: + drivaer_ml: + config: conf/dataset/drivaer_ml_volume.yaml + manifest: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/PhysicsNeMo-DrivaerML/manifest.json + train_split: train + val_split: val diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/requirements.txt b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/requirements.txt new file mode 100644 index 0000000000..2d6835ff92 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/requirements.txt @@ -0,0 +1,3 @@ +tabulate +tensorboard +mlflow \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/__init__.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/__init__.py new file mode 100644 index 0000000000..af85283aa4 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/collate.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/collate.py new file mode 100644 index 0000000000..ba6129b7c7 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/collate.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data-to-model mapping for converting datapipe output to model batch format. + +The datapipe produces ``(TensorDict, metadata_dict)`` tuples. A *mapping +specification* defines how TensorDict fields are extracted, optionally +concatenated, and assembled into the ``dict[str, Tensor]`` batch expected +by ``model.forward()``. + +Mapping specs are plain dictionaries registered in :data:`MODEL_MAPPINGS`:: + + MODEL_MAPPINGS = { + "geotransolver_automotive_surface": { + "geometry": "input/points", + "local_embedding": ["input/points", "input/normals"], + "local_positions": "input/points", + "global_embedding": "input/U_inf", + "fields": ["output/pressure", "output/wss"], + }, + } + +Each value is either: + +* A **string** path (``"group/key"``) — extract that tensor directly. +* A **list** of paths — extract each tensor, then concatenate along the + last dimension. + +The ``"fields"`` key is treated as the prediction target by the training +loop (popped from the batch before ``model(**batch)``). +""" + +from __future__ import annotations + +from typing import Callable + +import torch +from tensordict import TensorDict + +MappingSpec = dict[str, str | list[str]] + + +# --------------------------------------------------------------------------- +# Mapping registry — add new model mappings here +# The idea here is to build a dictionary to map datapipe outputs +# to model inputs. We can make it relatively targeted between +# model and application, and you can extend it to new models / domains. +# --------------------------------------------------------------------------- + +MODEL_MAPPINGS: dict[str, MappingSpec] = { + # Automotive surface: concatenates points+normals into local_embedding + # (breaks equivariance by design — GeoTransolver learns to disentangle). + "geotransolver_automotive_surface": { + "geometry": "input/points", + "local_embedding": ["input/points", "input/normals"], + "local_positions": "input/points", + "global_embedding": "input/U_inf", + "fields": ["output/pressure", "output/wss"], + }, + # High-lift airplane surface: compressible fields (P, T, rho, U, tau_wall). + "geotransolver_highlift_surface": { + "geometry": "input/points", + "local_embedding": ["input/points", "input/normals"], + "local_positions": "input/points", + "global_embedding": "input/U_inf", + "fields": [ + "output/pressure", + "output/temperature", + "output/density", + "output/velocity", + "output/tau_wall", + ], + }, + # High-lift airplane volume: SDF + normals from STL surface. + "geotransolver_highlift_volume": { + "geometry": "input/points", + "local_embedding": ["input/points", "input/sdf", "input/sdf_normals"], + "local_positions": "input/points", + "global_embedding": "input/U_inf", + "fields": [ + "output/pressure", + "output/temperature", + "output/density", + "output/velocity", + ], + }, + # Automotive volume: SDF + normals from STL surface, incompressible fields. + "geotransolver_automotive_volume": { + "geometry": "input/points", + "local_embedding": ["input/points", "input/sdf", "input/sdf_normals"], + "local_positions": "input/points", + "global_embedding": "input/U_inf", + "fields": ["output/velocity", "output/pressure", "output/nut"], + }, + # Automotive surface (Transolver): embedding = points+normals, fx = freestream velocity. + # fx is broadcast from (B,1,3) to (B,N,3) via broadcast_global in train.py. + "transolver_automotive_surface": { + "embedding": ["input/points", "input/normals"], + "fx": "input/U_inf", + "fields": ["output/pressure", "output/wss"], + }, + # Automotive volume (Transolver): SDF + normals from STL surface, fx = freestream velocity. + "transolver_automotive_volume": { + "embedding": ["input/points", "input/sdf", "input/sdf_normals"], + "fx": "input/U_inf", + "fields": ["output/velocity", "output/pressure", "output/nut"], + }, + # Automotive surface (FLARE): same interface as Transolver (fx + embedding). + "flare_automotive_surface": { + "embedding": ["input/points", "input/normals"], + "fx": "input/U_inf", + "fields": ["output/pressure", "output/wss"], + }, + # Automotive volume (FLARE): SDF + normals from STL surface. + "flare_automotive_volume": { + "embedding": ["input/points", "input/sdf", "input/sdf_normals"], + "fx": "input/U_inf", + "fields": ["output/velocity", "output/pressure", "output/nut"], + }, + # ----------------------------------------------------------------------- + # DoMINO mappings + # ----------------------------------------------------------------------- + # DoMINO.forward() takes a single ``data_dict`` argument instead of + # keyword-per-tensor. The ``_wrap_as`` key tells the collate layer to + # pack every non-``fields`` tensor into a nested dict with that name so + # that ``model(**batch)`` expands to ``model(data_dict={...})``. + # + # The datapipe keys below assume the dataset YAML's + # RestructureTensorDict produces DoMINO-compatible names under input/. + # ----------------------------------------------------------------------- + "domino_automotive_surface": { + "_wrap_as": "data_dict", + "geometry_coordinates": "input/geometry_coordinates", + "surf_grid": "input/surf_grid", + "sdf_surf_grid": "input/sdf_surf_grid", + "global_params_values": "input/global_params_values", + "global_params_reference": "input/global_params_reference", + "pos_surface_center_of_mass": "input/pos_surface_center_of_mass", + "surface_mesh_centers": "input/surface_mesh_centers", + "surface_mesh_neighbors": "input/surface_mesh_neighbors", + "surface_normals": "input/surface_normals", + "surface_neighbors_normals": "input/surface_neighbors_normals", + "surface_areas": "input/surface_areas", + "surface_neighbors_areas": "input/surface_neighbors_areas", + "surface_min_max": "input/surface_min_max", + "fields": ["output/pressure", "output/wss"], + }, + "domino_automotive_volume": { + "_wrap_as": "data_dict", + "geometry_coordinates": "input/geometry_coordinates", + "grid": "input/grid", + "surf_grid": "input/surf_grid", + "sdf_grid": "input/sdf_grid", + "sdf_surf_grid": "input/sdf_surf_grid", + "sdf_nodes": "input/sdf_nodes", + "global_params_values": "input/global_params_values", + "global_params_reference": "input/global_params_reference", + "pos_volume_closest": "input/pos_volume_closest", + "pos_volume_center_of_mass": "input/pos_volume_center_of_mass", + "volume_mesh_centers": "input/volume_mesh_centers", + "volume_min_max": "input/volume_min_max", + "fields": ["output/velocity", "output/pressure", "output/nut"], + }, +} + + +# --------------------------------------------------------------------------- +# Core helpers +# --------------------------------------------------------------------------- + + +def _extract(td: TensorDict, path: str) -> torch.Tensor: + """Extract a tensor from a TensorDict using a ``/``-separated path.""" + keys = path.split("/") + result = td + for key in keys: + result = result[key] + return result + + +def _resolve_spec(td: TensorDict, spec: str | list[str]) -> torch.Tensor: + """Resolve one mapping spec to a single tensor. + + - String spec: extract and ensure at least 2-D (adds a leading token dim). + - List spec: extract each path, align ndim, and concatenate along last dim. + """ + if isinstance(spec, str): + tensor = _extract(td, spec) + # Scalars / 1-D vectors (e.g. U_inf as (3,)) need a leading + # token dimension so they stack to (B, 1, D). + while tensor.ndim < 2: + tensor = tensor.unsqueeze(0) + return tensor + + tensors = [_extract(td, s) for s in spec] + # Align ndim before concatenation (e.g. pressure (N,) with + # wss (N, 3) — unsqueeze pressure to (N, 1)). + max_ndim = max(t.ndim for t in tensors) + tensors = [t.unsqueeze(-1) if t.ndim < max_ndim else t for t in tensors] + return torch.cat(tensors, dim=-1) + + +def map_data_to_model( + samples: list[tuple[TensorDict, dict]], + mapping: MappingSpec, + *, + wrap_as: str | None = None, +) -> dict[str, torch.Tensor]: + """Stack datapipe samples into a model-ready batch. + + Each sample is a ``(data, metadata)`` tuple where ``data`` is a TensorDict + with groups produced by + :class:`~physicsnemo.datapipes.transforms.mesh.RestructureTensorDict`. + + Parameters + ---------- + samples : list[tuple[TensorDict, dict]] + List of ``(data, metadata)`` pairs from the datapipe. + mapping : dict[str, str | list[str]] + Mapping from model batch keys to datapipe TensorDict paths. + A string value extracts that field directly; a list of strings + extracts each field and concatenates them along the last dimension. + Keys starting with ``_`` are metadata and are skipped. + wrap_as : str or None, optional + When set, all non-``fields`` tensors are packed into a nested dict + under this key. Used for models whose ``forward()`` accepts a + single dict argument (e.g. DoMINO's ``data_dict``). + + Returns + ------- + dict[str, torch.Tensor | dict[str, torch.Tensor]] + Batch dictionary ready for model consumption. + """ + real_mapping = {k: v for k, v in mapping.items() if not k.startswith("_")} + accumulators: dict[str, list[torch.Tensor]] = {key: [] for key in real_mapping} + + for data, _meta in samples: + for model_key, spec in real_mapping.items(): + accumulators[model_key].append(_resolve_spec(data, spec)) + + batch = {key: torch.stack(vals) for key, vals in accumulators.items()} + + if wrap_as is not None: + fields = batch.pop("fields") + return {wrap_as: batch, "fields": fields} + + return batch + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def build_collate_fn( + mapping: str | MappingSpec, +) -> Callable[[list[tuple[TensorDict, dict]]], dict[str, torch.Tensor]]: + """Return a collate function that applies a data-to-model mapping. + + Parameters + ---------- + mapping : str or dict + Either a key in :data:`MODEL_MAPPINGS` or an explicit mapping dict. + + Returns + ------- + Callable + A function suitable for ``DataLoader(collate_fn=...)``. + + Raises + ------ + ValueError + If *mapping* is a string not found in :data:`MODEL_MAPPINGS`. + """ + if isinstance(mapping, str): + if mapping not in MODEL_MAPPINGS: + raise ValueError( + f"Unknown mapping {mapping!r}. Available: {list(MODEL_MAPPINGS.keys())}" + ) + resolved = MODEL_MAPPINGS[mapping] + else: + resolved = mapping + + wrap_as = resolved.get("_wrap_as") + + def collate_fn( + samples: list[tuple[TensorDict, dict]], + ) -> dict[str, torch.Tensor]: + return map_data_to_model(samples, resolved, wrap_as=wrap_as) + + return collate_fn diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py new file mode 100644 index 0000000000..4b409a9e32 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py @@ -0,0 +1,429 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dataset factory functions for external aerodynamics mesh pipelines. + +Builds MeshDataset instances from Hydra-instantiable YAML configs. +Each config's ``pipeline:`` block declares a ``reader:`` and ``transforms:`` +list with ``_target_: ${dp:ComponentName}`` entries, instantiated via +``hydra.utils.instantiate()``. + +The main builder (``build_surface_dataset``) is fully generic and works +for both surface and volume mesh configs -- the distinction is purely in +the YAML transform chain. ``build_dataset`` is provided as a +mesh-type-agnostic alias. +""" + +from __future__ import annotations + +import json +import math +import sys +from pathlib import Path +from typing import Iterator + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +import torch +from torch.utils.data import Sampler +import hydra +from omegaconf import DictConfig, OmegaConf + +import physicsnemo.datapipes # noqa: F401 (registers ${dp:...} resolvers) +from physicsnemo.datapipes import MeshDataset, MultiDataset +from physicsnemo.mesh import DomainMesh, Mesh + +import nondim # noqa: F401 (registers NonDimensionalizeByMetadata) +import sdf # noqa: F401 (registers ComputeSDFFromBoundary, DropBoundary) + + +def load_dataset_config(yaml_path: str | Path) -> DictConfig: + """Load a dataset YAML config and return an OmegaConf DictConfig.""" + return OmegaConf.load(yaml_path) + + +_PATH_KEYS = {"stats_file"} +_CENTER_MESH_SUFFIX = "CenterMesh" + + +def _resolve_transform_paths(t_cfg: DictConfig, base_dir: Path) -> DictConfig: + """Resolve relative file paths in a transform config against *base_dir*. + + Transforms like ``NormalizeMeshFields`` accept ``stats_file`` + parameters that may be relative. When Hydra changes the working + directory these would break, so we resolve them to absolute paths + before instantiation. + """ + for key in _PATH_KEYS: + val = OmegaConf.select(t_cfg, key, default=None) + if val is not None and not Path(val).is_absolute(): + resolved = base_dir / val + if resolved.exists(): + t_cfg = OmegaConf.merge(t_cfg, {key: str(resolved)}) + return t_cfg + + +def _make_metadata_injector(metadata: dict): + """Create a callable that injects dataset metadata into ``global_data``. + + Handles both :class:`Mesh` (single-mesh) and :class:`DomainMesh` + (domain with interior + boundaries). For ``DomainMesh``, metadata + is merged into the domain-level ``global_data`` so that + domain-aware transforms like ``NonDimensionalizeByMetadata`` can + read freestream quantities from ``domain.global_data``. + + The returned callable is prepended to the transform list so that + downstream transforms can read freestream quantities from + ``global_data``. + """ + fields: dict[str, torch.Tensor] = {} + for k, v in metadata.items(): + if isinstance(v, torch.Tensor): + fields[k] = v.float() + elif isinstance(v, (list, tuple)): + fields[k] = torch.tensor(v, dtype=torch.float32) + else: + fields[k] = torch.tensor(v, dtype=torch.float32) + + def inject(data): + if isinstance(data, DomainMesh): + new_gd = data.global_data.clone() + device = data.interior.points.device + dtype = data.interior.points.dtype + for k, v in fields.items(): + new_gd[k] = v.to(device=device, dtype=dtype) + return DomainMesh( + interior=data.interior, + boundaries=data.boundaries, + global_data=new_gd, + ) + # Single Mesh path + new_gd = data.global_data.clone() + for k, v in fields.items(): + new_gd[k] = v.to(device=data.points.device, dtype=data.points.dtype) + return Mesh( + points=data.points, + cells=data.cells, + point_data=data.point_data, + cell_data=data.cell_data, + global_data=new_gd, + ) + + # MeshDataset calls t.apply_to_domain(data) for DomainMesh inputs, + # so the plain function needs this attribute. + inject.apply_to_domain = inject + return inject + + +def build_surface_dataset( + cfg: DictConfig, + base_dir: Path | None = None, + augment: bool = False, + device: str | torch.device | None = "auto", + num_workers: int = 1, + pin_memory: bool = False, +) -> MeshDataset: + """Build a single MeshDataset from a Hydra-style pipeline config. + + Parameters + ---------- + cfg : DictConfig + Dataset config with a ``pipeline:`` block containing ``reader:`` + and ``transforms:`` entries. An optional ``pipeline.augmentations`` + list defines stochastic augmentation transforms (e.g. + ``RandomRotateMesh``, ``RandomTranslateMesh``) that are inserted + after ``CenterMesh`` when *augment* is ``True``. If a top-level + ``metadata:`` block is present, its values are injected into + ``mesh.global_data`` as the first transform step. + base_dir : Path, optional + Root directory for resolving relative paths in transform configs + (e.g. ``stats_file``). Defaults to the recipe root + (two levels above this file). + augment : bool, optional + When ``True``, ``pipeline.augmentations`` transforms are inserted + into the pipeline after ``CenterMesh``. Should be ``False`` for + validation / test datasets. Default ``False``. + device : str or torch.device, optional + Device to transfer mesh data to before transforms. When ``None``, + data stays on CPU. + num_workers : int, default=1 + Number of worker threads for the MeshDataset prefetch pool. + pin_memory : bool, default=False + If True, the reader places tensors in pinned (page-locked) memory + for faster async CPU-to-GPU transfers. + + Returns + ------- + MeshDataset + """ + if base_dir is None: + base_dir = Path(__file__).resolve().parent.parent + + metadata = OmegaConf.to_container( + OmegaConf.select(cfg, "metadata", default=OmegaConf.create({})), + resolve=True, + ) + + reader = hydra.utils.instantiate(cfg.pipeline.reader, pin_memory=pin_memory) + resolved = [] + + # Inject dataset metadata into global_data as the first transform + if metadata: + resolved.append(_make_metadata_injector(metadata)) + + if "transforms" in cfg.pipeline and cfg.pipeline.transforms: + for t in cfg.pipeline.transforms: + t = _resolve_transform_paths(t, base_dir) + resolved.append(hydra.utils.instantiate(t)) + + if augment and "augmentations" in cfg.pipeline and cfg.pipeline.augmentations: + aug = [hydra.utils.instantiate(a) for a in cfg.pipeline.augmentations] + # +1 for the metadata injector prepended above + offset = 1 if metadata else 0 + insert_idx = next( + ( + offset + i + 1 + for i, t_cfg in enumerate(cfg.pipeline.transforms) + if t_cfg.get("_target_", "").endswith(_CENTER_MESH_SUFFIX) + ), + len(resolved), + ) + resolved[insert_idx:insert_idx] = aug + + transforms = resolved if resolved else None + return MeshDataset( + reader, transforms=transforms, device=device, num_workers=num_workers + ) + + +# Mesh-type-agnostic alias -- build_surface_dataset is fully generic. +build_dataset = build_surface_dataset + + +def build_multi_surface_dataset(*cfgs: DictConfig) -> MultiDataset: + """Build a MultiDataset from multiple Hydra-style pipeline configs. + + Parameters + ---------- + *cfgs : DictConfig + One config per dataset, each with a ``pipeline:`` block. + + Returns + ------- + MultiDataset + """ + datasets = [build_surface_dataset(c) for c in cfgs] + return MultiDataset(*datasets, output_strict=False) + + +# --------------------------------------------------------------------------- +# Manifest-based split support +# --------------------------------------------------------------------------- + + +def load_manifest(path: str | Path, *, split: str | None = None) -> list[str]: + """Load a split manifest file. + + Supports three formats: + + - **JSON dict** (with *split*): a dict of ``{split_name: [paths, ...]}``. + The *split* key selects which list to return. This is the format + used by ``PhysicsNeMo-HighLiftAeroML/manifest.json``. + - **JSON list** (without *split*): a flat list of sub-path strings. + - **Text** (without *split*): one sub-path per line (blank lines and + ``#`` comments are stripped). + + Parameters + ---------- + path : str or Path + Path to the manifest file. + split : str, optional + Key to extract from a JSON dict manifest (e.g. + ``"single_aoa_4_train"``). Required when the manifest is a dict, + ignored for flat list / text manifests. + + Returns + ------- + list[str] + Sorted list of sub-path strings. + + Raises + ------ + KeyError + If *split* is given but not found in the manifest dict. + ValueError + If the manifest format doesn't match expectations. + """ + p = Path(path) + text = p.read_text() + # Try JSON first + try: + data = json.loads(text) + if isinstance(data, dict): + if split is None: + raise ValueError( + f"Manifest {p.name} is a JSON dict; " + f"a 'split' key is required. " + f"Available keys: {list(data.keys())[:10]}" + ) + if split not in data: + raise KeyError( + f"Split {split!r} not found in manifest. " + f"Available: {list(data.keys())}" + ) + entries = data[split] + elif isinstance(data, list): + entries = data + else: + raise ValueError( + f"Manifest JSON must be a list or dict, got {type(data).__name__}" + ) + return sorted(str(e) for e in entries) + except json.JSONDecodeError: + pass + # Fall back to one-per-line text + entries = [] + for line in text.splitlines(): + line = line.strip() + if line and not line.startswith("#"): + entries.append(line) + return sorted(entries) + + +def resolve_manifest_indices( + reader, + manifest_entries: list[str], +) -> list[int]: + """Map manifest sub-paths to reader sample indices. + + Each manifest entry is matched against the reader's discovered paths. + A reader path matches if any of its parent directories (relative to + the reader root) equals the manifest entry. + + Parameters + ---------- + reader : MeshReader or DomainMeshReader + An instantiated reader with ``_root`` and ``_paths`` attributes. + manifest_entries : list[str] + Sub-path strings from the manifest (e.g. ``["run_1", "run_5"]``). + + Returns + ------- + list[int] + Sorted list of reader indices whose paths match the manifest. + + Raises + ------ + ValueError + If no reader paths match any manifest entry. + """ + entry_set = set(manifest_entries) + indices = [] + for idx, full_path in enumerate(reader._paths): + try: + rel = full_path.relative_to(reader._root) + except ValueError: + continue + # Check if any parent component matches a manifest entry + # e.g. rel = "run_1/domain_1.pmsh" -> parts = ("run_1", "domain_1.pmsh") + for part in rel.parts[:-1]: + if part in entry_set: + indices.append(idx) + break + else: + # Also check if the immediate parent dir name matches + if rel.parent.name in entry_set: + indices.append(idx) + if not indices: + raise ValueError( + f"No reader paths matched manifest entries. " + f"Reader root: {reader._root}, " + f"sample entries: {list(entry_set)[:5]}" + ) + return sorted(indices) + + +class ManifestSampler(Sampler[int]): + """Sampler that restricts iteration to a subset of dataset indices. + + Supports shuffling with epoch-aware seeding and distributed sharding. + + Parameters + ---------- + indices : list[int] + Dataset indices that belong to this split. + shuffle : bool + Whether to shuffle indices each epoch. + seed : int + Base random seed for reproducible shuffling. + rank : int + Current process rank (for distributed sharding). 0 for single-GPU. + world_size : int + Total number of processes. 1 for single-GPU. + drop_last : bool + If True, drop tail indices so every rank gets the same count. + """ + + def __init__( + self, + indices: list[int], + shuffle: bool = True, + seed: int = 0, + rank: int = 0, + world_size: int = 1, + drop_last: bool = False, + ) -> None: + self._indices = list(indices) + self._shuffle = shuffle + self._seed = seed + self._rank = rank + self._world_size = world_size + self._drop_last = drop_last + self._epoch = 0 + + def set_epoch(self, epoch: int) -> None: + """Set the epoch for deterministic shuffling.""" + self._epoch = epoch + + def __len__(self) -> int: + n = len(self._indices) + if self._world_size > 1: + if self._drop_last: + n = n // self._world_size + else: + n = math.ceil(n / self._world_size) + return n + + def __iter__(self) -> Iterator[int]: + indices = list(self._indices) + if self._shuffle: + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) + perm = torch.randperm(len(indices), generator=g).tolist() + indices = [indices[i] for i in perm] + + if self._world_size > 1: + # Pad to make evenly divisible + if not self._drop_last: + padding = math.ceil( + len(indices) / self._world_size + ) * self._world_size - len(indices) + indices += indices[:padding] + # Shard + indices = indices[self._rank :: self._world_size] + + return iter(indices) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py new file mode 100644 index 0000000000..94450bc8a0 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flexible loss calculator for configurable target fields.""" + +from __future__ import annotations + +from typing import Literal + +import torch +import torch.nn.functional as F + +from utils import FieldSpec, parse_target_config + +# Default delta for Huber loss +DEFAULT_HUBER_DELTA = 1.0 + + +# --------------------------------------------------------------------------- +# Core loss functions operating on tensors +# --------------------------------------------------------------------------- + + +def compute_huber( + pred: torch.Tensor, target: torch.Tensor, delta: float = DEFAULT_HUBER_DELTA +) -> torch.Tensor: + """Huber loss (smooth L1) for scalar fields. + + Huber loss is quadratic for small errors and linear for large errors, + making it more robust to outliers than MSE. + + Args: + pred: Predictions tensor + target: Targets tensor + delta: Threshold at which to switch from quadratic to linear. + + Returns: + Mean Huber loss as a scalar tensor. + """ + return F.huber_loss(pred, target, reduction="mean", delta=delta) + + +def compute_mse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Mean Squared Error loss.""" + return torch.mean((pred - target) ** 2.0) + + +def compute_rmse( + pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8 +) -> torch.Tensor: + """Relative Mean Squared Error (normalized by target magnitude).""" + num = torch.mean((pred - target) ** 2.0) + denom = torch.mean(target**2.0) + return num / (denom + eps) + + +def compute_huber_vector( + pred: torch.Tensor, target: torch.Tensor, delta: float = DEFAULT_HUBER_DELTA +) -> torch.Tensor: + """Huber loss for vector fields, summed across components. + + Args: + pred: Predictions of shape [batch, points, dim] + target: Targets of shape [batch, points, dim] + delta: Threshold at which to switch from quadratic to linear. + + Returns: + Sum of per-component Huber losses. + """ + # Compute Huber loss per component + total_loss = torch.tensor(0.0, device=pred.device, dtype=pred.dtype) + for i in range(pred.shape[-1]): + total_loss = total_loss + F.huber_loss( + pred[:, :, i], target[:, :, i], reduction="mean", delta=delta + ) + return total_loss + + +def compute_mse_vector(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """MSE for vector fields, summed across components. + + Args: + pred: Predictions of shape [batch, points, dim] + target: Targets of shape [batch, points, dim] + + Returns: + Sum of per-component MSE losses. + """ + # Compute mean squared diff per component, keeping last dim + diff_sq = torch.mean((pred - target) ** 2.0, dim=(0, 1)) + return torch.sum(diff_sq) + + +def compute_rmse_vector( + pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8 +) -> torch.Tensor: + """Relative MSE for vector fields, normalized per component then summed. + + Args: + pred: Predictions of shape [batch, points, dim] + target: Targets of shape [batch, points, dim] + eps: Small value to avoid division by zero. + + Returns: + Sum of per-component relative MSE losses. + """ + # Compute mean squared diff per component + diff_sq = torch.mean((pred - target) ** 2.0, dim=(0, 1)) + # Compute mean squared target per component + target_sq = torch.mean(target**2.0, dim=(0, 1)) + return torch.sum(diff_sq / (target_sq + eps)) + + +LOSS_FUNCTIONS_SCALAR = { + "huber": compute_huber, + "mse": compute_mse, + "rmse": compute_rmse, +} + +LOSS_FUNCTIONS_VECTOR = { + "huber": compute_huber_vector, + "mse": compute_mse_vector, + "rmse": compute_rmse_vector, +} + + +# --------------------------------------------------------------------------- +# LossCalculator class +# --------------------------------------------------------------------------- + + +class LossCalculator: + """Configurable loss calculator for scalar and vector target fields. + + Computes loss for each configured target field separately, then combines them. + Supports Huber, MSE, and RMSE (relative MSE) loss types. + + For vector fields, computes per-component losses and sums them. + The final loss is normalized by the total number of channels. + + Parameters + ---------- + target_config : dict[str, str] + Mapping of field names to types. Order determines channel indices. + Example: {"pressure": "scalar", "velocity": "vector", "turbulence": "scalar"} + loss_type : Literal["huber", "mse", "rmse"], optional + Type of loss to compute. Default is "huber". + - "huber": Huber loss (smooth L1), robust to outliers + - "mse": Mean Squared Error + - "rmse": Relative MSE (normalized by target magnitude) + n_spatial_dims : int, optional + Dimensionality of vector fields. Default is 3. + prefix : str, optional + Prefix for all loss names (e.g., "surface" -> "loss/surface/pressure"). + Default is empty string. + normalize_by_channels : bool, optional + Whether to normalize the total loss by the number of channels. + Default is True. + + Examples + -------- + >>> calc = LossCalculator( + ... target_config={"pressure": "scalar", "wall_shear": "vector"}, + ... loss_type="huber", + ... prefix="surface", + ... ) + >>> pred = torch.randn(2, 100, 4) # [batch, points, channels] + >>> target = torch.randn(2, 100, 4) + >>> total_loss, loss_dict = calc(pred, target) + """ + + def __init__( + self, + target_config: dict[str, str], + loss_type: Literal["huber", "mse", "rmse"] = "mse", + n_spatial_dims: int = 3, + prefix: str = "", + normalize_by_channels: bool = True, + ): + self.loss_type = loss_type + self.n_spatial_dims = n_spatial_dims + self.prefix = prefix + self.normalize_by_channels = normalize_by_channels + + # Validate loss type + if loss_type not in LOSS_FUNCTIONS_SCALAR: + raise ValueError( + f"Unknown loss type '{loss_type}'. " + f"Available: {list(LOSS_FUNCTIONS_SCALAR.keys())}" + ) + + # Parse target config to build field specifications using shared utility + self.field_specs = parse_target_config(target_config, n_spatial_dims) + self.total_channels = sum(spec.dim for spec in self.field_specs) + + def _make_key(self, *parts: str) -> str: + """Construct a loss key with optional prefix.""" + segments = ["loss"] + if self.prefix: + segments.append(self.prefix) + segments.extend(parts) + return "/".join(segments) + + def _compute_scalar_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + name: str, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute loss for a scalar field [batch, points]. + + Returns: + Tuple of (loss_value, {loss_key: loss_value}) + """ + loss_fn = LOSS_FUNCTIONS_SCALAR[self.loss_type] + loss = loss_fn(pred, target) + return loss, {self._make_key(name): loss} + + def _compute_vector_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + name: str, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute loss for a vector field [batch, points, dim]. + + Returns: + Tuple of (loss_value, {loss_key: loss_value}) + """ + loss_fn = LOSS_FUNCTIONS_VECTOR[self.loss_type] + loss = loss_fn(pred, target) + return loss, {self._make_key(name): loss} + + def __call__( + self, + pred: torch.Tensor, + target: torch.Tensor, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute losses for all configured fields. + + Args: + pred: Predicted values, shape [batch, points, channels]. + target: Target values, shape [batch, points, channels]. + + Returns: + Tuple of: + - total_loss: Combined loss as a scalar tensor + - loss_dict: Dictionary of loss name -> scalar tensor value + """ + if pred.shape[-1] != self.total_channels: + raise ValueError( + f"Expected {self.total_channels} channels based on target config, " + f"but got {pred.shape[-1]}." + ) + + total_loss = torch.tensor(0.0, device=pred.device, dtype=pred.dtype) + loss_dict = {} + + for spec in self.field_specs: + pred_field = pred[:, :, spec.start_index : spec.end_index] + target_field = target[:, :, spec.start_index : spec.end_index] + + if spec.field_type == "scalar": + field_loss, field_dict = self._compute_scalar_loss( + pred_field.squeeze(-1), target_field.squeeze(-1), spec.name + ) + else: + field_loss, field_dict = self._compute_vector_loss( + pred_field, target_field, spec.name + ) + + total_loss = total_loss + field_loss + loss_dict.update(field_dict) + + if self.normalize_by_channels: + total_loss = total_loss / self.total_channels + + # Add total loss to dict + total_key = f"loss/{self.prefix}" if self.prefix else "loss/total" + loss_dict[total_key] = total_loss + + return total_loss, loss_dict + + def __repr__(self) -> str: + fields_str = ", ".join( + f"{s.name}:{s.field_type}[{s.start_index}:{s.end_index}]" + for s in self.field_specs + ) + return ( + f"LossCalculator(fields=[{fields_str}], " + f"loss_type='{self.loss_type}', prefix='{self.prefix}')" + ) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/metrics.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/metrics.py new file mode 100644 index 0000000000..dea69ee778 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/metrics.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flexible metric calculator for configurable target fields.""" + +from __future__ import annotations + +import torch +import torch.distributed as dist + +from utils import FieldSpec, parse_target_config + + +# --------------------------------------------------------------------------- +# Core metric functions operating on [batch, points] tensors +# --------------------------------------------------------------------------- + + +def compute_mae(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Mean Absolute Error (absolute).""" + return torch.mean(torch.abs(pred - target)) + + +def compute_relative_l1( + pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8 +) -> torch.Tensor: + """Relative L1: sum|diff| / sum|target| per sample, then mean.""" + abs_diff = torch.abs(pred - target) + l1_num = torch.sum(abs_diff, dim=1) + l1_denom = torch.sum(torch.abs(target), dim=1) + return torch.mean(l1_num / (l1_denom + eps)) + + +def compute_relative_l2( + pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8 +) -> torch.Tensor: + """Relative L2: sqrt(sum(diff^2)) / sqrt(sum(target^2)) per sample, then mean.""" + diff = pred - target + l2_num = torch.sqrt(torch.sum(diff**2, dim=1)) + l2_denom = torch.sqrt(torch.sum(target**2, dim=1)) + return torch.mean(l2_num / (l2_denom + eps)) + + +METRIC_FUNCTIONS = { + "mae": compute_mae, + "l1": compute_relative_l1, + "l2": compute_relative_l2, +} + + +# --------------------------------------------------------------------------- +# MetricCalculator class +# --------------------------------------------------------------------------- + + +class MetricCalculator: + """Configurable metric calculator for scalar and vector target fields. + + Computes L1, L2, and MAE metrics for each configured target field. + For vector fields, computes both elementwise metrics (per component) + and aggregate metrics (magnitude-based). + + Parameters + ---------- + target_config : dict[str, str] + Mapping of field names to types. Order determines channel indices. + Example: {"pressure": "scalar", "velocity": "vector", "turbulence": "scalar"} + process_group : dist.ProcessGroup | None, optional + Process group for distributed all-reduce. If None, no reduction is performed. + n_spatial_dims : int, optional + Dimensionality of vector fields. Default is 3. + metrics : list[str] | None, optional + Which metrics to compute. Options: "l1", "l2", "mae". + Default is all three: ["l1", "l2", "mae"]. + prefix : str, optional + Prefix for all metric names (e.g., "surface" -> "surface/pressure_l1"). + Default is empty string. + + Examples + -------- + >>> calc = MetricCalculator( + ... target_config={"pressure": "scalar", "velocity": "vector"}, + ... prefix="surface", + ... ) + >>> pred = torch.randn(2, 100, 4) # [batch, points, channels] + >>> target = torch.randn(2, 100, 4) + >>> metrics = calc(pred, target) + """ + + VECTOR_COMPONENTS = ("x", "y", "z") + + def __init__( + self, + target_config: dict[str, str], + process_group: dist.ProcessGroup | None = None, + n_spatial_dims: int = 3, + metrics: list[str] | None = None, + prefix: str = "", + ): + self.process_group = process_group + self.n_spatial_dims = n_spatial_dims + self.metric_names = metrics if metrics is not None else ["l1", "l2", "mae"] + self.prefix = prefix + + # Validate metric names + for m in self.metric_names: + if m not in METRIC_FUNCTIONS: + raise ValueError( + f"Unknown metric '{m}'. Available: {list(METRIC_FUNCTIONS.keys())}" + ) + + # Parse target config to build field specifications using shared utility + self.field_specs = parse_target_config(target_config, n_spatial_dims) + self.total_channels = sum(spec.dim for spec in self.field_specs) + + def _make_key(self, *parts: str) -> str: + """Construct a metric key with optional prefix.""" + key = "_".join(parts) + return f"{self.prefix}/{key}" if self.prefix else key + + def _compute_metrics_for_field( + self, + pred: torch.Tensor, + target: torch.Tensor, + name: str, + ) -> dict[str, torch.Tensor]: + """Compute all configured metrics for a [batch, points] field.""" + return { + self._make_key(name, m): METRIC_FUNCTIONS[m](pred, target) + for m in self.metric_names + } + + def _compute_scalar_metrics( + self, + pred: torch.Tensor, + target: torch.Tensor, + name: str, + ) -> dict[str, torch.Tensor]: + """Compute metrics for a scalar field [batch, points].""" + return self._compute_metrics_for_field(pred, target, name) + + def _compute_vector_metrics( + self, + pred: torch.Tensor, + target: torch.Tensor, + name: str, + ) -> dict[str, torch.Tensor]: + """Compute metrics for a vector field [batch, points, dim]. + + Computes elementwise (per component) and aggregate (magnitude) metrics. + """ + metrics = {} + + # Elementwise metrics (per component) + for i, comp in enumerate(self.VECTOR_COMPONENTS[: pred.shape[-1]]): + comp_metrics = self._compute_metrics_for_field( + pred[:, :, i], target[:, :, i], f"{name}_{comp}" + ) + metrics.update(comp_metrics) + + # Aggregate metrics (magnitude) + pred_mag = torch.linalg.vector_norm(pred, dim=-1) + target_mag = torch.linalg.vector_norm(target, dim=-1) + metrics.update(self._compute_metrics_for_field(pred_mag, target_mag, name)) + + return metrics + + def _all_reduce(self, metrics: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """All-reduce metrics across the process group.""" + if self.process_group is None: + return metrics + + world_size = dist.get_world_size(self.process_group) + if world_size == 1: + return metrics + + keys = list(metrics.keys()) + stacked = torch.stack([metrics[k] for k in keys]) + + dist.all_reduce(stacked, group=self.process_group) + stacked = stacked / world_size + + return {k: stacked[i] for i, k in enumerate(keys)} + + def __call__( + self, + pred: torch.Tensor, + target: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """Compute all configured metrics. + + Args: + pred: Predicted values, shape [batch, points, channels]. + target: Target values, shape [batch, points, channels]. + + Returns: + Dictionary of metric name -> scalar tensor value. + """ + if pred.shape[-1] != self.total_channels: + raise ValueError( + f"Expected {self.total_channels} channels based on target config, " + f"but got {pred.shape[-1]}." + ) + + metrics = {} + + with torch.no_grad(): + for spec in self.field_specs: + pred_field = pred[:, :, spec.start_index : spec.end_index] + target_field = target[:, :, spec.start_index : spec.end_index] + + if spec.field_type == "scalar": + field_metrics = self._compute_scalar_metrics( + pred_field.squeeze(-1), target_field.squeeze(-1), spec.name + ) + else: + field_metrics = self._compute_vector_metrics( + pred_field, target_field, spec.name + ) + + metrics.update(field_metrics) + + metrics = self._all_reduce(metrics) + + return metrics + + def __repr__(self) -> str: + fields_str = ", ".join( + f"{s.name}:{s.field_type}[{s.start_index}:{s.end_index}]" + for s in self.field_specs + ) + return ( + f"MetricCalculator(fields=[{fields_str}], " + f"metrics={self.metric_names}, prefix='{self.prefix}')" + ) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py new file mode 100644 index 0000000000..72440d0ebc --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Physics-based non-dimensionalization transform. + +Recipe-local transform registered into the global datapipe component +registry so it can be referenced via ``${dp:NonDimensionalizeByMetadata}`` +in Hydra YAML configs. + +Import this module before Hydra instantiation to register the transform. +""" + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.mesh.base import MeshTransform +from physicsnemo.mesh import DomainMesh, Mesh + + +def _get_mesh_section(mesh: Mesh, section: str) -> TensorDict: + """Look up a Mesh data section by name.""" + if section == "point_data": + return mesh.point_data + if section == "cell_data": + return mesh.cell_data + if section == "global_data": + return mesh.global_data + raise ValueError(f"Unknown mesh section: {section!r}") + + +def _freestream_scales( + global_data: TensorDict, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Derive reference scales from freestream metadata (cast to float32 once). + + Returns ``(q_inf, p_inf, U_inf_mag, rho_inf, T_inf)`` where + ``q_inf = 0.5 * rho_inf * |U_inf|^2``. ``T_inf`` is ``None`` + when the metadata does not contain a freestream temperature (e.g. + incompressible datasets). + """ + U_inf = global_data["U_inf"].float() + rho_inf = global_data["rho_inf"].float() + p_inf = global_data["p_inf"].float() + U_inf_mag_sq = (U_inf * U_inf).sum() + q_inf = 0.5 * rho_inf * U_inf_mag_sq + U_inf_mag = U_inf_mag_sq.sqrt() + T_inf = global_data["T_inf"].float() if "T_inf" in global_data else None + return q_inf, p_inf, U_inf_mag, rho_inf, T_inf + + +_FIELD_TYPES = frozenset( + {"pressure", "stress", "velocity", "temperature", "density", "identity"} +) + +# Number of tensor channels each field type occupies. +_FIELD_CHANNELS = { + "pressure": 1, + "stress": 3, + "velocity": 3, + "temperature": 1, + "density": 1, + "identity": 1, +} + + +def _nondim_field( + val: torch.Tensor, + ftype: str, + q_inf: torch.Tensor, + p_inf: torch.Tensor, + U_inf_mag: torch.Tensor, + *, + rho_inf: torch.Tensor | None = None, + T_inf: torch.Tensor | None = None, +) -> torch.Tensor: + """Apply forward non-dimensionalization to a single field.""" + if ftype == "identity": + return val + if ftype == "pressure": + return (val - p_inf) / q_inf + if ftype == "stress": + return val / q_inf + if ftype == "velocity": + return val / U_inf_mag + if ftype == "temperature": + if T_inf is None: + raise ValueError("T_inf required for temperature non-dimensionalization") + return val / T_inf + if ftype == "density": + if rho_inf is None: + raise ValueError("rho_inf required for density non-dimensionalization") + return val / rho_inf + raise ValueError(f"Unknown field type: {ftype!r}") + + +def _redim_field( + val: torch.Tensor, + ftype: str, + q_inf: torch.Tensor, + p_inf: torch.Tensor, + U_inf_mag: torch.Tensor, + *, + rho_inf: torch.Tensor | None = None, + T_inf: torch.Tensor | None = None, +) -> torch.Tensor: + """Reverse non-dimensionalization for a single field.""" + if ftype == "identity": + return val + if ftype == "pressure": + return val * q_inf + p_inf + if ftype == "stress": + return val * q_inf + if ftype == "velocity": + return val * U_inf_mag + if ftype == "temperature": + if T_inf is None: + raise ValueError("T_inf required for temperature re-dimensionalization") + return val * T_inf + if ftype == "density": + if rho_inf is None: + raise ValueError("rho_inf required for density re-dimensionalization") + return val * rho_inf + raise ValueError(f"Unknown field type: {ftype!r}") + + +@register() +class NonDimensionalizeByMetadata(MeshTransform): + r"""Non-dimensionalize fields and geometry using freestream conditions from ``global_data``. + + Expects ``U_inf``, ``rho_inf``, and ``p_inf`` to be present in + ``global_data`` (injected by the dataset builder). Computes + the dynamic pressure ``q_inf = 0.5 * rho_inf * |U_inf|^2`` and + applies standard non-dimensionalization formulas: + + - **pressure**: ``(p - p_inf) / q_inf`` (pressure coefficient Cp) + - **stress**: ``tau / q_inf`` (skin-friction coefficient Cf) + - **velocity**: ``U / |U_inf|`` + - **temperature**: ``T / T_inf`` (requires ``T_inf`` in ``global_data``) + - **density**: ``rho / rho_inf`` + - **identity**: pass-through (no scaling applied) + + If ``L_ref`` is present in ``global_data``, mesh points are divided + by it to produce non-dimensional coordinates: ``x* = x / L_ref``. + This normalises point clouds and cell centroids computed downstream. + + Parameters + ---------- + fields : dict[str, str] + Mapping of ``{field_name: field_type}`` where *field_type* is one + of ``"pressure"``, ``"stress"``, ``"velocity"``, ``"temperature"``, + ``"density"``, or ``"identity"``. + section : str + Mesh data section containing the fields (``"point_data"`` or + ``"cell_data"``). + + Example YAML:: + + - _target_: ${dp:NonDimensionalizeByMetadata} + fields: + pMeanTrim: pressure + wallShearStressMeanTrim: stress + section: point_data + """ + + def __init__( + self, + fields: dict[str, str], + section: str = "point_data", + ) -> None: + super().__init__() + for name, ftype in fields.items(): + if ftype not in _FIELD_TYPES: + raise ValueError( + f"Unknown field type {ftype!r} for {name!r}. " + f"Must be one of {sorted(_FIELD_TYPES)}." + ) + self._fields = fields + self._section = section + + def _transform_mesh( + self, + mesh: Mesh, + field_fn, + *, + inverse: bool, + scales: tuple | None = None, + skip_missing: bool = False, + ) -> Mesh: + """Shared implementation for forward and inverse mesh transforms. + + Parameters + ---------- + scales : tuple or None + Pre-computed ``(q_inf, p_inf, U_inf_mag, rho_inf, T_inf, L_ref)`` + to use instead of deriving them from ``mesh.global_data``. + skip_missing : bool + If *True*, silently skip fields not present in the mesh section. + """ + if scales is not None: + q_inf, p_inf, U_inf_mag, rho_inf, T_inf, L_ref = scales + else: + gd = mesh.global_data + q_inf, p_inf, U_inf_mag, rho_inf, T_inf = _freestream_scales(gd) + L_ref = gd["L_ref"].float() if "L_ref" in gd else None + + td = _get_mesh_section(mesh, self._section) + new_td = td.clone() + + for field_name, ftype in self._fields.items(): + if skip_missing and field_name not in new_td.keys(): + continue + val = new_td[field_name].float() + new_td[field_name] = field_fn( + val, + ftype, + q_inf, + p_inf, + U_inf_mag, + rho_inf=rho_inf, + T_inf=T_inf, + ) + + points = mesh.points + if L_ref is not None: + points = points * L_ref if inverse else points / L_ref + + kwargs: dict = { + "points": points, + "cells": mesh.cells, + "point_data": mesh.point_data, + "cell_data": mesh.cell_data, + "global_data": mesh.global_data, + } + kwargs[self._section] = new_td + return Mesh(**kwargs) + + def __call__(self, mesh: Mesh) -> Mesh: + return self._transform_mesh(mesh, _nondim_field, inverse=False) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Non-dimensionalize a DomainMesh using domain-level ``global_data``. + + Freestream scales are read once from ``domain.global_data`` + (where the metadata injector placed them) and applied to the + interior and every boundary mesh. Fields that are not present + on a particular sub-mesh (e.g. volume fields on a surface + boundary) are silently skipped. + """ + gd = domain.global_data + q_inf, p_inf, U_inf_mag, rho_inf, T_inf = _freestream_scales(gd) + L_ref = gd["L_ref"].float() if "L_ref" in gd else None + scales = (q_inf, p_inf, U_inf_mag, rho_inf, T_inf, L_ref) + + return domain._map_meshes( + lambda m: self._transform_mesh( + m, + _nondim_field, + inverse=False, + scales=scales, + skip_missing=True, + ) + ) + + def inverse(self, mesh: Mesh) -> Mesh: + """Re-dimensionalize: reverse the non-dimensionalization. + + Uses the same ``global_data`` metadata (``U_inf``, ``rho_inf``, + ``p_inf``, and optionally ``L_ref``) to convert non-dimensional + fields and geometry back to physical units. + + Parameters + ---------- + mesh : Mesh + Mesh with non-dimensionalized fields and metadata in ``global_data``. + + Returns + ------- + Mesh + Mesh with re-dimensionalized fields. + """ + return self._transform_mesh(mesh, _redim_field, inverse=True) + + def inverse_tensor( + self, + tensor: torch.Tensor, + field_types: dict[str, str], + q_inf: torch.Tensor, + p_inf: torch.Tensor, + U_inf_mag: torch.Tensor, + *, + rho_inf: torch.Tensor | None = None, + T_inf: torch.Tensor | None = None, + ) -> torch.Tensor: + """Re-dimensionalize a concatenated output tensor. + + Operates on model output tensors (shape ``(*, C)``) where channels + are ordered according to *field_types*. This is useful at inference + time when you have a raw model prediction rather than a Mesh. + + Parameters + ---------- + tensor : Tensor + Shape ``(*, C)`` with channels ordered by *field_types*. + field_types : dict[str, str] + Ordered mapping of ``{field_name: nondim_type}`` where + *nondim_type* is one of ``"pressure"``, ``"stress"``, + ``"velocity"``, ``"temperature"``, ``"density"``, or + ``"identity"``. + Uses the model's output field names (e.g. after renaming), + not the original mesh field names. + q_inf, p_inf, U_inf_mag : Tensor + Reference quantities (scalars or broadcastable). + rho_inf : Tensor or None + Freestream density. Required when *field_types* contains + ``"density"``. + T_inf : Tensor or None + Freestream temperature. Required when *field_types* contains + ``"temperature"``. + + Returns + ------- + Tensor + Same shape, with each field's channels re-dimensionalized. + """ + out = tensor.clone() + idx = 0 + for name, ftype in field_types.items(): + n = _FIELD_CHANNELS[ftype] + out[..., idx : idx + n] = _redim_field( + out[..., idx : idx + n], + ftype, + q_inf, + p_inf, + U_inf_mag, + rho_inf=rho_inf, + T_inf=T_inf, + ) + idx += n + return out + + def extra_repr(self) -> str: + return f"fields={self._fields}, section={self._section}" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/sdf.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/sdf.py new file mode 100644 index 0000000000..c4f21d2614 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/sdf.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SDF (Signed Distance Field) pipeline transforms for volume meshes. + +Provides a transform that computes SDF + normals from a boundary surface +onto interior volume points, and a cleanup transform to drop temporary +boundaries before TensorDict conversion. + +These work with ``DomainMeshReader``'s ``extra_boundaries`` parameter, +which loads a sibling STL mesh at full resolution alongside the domain +mesh. The SDF transform reads the injected boundary, computes the +signed distance field, and writes results into ``interior.point_data``. + +Recipe-local module registered into the global datapipe component +registry so components can be referenced via ``${dp:...}`` in Hydra +YAML configs. + +Import this module before Hydra instantiation to register the components. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.mesh.base import MeshTransform +from physicsnemo.mesh import DomainMesh, Mesh +from physicsnemo.nn.functional import signed_distance_field + + +@register() +class ComputeSDFFromBoundary(MeshTransform): + r"""Compute SDF and optionally normals from a boundary surface onto interior points. + + Reads the surface mesh from ``domain.boundaries[boundary_name]`` and + evaluates the signed distance field at every interior point using + :func:`physicsnemo.nn.functional.signed_distance_field` (Warp-backed, + GPU-accelerated). + + The computed SDF is stored as a scalar field ``(N, 1)`` in + ``interior.point_data[sdf_field]``. If ``normals_field`` is set, + approximate surface normals ``(N, 3)`` are also stored, computed as + the normalized direction from each query point to its closest point + on the surface (with center-of-mass fallback for on-surface points). + + Parameters + ---------- + boundary_name : str + Key of the boundary mesh to use as the SDF surface. + sdf_field : str + Name for the SDF field in ``interior.point_data``. + normals_field : str or None + Optional name for the normals field. ``None`` to skip. + use_winding_number : bool + Whether to use winding-number sign computation. Required for + non-watertight meshes; slightly slower. + """ + + def __init__( + self, + boundary_name: str = "stl_geometry", + sdf_field: str = "sdf", + normals_field: Optional[str] = "sdf_normals", + *, + use_winding_number: bool = True, + ) -> None: + super().__init__() + self.boundary_name = boundary_name + self.sdf_field = sdf_field + self.normals_field = normals_field + self.use_winding_number = use_winding_number + + def __call__(self, mesh: Mesh) -> Mesh: + # Single-mesh path is not meaningful for SDF (we need a separate + # surface mesh). Pass through unchanged. + return mesh + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Compute SDF from the boundary surface onto interior points. + + Parameters + ---------- + domain : DomainMesh + Must contain a boundary named ``self.boundary_name`` with + triangle cells. + + Returns + ------- + DomainMesh + Domain with SDF (and optionally normals) injected into + ``interior.point_data``. + """ + if self.boundary_name not in domain.boundaries: + raise KeyError( + f"Boundary {self.boundary_name!r} not found. " + f"Available: {domain.boundary_names}" + ) + + surface = domain.boundaries[self.boundary_name] + vertices = surface.points.float() + faces = surface.cells + + if faces is None or faces.numel() == 0: + raise ValueError( + f"Boundary {self.boundary_name!r} has no cell connectivity " + f"(required for SDF computation)" + ) + + query_points = domain.interior.points.float() + + sdf_values, closest_points = signed_distance_field( + vertices, + faces, + query_points, + use_sign_winding_number=self.use_winding_number, + ) + + # Build updated point_data with SDF (N, 1) + new_pd = domain.interior.point_data.clone() + new_pd[self.sdf_field] = sdf_values.unsqueeze(-1) + + # Optionally compute approximate normals from closest-point direction + if self.normals_field is not None: + normals = query_points - closest_points + + # Fallback for points on the surface (zero distance): + # use direction from center of mass instead. + dist = torch.norm(normals, dim=-1) + on_surface = dist < 1e-6 + if on_surface.any(): + com = vertices.mean(dim=0, keepdim=True) + normals[on_surface] = query_points[on_surface] - com + + # Normalize to unit vectors + norm = torch.norm(normals, dim=-1, keepdim=True).clamp(min=1e-8) + normals = normals / norm + new_pd[self.normals_field] = normals + + new_interior = Mesh( + points=domain.interior.points, + cells=domain.interior.cells, + point_data=new_pd, + cell_data=domain.interior.cell_data, + global_data=domain.interior.global_data, + ) + + return DomainMesh( + interior=new_interior, + boundaries=domain.boundaries, + global_data=domain.global_data, + ) + + def extra_repr(self) -> str: + parts = [ + f"boundary={self.boundary_name!r}", + f"sdf_field={self.sdf_field!r}", + ] + if self.normals_field: + parts.append(f"normals_field={self.normals_field!r}") + parts.append(f"winding_number={self.use_winding_number}") + return ", ".join(parts) + + +@register() +class DropBoundary(MeshTransform): + r"""Remove one or more boundaries from a :class:`DomainMesh`. + + Useful for stripping temporary data (e.g. a full-resolution STL + boundary injected for SDF computation) before downstream transforms + like ``MeshToTensorDict`` that would otherwise serialize the large + surface into the output TensorDict. + + Parameters + ---------- + names : list[str] + Boundary names to remove. + """ + + def __init__(self, names: list[str]) -> None: + super().__init__() + self.names = set(names) + + def __call__(self, mesh: Mesh) -> Mesh: + # Single-mesh path: nothing to drop. + return mesh + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Remove the named boundaries from the domain. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh. + + Returns + ------- + DomainMesh + Domain mesh without the dropped boundaries. + """ + filtered = { + name: bnd + for name, bnd in domain.boundaries.items() + if name not in self.names + } + return DomainMesh( + interior=domain.interior, + boundaries=filtered, + global_data=domain.global_data, + ) + + def extra_repr(self) -> str: + return f"names={sorted(self.names)}" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py new file mode 100644 index 0000000000..57e9107d21 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py @@ -0,0 +1,1128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified External Aerodynamics Training Script + +Trains a point-cloud model (GeoTransolver, Transolver, etc.) on surface +or volume fields using the mesh datapipe infrastructure. + +Usage:: + + # Single-GPU + python src/train.py + + # Multi-GPU with torchrun + torchrun --nproc_per_node=N src/train.py + + # I/O benchmark: iterate dataloaders without model logic + python src/train.py benchmark_io=true profile=true + python src/train.py benchmark_io=true +training.benchmark_max_steps=20 +""" + +import os +import sys +import time +from contextlib import nullcontext +from pathlib import Path + +import hydra +import omegaconf +from omegaconf import DictConfig, OmegaConf + +import torch +from torch.amp import autocast, GradScaler +import mlflow + +from tabulate import tabulate + +from physicsnemo.utils import load_checkpoint, save_checkpoint +from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils.profiling import profile, Profiler + +from physicsnemo import datapipes # noqa: F401 - registers ${dp:...} resolver +from physicsnemo.datapipes import DataLoader + +from datasets import ( + build_dataset, + load_dataset_config, + load_manifest, + resolve_manifest_indices, + ManifestSampler, +) +from collate import build_collate_fn +from metrics import MetricCalculator +from loss import LossCalculator +from utils import build_muon_optimizer, set_seed + +from physicsnemo.core.version_check import OptionalImport + +te = OptionalImport("transformer_engine.pytorch") +te_recipe = OptionalImport("transformer_engine.common.recipe") +TE_AVAILABLE = te.available + + +def _flatten_config(d: dict, parent: str = "", sep: str = ".") -> dict[str, str]: + """Recursively flatten a nested dict into dot-separated key/value pairs. + + MLflow has a 500-param limit; values are stringified and truncated + to 500 chars to stay within MLflow's constraints. + """ + items: dict[str, str] = {} + for k, v in d.items(): + key = f"{parent}{sep}{k}" if parent else k + if isinstance(v, dict): + items.update(_flatten_config(v, key, sep)) + elif isinstance(v, (list, tuple)): + items[key] = str(v)[:500] + else: + items[key] = str(v)[:500] + return items + + +def get_autocast_context(precision: str): + """Return an autocast context manager for the given precision. + + Parameters + ---------- + precision : str + One of ``"float16"``, ``"bfloat16"``, ``"float8"``, or ``"float32"``. + For ``"float8"``, Transformer Engine must be available. + + Returns + ------- + contextlib.AbstractContextManager + An autocast context manager for the requested precision, or a + no-op ``nullcontext`` when no casting is needed. + """ + if precision == "float16": + return autocast("cuda", dtype=torch.float16) + elif precision == "bfloat16": + return autocast("cuda", dtype=torch.bfloat16) + elif precision == "float8" and TE_AVAILABLE: + fp8_format = te_recipe.Format.HYBRID + fp8_recipe = te_recipe.DelayedScaling( + fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max" + ) + return te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) + else: + return nullcontext() + + +def _recursive_to(obj, *args, **kwargs): + """Apply ``.to()`` recursively through nested dicts/lists of tensors.""" + if isinstance(obj, torch.Tensor): + return obj.to(*args, **kwargs) + if isinstance(obj, dict): + return {k: _recursive_to(v, *args, **kwargs) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(_recursive_to(v, *args, **kwargs) for v in obj) + return obj + + +def forward_pass( + batch: dict[str, torch.Tensor], + model: torch.nn.Module, + precision: str, + loss_calculator: LossCalculator, + metric_calculator: MetricCalculator, + *, + broadcast_global: bool = False, +) -> tuple[torch.Tensor, dict[str, float], tuple]: + """Run forward pass, compute loss and metrics. + + Parameters + ---------- + batch : dict + Model-ready batch produced by the collate function. Must contain + a ``"fields"`` key holding the prediction targets (popped before + the forward call). Values may be tensors or nested dicts of + tensors (e.g. DoMINO's ``data_dict``). + model : torch.nn.Module + Point-cloud model whose ``forward`` accepts the remaining batch + keys as keyword arguments. + precision : str + One of "float32", "float16", "bfloat16", "float8". + loss_calculator : LossCalculator + metric_calculator : MetricCalculator + broadcast_global : bool, default False + When ``True``, any tensor with spatial dimension 1 (e.g. global + features shaped ``(B, 1, C)``) is expanded to match the largest + spatial dimension in the batch. Required for Transolver, whose + ``forward`` concatenates ``[embedding, fx]`` along the last dim + and therefore needs matching spatial sizes. + + Returns + ------- + loss, metrics_dict, (outputs, targets) + """ + targets = batch.pop("fields") + + if broadcast_global: + max_n = max( + v.shape[1] + for v in batch.values() + if isinstance(v, torch.Tensor) and v.ndim >= 3 + ) + batch = { + k: v.expand(-1, max_n, -1) + if isinstance(v, torch.Tensor) and v.ndim >= 3 and v.shape[1] == 1 + else v + for k, v in batch.items() + } + + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16} + dtype = dtype_map.get(precision) + if dtype is not None: + batch = {k: _recursive_to(v, dtype) for k, v in batch.items()} + + with get_autocast_context(precision): + outputs = model(**batch) + + # Models like DoMINO return (vol_output, surf_output); extract the + # non-None element for single-mode training. + if isinstance(outputs, tuple): + outputs = next(o for o in outputs if o is not None) + + loss, loss_dict = loss_calculator(outputs, targets) + + metrics = {k: v.item() for k, v in loss_dict.items()} + with torch.no_grad(): + metrics.update(metric_calculator(outputs, targets)) + + return loss, metrics, (outputs, targets) + + +@profile +def train_epoch( + dataloader, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler, + loss_calculator: LossCalculator, + metric_calculator: MetricCalculator, + logger, + epoch: int, + cfg: DictConfig, + dist_manager: DistributedManager, + scaler: GradScaler | None = None, + broadcast_global: bool = False, + log_every_n_steps: int = 10, +) -> tuple[float, dict[str, float]]: + """Run one training epoch over the dataloader. + + Iterates through all batches, computes forward pass, back-propagates + gradients, and logs per-step and per-epoch statistics to MLflow. + + Parameters + ---------- + dataloader : DataLoader + Training dataloader yielding ``dict[str, Tensor]`` batches. + model : torch.nn.Module + The model to train (already on ``dist_manager.device``). + optimizer : torch.optim.Optimizer + Optimizer instance. + scheduler : torch.optim.lr_scheduler._LRScheduler + Learning-rate scheduler. Updated per step or per epoch depending + on ``cfg.training.scheduler_update_mode``. + loss_calculator : LossCalculator + Computes the training loss from model outputs and targets. + metric_calculator : MetricCalculator + Computes evaluation metrics (L1, L2, MAE, etc.). + logger : RankZeroLoggingWrapper + Logger for console output. + epoch : int + Current epoch index (0-based). + cfg : DictConfig + Full Hydra config; uses ``cfg.profile`` and ``cfg.training``. + dist_manager : DistributedManager + Distributed training manager. + scaler : torch.amp.GradScaler or None, optional + Gradient scaler for mixed-precision (float16) training. + broadcast_global : bool, default False + Expand global (B,1,C) tensors to match the spatial dimension + of other batch tensors before forwarding. + log_every_n_steps : int, default 10 + How often to log per-step metrics to MLflow. + + Returns + ------- + avg_loss : float + Mean training loss over all batches. + avg_metrics : dict[str, float] + Mean per-metric values over all batches. + """ + model.train() + total_loss = 0.0 + total_metrics: dict[str, float] = {} + precision = getattr(cfg, "precision", "float32") + n_batches = 0 + num_steps = len(dataloader) + epoch_t0 = time.perf_counter() + + step_t0 = time.perf_counter() + for i, batch in enumerate(dataloader): + batch = {k: _recursive_to(v, dist_manager.device) for k, v in batch.items()} + + loss, metrics, (outputs, targets) = forward_pass( + batch, + model, + precision, + loss_calculator, + metric_calculator, + broadcast_global=broadcast_global, + ) + + optimizer.zero_grad() + if precision == "float16" and scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + if cfg.training.get("scheduler_update_mode", "epoch") == "step": + scheduler.step() + + this_loss = loss.detach().item() + total_loss += this_loss + n_batches += 1 + + for k, v in metrics.items(): + total_metrics[k] = total_metrics.get(k, 0.0) + ( + v if isinstance(v, float) else v.item() + ) + + step_dt = time.perf_counter() - step_t0 + + mem_gb = ( + torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0 + ) + logger.info( + f"Epoch {epoch} [{i + 1}/{num_steps}] " + f"Loss: {this_loss:.6f} " + f"Step: {step_dt:.3f}s " + f"Mem: {mem_gb:.2f}GB" + ) + + # Per-step MLflow logging at configured frequency + global_step = epoch * num_steps + i + if dist_manager.rank == 0 and (i + 1) % log_every_n_steps == 0: + step_metrics = { + "step/train_loss": this_loss, + "step/mem_gb": mem_gb, + "step/step_time_s": step_dt, + } + step_metrics.update({f"step/train_{k}": v for k, v in metrics.items()}) + mlflow.log_metrics(step_metrics, step=global_step) + + if cfg.profile and i >= 10: + break + step_t0 = time.perf_counter() + + epoch_dt = time.perf_counter() - epoch_t0 + avg_loss = total_loss / max(n_batches, 1) + avg_metrics = {k: v / max(n_batches, 1) for k, v in total_metrics.items()} + + logger.info( + f"Epoch {epoch} train done in {epoch_dt:.1f}s " + f"({n_batches} steps, {epoch_dt / max(n_batches, 1):.3f}s/step avg)" + ) + + if dist_manager.rank == 0: + epoch_metrics = {"train/loss": avg_loss} + epoch_metrics.update({f"train/{k}": v for k, v in avg_metrics.items()}) + mlflow.log_metrics(epoch_metrics, step=epoch) + + return avg_loss, avg_metrics + + +@profile +def val_epoch( + dataloader, + model: torch.nn.Module, + loss_calculator: LossCalculator, + metric_calculator: MetricCalculator, + logger, + epoch: int, + cfg: DictConfig, + dist_manager: DistributedManager, + broadcast_global: bool = False, +) -> tuple[float, dict[str, float]]: + """Run one validation epoch. + + Parameters + ---------- + dataloader : DataLoader + Validation dataloader yielding ``dict[str, Tensor]`` batches. + model : torch.nn.Module + The model to evaluate (already on ``dist_manager.device``). + loss_calculator : LossCalculator + Computes the validation loss. + metric_calculator : MetricCalculator + Computes normalised-space metrics. + logger : RankZeroLoggingWrapper + Logger for console output. + epoch : int + Current epoch index (0-based). + cfg : DictConfig + Full Hydra config; uses ``cfg.profile`` and ``cfg.precision``. + dist_manager : DistributedManager + Distributed training manager. + broadcast_global : bool, default False + Expand global (B,1,C) tensors to match the spatial dimension + of other batch tensors before forwarding. + + Returns + ------- + avg_loss : float + Mean validation loss over all batches. + avg_metrics : dict[str, float] + Mean normalised-space metrics. + """ + model.eval() + total_loss = 0.0 + total_metrics: dict[str, float] = {} + precision = getattr(cfg, "precision", "float32") + n_batches = 0 + num_steps = len(dataloader) + epoch_t0 = time.perf_counter() + + with torch.no_grad(): + step_t0 = time.perf_counter() + for i, batch in enumerate(dataloader): + batch = {k: _recursive_to(v, dist_manager.device) for k, v in batch.items()} + + loss, metrics, _ = forward_pass( + batch, + model, + precision, + loss_calculator, + metric_calculator, + broadcast_global=broadcast_global, + ) + + step_dt = time.perf_counter() - step_t0 + total_loss += loss.item() + n_batches += 1 + for k, v in metrics.items(): + total_metrics[k] = total_metrics.get(k, 0.0) + ( + v if isinstance(v, float) else v.item() + ) + + logger.info( + f"Val Epoch {epoch} [{i + 1}/{num_steps}] " + f"Loss: {loss.item():.6f} " + f"Step: {step_dt:.3f}s" + ) + + if cfg.profile and i >= 10: + break + step_t0 = time.perf_counter() + + epoch_dt = time.perf_counter() - epoch_t0 + avg_loss = total_loss / max(n_batches, 1) + avg_metrics = {k: v / max(n_batches, 1) for k, v in total_metrics.items()} + + logger.info( + f"Epoch {epoch} val done in {epoch_dt:.1f}s " + f"({n_batches} steps, {epoch_dt / max(n_batches, 1):.3f}s/step avg)" + ) + + if dist_manager.rank == 0: + val_metrics_log = {"val/loss": avg_loss} + val_metrics_log.update({f"val/{k}": v for k, v in avg_metrics.items()}) + mlflow.log_metrics(val_metrics_log, step=epoch) + + return avg_loss, avg_metrics + + +@profile +def benchmark_io_epoch( + dataloader, + label: str, + logger, + max_steps: int | None = None, +) -> None: + """Iterate a dataloader without any model logic and report I/O timing. + + Parameters + ---------- + dataloader : DataLoader + Dataloader to benchmark. + label : str + Human-readable label for logging (e.g. ``"train"`` or ``"val"``). + logger : RankZeroLoggingWrapper + Logger for console output. + max_steps : int or None, optional + Stop after this many batches. ``None`` means exhaust the loader. + """ + import statistics + + num_steps = len(dataloader) + times: list[float] = [] + + step_t0 = time.perf_counter() + for i, batch in enumerate(dataloader): + dt = time.perf_counter() - step_t0 + times.append(dt) + + mem_gb = ( + torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0 + ) + shapes = " ".join(f"{k}:{tuple(v.shape)}" for k, v in batch.items()) + logger.info( + f" [{label}] [{i + 1}/{num_steps}] " + f"dt={dt:.4f}s Mem={mem_gb:.2f}GB {shapes}" + ) + for k, v in batch.items(): + v_flat = v.float() + logger.info( + f" {k:20s} " + f"min={v_flat.min().item(): .6e} " + f"mean={v_flat.mean().item(): .6e} " + f"std={v_flat.std().item(): .6e} " + f"max={v_flat.max().item(): .6e}" + ) + + if max_steps is not None and i + 1 >= max_steps: + break + step_t0 = time.perf_counter() + + if not times: + logger.info(f" [{label}] empty dataloader") + return + + total = sum(times) + mean = statistics.mean(times) + med = statistics.median(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + p95 = sorted(times)[int(len(times) * 0.95)] if len(times) > 1 else times[0] + + logger.info( + f" [{label}] {len(times)} batches in {total:.2f}s " + f"mean={mean:.4f}s median={med:.4f}s std={std:.4f}s p95={p95:.4f}s " + f"throughput={len(times) / total:.2f} batches/sec" + ) + + +def _extract_pipeline_transforms(datasets: list) -> tuple: + """Find NormalizeMeshFields and NonDimensionalizeByMetadata in transform chains. + + Returns (normalizer, nondim) instances from the first dataset that has them, + or (None, None) if not found. + """ + from physicsnemo.datapipes.transforms.mesh import NormalizeMeshFields + from nondim import NonDimensionalizeByMetadata + + normalizer = None + nondim = None + for ds in datasets: + for t in getattr(ds, "transforms", []): + if isinstance(t, NormalizeMeshFields) and normalizer is None: + normalizer = t + if isinstance(t, NonDimensionalizeByMetadata) and nondim is None: + nondim = t + return normalizer, nondim + + +def build_dataloaders(cfg: DictConfig): + """Build train and val dataloaders from dataset configs. + + Supports two split strategies: + + **Directory-based** (existing): separate ``train_datadir`` and + ``val_datadir`` in the dataset YAML. Each split gets its own reader + and dataset. + + **Manifest-based** (new): a single ``datadir`` in the dataset YAML + with ``train_manifest`` and ``val_manifest`` in the training config's + ``data.`` block. One reader/dataset covers the full directory; + ``ManifestSampler`` restricts each loader to the correct subset of + indices. + """ + recipe_root = Path(__file__).resolve().parent.parent + batch_size = cfg.training.get("batch_size", 1) + sampling_resolution = cfg.dataset.get("sampling_resolution", None) + augment = cfg.get("augment", False) + dist_manager = DistributedManager() + use_distributed = dist_manager.world_size > 1 + collate_fn = build_collate_fn( + cfg.get("data_mapping", "geotransolver_automotive_surface") + ) + + # DataLoader / MeshDataset performance tuning from cfg.dataloader + dl_cfg = cfg.get("dataloader", {}) + prefetch_factor = dl_cfg.get("prefetch_factor", 2) + num_streams = dl_cfg.get("num_streams", 4) + use_streams = dl_cfg.get("use_streams", False) + num_workers = dl_cfg.get("num_workers", 1) + pin_memory = dl_cfg.get("pin_memory", False) + device = "cuda" if torch.cuda.is_available() else "cpu" + sampler_seed = cfg.training.get("seed", 0) or 0 + + train_datasets = [] + val_datasets = [] + # When using manifest-based splits, we collect indices per dataset + # and build samplers instead of separate datasets. + manifest_train_indices: list[int] | None = None + manifest_val_indices: list[int] | None = None + using_manifests = False + first_metadata = None + + for ds_key in cfg.data: + ds_cfg_block = cfg.data[ds_key] + config_path = recipe_root / ds_cfg_block.config + if not config_path.exists(): + continue + train_dir = ds_cfg_block.get("train_dir", "") + if train_dir and not Path(train_dir).exists(): + continue + ds_yaml = load_dataset_config(config_path) + if sampling_resolution is not None: + ds_yaml = OmegaConf.merge( + ds_yaml, {"sampling_resolution": sampling_resolution} + ) + if first_metadata is None: + first_metadata = OmegaConf.to_container( + OmegaConf.select(ds_yaml, "metadata", default=OmegaConf.create({})), + resolve=True, + ) + + # --- Manifest-based splits --- + # Two config styles are supported: + # + # Style A (separate files): + # train_manifest: /path/to/train_runs.txt + # val_manifest: /path/to/val_runs.txt + # + # Style B (single dict manifest with split keys): + # manifest: /path/to/manifest.json + # train_split: single_aoa_4_train + # val_split: single_aoa_4_val + # + # Both styles accept an optional ``datadir`` to override + # ``train_datadir`` in the dataset YAML with the root directory + # containing all runs. + train_manifest = ds_cfg_block.get("train_manifest", None) + val_manifest = ds_cfg_block.get("val_manifest", None) + manifest = ds_cfg_block.get("manifest", None) + train_split = ds_cfg_block.get("train_split", None) + val_split = ds_cfg_block.get("val_split", None) + + has_manifest = train_manifest is not None or ( + manifest is not None and train_split is not None + ) + + if has_manifest: + using_manifests = True + # When using manifests, the reader must see ALL runs under one + # root. The config block can provide ``datadir`` to override the + # dataset YAML's ``train_datadir`` with the parent directory that + # contains every run (train + val). + datadir = ds_cfg_block.get("datadir", None) + if datadir: + ds_yaml = OmegaConf.merge(ds_yaml, {"train_datadir": datadir}) + dataset = build_dataset( + ds_yaml, + augment=augment, + device=device, + num_workers=num_workers, + pin_memory=pin_memory, + ) + train_datasets.append(dataset) + + # Resolve train indices + if train_manifest is not None: + train_entries = load_manifest(train_manifest) + else: + train_entries = load_manifest(manifest, split=train_split) + manifest_train_indices = resolve_manifest_indices( + dataset.reader, train_entries + ) + + # Resolve val indices + if val_manifest is not None: + val_entries = load_manifest(val_manifest) + manifest_val_indices = resolve_manifest_indices( + dataset.reader, val_entries + ) + elif val_split is not None: + val_entries = load_manifest(manifest, split=val_split) + manifest_val_indices = resolve_manifest_indices( + dataset.reader, val_entries + ) + continue + + # --- Directory-based splits (existing path) --- + train_datasets.append( + build_dataset( + ds_yaml, + augment=augment, + device=device, + num_workers=num_workers, + pin_memory=pin_memory, + ) + ) + + val_datadir = OmegaConf.select(ds_yaml, "val_datadir", default=None) + if val_datadir and Path(val_datadir).exists(): + val_yaml = OmegaConf.merge(ds_yaml, {"train_datadir": val_datadir}) + val_datasets.append( + build_dataset( + val_yaml, + augment=False, + device=device, + num_workers=num_workers, + pin_memory=pin_memory, + ) + ) + + if not train_datasets: + raise RuntimeError("No valid datasets found. Check data paths in config.") + + normalizer, nondim_transform = _extract_pipeline_transforms(train_datasets) + + if len(train_datasets) == 1: + train_dataset = train_datasets[0] + else: + from physicsnemo.datapipes import MultiDataset + + train_dataset = MultiDataset(*train_datasets, output_strict=False) + + if using_manifests: + # Manifest path: single dataset, split via samplers + rank = dist_manager.rank if use_distributed else 0 + world_size = dist_manager.world_size if use_distributed else 1 + + train_sampler = ManifestSampler( + manifest_train_indices, + shuffle=True, + seed=sampler_seed, + rank=rank, + world_size=world_size, + drop_last=True, + ) + if manifest_val_indices is not None: + val_sampler = ManifestSampler( + manifest_val_indices, + shuffle=False, + seed=sampler_seed, + rank=rank, + world_size=world_size, + drop_last=False, + ) + else: + val_sampler = train_sampler + val_dataset = train_dataset + else: + # Directory-based path: separate datasets per split + if val_datasets: + if len(val_datasets) == 1: + val_dataset = val_datasets[0] + else: + from physicsnemo.datapipes import MultiDataset + + val_dataset = MultiDataset(*val_datasets, output_strict=False) + else: + val_dataset = train_dataset + + train_sampler = None + val_sampler = None + if use_distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + shuffle=True, + drop_last=True, + seed=sampler_seed, + ) + val_sampler = torch.utils.data.distributed.DistributedSampler( + val_dataset, + shuffle=False, + drop_last=False, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + collate_fn=collate_fn, + drop_last=True, + prefetch_factor=prefetch_factor, + num_streams=num_streams, + use_streams=use_streams, + seed=sampler_seed, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + sampler=val_sampler, + collate_fn=collate_fn, + drop_last=False, + prefetch_factor=prefetch_factor, + num_streams=num_streams, + use_streams=use_streams, + seed=sampler_seed, + ) + + return train_loader, val_loader, normalizer, nondim_transform, first_metadata or {} + + +_NONDIM_TYPE_MAP = {"scalar": "pressure", "vector": "stress"} + + +def _to_physical( + tensor: torch.Tensor, + target_config: dict[str, str], + normalizer, + nondim_transform, + metadata: dict, + nondim_type_overrides: dict[str, str] | None = None, +) -> torch.Tensor: + """Convert a model-space tensor (normalized + non-dim) back to physical units. + + Chains two inverse operations using the existing transform instances: + 1. ``NormalizeMeshFields.inverse_tensor`` -- undo z-score normalization + 2. ``NonDimensionalizeByMetadata.inverse_tensor`` -- undo non-dimensionalization + + Parameters + ---------- + nondim_type_overrides : dict or None + Optional per-field mapping of ``{field_name: nondim_type}`` (e.g. + ``{"temperature": "temperature", "density": "density"}``). When + provided, overrides the default ``_NONDIM_TYPE_MAP`` lookup for + fields that don't follow the simple scalar=pressure / vector=stress + convention. + """ + if not metadata: + return tensor + + out = tensor + device, dtype = tensor.device, tensor.dtype + + # Step 1: undo z-score normalization + if normalizer is not None: + out = normalizer.inverse_tensor(out, target_config) + + # Step 2: undo non-dimensionalization + if nondim_transform is not None: + overrides = nondim_type_overrides or {} + nondim_fields = { + name: overrides.get(name, _NONDIM_TYPE_MAP.get(ftype, ftype)) + for name, ftype in target_config.items() + } + U_inf = torch.tensor(metadata["U_inf"], dtype=dtype, device=device) + rho_inf = torch.tensor(metadata["rho_inf"], dtype=dtype, device=device) + p_inf = torch.tensor(metadata["p_inf"], dtype=dtype, device=device) + q_inf = 0.5 * rho_inf * (U_inf * U_inf).sum() + U_inf_mag = (U_inf * U_inf).sum().sqrt() + + T_inf = None + if "T_inf" in metadata: + T_inf = torch.tensor(metadata["T_inf"], dtype=dtype, device=device) + + out = nondim_transform.inverse_tensor( + out, + nondim_fields, + q_inf, + p_inf, + U_inf_mag, + rho_inf=rho_inf, + T_inf=T_inf, + ) + + return out + + +@profile +def main(cfg: DictConfig): + """Run the full training loop, or I/O-only benchmark when ``benchmark_io=true``. + + Orchestrates the complete training workflow: + + 1. Initialise distributed training and MLflow experiment tracking. + 2. Build train/val dataloaders and extract pipeline transforms. + 3. If ``cfg.benchmark_io`` is true, iterate dataloaders to measure + I/O throughput and return early (no model, no optimizer). + 4. Otherwise, instantiate the model, optimizer, and run the normal + train/val epoch loop with checkpointing. + + Parameters + ---------- + cfg : DictConfig + Hydra config containing ``model``, ``training``, ``dataset``, + ``data``, ``output_dir``, ``run_id``, ``precision``, ``compile``, + ``profile``, ``benchmark_io``, ``mlflow``, and related keys. + """ + DistributedManager.initialize() + dist_manager = DistributedManager() + logger = RankZeroLoggingWrapper(PythonLogger(name="training"), dist_manager) + + seed = cfg.training.get("seed", None) + set_seed(seed, rank=dist_manager.rank) + logger.info(f"Random seed: {seed} (rank offset: {dist_manager.rank})") + + checkpoint_dir = getattr(cfg, "checkpoint_dir", None) or cfg.output_dir + + # -- MLflow setup (rank 0 only) --------------------------------------------- + mlflow_cfg = cfg.get("mlflow", {}) + log_every_n_steps = mlflow_cfg.get("log_every_n_steps", 10) if mlflow_cfg else 10 + if dist_manager.rank == 0: + os.makedirs(cfg.output_dir, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + + tracking_uri = mlflow_cfg.get("tracking_uri") if mlflow_cfg else None + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + # When tracking_uri is null/omitted, MLflow defaults to local ./mlruns + experiment_name = ( + mlflow_cfg.get("experiment_name", "unified_external_aero") + if mlflow_cfg + else "unified_external_aero" + ) + mlflow.set_experiment(experiment_name) + run_name = (mlflow_cfg.get("run_name") if mlflow_cfg else None) or cfg.run_id + mlflow.start_run(run_name=run_name) + for k, v in (mlflow_cfg.get("tags") or {}).items(): + mlflow.set_tag(k, v) + + logger.info(f"Config:\n{omegaconf.OmegaConf.to_yaml(cfg, resolve=True)}") + + train_loader, val_loader, normalizer, _, ds_metadata = build_dataloaders(cfg) + logger.info(f"Train samples: {len(train_loader.sampler)}") + logger.info(f"Val samples: {len(val_loader.sampler)}") + + # -- Log dataset metadata to MLflow (rank 0) -------------------------------- + recipe_root = Path(__file__).resolve().parent.parent + if dist_manager.rank == 0: + mlflow.log_params( + { + "train_samples": len(train_loader.dataset), + "val_samples": len(val_loader.dataset), + } + ) + for ds_key, ds_block in cfg.data.items(): + mlflow.set_tag(f"dataset/{ds_key}/config", ds_block.config) + ds_config_path = recipe_root / ds_block.config + if ds_config_path.exists(): + mlflow.log_artifact( + str(ds_config_path), artifact_path="dataset_configs" + ) + if ds_metadata: + for mk, mv in ds_metadata.items(): + mlflow.log_param(f"metadata.{mk}", str(mv)[:500]) + + # -- I/O benchmark mode: iterate dataloaders, skip model entirely ----------- + if cfg.get("benchmark_io", False): + num_epochs = cfg.training.num_epochs + max_steps = cfg.training.get("benchmark_max_steps", None) + logger.info( + f"benchmark_io=True — benchmarking dataloader I/O only " + f"({num_epochs} epoch(s), max_steps={max_steps})" + ) + with torch.no_grad(), Profiler(): + for epoch in range(num_epochs): + logger.info(f"--- Epoch {epoch + 1}/{num_epochs} ---") + train_loader.set_epoch(epoch) + benchmark_io_epoch(train_loader, "train", logger, max_steps=max_steps) + benchmark_io_epoch(val_loader, "val", logger, max_steps=max_steps) + logger.info("benchmark_io complete!") + if dist_manager.rank == 0: + mlflow.end_run() + return + + # -- Normal training path --------------------------------------------------- + model = hydra.utils.instantiate(cfg.model, _convert_="partial") + logger.info(f"Model: {model.__class__.__name__}") + num_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Parameters: {num_params:,}") + + model.to(dist_manager.device) + + if dist_manager.world_size > 1: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_manager.local_rank], + output_device=dist_manager.device, + ) + + if normalizer is not None: + logger.info( + f"Normalization: {', '.join(f'{k}({v["type"]})' for k, v in normalizer.stats.items())}" + ) + + optimizer = build_muon_optimizer(model, cfg, compile_optimizer=cfg.compile) + logger.info(f"Optimizer: {optimizer}") + scheduler = hydra.utils.instantiate(cfg.training.scheduler, optimizer=optimizer) + + precision = cfg.precision + scaler = GradScaler() if precision == "float16" else None + + # -- Log full config + model params to MLflow (rank 0) ---------------------- + if dist_manager.rank == 0: + flat_cfg = _flatten_config( + OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False) + ) + mlflow.log_params(flat_cfg) + mlflow.log_param("num_parameters", num_params) + mlflow.set_tag("model", model.__class__.__name__) + + # Save the full resolved config as an artifact + resolved_yaml = omegaconf.OmegaConf.to_yaml(cfg, resolve=True) + config_artifact_path = os.path.join( + cfg.output_dir, cfg.run_id, "resolved_config.yaml" + ) + os.makedirs(os.path.dirname(config_artifact_path), exist_ok=True) + with open(config_artifact_path, "w") as f: + f.write(resolved_yaml) + mlflow.log_artifact(config_artifact_path) + + ds_cfg = cfg.dataset + targets = omegaconf.OmegaConf.to_container(ds_cfg.targets, resolve=True) + metrics_list = omegaconf.OmegaConf.to_container( + ds_cfg.get("metrics", ["l1", "l2", "mae"]), resolve=True + ) + metric_calculator = MetricCalculator( + target_config=targets, + metrics=metrics_list, + ) + loss_calculator = LossCalculator( + target_config=targets, + loss_type=cfg.training.get("loss_type", "huber"), + ) + broadcast_global = cfg.get("broadcast_global", False) + logger.info(f"Loss: {loss_calculator}") + logger.info(f"Metrics: {metric_calculator}") + + ckpt_args = { + "path": os.path.join(checkpoint_dir, cfg.run_id, "checkpoints"), + "optimizer": optimizer, + "scheduler": scheduler, + "models": model, + } + loaded_epoch = load_checkpoint(device=dist_manager.device, **ckpt_args) + + if cfg.compile: + model = torch.compile(model) + + num_epochs = cfg.training.num_epochs + logger.info(f"Starting training for {num_epochs} epochs...") + + # Unless profiling is enabled, this is a null context: + with Profiler(): + for epoch in range(loaded_epoch, num_epochs): + logger.info(f"--- Epoch {epoch + 1}/{num_epochs} ---") + train_loader.set_epoch(epoch) + + train_loss, train_metrics = train_epoch( + train_loader, + model, + optimizer, + scheduler, + loss_calculator, + metric_calculator, + logger, + epoch, + cfg, + dist_manager, + scaler, + broadcast_global=broadcast_global, + log_every_n_steps=log_every_n_steps, + ) + + val_loss, val_metrics = val_epoch( + val_loader, + model, + loss_calculator, + metric_calculator, + logger, + epoch, + cfg, + dist_manager, + broadcast_global=broadcast_global, + ) + + # Log learning rate per epoch + if dist_manager.rank == 0: + current_lr = scheduler.get_last_lr()[0] + mlflow.log_metric("lr", current_lr, step=epoch) + + if dist_manager.rank == 0: + all_keys = list(dict.fromkeys(list(train_metrics) + list(val_metrics))) + + rows = [ + [ + k, + f"{train_metrics.get(k, float('nan')):.6f}", + f"{val_metrics.get(k, float('nan')):.6f}", + ] + for k in all_keys + ] + + table = tabulate( + rows, headers=["Metric", "Train", "Val"], tablefmt="pretty" + ) + logger.info( + f"\nEpoch [{epoch}/{cfg.training.num_epochs}] " + f"Train Loss: {train_loss:.6f} Val Loss: {val_loss:.6f}\n" + f"{table}\n" + ) + + if epoch % cfg.training.save_interval == 0 and dist_manager.rank == 0: + save_checkpoint(**ckpt_args, epoch=epoch + 1) + if normalizer is not None: + norm_path = os.path.join(ckpt_args["path"], "norm_stats.pt") + torch.save(normalizer.stats, norm_path) + mlflow.log_artifacts(ckpt_args["path"], artifact_path="checkpoints") + + if cfg.training.get("scheduler_update_mode", "epoch") == "epoch": + scheduler.step() + + if dist_manager.rank == 0: + mlflow.end_run() + + logger.info("Training completed!") + + +@hydra.main( + version_base=None, + config_path="../conf", + config_name="train_geotransolver_automotive_surface", +) +def launch(cfg: DictConfig): + """Hydra entry point: configure profiling and delegate to :func:`main`. + + Parameters + ---------- + cfg : DictConfig + Hydra-composed config (override with ``--config-name``). + When ``cfg.profile`` is truthy, torch profiling is enabled. + """ + profiler = Profiler() + if cfg.profile: + profiler.enable("torch") + profiler.initialize() + main(cfg) + profiler.finalize() + + +if __name__ == "__main__": + launch() diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py new file mode 100644 index 0000000000..d5e57e7f77 --- /dev/null +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for the unified training recipe.""" + +from __future__ import annotations + +import random +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch + +from omegaconf import DictConfig +from physicsnemo.optim import CombinedOptimizer + + +def set_seed(seed: int | None, rank: int = 0) -> None: + """Pin all RNG states for reproducible training. + + When *seed* is not None, seeds Python, NumPy, and PyTorch (CPU + all + CUDA devices) with ``seed + rank`` so that different ranks diverge + deterministically. When *seed* is None this function is a no-op, + preserving the current (non-deterministic) behaviour. + """ + if seed is None: + return + seed = seed + rank + random.seed(seed) + np.random.seed(seed % (1 << 31)) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def build_muon_optimizer( + model: torch.nn.Module, cfg: DictConfig, *, compile_optimizer: bool = False +) -> torch.optim.Optimizer: + """Build Muon + AdamW combined optimizer. + + Muon handles 2-D parameters (linear/attention weight matrices) while AdamW + handles everything else (biases, layer-norm, embeddings, etc.). + + Parameters + ---------- + model : torch.nn.Module + The model (may be DDP-wrapped). + cfg : DictConfig + Full Hydra config. Reads ``cfg.training.optimizer.*`` for lr, + weight_decay, betas, and eps. + compile_optimizer : bool + If True, compile the optimizer step functions with ``torch.compile``. + """ + base_model = model.module if hasattr(model, "module") else model + muon_params = [p for p in base_model.parameters() if p.ndim == 2] + other_params = [p for p in base_model.parameters() if p.ndim != 2] + + opt_cfg = cfg.training.optimizer + lr = opt_cfg.lr + weight_decay = opt_cfg.get("weight_decay", 1e-4) + betas = tuple(opt_cfg.get("betas", [0.9, 0.999])) + eps = opt_cfg.get("eps", 1e-8) + + compile_kwargs = {} if compile_optimizer else None + + if muon_params and other_params: + return CombinedOptimizer( + [ + torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ), + torch.optim.AdamW( + other_params, + lr=lr, + weight_decay=weight_decay, + betas=betas, + eps=eps, + ), + ], + torch_compile_kwargs=compile_kwargs, + ) + elif muon_params: + opt = torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ) + if compile_optimizer: + opt.step = torch.compile(opt.step) + return opt + else: + opt = torch.optim.AdamW( + other_params, lr=lr, weight_decay=weight_decay, betas=betas, eps=eps + ) + if compile_optimizer: + opt.step = torch.compile(opt.step) + return opt + + +# --------------------------------------------------------------------------- +# Field specification for target configurations +# --------------------------------------------------------------------------- + + +@dataclass +class FieldSpec: + """Specification for a single target field. + + Attributes: + name: Human-readable name for the field (used in metric/loss keys). + field_type: Either "scalar" or "vector". + start_index: Starting index in the channel dimension. + end_index: Ending index (exclusive) in the channel dimension. + """ + + name: str + field_type: Literal["scalar", "vector"] + start_index: int + end_index: int + + @property + def dim(self) -> int: + """Number of channels for this field.""" + return self.end_index - self.start_index + + +def parse_target_config( + target_config: dict[str, str], n_spatial_dims: int = 3 +) -> list[FieldSpec]: + """Parse target configuration to field specifications. + + Args: + target_config: Mapping of field names to types ("scalar" or "vector"). + Order determines channel indices. + n_spatial_dims: Dimensionality of vector fields. Default is 3. + + Returns: + List of FieldSpec objects describing each field. + + Raises: + ValueError: If an unknown field type is specified. + + Example: + >>> config = {"pressure": "scalar", "velocity": "vector"} + >>> specs = parse_target_config(config) + >>> specs[0] + FieldSpec(name='pressure', field_type='scalar', start_index=0, end_index=1) + >>> specs[1] + FieldSpec(name='velocity', field_type='vector', start_index=1, end_index=4) + """ + specs = [] + current_index = 0 + + for name, field_type in target_config.items(): + field_type = field_type.lower() + if field_type == "scalar": + dim = 1 + elif field_type == "vector": + dim = n_spatial_dims + else: + raise ValueError( + f"Unknown field type '{field_type}' for field '{name}'. " + "Expected 'scalar' or 'vector'." + ) + + specs.append( + FieldSpec( + name=name, + field_type=field_type, + start_index=current_index, + end_index=current_index + dim, + ) + ) + current_index += dim + + return specs diff --git a/physicsnemo/datapipes/RNG.md b/physicsnemo/datapipes/RNG.md new file mode 100644 index 0000000000..f390d544de --- /dev/null +++ b/physicsnemo/datapipes/RNG.md @@ -0,0 +1,174 @@ +# Datapipe RNG & Reproducibility + +Deterministic data loading is opt-in: pass `seed=` to `DataLoader` and the +entire pipeline — sampler, reader, and every stochastic transform — becomes +reproducible across runs. + +If no seed is passed, all random operations will fall back to their default +behavior - which may still be deterministic if you have set seeds carefully +in pytorch, and executed operations carefully. In short, using a `seed` in the +`DataLoader` will deploy `torch.Generator` objects at all object-level random +calls, making each object sequentially deterministic. Your whole pipeline +becomes reproducible. Not using a seed means you rely on globally set behavior. + +## Quick start + +```python +loader = DataLoader(dataset, batch_size=16, shuffle=True, seed=42) + +for epoch in range(n_epochs): + loader.set_epoch(epoch) # vary randomness per epoch, still deterministic + for batch in loader: + ... +``` + +## How it works + +### Generator forking (`_rng.py`) + +The system derives independent `torch.Generator` streams from a single +master seed using `fork_generator(parent, n)`. Each child is seeded with +`parent.initial_seed() + i + 1`, so children are independent of each other +and stable across runs. Children are created on the **same device** as the +parent. + +### DataLoader + +When `seed` is set the DataLoader: + +1. Creates a CPU master generator: `torch.Generator().manual_seed(seed)`. +2. Forks it into **2 children**: + - **Child 0 → sampler** — passed to `RandomSampler(generator=...)`. + - **Child 1 → dataset** — passed via `dataset.set_generator(...)`. + +### Dataset (TensorDict path) + +`Dataset.set_generator(generator)` flattens its transform pipeline +(unwrapping `Compose` if present) and forks into +`1 + len(flat_transforms)` children: + +- **Child 0 → reader** — passed via `reader.set_generator(...)`. +- **Children 1..N → transforms** (1-to-1 mapping; deterministic transforms + silently ignore theirs). + +If the dataset's `target_device` differs from the child generator's +device, a new generator is created on `target_device` and seeded from +the child's `initial_seed()`. + +### MeshDataset + +`MeshDataset.set_generator(generator)` follows the same pattern as +`Dataset`: forks into `1 + len(transforms)` children, distributing to +the reader and each transform with device alignment. + +### MultiDataset + +`MultiDataset.set_generator(generator)` forks into +`len(sub_datasets)` children and calls `set_generator` on each +sub-dataset. + +### Epoch reseeding + +`DataLoader.set_epoch(epoch)` propagates to the sampler and dataset. +Each component with a generator reseeds it with +`initial_seed() + epoch`, producing a different but deterministic +random sequence every epoch. + +## Generator tree + +```text +DataLoader(seed=S) +│ +├── master = Generator().manual_seed(S) +│ +├── fork_generator(master, 2) +│ ├── child[0] (seed S+1) ──► Sampler +│ └── child[1] (seed S+2) ──► Dataset / MeshDataset / MultiDataset +│ │ +│ ├── fork_generator(child[1], 1+N_transforms) +│ │ ├── child[0] (seed S+3) ──► Reader +│ │ ├── child[1] (seed S+4) ──► Transform 0 +│ │ ├── child[2] (seed S+5) ──► Transform 1 +│ │ └── ... +``` + +For `MultiDataset`, the fork distributes one child per sub-dataset, +and each sub-dataset then re-forks internally for its reader and +transforms. + +## Device management + +`torch.Generator` objects are device-bound and cannot be moved in-place. +Every boundary where a generator might cross devices contains explicit +re-creation logic: + +| Location | What happens | +|---|---| +| `fork_generator` | Creates children on `parent.device` | +| `Dataset.set_generator` | If `target_device != child.device`, creates a new generator on `target_device` seeded from the child | +| `MeshDataset.set_generator` | Same device-alignment logic as `Dataset` | +| `MeshTransform.to(device)` | Creates a new generator on `device`, seeded from the original's `initial_seed()` | +| `_sample_distribution` | Draws uniforms on `generator.device` | + +All random draws (`torch.rand`, `torch.randn`, `torch.randint`) pass +`device=generator.device` to stay on the correct device. + +## Stochastic transforms + +### Opting in + +Both `Transform` (TensorDict) and `MeshTransform` (Mesh) base classes +define the same generator protocol: + +- **`stochastic`** — property; `True` when `self._generator` exists. +- **`set_generator(g)`** — assigns `g` if stochastic; no-op otherwise. +- **`set_epoch(epoch)`** — reseeds with `initial_seed() + epoch`. + +To make a transform stochastic, declare +`self._generator: torch.Generator | None = None` in `__init__`. +Deterministic transforms never declare it, so all three methods are +silent no-ops. + +### TensorDict stochastic transforms + +- **`SubsamplePoints`** — declares `_generator` and passes it to + `torch.randperm`, `torch.multinomial`, and + `poisson_sample_indices_fixed`. + +### Mesh stochastic transforms + +- **`RandomScaleMesh`**, **`RandomTranslateMesh`**, + **`RandomRotateMesh`** — sample augmentation parameters from + `torch.distributions.Distribution` objects via ICDF + generator. +- **`SubsampleMesh`** — uses `torch.randperm` / `poisson_sample_indices_fixed`. + +### `Compose` + +`Compose.set_generator(generator)` forks and distributes one child per +child transform. `Compose.set_epoch(epoch)` propagates to all children. +When used inside `Dataset`, the dataset flattens `Compose` and assigns +forks per leaf transform directly; `Compose`'s own methods are for +standalone use. + +## Readers + +The `Reader` base class defines no-op `set_generator` / `set_epoch`. +Readers that use randomness override them: + +| Reader | Randomness | Generator support | +|---|---|---| +| `MeshReader` | `torch.randint` (contiguous block selection) | Yes | +| `DomainMeshReader` | `torch.randint` | Yes | +| `NumpyReader` | `torch.randint` (coordinated subsampling) | Yes | +| `ZarrReader` | `torch.randint` | Yes | +| `TensorStoreZarrReader` | `torch.randint` | Yes | +| `HDF5Reader` | None | No-op (inherited) | +| `VTKReader` | None | No-op (inherited) | + +## Current limitations + +- `DistributedSampler` manages its own seed internally; when using it, + pass `seed=` at `DistributedSampler` construction time rather than + relying on DataLoader's seed propagation. +- Legacy datapipes (`cae/`, `gnn/`, `climate/`, `healpix/`, + `benchmarks/`) are not wired into the generator protocol. diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index 0e3706475a..8036617fa2 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -40,9 +40,13 @@ ) from physicsnemo.datapipes.dataloader import DataLoader from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.mesh_dataset import MeshDataset from physicsnemo.datapipes.multi_dataset import MultiDataset +from physicsnemo.datapipes.protocols import DatasetBase from physicsnemo.datapipes.readers import ( + DomainMeshReader, HDF5Reader, + MeshReader, NumpyReader, Reader, TensorStoreZarrReader, @@ -58,23 +62,41 @@ from physicsnemo.datapipes.transforms import ( BoundingBoxFilter, BroadcastGlobalFeatures, + CenterMesh, CenterOfMass, Compose, + ComputeCellCentroids, ComputeNormals, ComputeSDF, + ComputeSurfaceNormals, ConcatFields, ConstantField, CreateGrid, + DropMeshFields, FieldSlice, KNearestNeighbors, + MeshToTensorDict, + MeshTransform, Normalize, + NormalizeMeshFields, NormalizeVectors, Purge, + RandomRotateMesh, + RandomScaleMesh, + RandomTranslateMesh, Rename, + RenameMeshFields, + RestructureTensorDict, + RotateMesh, Scale, + ScaleMesh, + SetGlobalField, + SubsampleMesh, SubsamplePoints, Transform, Translate, + TranslateMesh, + apply_to_tensordict_mesh, ) # Auto-register OmegaConf resolvers so ${dp:ComponentName} works in Hydra configs @@ -83,7 +105,9 @@ __all__ = [ # "TensorDict", # Re-export from tensordict + "DatasetBase", "Dataset", + "MeshDataset", "DataLoader", "MultiDataset", # Transforms - Base @@ -113,6 +137,25 @@ "Rename", "Purge", "ConstantField", + # Transforms - Mesh + "MeshTransform", + "ComputeCellCentroids", + "ComputeSurfaceNormals", + "MeshToTensorDict", + "apply_to_tensordict_mesh", + "ScaleMesh", + "TranslateMesh", + "RotateMesh", + "CenterMesh", + "SubsampleMesh", + "DropMeshFields", + "RenameMeshFields", + "NormalizeMeshFields", + "SetGlobalField", + "RestructureTensorDict", + "RandomScaleMesh", + "RandomTranslateMesh", + "RandomRotateMesh", # Readers "Reader", "HDF5Reader", @@ -120,6 +163,8 @@ "NumpyReader", "VTKReader", "TensorStoreZarrReader", + "MeshReader", + "DomainMeshReader", # Collation "Collator", "DefaultCollator", diff --git a/physicsnemo/datapipes/_rng.py b/physicsnemo/datapipes/_rng.py new file mode 100644 index 0000000000..d76a374df0 --- /dev/null +++ b/physicsnemo/datapipes/_rng.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Internal RNG utilities for deterministic generator forking. + +Used by :class:`DataLoader` and :class:`MeshDataset` to derive +independent per-component generators from a single master seed. +""" + +from __future__ import annotations + +import torch + + +def fork_generator( + parent: torch.Generator, + n: int, +) -> list[torch.Generator]: + """Deterministically derive *n* child generators from *parent*. + + Each child is seeded with ``parent.initial_seed() + i + 1``, so + children are independent of each other and stable across runs. + + Parameters + ---------- + parent : torch.Generator + Master generator whose ``initial_seed()`` is used as the base. + n : int + Number of child generators to create. + + Returns + ------- + list[torch.Generator] + *n* independent generators on the same device as *parent*. + """ + + # I miss JAX ... + # https://docs.jax.dev/en/latest/jax.random.html + + base_seed = parent.initial_seed() + children: list[torch.Generator] = [] + for i in range(n): + g = torch.Generator(device=parent.device) + g.manual_seed(base_seed + i + 1) + children.append(g) + return children diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index cf3b160cd2..4e6b7bc61f 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -31,8 +31,9 @@ from tensordict import TensorDict from torch.utils.data import RandomSampler, Sampler, SequentialSampler +from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.collate import Collator, get_collator -from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes.protocols import DatasetBase from physicsnemo.datapipes.registry import register @@ -79,7 +80,7 @@ class DataLoader: def __init__( self, - dataset: Dataset, + dataset: DatasetBase, *, batch_size: int = 1, shuffle: bool = False, @@ -96,14 +97,16 @@ def __init__( prefetch_factor: int = 2, num_streams: int = 4, use_streams: bool = True, + seed: int | None = None, ) -> None: """ Initialize the DataLoader. Parameters ---------- - dataset : Dataset - Dataset to load from. + dataset : DatasetBase + Dataset to load from. Any subclass of :class:`DatasetBase` + (e.g. :class:`Dataset`, :class:`MeshDataset`). batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False @@ -125,6 +128,12 @@ def __init__( use_streams : bool, default=True If True, use CUDA streams for overlap. Set False for debugging or CPU-only operation. + seed : int, optional + Master seed for all pipeline randomness. When set, the + DataLoader derives independent generators for the sampler, + reader, and every stochastic transform, making the full + pipeline reproducible. Use :meth:`set_epoch` to vary the + random sequence across epochs while staying deterministic. Raises ------ @@ -141,12 +150,29 @@ def __init__( self.prefetch_factor = prefetch_factor self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() + self._seed = seed + + # Build master generator and fork for sampler + dataset + sampler_generator: torch.Generator | None = None + if seed is not None: + master = torch.Generator() + master.manual_seed(seed) + # Fork: child 0 → sampler, child 1 → dataset + forks = fork_generator(master, 2) + sampler_generator = forks[0] + if hasattr(dataset, "set_generator"): + dataset.set_generator(forks[1]) # Handle sampler if sampler is not None: self.sampler = sampler + # For DistributedSampler, propagate seed if available + if seed is not None and hasattr(sampler, "seed"): + # DistributedSampler exposes seed as a constructor arg + # but it's read-only; users should pass seed at construction. + pass elif shuffle: - self.sampler = RandomSampler(dataset) + self.sampler = RandomSampler(dataset, generator=sampler_generator) else: self.sampler = SequentialSampler(dataset) @@ -286,9 +312,12 @@ def _iter_prefetch( def set_epoch(self, epoch: int) -> None: """ - Set the epoch for the sampler. + Set the epoch for the sampler and the full data pipeline. - Required for DistributedSampler to shuffle properly across epochs. + Propagates the epoch to the sampler (for + :class:`~torch.utils.data.distributed.DistributedSampler`), + the dataset, reader, and every stochastic transform so all + RNG streams are reseeded deterministically. Parameters ---------- @@ -297,6 +326,8 @@ def set_epoch(self, epoch: int) -> None: """ if hasattr(self.sampler, "set_epoch"): self.sampler.set_epoch(epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) def enable_prefetch(self) -> None: """ diff --git a/physicsnemo/datapipes/datapipes.md b/physicsnemo/datapipes/datapipes.md new file mode 100644 index 0000000000..b41ca1d846 --- /dev/null +++ b/physicsnemo/datapipes/datapipes.md @@ -0,0 +1,285 @@ +# Datapipes -- Design Overview + +A GPU-centric, modular data pipeline for scientific machine learning. +The system uses **threads and CUDA streams** to overlap disk I/O, +host-to-device transfer, and GPU-side transforms within a single +process. The result is low latency, zero inter-process serialization, +and natural support for GPU-accelerated preprocessing -- properties +that matter when datasets are large, batches are small, and transforms +benefit from GPU execution. + +## Architecture + +The pipeline has four composable layers: + +```text +Reader --> Dataset --> DataLoader --> Training loop + (I/O) (transforms) (batching) +``` + +```text + ┌─────────────────────────────────────────────────┐ + │ DataLoader │ + ┌──────────┐ │ ┌──────────────────────────────────────────┐ │ + │ Sampler │─indices─▶ │ Dataset │ │ + └──────────┘ │ │ │ │ + │ │ Reader ──► Device transfer ──► Transforms│ │ + │ │ (CPU I/O) (non_blocking) (Compose) │ │ + │ └──────────────┬───────────────────────────┘ │ + │ │ │ + │ ▼ │ + │ Collator │ + └────────────────┬────────────────────────────────┘ + │ + ▼ + Batched TensorDict + (training loop) +``` + +Three dataset types share this pattern: + +| Type | Data model | Transform base | +|------|------------|----------------| +| `Dataset` | `TensorDict` fields | `Transform` | +| `MeshDataset` | `Mesh` / `DomainMesh` tensorclasses | `MeshTransform` | +| `MultiDataset` | Union of child `DatasetBase` instances | Delegates to children | + +All three inherit from `DatasetBase`, which provides thread-pool +prefetching and a `Future`-based cache (see +[Performance](#performance-threading-and-stream-based-concurrency) below). + +## Composability + +### Readers + +A `Reader` is an ABC with a single contract: + +```python +class Reader(ABC): + @abstractmethod + def _load_sample(self, index: int) -> dict[str, Tensor]: ... +``` + +`__getitem__` wraps the result in a `TensorDict` on CPU (optionally +pinned). + +### Transforms + +Transforms are pure functions on `TensorDict` (or `Mesh`): + +```python +class Transform(ABC): + @abstractmethod + def __call__(self, data: TensorDict) -> TensorDict: ... +``` + +For meshes, the `MeshTransform` ABC provides the same interface with +`__call__(Mesh) -> Mesh` plus `apply_to_domain(DomainMesh)` for +multi-region consistency. + +### Collators + +Collators combine per-sample `(TensorDict, metadata)` tuples into batches: + +| Collator | Strategy | +|----------|----------| +| `DefaultCollator` | `TensorDict.stack()` -- all samples must share shape | +| `ConcatCollator` | `torch.cat()` along an axis with optional `batch_idx` -- for variable-length point clouds | +| `FunctionCollator` | Wraps any callable | + +### Registry and Hydra integration + +All readers, transforms, datasets, and the DataLoader are decorated with +`@register()`, placing them in a global `COMPONENT_REGISTRY`. The helper +`register_resolvers()` (called at import time) registers an OmegaConf +resolver so Hydra configs can reference components by short name: + +```yaml +dataset: + _target_: ${dp:Dataset} + reader: + _target_: ${dp:ZarrReader} + path: /data/field.zarr + fields: [pressure, velocity] + transforms: + - _target_: ${dp:Normalize} + fields: [pressure] + method: mean_std + means: {pressure: 0.0} + stds: {pressure: 1.0} + - _target_: ${dp:SubsamplePoints} + input_keys: [pressure, velocity] + n_points: 10000 + device: cuda +``` + +The equivalent Python: + +```python +from physicsnemo.datapipes import Dataset, ZarrReader, Normalize, SubsamplePoints + +dataset = Dataset( + ZarrReader("/data/field.zarr", fields=["pressure", "velocity"]), + transforms=[ + Normalize(["pressure"], method="mean_std", + means={"pressure": 0.0}, stds={"pressure": 1.0}), + SubsamplePoints(["pressure", "velocity"], n_points=10000), + ], + device="cuda", +) +``` + +## Performance: threading and stream-based concurrency + +### Why threads + streams + +Scientific ML data loading is dominated by disk I/O and GPU-side +preprocessing. Threads are a natural fit: + +- **Shared state** -- threads share memory, file handles, and the CUDA + context within a single process, so there is no serialization or + duplication overhead. +- **I/O concurrency** -- the GIL is released during disk reads and CUDA + kernel launches, so multiple threads usefully overlap I/O with GPU work. +- **Stream parallelism** -- each prefetched sample is assigned its own + CUDA stream, allowing host-to-device transfers and GPU transforms to + run concurrently with the main training computation. + +### Thread-pool prefetch + +`DatasetBase` owns a `ThreadPoolExecutor` (configurable via +`num_workers`, default 2). Calling `prefetch(index)` submits the +load-and-transform pipeline to the pool and stashes the `Future`: + +```python +def prefetch(self, index, stream=None): + if index in self._prefetch_futures: + return + executor = self._ensure_executor() + self._prefetch_futures[index] = executor.submit(self._load, index) +``` + +`__getitem__` pops the `Future` if one exists, otherwise loads +synchronously: + +```python +def __getitem__(self, index): + future = self._prefetch_futures.pop(index, None) + if future is not None: + return future.result() + return self._load(index) +``` + +This means the DataLoader can keep the next batch loading in background +threads while the current batch is being consumed by the model. + +### CUDA stream overlap + +When GPU execution is available, `Dataset` (and `MeshDataset`) override +`prefetch` to run device transfer and transforms on a caller-supplied +CUDA stream, then record an event for later synchronization: + +```python +def _load_and_transform(self, index, stream=None): + result = _PrefetchResult(index=index) + data, metadata = self.reader[index] # CPU I/O in worker thread + + if stream is not None: + with torch.cuda.stream(stream): + data = data.to(device, non_blocking=True) # H2D on stream + data = self.transforms(data) # GPU transforms on stream + result.event = torch.cuda.Event() + result.event.record(stream) # mark completion + + result.data, result.metadata = data, metadata + return result +``` + +On retrieval, `__getitem__` synchronizes the event before returning: + +```python +if result.event is not None: + result.event.synchronize() +return result.data, result.metadata +``` + +The `DataLoader` owns a pool of `num_streams` CUDA streams (default 4) +and round-robins them across samples. It also maintains a sliding +prefetch window of `prefetch_factor` batches (default 2) ahead of the +current yield position: + +```python +# Prefetch the next batch as we yield the current one +for sample_idx in all_batches[next_prefetch_idx]: + stream = self._streams[stream_idx % self.num_streams] + self.dataset.prefetch(sample_idx, stream=stream) + stream_idx += 1 +``` + +### Concurrency timeline + +The diagram below shows how threads and streams overlap for a two-sample +batch with `prefetch_factor=1`: + +```text +Main thread Worker 1 Worker 2 Stream 1 Stream 2 + │ │ │ │ │ + ├─prefetch(0,S1)─►│ │ │ │ + ├─prefetch(1,S2)─────────────────────►│ │ │ + │ ├─ Read (I/O) │ │ │ + │ │ ├─ Read (I/O) │ │ + │ ├─ to(device) ─────────────────────────►│ │ + │ ├─ transforms ─────────────────────────►│ │ + │ ├─ event.record() ─────────────────────►│ │ + │ │ ├─ to(device) ─────────────────►│ + │ │ ├─ transforms ─────────────────►│ + │ │ ├─ event.record() ─────────────►│ + ├─ event.synchronize() ×2 │ │ │ + ├─ collate + yield batch │ │ │ + │ │ │ │ │ +``` + +While the main thread consumes batch N, worker threads are already +loading batch N+1 on different streams. + +### Pinned memory + +Readers can set `pin_memory=True` to allocate CPU tensors in pinned +(page-locked) memory. Pinned memory enables truly asynchronous +`non_blocking` transfers to GPU, so the CUDA stream overlap described +above is most effective when the reader pins its output. + +### Debugging + +Prefetching can be toggled at runtime for debugging: + +```python +loader.disable_prefetch() # synchronous, single-stream -- easy to debug +loader.enable_prefetch() # re-enable after debugging +``` + +Setting `use_streams=False` or `prefetch_factor=0` at construction time +also forces synchronous execution. + +## RNG and reproducibility + +Deterministic data loading is opt-in. Passing `seed=` to `DataLoader` +creates a master `torch.Generator` that is forked into independent +streams for the sampler, the reader, and every stochastic transform. +`set_epoch(epoch)` reseeds all streams deterministically so each epoch +produces a different but reproducible random sequence. The full +generator tree, device management rules, and per-component details are +documented in **[RNG.md](RNG.md)**. + +## Augmentations + +Mesh augmentations (`RandomScaleMesh`, `RandomTranslateMesh`, +`RandomRotateMesh`) accept any `torch.distributions.Distribution` to +parametrize their random sampling. To preserve reproducibility with +seeded `torch.Generator` objects (which `Distribution.sample()` does not +accept), the augmentations use **inverse CDF sampling**: draw +`U ~ Uniform(0,1)` via `torch.rand(generator=g)`, then compute +`X = distribution.icdf(U)`. This gives exact samples from the target +distribution while keeping all randomness under generator control. +Full usage examples, YAML configuration, and the supported-distribution +table are in **[transforms/mesh/DISTRIBUTIONS.md](transforms/mesh/DISTRIBUTIONS.md)**. diff --git a/physicsnemo/datapipes/dataset.py b/physicsnemo/datapipes/dataset.py index 8087c19d84..727503e800 100644 --- a/physicsnemo/datapipes/dataset.py +++ b/physicsnemo/datapipes/dataset.py @@ -25,13 +25,13 @@ from __future__ import annotations -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass -from typing import Any, Iterator, Optional, Sequence +from typing import Any, Optional, Sequence import torch from tensordict import TensorDict +from physicsnemo.datapipes._rng import fork_generator +from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.base import Transform @@ -39,19 +39,8 @@ from physicsnemo.distributed import DistributedManager -@dataclass -class _PrefetchResult: - """Result of a prefetch operation.""" - - index: int - data: Optional[TensorDict] = None - metadata: Optional[dict[str, Any]] = None - error: Optional[Exception] = None - event: Optional[torch.cuda.Event] = None # For stream sync - - @register() -class Dataset: +class Dataset(DatasetBase): """ A dataset combining a Reader with a transform pipeline. @@ -126,6 +115,8 @@ def __init__( TypeError If reader is not a Reader instance. """ + super().__init__(num_workers=num_workers) + if not isinstance(reader, Reader): raise TypeError( f"reader must be a Reader instance, got {type(reader).__name__}" @@ -142,7 +133,6 @@ def __init__( else: device = "cpu" - # Now, instantiate the device if not already done: match device: case torch.device(): self.target_device = device @@ -151,7 +141,6 @@ def __init__( case None: self.target_device = None - # Handle transforms if transforms is None: self.transforms: Optional[Transform] = None elif isinstance(transforms, Transform): @@ -169,30 +158,92 @@ def __init__( f"got {type(transforms).__name__}" ) - # Share device with transforms so their internal state is on the right device if self.target_device is not None and self.transforms is not None: self.transforms.to(self.target_device) - # Prefetch state - using thread-safe dict for results - # Key: index, Value: Future[_PrefetchResult] - self._prefetch_futures: dict[int, Future[_PrefetchResult]] = {} - self._executor: Optional[ThreadPoolExecutor] = None + # ------------------------------------------------------------------ + # DatasetBase implementation + # ------------------------------------------------------------------ + + def _load(self, index: int) -> tuple[TensorDict, dict[str, Any]]: + """Synchronous load: reader → device transfer → transforms.""" + data, metadata = self.reader[index] + + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + + if self.transforms is not None: + data = self.transforms(data) + + return data, metadata + + def __len__(self) -> int: + return len(self.reader) - def _ensure_executor(self) -> ThreadPoolExecutor: + # ------------------------------------------------------------------ + # RNG management + # ------------------------------------------------------------------ + + def _flat_transforms(self) -> list[Transform]: + """Return transforms as a flat list (unwrapping Compose).""" + if self.transforms is None: + return [] + if isinstance(self.transforms, Compose): + return list(self.transforms.transforms) + return [self.transforms] + + def set_generator(self, generator: torch.Generator) -> None: + """Distribute forked generators to the reader and every stochastic transform. + + Forks *generator* into ``1 + len(flat_transforms)`` independent + children: the first goes to the reader, the rest map 1-to-1 to + the transform list (deterministic transforms silently ignore + theirs). + + Parameters + ---------- + generator : torch.Generator + Parent generator (typically forked from the DataLoader's + master generator). """ - Lazily create the thread pool executor. + flat = self._flat_transforms() + n_children = 1 + len(flat) + children = fork_generator(generator, n_children) + + # Child 0 → reader + if hasattr(self.reader, "set_generator"): + self.reader.set_generator(children[0]) + + # Children 1..N → transforms (deterministic ones ignore via base no-op) + for child, t in zip(children[1:], flat): + if self.target_device is not None and self.target_device != child.device: + dev_gen = torch.Generator(device=self.target_device) + dev_gen.manual_seed(child.initial_seed()) + t.set_generator(dev_gen) + else: + t.set_generator(child) - Returns - ------- - ThreadPoolExecutor - The thread pool executor for prefetching. + def set_epoch(self, epoch: int) -> None: + """Propagate epoch to the reader and every transform. + + Reseeds all generators assigned via :meth:`set_generator` so + each epoch produces a different but deterministic random + sequence. + + Parameters + ---------- + epoch : int + Current epoch number. """ - if self._executor is None: - self._executor = ThreadPoolExecutor( - max_workers=self.num_workers, - thread_name_prefix="datapipe_prefetch", - ) - return self._executor + if hasattr(self.reader, "set_epoch"): + self.reader.set_epoch(epoch) + + for t in self._flat_transforms(): + t.set_epoch(epoch) + + # ------------------------------------------------------------------ + # Stream-aware prefetch (overrides DatasetBase defaults) + # ------------------------------------------------------------------ def _load_and_transform( self, @@ -200,7 +251,7 @@ def _load_and_transform( stream: Optional[torch.cuda.Stream] = None, ) -> _PrefetchResult: """ - Load a sample and apply transforms. Called by worker threads. + Load a sample and apply transforms with optional CUDA stream. Parameters ---------- @@ -217,10 +268,8 @@ def _load_and_transform( result = _PrefetchResult(index=index) try: - # Load from reader (CPU, potentially slow IO) data, metadata = self.reader[index] - # Auto-transfer to target device if specified if self.target_device is not None: if stream is not None: with torch.cuda.stream(stream): @@ -228,12 +277,10 @@ def _load_and_transform( else: data = data.to(self.target_device, non_blocking=True) - # Apply transforms (data is now on target device if specified) if self.transforms is not None: if stream is not None: with torch.cuda.stream(stream): data = self.transforms(data) - # Record event for synchronization result.event = torch.cuda.Event() result.event.record(stream) else: @@ -255,10 +302,8 @@ def prefetch( """ Start prefetching a sample asynchronously. - The sample will be loaded in a background thread. If a CUDA stream - is provided, GPU operations happen on that stream. - - Call __getitem__ to retrieve the result (it will wait if needed). + When a CUDA stream is provided, GPU operations (device transfer + and transforms) run on that stream for overlap with computation. Parameters ---------- @@ -267,7 +312,6 @@ def prefetch( stream : torch.cuda.Stream, optional Optional CUDA stream for GPU operations. """ - # Don't prefetch if already in flight if index in self._prefetch_futures: return @@ -275,28 +319,6 @@ def prefetch( future = executor.submit(self._load_and_transform, index, stream) self._prefetch_futures[index] = future - def prefetch_batch( - self, - indices: Sequence[int], - streams: Optional[Sequence[torch.cuda.Stream]] = None, - ) -> None: - """ - Start prefetching multiple samples. - - Parameters - ---------- - indices : Sequence[int] - Sample indices to prefetch. - streams : Sequence[torch.cuda.Stream], optional - Optional CUDA streams, one per index. If shorter than - indices, streams are cycled. If None, no streams used. - """ - for i, idx in enumerate(indices): - stream = None - if streams: - stream = streams[i % len(streams)] - self.prefetch(idx, stream=stream) - def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: """ Get a transformed sample by index. @@ -321,77 +343,21 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: Exception If prefetch failed, re-raises the error. """ - # Check if prefetched future = self._prefetch_futures.pop(index, None) if future is not None: - # Wait for prefetch to complete result = future.result() - if result.error is not None: - raise result.error + if isinstance(result, _PrefetchResult): + if result.error is not None: + raise result.error + if result.event is not None: + torch.cuda.current_stream().wait_event(result.event) + return result.data, result.metadata - # Sync stream if needed - if result.event is not None: - result.event.synchronize() + return result - return result.data, result.metadata - - # Not prefetched, load synchronously - data, metadata = self.reader[index] - - # Auto-transfer to target device if specified - if self.target_device is not None: - data = data.to(self.target_device, non_blocking=True) - - # Apply transforms - if self.transforms is not None: - data = self.transforms(data) - - return data, metadata - - def cancel_prefetch(self, index: Optional[int] = None) -> None: - """ - Cancel prefetch requests. - - Note: Already-running tasks will complete, but results are discarded. - - Parameters - ---------- - index : int, optional - Specific index to cancel. If None, cancels all. - """ - if index is None: - # Cancel all - just clear the dict, let futures complete - self._prefetch_futures.clear() - else: - self._prefetch_futures.pop(index, None) - - def __len__(self) -> int: - """ - Return the number of samples in the dataset. - - Returns - ------- - int - Number of samples. - """ - return len(self.reader) - - def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: - """ - Iterate over all samples. - - Note: This does NOT automatically prefetch. For prefetched iteration, - use the DataLoader which manages prefetching strategy. - - Yields - ------ - tuple[TensorDict, dict[str, Any]] - Tuple of (transformed data, metadata) for each sample. - """ - for i in range(len(self)): - yield self[i] + return self._load(index) @property def field_names(self) -> list[str]: @@ -405,18 +371,6 @@ def field_names(self) -> list[str]: """ return self.reader.field_names - @property - def prefetch_count(self) -> int: - """ - Number of items currently being prefetched. - - Returns - ------- - int - Count of in-flight prefetch operations. - """ - return len(self._prefetch_futures) - def close(self) -> None: """ Close the dataset and stop prefetching. @@ -425,31 +379,9 @@ def close(self) -> None: This prevents "cannot schedule new futures after shutdown" errors from libraries like zarr that use async I/O internally. """ - # Wait for any in-flight prefetch tasks to complete before shutdown. - # This prevents "cannot schedule new futures after shutdown" errors - # from libraries like zarr that use async I/O internally. - for future in self._prefetch_futures.values(): - try: - future.result(timeout=30.0) # Wait up to 30s per task - except Exception: # noqa: BLE001, S110 - pass # Ignore errors during shutdown - - self._prefetch_futures.clear() - - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None - + super().close() self.reader.close() - def __enter__(self) -> "Dataset": - """Context manager entry.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """Context manager exit.""" - self.close() - def __repr__(self) -> str: """ Return string representation. diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py new file mode 100644 index 0000000000..6ca740bb4b --- /dev/null +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MeshDataset - Combines a mesh reader (MeshReader or DomainMeshReader) with mesh transforms. + +Returns (Mesh, metadata) or (DomainMesh, metadata). No key-based filtering. +Supports CUDA stream-aware prefetching for overlapped IO and computation. +""" + +from __future__ import annotations + +from typing import Any, Optional, Sequence, Union + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes._rng import fork_generator +from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.mesh.base import MeshTransform +from physicsnemo.mesh import DomainMesh, Mesh + + +@register() +class MeshDataset(DatasetBase): + r""" + Dataset for mesh readers and mesh-only transforms. + + Accepts :class:`MeshReader` (single-mesh) or :class:`DomainMeshReader` + (domain mesh with interior + boundaries). + + Applies a sequence of :class:`MeshTransform` (Mesh -> Mesh). + For single-mesh data each transform is called directly. + For :class:`DomainMesh` data each transform is applied via + :meth:`MeshTransform.apply_to_domain`, which handles domain-level + ``global_data``, consistent random parameter sampling, and + proper centering semantics. + + Supports CUDA stream-aware prefetching: when a stream is provided to + :meth:`prefetch`, device transfer and transforms run on that stream, + allowing overlap with training computation. + + Examples + -------- + >>> from physicsnemo.datapipes import DataLoader, MeshDataset, MeshReader + >>> + >>> reader = MeshReader("data/meshes/") # doctest: +SKIP + >>> dataset = MeshDataset(reader, transforms=[...], device="cuda") # doctest: +SKIP + >>> loader = DataLoader(dataset, batch_size=1, shuffle=True) # doctest: +SKIP + + With DistributedSampler: + + >>> from torch.utils.data.distributed import DistributedSampler + >>> sampler = DistributedSampler(dataset) # doctest: +SKIP + >>> loader = DataLoader(dataset, batch_size=1, sampler=sampler) # doctest: +SKIP + """ + + def __init__( + self, + reader: MeshReader | DomainMeshReader, + *, + transforms: Sequence[MeshTransform] | None = None, + device: str | torch.device | None = None, + num_workers: int = 1, + ) -> None: + """ + Parameters + ---------- + reader : MeshReader or DomainMeshReader + Mesh reader; returns (Mesh, metadata) or (DomainMesh, metadata). + transforms : sequence of MeshTransform, optional + Transforms to apply in order. None means no transforms. + device : str or torch.device, optional + If set, move mesh data to this device after loading (before transforms). + num_workers : int, default=1 + Number of worker threads for prefetching. Defaults to 1 + because mesh transforms construct new Mesh objects internally + and tensordict's ``_device_recorder`` is not safe for + concurrent TensorDict construction across threads. + """ + super().__init__(num_workers=num_workers) + self.reader = reader + self.transforms = list(transforms) if transforms else [] + self._device = torch.device(device) if isinstance(device, str) else device + + if self._device is not None: + for t in self.transforms: + if hasattr(t, "to"): + t.to(self._device) + + # ------------------------------------------------------------------ + # RNG management + # ------------------------------------------------------------------ + + def set_generator(self, generator: torch.Generator) -> None: + """Distribute forked generators to the reader and every stochastic transform. + + Forks *generator* into ``1 + len(self.transforms)`` independent + children: the first goes to the reader, the rest map 1-to-1 to + the transform list (deterministic transforms silently ignore + theirs). + + Parameters + ---------- + generator : torch.Generator + Parent generator (typically forked from the DataLoader's + master generator). + """ + n_children = 1 + len(self.transforms) + children = fork_generator(generator, n_children) + + # Child 0 → reader + if hasattr(self.reader, "set_generator"): + self.reader.set_generator(children[0]) + + # Children 1..N → transforms (deterministic ones ignore via base no-op) + for child, t in zip(children[1:], self.transforms): + if hasattr(t, "set_generator"): + if self._device is not None and self._device != child.device: + dev_gen = torch.Generator(device=self._device) + dev_gen.manual_seed(child.initial_seed()) + t.set_generator(dev_gen) + else: + t.set_generator(child) + + def set_epoch(self, epoch: int) -> None: + """Propagate epoch to the reader and every transform. + + Reseeds all generators assigned via :meth:`set_generator` so + each epoch produces a different but deterministic random + sequence. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + if hasattr(self.reader, "set_epoch"): + self.reader.set_epoch(epoch) + + for t in self.transforms: + if hasattr(t, "set_epoch"): + t.set_epoch(epoch) + + # ------------------------------------------------------------------ + # DatasetBase implementation + # ------------------------------------------------------------------ + + def _load( + self, index: int + ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: + """Synchronous load: reader -> device transfer -> transforms.""" + data, metadata = self.reader[index] + + if self._device is not None: + data = data.to(self._device) + + for t in self.transforms: + if isinstance(data, DomainMesh): + data = t.apply_to_domain(data) + else: + data = t(data) + + return data, metadata + + def __len__(self) -> int: + return len(self.reader) + + # ------------------------------------------------------------------ + # Stream-aware prefetch (overrides DatasetBase defaults) + # ------------------------------------------------------------------ + + def _load_and_transform( + self, + index: int, + stream: Optional[torch.cuda.Stream] = None, + ) -> _PrefetchResult: + """Load a sample and apply transforms with optional CUDA stream. + + Parameters + ---------- + index : int + Sample index. + stream : torch.cuda.Stream, optional + Optional CUDA stream for GPU operations. + + Returns + ------- + _PrefetchResult + Result with data, metadata, or error. + """ + result = _PrefetchResult(index=index) + + try: + data, metadata = self.reader[index] + + if self._device is not None: + if stream is not None: + with torch.cuda.stream(stream): + data = data.to(self._device, non_blocking=True) + else: + data = data.to(self._device, non_blocking=True) + + for t in self.transforms: + if stream is not None: + with torch.cuda.stream(stream): + if isinstance(data, DomainMesh): + data = t.apply_to_domain(data) + else: + data = t(data) + else: + if isinstance(data, DomainMesh): + data = t.apply_to_domain(data) + else: + data = t(data) + + if stream is not None: + result.event = torch.cuda.Event() + result.event.record(stream) + + result.data = data + result.metadata = metadata + + except Exception as e: + result.error = e + + return result + + def prefetch( + self, + index: int, + stream: Optional[torch.cuda.Stream] = None, + ) -> None: + """Start prefetching a sample asynchronously. + + When a CUDA stream is provided, GPU operations (device transfer + and transforms) run on that stream for overlap with computation. + + Parameters + ---------- + index : int + Sample index to prefetch. + stream : torch.cuda.Stream, optional + Optional CUDA stream for GPU operations. + """ + if index in self._prefetch_futures: + return + + executor = self._ensure_executor() + future = executor.submit(self._load_and_transform, index, stream) + self._prefetch_futures[index] = future + + def __getitem__( + self, index: int + ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: + """Get a transformed sample by index. + + If the index was prefetched, returns the prefetched result + (waiting for completion if necessary). Otherwise loads synchronously. + + Parameters + ---------- + index : int + Sample index. + + Returns + ------- + tuple[Mesh | DomainMesh | TensorDict, dict[str, Any]] + Tuple of (transformed data, metadata dict). + + Raises + ------ + Exception + If prefetch failed, re-raises the error. + """ + future = self._prefetch_futures.pop(index, None) + + if future is not None: + result = future.result() + + if isinstance(result, _PrefetchResult): + if result.error is not None: + raise result.error + if result.event is not None: + torch.cuda.current_stream().wait_event(result.event) + return result.data, result.metadata + + return result + + return self._load(index) + + def close(self) -> None: + """Close the dataset and stop prefetching. + + Waits for any in-flight prefetch tasks to complete before shutdown. + """ + super().close() diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py index e09c5855dc..2547bf5f95 100644 --- a/physicsnemo/datapipes/multi_dataset.py +++ b/physicsnemo/datapipes/multi_dataset.py @@ -28,16 +28,18 @@ from typing import Any, Iterator, Optional, Sequence +import torch from tensordict import TensorDict -from physicsnemo.datapipes.dataset import Dataset +from physicsnemo.datapipes._rng import fork_generator +from physicsnemo.datapipes.protocols import DatasetBase from physicsnemo.datapipes.registry import register # Metadata key added by MultiDataset to identify which sub-dataset produced the sample. DATASET_INDEX_METADATA_KEY = "dataset_index" -def _validate_strict_outputs(datasets: Sequence[Dataset]) -> list[str]: +def _validate_strict_outputs(datasets: Sequence[DatasetBase]) -> list[str]: """ Check that all non-empty datasets produce the same TensorDict keys; return them. @@ -46,7 +48,7 @@ def _validate_strict_outputs(datasets: Sequence[Dataset]) -> list[str]: Parameters ---------- - datasets : Sequence[Dataset] + datasets : Sequence[DatasetBase] Datasets to validate. Returns @@ -76,25 +78,30 @@ def _validate_strict_outputs(datasets: Sequence[Dataset]) -> list[str]: "output_strict=True requires identical output keys (TensorDict keys) " f"across datasets: dataset {ref_index} has {ref_keys}, dataset {i} has {keys}" ) - return list(ref_keys) if ref_keys is not None else list(datasets[0].field_names) + if ref_keys is not None: + return list(ref_keys) + first = datasets[0] + return list(first.field_names) if hasattr(first, "field_names") else [] @register() class MultiDataset: r""" - A dataset that composes multiple :class:`Dataset` instances behind one index space. + A dataset that composes multiple :class:`DatasetBase` instances behind one index space. - Global indices are mapped to (dataset_index, local_index) by concatenation: - indices 0..len0-1 come from the first dataset, len0..len0+len1-1 from the second, - and so on. Each constituent can have its own Reader and transforms. Metadata - is enriched with ``dataset_index`` so batches can identify the source. + Accepts both :class:`Dataset` (TensorDict pipelines) and :class:`MeshDataset` + (Mesh pipelines) as sub-datasets. Global indices are mapped to + (dataset_index, local_index) by concatenation: indices 0..len0-1 come from the + first dataset, len0..len0+len1-1 from the second, and so on. Each constituent + can have its own Reader and transforms. Metadata is enriched with + ``dataset_index`` so batches can identify the source. Parameters ---------- - *datasets : Dataset - One or more Dataset instances passed as positional arguments - (Reader + transforms each). Order defines index mapping: first - dataset occupies 0..len(ds0)-1, etc. + *datasets : DatasetBase + One or more Dataset or MeshDataset instances passed as positional + arguments (Reader + transforms each). Order defines index mapping: + first dataset occupies 0..len(ds0)-1, etc. output_strict : bool, default=True If True, require all datasets to produce the same TensorDict keys (output keys after transforms) so :class:`DefaultCollator` can stack batches. If @@ -112,7 +119,7 @@ class MultiDataset: Notes ----- MultiDataset implements the same interface as :class:`Dataset` (``__len__``, - ``__getitem__``, ``prefetch``, ``prefetch_batch``, ``prefetch_count``, + ``__getitem__``, ``prefetch``, ``cancel_prefetch``, ``close``, ``field_names``) and can be passed to :class:`DataLoader` in place of a single dataset. Prefetch and close are delegated to the sub-dataset that owns the index. When ``output_strict=True``, @@ -157,7 +164,7 @@ class MultiDataset: def __init__( self, - *datasets: Dataset, + *datasets: DatasetBase, output_strict: bool = True, ) -> None: if len(datasets) < 1: @@ -165,9 +172,10 @@ def __init__( f"MultiDataset requires at least one dataset, got {len(datasets)}" ) for i, ds in enumerate(datasets): - if not isinstance(ds, Dataset): + if not isinstance(ds, DatasetBase): raise TypeError( - f"datasets[{i}] must be a Dataset instance, got {type(ds).__name__}" + f"datasets[{i}] must be a Dataset or MeshDataset instance, " + f"got {type(ds).__name__}" ) self._datasets = list(datasets) @@ -183,7 +191,11 @@ def __init__( if output_strict: self._field_names = _validate_strict_outputs(self._datasets) else: - self._field_names = list(self._datasets[0].field_names) + first = self._datasets[0] + if hasattr(first, "field_names"): + self._field_names = list(first.field_names) + else: + self._field_names = [] def _index_to_dataset_and_local(self, index: int) -> tuple[int, int]: """ @@ -240,6 +252,36 @@ def __len__(self) -> int: """Return the total number of samples (sum of all sub-dataset lengths).""" return self._cumul[-1] + # ------------------------------------------------------------------ + # RNG management + # ------------------------------------------------------------------ + + def set_generator(self, generator: torch.Generator) -> None: + """Fork *generator* and distribute one child per sub-dataset. + + Parameters + ---------- + generator : torch.Generator + Parent generator (typically forked from the DataLoader's + master generator). + """ + children = fork_generator(generator, len(self._datasets)) + for child, ds in zip(children, self._datasets): + if hasattr(ds, "set_generator"): + ds.set_generator(child) + + def set_epoch(self, epoch: int) -> None: + """Propagate epoch to every sub-dataset. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + for ds in self._datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) + def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: """ Return the transformed sample and metadata for the given global index. @@ -282,31 +324,6 @@ def prefetch( ds_id, local_i = self._index_to_dataset_and_local(index) self._datasets[ds_id].prefetch(local_i, stream=stream) - def prefetch_batch( - self, - indices: Sequence[int], - streams: Optional[Sequence[Any]] = None, - ) -> None: - """ - Start prefetching multiple samples by global index. - - Delegates to the sub-dataset that owns each index. Streams are cycled - if shorter than indices. - - Parameters - ---------- - indices : Sequence[int] - Global sample indices to prefetch. - streams : Sequence[Any], optional - Optional CUDA streams, one per index. If shorter than indices, - streams are cycled. If None, no streams used. - """ - for i, idx in enumerate(indices): - stream = None - if streams: - stream = streams[i % len(streams)] - self.prefetch(idx, stream=stream) - def cancel_prefetch(self, index: Optional[int] = None) -> None: """ Cancel prefetch for the given index or all sub-datasets. @@ -328,18 +345,6 @@ def cancel_prefetch(self, index: Optional[int] = None) -> None: ds_id, local_i = mapped self._datasets[ds_id].cancel_prefetch(local_i) - @property - def prefetch_count(self) -> int: - """ - Number of items currently being prefetched across all sub-datasets. - - Returns - ------- - int - Sum of in-flight prefetch counts from each sub-dataset. - """ - return sum(ds.prefetch_count for ds in self._datasets) - def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: """Iterate over all samples in global index order.""" for i in range(len(self)): diff --git a/physicsnemo/datapipes/protocols.py b/physicsnemo/datapipes/protocols.py new file mode 100644 index 0000000000..69d4363f3b --- /dev/null +++ b/physicsnemo/datapipes/protocols.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base class for dataset components. + +Provides :class:`DatasetBase`, an ABC that owns the thread-based prefetch +infrastructure shared by :class:`Dataset`, :class:`MeshDataset`, and any +future dataset implementations. The user-facing extension points are +**Readers** and **Transforms**, not dataset subclasses. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Iterator, Optional + +import torch + + +@dataclass +class _PrefetchResult: + """Result of a stream-aware prefetch operation. + + Used by :class:`Dataset` and :class:`MeshDataset` to carry data, + metadata, and an optional CUDA event through the prefetch pipeline. + """ + + index: int + data: Any = None + metadata: Optional[dict[str, Any]] = field(default=None) + error: Optional[Exception] = field(default=None) + event: Optional[torch.cuda.Event] = field(default=None) + + +class DatasetBase(ABC): + """Abstract base for datasets compatible with :class:`DataLoader`. + + Subclasses implement :meth:`_load` (the actual data-loading pipeline) + and :meth:`__len__`. Everything else — ``__getitem__`` with prefetch + cache lookup, thread-pool prefetching, cancellation, cleanup — is + provided here. + + Both :class:`Dataset` and :class:`MeshDataset` override + :meth:`prefetch` and :meth:`__getitem__` to add CUDA-stream + support via :class:`_PrefetchResult`. + """ + + def __init__(self, *, num_workers: int = 2) -> None: + self._prefetch_futures: dict[int, Future] = {} + self._executor: Optional[ThreadPoolExecutor] = None + self._num_workers = num_workers + + @abstractmethod + def _load(self, index: int) -> tuple[Any, dict[str, Any]]: + """Load and return a single sample ``(data, metadata)``. + + This is the hook that subclasses must implement. It is called + both synchronously (from ``__getitem__``) and asynchronously + (from the prefetch thread pool). + """ + ... + + @abstractmethod + def __len__(self) -> int: ... + + # ------------------------------------------------------------------ + # Concrete interface + # ------------------------------------------------------------------ + + def __getitem__(self, index: int) -> tuple[Any, dict[str, Any]]: + """Return sample *index*, using a prefetched result when available.""" + future = self._prefetch_futures.pop(index, None) + if future is not None: + return future.result() # re-raises on error + return self._load(index) + + def prefetch( + self, + index: int, + stream: Optional[torch.cuda.Stream] = None, + ) -> None: + """Submit *index* for background loading in a worker thread. + + The ``stream`` parameter is accepted for interface compatibility + but ignored by the default implementation. :class:`Dataset` + overrides this to run GPU transfers on the given stream. + """ + if index in self._prefetch_futures: + return + executor = self._ensure_executor() + self._prefetch_futures[index] = executor.submit(self._load, index) + + def cancel_prefetch(self, index: Optional[int] = None) -> None: + """Discard prefetch results (already-running tasks still complete).""" + if index is None: + self._prefetch_futures.clear() + else: + self._prefetch_futures.pop(index, None) + + def close(self) -> None: + """Drain in-flight prefetches and shut down the thread pool.""" + for future in self._prefetch_futures.values(): + try: + future.result(timeout=30.0) + except Exception: # noqa: BLE001, S110 + pass + self._prefetch_futures.clear() + + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _ensure_executor(self) -> ThreadPoolExecutor: + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self._num_workers, + thread_name_prefix="datapipe_prefetch", + ) + return self._executor + + def __iter__(self) -> Iterator[tuple[Any, dict[str, Any]]]: + for i in range(len(self)): + yield self[i] + + def __enter__(self) -> "DatasetBase": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() diff --git a/physicsnemo/datapipes/readers/__init__.py b/physicsnemo/datapipes/readers/__init__.py index 89ab147e4e..992c25829e 100644 --- a/physicsnemo/datapipes/readers/__init__.py +++ b/physicsnemo/datapipes/readers/__init__.py @@ -26,6 +26,7 @@ from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.readers.hdf5 import HDF5Reader +from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader from physicsnemo.datapipes.readers.numpy import NumpyReader from physicsnemo.datapipes.readers.tensorstore_zarr import TensorStoreZarrReader from physicsnemo.datapipes.readers.vtk import VTKReader @@ -38,4 +39,6 @@ "NumpyReader", "VTKReader", "TensorStoreZarrReader", + "MeshReader", + "DomainMeshReader", ] diff --git a/physicsnemo/datapipes/readers/base.py b/physicsnemo/datapipes/readers/base.py index ab8fee3d11..17ce9b0b86 100644 --- a/physicsnemo/datapipes/readers/base.py +++ b/physicsnemo/datapipes/readers/base.py @@ -278,6 +278,30 @@ def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: logger.error(error_msg) raise RuntimeError(error_msg) from e + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible random sampling. + + Override in subclasses that use randomness (e.g. subsampling). + The default implementation is a no-op. + + Parameters + ---------- + generator : torch.Generator + Generator to use for random draws. + """ + + def set_epoch(self, epoch: int) -> None: + """Reseed the reader's RNG for a new epoch. + + Override in subclasses that use randomness. + The default implementation is a no-op. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + def close(self) -> None: """ Clean up resources (file handles, connections, etc.). diff --git a/physicsnemo/datapipes/readers/mesh.py b/physicsnemo/datapipes/readers/mesh.py new file mode 100644 index 0000000000..3c9527a367 --- /dev/null +++ b/physicsnemo/datapipes/readers/mesh.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mesh readers - Load physicsnemo Mesh / DomainMesh from physicsnemo mesh format (.pt). + +MeshReader returns (Mesh, metadata) per sample. +DomainMeshReader returns (DomainMesh, metadata) per sample. +Both use tensorclass .load(path) directly; no conversion from other formats. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Iterator + +import torch + +from physicsnemo.datapipes.registry import register +from physicsnemo.mesh import DomainMesh, Mesh + +logger = logging.getLogger(__name__) + +# Default extension for physicsnemo mesh format (tensordict/tensorclass layout). +# Do not hardcode elsewhere so format can evolve. +DEFAULT_MESH_EXTENSION = ".pmsh" + + +def _contiguous_block_slice( + total: int, + k: int, + generator: torch.Generator | None = None, +) -> slice: + """Return a random contiguous ``slice`` of length *k* within ``[0, total)``. + + A contiguous slice produces sequential I/O on memmap-backed tensors, + which is orders of magnitude faster than scattered fancy-indexing. + For best results the on-disk point order should be pre-shuffled so + that a contiguous block is spatially representative. + """ + if total <= k: + return slice(0, total) + start = torch.randint(0, total - k, (1,), generator=generator).item() + return slice(start, start + k) + + +def _subsample_mesh_points( + mesh: Mesh, + n_points: int, + generator: torch.Generator | None = None, +) -> Mesh: + """Subsample a Mesh to *n_points* via a contiguous block read. + + Uses a contiguous slice for sequential I/O on memmap-backed data. + For point clouds (``n_cells == 0``) this avoids the heavy + cell-remapping logic in :meth:`Mesh.slice_points` which allocates + two *N*-element intermediate tensors. For meshes with cells it + falls back to ``slice_points``. + """ + if mesh.n_points <= n_points: + return mesh + sl = _contiguous_block_slice(mesh.n_points, n_points, generator=generator) + if mesh.n_cells == 0: + return Mesh( + points=mesh.points[sl], + cells=mesh.cells, + point_data=mesh.point_data[sl], + cell_data=mesh.cell_data, + global_data=mesh.global_data, + ) + return mesh.slice_points(torch.arange(sl.start, sl.stop, device=mesh.points.device)) + + +def _subsample_mesh_cells( + mesh: Mesh, + n_cells: int, + generator: torch.Generator | None = None, +) -> Mesh: + """Subsample a Mesh to *n_cells* via a contiguous block read on cells. + + Preserves cell topology: each selected cell retains its full vertex + connectivity. Unreferenced points are compacted out. Uses + ``_contiguous_block_slice`` for sequential I/O on memmap-backed + cell tensors. + + Use this instead of :func:`_subsample_mesh_points` when the mesh + has cell connectivity (triangulated surfaces, volume meshes) and + downstream transforms or outputs depend on cell topology (e.g. + surface normals, cell centroids, cell_data fields). + """ + if mesh.n_cells <= n_cells: + return mesh + sl = _contiguous_block_slice(mesh.n_cells, n_cells, generator=generator) + mesh = mesh.slice_cells(sl) + # Compact: drop vertices not referenced by any surviving cell + referenced = torch.unique(mesh.cells) + if referenced.numel() < mesh.n_points: + mesh = mesh.slice_points(referenced) + return mesh + + +def _subsample_mesh( + mesh: Mesh, + n_cells: int | None = None, + n_points: int | None = None, + generator: torch.Generator | None = None, +) -> Mesh: + """Apply cell and/or point subsampling to a single Mesh. + + Cells are subsampled first (preserving topology) so that the + subsequent point subsample operates on the already-reduced mesh. + """ + if n_cells is not None: + mesh = _subsample_mesh_cells(mesh, n_cells, generator=generator) + if n_points is not None: + mesh = _subsample_mesh_points(mesh, n_points, generator=generator) + return mesh + + +@register() +class MeshReader: + r""" + Read single-mesh samples from directories of physicsnemo mesh files. + + Each sample is one Mesh. Returns (Mesh, metadata) per index. + Uses Mesh.load(path) for physicsnemo mesh format (currently .pt). + """ + + def __init__( + self, + path: Path | str, + *, + pattern: str = f"**/*{DEFAULT_MESH_EXTENSION}", + pin_memory: bool = False, + include_index_in_metadata: bool = True, + subsample_n_points: int | None = None, + subsample_n_cells: int | None = None, + ) -> None: + """ + Initialize the mesh reader. + + Parameters + ---------- + path : Path or str + Root directory containing mesh files (e.g. .pt directories). + pattern : str, optional + Glob pattern for mesh paths under ``path``. Default matches ``**/*.pmsh``. + pin_memory : bool, default=False + If True, place tensors in pinned (page-locked) memory for faster + async CPU→GPU transfers. + include_index_in_metadata : bool, default=True + If True, include sample index in metadata. + subsample_n_points : int, optional + If set, subsample the mesh to this many points *before* + ``pin_memory``. Uses contiguous block reads for sequential + I/O on memmap-backed data. Appropriate for point clouds + or meshes where cell topology is not needed downstream. + For best results, pre-shuffle the on-disk point order so + that a contiguous block is spatially representative. + subsample_n_cells : int, optional + If set, subsample the mesh to this many cells *before* + ``pin_memory``. Uses contiguous block reads on the cell + tensor for sequential I/O, then compacts unreferenced + vertices. Preserves cell topology and is the correct + choice for triangulated surface meshes where downstream + transforms depend on cells (e.g. surface normals, cell + centroids, cell_data fields). Applied before + ``subsample_n_points`` when both are set. + """ + self._root = Path(path) + self._pattern = pattern + self.pin_memory = pin_memory + self.include_index_in_metadata = include_index_in_metadata + self.subsample_n_points = subsample_n_points + self.subsample_n_cells = subsample_n_cells + self._subsample_generator: torch.Generator | None = None + + if not self._root.exists(): + raise FileNotFoundError(f"Path not found: {self._root}") + if not self._root.is_dir(): + raise ValueError(f"Path must be a directory: {self._root}") + + self._paths = sorted(self._root.glob(pattern)) + if not self._paths: + raise ValueError(f"No paths matching {pattern!r} found in {self._root}") + + def _load_sample(self, index: int) -> Mesh: + """Load a single Mesh from disk.""" + mesh_path = self._paths[index] + return Mesh.load(mesh_path) + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for the sample (e.g. source path).""" + return {"source_path": str(self._paths[index])} + + def __len__(self) -> int: + return len(self._paths) + + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible subsampling. + + Called by :class:`MeshDataset` when the DataLoader provides a + seed. Replaces any previously assigned generator. + + Parameters + ---------- + generator : torch.Generator + Generator to use for contiguous block selection. + """ + self._subsample_generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the subsample RNG for a new epoch. + + Produces a different (but deterministic) sequence of contiguous + blocks each epoch when a generator has been assigned via + :meth:`set_generator`. + """ + if self._subsample_generator is not None: + self._subsample_generator.manual_seed( + self._subsample_generator.initial_seed() + epoch + ) + + def __getitem__(self, index: int) -> tuple[Mesh, dict[str, Any]]: + mesh = self._load_sample(index) + + mesh = _subsample_mesh( + mesh, + self.subsample_n_cells, + self.subsample_n_points, + generator=self._subsample_generator, + ) + + if self.pin_memory: + mesh = mesh.pin_memory() + + metadata = self._get_sample_metadata(index) + if self.include_index_in_metadata: + metadata["index"] = index + return mesh, metadata + + def __iter__(self) -> Iterator[tuple[Mesh, dict[str, Any]]]: + for i in range(len(self)): + try: + yield self[i] + except Exception as e: + logger.error("Sample %s failed: %s", i, e) + raise RuntimeError(f"Sample {i} failed: {e}") from e + + def __repr__(self) -> str: + return f"MeshReader(path={self._root!r}, len={len(self)})" + + +@register() +class DomainMeshReader: + r""" + Read DomainMesh samples from a directory of physicsnemo mesh files. + + Each sample is one DomainMesh (interior + named boundaries + global_data). + Returns (DomainMesh, metadata) per index. + Uses DomainMesh.load(path) for physicsnemo mesh format (currently .pt). + """ + + def __init__( + self, + path: Path | str, + *, + pattern: str = f"**/*{DEFAULT_MESH_EXTENSION}", + pin_memory: bool = False, + include_index_in_metadata: bool = True, + subsample_n_points: int | None = None, + subsample_n_cells: int | None = None, + extra_boundaries: dict[str, dict] | None = None, + ) -> None: + """ + Initialize the domain mesh reader. + + Parameters + ---------- + path : Path or str + Root directory containing DomainMesh files (e.g. .pt archives). + pattern : str, optional + Glob pattern for DomainMesh paths under ``path``. + Default matches ``**/*.pmsh``. + pin_memory : bool, default=False + If True, place tensors in pinned (page-locked) memory for faster + async CPU→GPU transfers. + include_index_in_metadata : bool, default=True + If True, include sample index in metadata. + subsample_n_points : int, optional + If set, subsample the interior and each boundary mesh to + at most this many points *before* ``pin_memory``. Uses + contiguous block reads for sequential I/O on memmap-backed + data. Appropriate for point clouds or meshes where cell + topology is not needed downstream. For best results, + pre-shuffle the on-disk point order so that a contiguous + block is spatially representative. + subsample_n_cells : int, optional + If set, subsample the interior and each boundary mesh to + at most this many cells *before* ``pin_memory``. Uses + contiguous block reads on cell tensors for sequential I/O, + then compacts unreferenced vertices. Preserves cell + topology and is the correct choice when downstream + transforms depend on cells. Applied before + ``subsample_n_points`` when both are set. + extra_boundaries : dict[str, dict] or None, optional + Load additional sibling meshes as extra boundaries on each + sample. Each key is the boundary name to assign; each value + is a dict with a ``"pattern"`` key giving a glob pattern + (relative to the sample's parent directory) to find the mesh + file. These meshes are loaded at full resolution and are + **not** subsampled, making them suitable for geometric + queries like SDF computation. + + Example:: + + extra_boundaries: + stl_geometry: + pattern: "*_single_solid.stl.pmsh" + """ + self._root = Path(path) + self._pattern = pattern + self.pin_memory = pin_memory + self.include_index_in_metadata = include_index_in_metadata + self.subsample_n_points = subsample_n_points + self.subsample_n_cells = subsample_n_cells + self._subsample_generator: torch.Generator | None = None + self._extra_boundaries = extra_boundaries or {} + + if not self._root.exists(): + raise FileNotFoundError(f"Path not found: {self._root}") + if not self._root.is_dir(): + raise ValueError(f"Path must be a directory: {self._root}") + + self._paths = sorted(self._root.glob(pattern)) + if not self._paths: + raise ValueError(f"No paths matching {pattern!r} found in {self._root}") + + def _load_sample(self, index: int) -> DomainMesh: + """Load a single DomainMesh from disk.""" + return DomainMesh.load(self._paths[index]) + + def __len__(self) -> int: + return len(self._paths) + + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible subsampling. + + Called by :class:`MeshDataset` when the DataLoader provides a + seed. Replaces any previously assigned generator. + + Parameters + ---------- + generator : torch.Generator + Generator to use for contiguous block selection. + """ + self._subsample_generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the subsample RNG for a new epoch. + + Produces a different (but deterministic) sequence of contiguous + blocks each epoch when a generator has been assigned via + :meth:`set_generator`. + """ + if self._subsample_generator is not None: + self._subsample_generator.manual_seed( + self._subsample_generator.initial_seed() + epoch + ) + + def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: + dm = self._load_sample(index) + + if self.subsample_n_cells is not None or self.subsample_n_points is not None: + sub_kw = dict( + n_cells=self.subsample_n_cells, + n_points=self.subsample_n_points, + generator=self._subsample_generator, + ) + dm = DomainMesh( + interior=_subsample_mesh(dm.interior, **sub_kw), + boundaries={ + name: _subsample_mesh(dm.boundaries[name], **sub_kw) + for name in dm.boundary_names + }, + global_data=dm.global_data, + ) + + # Load extra boundary meshes (full resolution, no subsampling). + if self._extra_boundaries: + dm = self._load_extra_boundaries(dm, index) + + if self.pin_memory: + dm = dm.pin_memory() + + metadata: dict[str, Any] = { + "source_path": str(self._paths[index]), + "boundary_names": dm.boundary_names, + } + if self.include_index_in_metadata: + metadata["index"] = index + return dm, metadata + + def _load_extra_boundaries(self, dm: DomainMesh, index: int) -> DomainMesh: + """Find and load sibling meshes as additional boundaries. + + Extra boundaries are loaded at full resolution (no subsampling) + so they are suitable for geometric queries like SDF computation. + """ + case_dir = Path(self._paths[index]).parent + new_boundaries = dict(dm.boundaries) + + for bnd_name, bnd_cfg in self._extra_boundaries.items(): + glob_pattern = bnd_cfg["pattern"] + matches = sorted(case_dir.glob(glob_pattern)) + if not matches: + raise FileNotFoundError( + f"No mesh matching {glob_pattern!r} found in " + f"{case_dir} for extra boundary {bnd_name!r}" + ) + if len(matches) > 1: + logger.warning( + "Multiple meshes found for extra boundary %r in %s " + "matching %r; using %s", + bnd_name, + case_dir, + glob_pattern, + matches[0], + ) + new_boundaries[bnd_name] = Mesh.load(matches[0]) + + return DomainMesh( + interior=dm.interior, + boundaries=new_boundaries, + global_data=dm.global_data, + ) + + def __iter__(self) -> Iterator[tuple[DomainMesh, dict[str, Any]]]: + for i in range(len(self)): + try: + yield self[i] + except Exception as e: + logger.error("Sample %s failed: %s", i, e) + raise RuntimeError(f"Sample {i} failed: {e}") from e + + def __repr__(self) -> str: + return f"DomainMeshReader(path={self._root!r}, len={len(self)})" diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index aba2221b43..2f546d5916 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -112,6 +112,7 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key + self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -167,6 +168,17 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible subsampling.""" + self._subsample_generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the subsample RNG for a new epoch.""" + if self._subsample_generator is not None: + self._subsample_generator.manual_seed( + self._subsample_generator.initial_seed() + epoch + ) + def _select_random_sections_from_slice( self, slice_start: int, @@ -203,7 +215,12 @@ def _select_random_sections_from_slice( f"{n_points} requested for subsampling" ) - start = np.random.randint(slice_start, slice_stop - n_points + 1) + start = torch.randint( + slice_start, + slice_stop - n_points + 1, + (1,), + generator=self._subsample_generator, + ).item() return slice(start, start + n_points) def _load_from_npz( diff --git a/physicsnemo/datapipes/readers/tensorstore_zarr.py b/physicsnemo/datapipes/readers/tensorstore_zarr.py index 22bc100b78..9bc407aea5 100644 --- a/physicsnemo/datapipes/readers/tensorstore_zarr.py +++ b/physicsnemo/datapipes/readers/tensorstore_zarr.py @@ -28,7 +28,6 @@ from pathlib import Path from typing import Any, Optional -import numpy as np import torch from physicsnemo.core.version_check import check_version_spec @@ -156,6 +155,7 @@ def __init__( self._user_fields = fields self.default_values = default_values or {} self.group_pattern = group_pattern + self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -235,6 +235,17 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible subsampling.""" + self._subsample_generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the subsample RNG for a new epoch.""" + if self._subsample_generator is not None: + self._subsample_generator.manual_seed( + self._subsample_generator.initial_seed() + epoch + ) + def _read_attributes(self, group_path: Path) -> dict[str, Any]: """Read attributes from a Zarr group (v2 or v3).""" store_spec = {"driver": "file", "path": str(group_path)} @@ -294,7 +305,12 @@ def _select_random_sections_from_slice( f"{n_points} requested for subsampling" ) - start = np.random.randint(slice_start, slice_stop - n_points + 1) + start = torch.randint( + slice_start, + slice_stop - n_points + 1, + (1,), + generator=self._subsample_generator, + ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: diff --git a/physicsnemo/datapipes/readers/zarr.py b/physicsnemo/datapipes/readers/zarr.py index 32ddac640b..a27cee5447 100644 --- a/physicsnemo/datapipes/readers/zarr.py +++ b/physicsnemo/datapipes/readers/zarr.py @@ -144,6 +144,7 @@ def __init__( self.group_pattern = group_pattern self._cache_stores = cache_stores self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores + self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -205,6 +206,17 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible subsampling.""" + self._subsample_generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the subsample RNG for a new epoch.""" + if self._subsample_generator is not None: + self._subsample_generator.manual_seed( + self._subsample_generator.initial_seed() + epoch + ) + def _open_zarr_store(self, path: Path) -> Any: """ Open a zarr store, using cache if enabled. @@ -274,7 +286,12 @@ def _select_random_sections_from_slice( f"{n_points} requested for subsampling" ) - start = np.random.randint(slice_start, slice_stop - n_points + 1) + start = torch.randint( + slice_start, + slice_stop - n_points + 1, + (1,), + generator=self._subsample_generator, + ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: diff --git a/physicsnemo/datapipes/transforms/__init__.py b/physicsnemo/datapipes/transforms/__init__.py index 963b4b0985..1b1479925f 100644 --- a/physicsnemo/datapipes/transforms/__init__.py +++ b/physicsnemo/datapipes/transforms/__init__.py @@ -39,6 +39,26 @@ Scale, Translate, ) +from physicsnemo.datapipes.transforms.mesh import ( + CenterMesh, + ComputeCellCentroids, + ComputeSurfaceNormals, + DropMeshFields, + MeshToTensorDict, + MeshTransform, + NormalizeMeshFields, + RandomRotateMesh, + RandomScaleMesh, + RandomTranslateMesh, + RenameMeshFields, + RestructureTensorDict, + RotateMesh, + ScaleMesh, + SetGlobalField, + SubsampleMesh, + TranslateMesh, + apply_to_tensordict_mesh, +) from physicsnemo.datapipes.transforms.normalize import Normalize from physicsnemo.datapipes.transforms.spatial import ( BoundingBoxFilter, @@ -87,4 +107,23 @@ "Rename", "Purge", "ConstantField", + # Mesh + "MeshTransform", + "apply_to_tensordict_mesh", + "ComputeCellCentroids", + "ComputeSurfaceNormals", + "ScaleMesh", + "TranslateMesh", + "RotateMesh", + "CenterMesh", + "SubsampleMesh", + "DropMeshFields", + "RenameMeshFields", + "NormalizeMeshFields", + "SetGlobalField", + "MeshToTensorDict", + "RestructureTensorDict", + "RandomScaleMesh", + "RandomTranslateMesh", + "RandomRotateMesh", ] diff --git a/physicsnemo/datapipes/transforms/base.py b/physicsnemo/datapipes/transforms/base.py index bd3cea7510..55d64138fa 100644 --- a/physicsnemo/datapipes/transforms/base.py +++ b/physicsnemo/datapipes/transforms/base.py @@ -78,6 +78,47 @@ def __call__(self, data: TensorDict) -> TensorDict: """ raise NotImplementedError + @property + def stochastic(self) -> bool: + """Whether this transform uses random sampling. + + Returns ``True`` if the instance has a ``_generator`` attribute + (set by stochastic subclasses such as ``SubsamplePoints``). + Deterministic transforms return ``False``. + """ + return hasattr(self, "_generator") + + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible random sampling. + + Only takes effect on stochastic transforms (those that declare + ``self._generator``). Deterministic transforms silently ignore + the call. + + Parameters + ---------- + generator : torch.Generator + Generator to use for all subsequent random draws. + """ + if self.stochastic: + self._generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the generator for a new epoch. + + Reseeds ``self._generator`` with ``initial_seed() + epoch`` so + each epoch produces a different but deterministic random + sequence. No-op for deterministic transforms or when no + generator has been assigned. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + if self.stochastic and self._generator is not None: + self._generator.manual_seed(self._generator.initial_seed() + epoch) + def to(self, device: torch.device | str) -> Transform: """ Move any internal tensors to the specified device. diff --git a/physicsnemo/datapipes/transforms/compose.py b/physicsnemo/datapipes/transforms/compose.py index fa261fcf20..183520571a 100644 --- a/physicsnemo/datapipes/transforms/compose.py +++ b/physicsnemo/datapipes/transforms/compose.py @@ -25,6 +25,7 @@ import torch from tensordict import TensorDict +from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.base import Transform @@ -104,6 +105,34 @@ def __call__(self, data: TensorDict) -> TensorDict: data = transform(data) return data + @property + def stochastic(self) -> bool: + """True if any child transform is stochastic.""" + return any(t.stochastic for t in self.transforms) + + def set_generator(self, generator: torch.Generator) -> None: + """Fork *generator* and distribute one child per transform. + + Parameters + ---------- + generator : torch.Generator + Parent generator to fork from. + """ + children = fork_generator(generator, len(self.transforms)) + for child, t in zip(children, self.transforms): + t.set_generator(child) + + def set_epoch(self, epoch: int) -> None: + """Propagate epoch to every child transform. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + for t in self.transforms: + t.set_epoch(epoch) + def to(self, device: torch.device | str) -> Compose: """ Move all transforms to the specified device. diff --git a/physicsnemo/datapipes/transforms/mesh/DISTRIBUTIONS.md b/physicsnemo/datapipes/transforms/mesh/DISTRIBUTIONS.md new file mode 100644 index 0000000000..35634c7ea3 --- /dev/null +++ b/physicsnemo/datapipes/transforms/mesh/DISTRIBUTIONS.md @@ -0,0 +1,223 @@ +# Distribution-Parametrized Mesh Augmentations + +## Why + +Mesh augmentations (`RandomScaleMesh`, `RandomTranslateMesh`, +`RandomRotateMesh`) need to sample random parameters (scale factors, +translation offsets, rotation angles) on every call. A uniform +distribution is the simplest choice but is not always the best, nor +should we lock into that design explicitly. Other alternatives exist: + +- **Gaussian** (`Normal`) concentrates samples near a center value, + making it ideal for small perturbations around an identity + transformation (e.g. scale factors near 1.0, small angles near 0). +- **Laplace** has a sharper peak than Gaussian but heavier tails, + producing most samples near the center with occasional large ones. +- **Cauchy** has very heavy tails, useful when rare extreme + augmentations are desirable. +- **LogNormal** is positive-valued by construction, making it a + natural fit for scale factors that must stay positive. +- **Exponential**, **Gumbel**, **Weibull** cover various skewed or + extreme-value scenarios. + +Rather than adding bespoke parameters for each distribution family, +which would be cumbersome and create too much code to maintain, +the augmentations accept any `torch.distributions.Distribution` +object directly. This delegates the full parametric flexibility of +PyTorch's distributions library to the user with zero custom +abstractions. + +## How It Works: Inverse cumulative distributions + +`torch.distributions.Distribution.sample()` does **not** accept a +`torch.Generator`. This is a problem because the augmentation +pipeline relies on seeded generators for reproducibility, which is +essential in ML pipelines. + +We solve this with the **inverse transform method** (a.k.a. inverse +CDF / quantile-function sampling): + +1. Draw `U ~ Uniform(0, 1)` using `torch.rand(generator=generator)`. + This step is reproducible because `torch.rand` accepts a generator, + which is seeded. +2. Compute `X = distribution.icdf(U)`. The inverse CDF transforms + the uniform variate into a sample from the target distribution. + +By the probability integral transform, `X` is exactly distributed +according to `distribution`. Reproducibility comes from step 1: +the same generator seed always produces the same `U`, and `icdf` is +a deterministic function. + +For distributions that do **not** implement `icdf`, the code falls +back to `distribution.sample()` and emits a warning that generator +reproducibility is lost. In practice this is only a small subset of +distributions. + +## Reproducibility + +Reproducibility flows from the `DataLoader`. The loader seeds a +master `torch.Generator` and passes it to +`MeshDataset.set_generator(parent_gen)`, which forks the parent +into independent children — one for the reader and one per +transform. `MeshDataset.set_epoch(epoch)` reseeds every child +with `initial_seed() + epoch` so each epoch is different but +deterministic. Deterministic transforms silently ignore both calls. + +For standalone usage outside a `DataLoader`, call `set_generator` +on the transform directly: + +```python +aug = RandomScaleMesh(distribution=D.Normal(1.0, 0.05)) +aug.set_generator(torch.Generator().manual_seed(42)) +result = aug(mesh) # reproducible +``` + + +## Python Usage + +```python +import torch +import torch.distributions as D +from physicsnemo.datapipes.transforms.mesh import ( + RandomScaleMesh, + RandomTranslateMesh, + RandomRotateMesh, +) +``` + +### RandomScaleMesh + +```python +# Default: Uniform(0.9, 1.1) +aug = RandomScaleMesh() + +# Gaussian perturbation around identity scale +aug = RandomScaleMesh(distribution=D.Normal(loc=1.0, scale=0.05)) +aug.set_generator(torch.Generator().manual_seed(42)) + +# LogNormal (always positive, centered near 1) +aug = RandomScaleMesh(distribution=D.LogNormal(loc=0.0, scale=0.1)) +``` + +### RandomTranslateMesh + +```python +# Default: Uniform(-0.1, 0.1) per axis (IID) +aug = RandomTranslateMesh() + +# Laplace offsets (sharp peak, heavy tails) +aug = RandomTranslateMesh(distribution=D.Laplace(loc=0.0, scale=0.02)) +aug.set_generator(torch.Generator().manual_seed(42)) + +# Per-axis control via batched distribution +aug = RandomTranslateMesh( + distribution=D.Uniform( + torch.tensor([-0.1, -0.2, -0.3]), + torch.tensor([ 0.1, 0.2, 0.3]), + ), +) + +# Per-axis Gaussian with different scales +aug = RandomTranslateMesh( + distribution=D.Normal( + loc=torch.zeros(3), + scale=torch.tensor([0.01, 0.02, 0.05]), + ), +) +``` + +### RandomRotateMesh + +```python +# Default: Uniform(-pi, pi) axis-aligned rotation +aug = RandomRotateMesh() + +# Concentrated small-angle perturbations +aug = RandomRotateMesh(distribution=D.Normal(loc=0.0, scale=0.1)) +aug.set_generator(torch.Generator().manual_seed(42)) + +# Only rotate about z-axis, Laplace angle distribution +aug = RandomRotateMesh( + axes=["z"], + distribution=D.Laplace(loc=0.0, scale=0.5), +) + +# Uniform SO(3) (ignores distribution, uses quaternion method) +aug = RandomRotateMesh(mode="uniform") +``` + +## YAML / Hydra Usage + +Distributions can be constructed inline using Hydra's `_target_` +syntax: + +```yaml +# Gaussian scale perturbation +- _target_: ${dp:RandomScaleMesh} + distribution: + _target_: torch.distributions.Normal + loc: 1.0 + scale: 0.05 + +# Laplace translation +- _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Laplace + loc: 0.0 + scale: 0.02 + +# Small-angle Gaussian rotation about z only +- _target_: ${dp:RandomRotateMesh} + axes: ["z"] + distribution: + _target_: torch.distributions.Normal + loc: 0.0 + scale: 0.1 +``` + +For per-axis batched distributions in YAML, pass list parameters: + +```yaml +- _target_: ${dp:RandomTranslateMesh} + distribution: + _target_: torch.distributions.Uniform + low: [-0.1, -0.2, -0.3] + high: [0.1, 0.2, 0.3] +``` + +## Supported Distributions + +The ICDF method works with any `torch.distributions.Distribution` +that implements `icdf()`. The table below summarises support for the +most commonly used distributions. + +| Distribution | `icdf` | Generator-reproducible | Typical use case | +|-----------------|--------|------------------------|------------------------------------------| +| `Uniform` | Yes | Yes | Bounded ranges (default behaviour) | +| `Normal` | Yes | Yes | Small perturbations around a center | +| `Laplace` | Yes | Yes | Sharper peak, heavier tails than Gaussian | +| `Cauchy` | Yes | Yes | Very heavy tails, rare extremes | +| `LogNormal` | Yes | Yes | Positive-only (e.g. scale factors) | +| `Exponential` | Yes | Yes | One-sided positive values | +| `Gumbel` | Yes | Yes | Extreme-value modelling | +| `Weibull` | Yes | Yes | Flexible shape for positive values | +| `Poisson` | No | **No** (fallback) | Discrete; generator ignored with warning | +| `Gamma` | No | **No** (fallback) | Positive continuous; no closed-form ICDF | +| `Dirichlet` | No | **No** (fallback) | Simplex-valued; no scalar ICDF | + +Distributions without `icdf` will still work via +`distribution.sample()`, but the `torch.Generator` is **not** used +for those draws. A `UserWarning` is emitted in this case, and +datapipe reproducibility is not possible at the generator level. + +## Choosing a Distribution + +| Goal | Recommended distribution | +|----------------------------------------------|---------------------------------------| +| Bounded range `[a, b]` | `Uniform(a, b)` | +| Small perturbations around a center `c` | `Normal(loc=c, scale=sigma)` | +| Sharper peak + occasional large values | `Laplace(loc=c, scale=b)` | +| Very heavy tails (rare extreme augmentation) | `Cauchy(loc=c, scale=gamma)` | +| Strictly positive parameter (e.g. scale) | `LogNormal(loc=mu, scale=sigma)` | +| One-sided perturbation from zero | `Exponential(rate=lambda)` | +| Per-axis different parameters | Batched distribution, e.g. `Normal(loc=tensor([...]), scale=tensor([...]))` | diff --git a/physicsnemo/datapipes/transforms/mesh/__init__.py b/physicsnemo/datapipes/transforms/mesh/__init__.py new file mode 100644 index 0000000000..4574db1c40 --- /dev/null +++ b/physicsnemo/datapipes/transforms/mesh/__init__.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mesh transforms and augmentations. + +Transforms operate on Mesh (single-mesh) or TensorDict[str, Mesh] (multi-mesh). +Type-based only; no key-based filtering. +""" + +from physicsnemo.datapipes.transforms.mesh.augmentations import ( + RandomRotateMesh, + RandomScaleMesh, + RandomTranslateMesh, +) +from physicsnemo.datapipes.transforms.mesh.base import ( + MeshTransform, + apply_to_tensordict_mesh, +) +from physicsnemo.datapipes.transforms.mesh.transforms import ( + CenterMesh, + ComputeCellCentroids, + ComputeSurfaceNormals, + DropMeshFields, + MeshToTensorDict, + NormalizeMeshFields, + RenameMeshFields, + RestructureTensorDict, + RotateMesh, + ScaleMesh, + SetGlobalField, + SubsampleMesh, + TranslateMesh, +) + +__all__ = [ + "MeshTransform", + "apply_to_tensordict_mesh", + "ComputeCellCentroids", + "ComputeSurfaceNormals", + "ScaleMesh", + "TranslateMesh", + "RotateMesh", + "CenterMesh", + "SubsampleMesh", + "DropMeshFields", + "RenameMeshFields", + "SetGlobalField", + "NormalizeMeshFields", + "MeshToTensorDict", + "RestructureTensorDict", + "RandomScaleMesh", + "RandomTranslateMesh", + "RandomRotateMesh", +] diff --git a/physicsnemo/datapipes/transforms/mesh/augmentations.py b/physicsnemo/datapipes/transforms/mesh/augmentations.py new file mode 100644 index 0000000000..44f64325f3 --- /dev/null +++ b/physicsnemo/datapipes/transforms/mesh/augmentations.py @@ -0,0 +1,506 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Random mesh augmentations (on-the-fly randomizations). Mesh -> Mesh. + +Augmentation parameters are sampled from ``torch.distributions.Distribution`` +objects, enabling arbitrary continuous distributions (Gaussian, Laplace, +Cauchy, etc.) while preserving ``torch.Generator``-based reproducibility +via the inverse CDF (ICDF) method. See ``DISTRIBUTIONS.md`` in this +directory for full design documentation. +""" + +from __future__ import annotations + +import math +import warnings +from typing import Literal + +import torch +from jaxtyping import Float + +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.mesh.base import MeshTransform +from physicsnemo.mesh import DomainMesh, Mesh + + +def _sample_distribution( + distribution: torch.distributions.Distribution, + shape: tuple[int, ...], + generator: torch.Generator | None, + fallback_device: torch.device | None = None, +) -> torch.Tensor: + """Sample from a distribution using ICDF + generator for reproducibility. + + Draws ``U ~ Uniform(0, 1)`` with the provided generator, then + transforms through ``distribution.icdf(U)``. The generator and + distribution parameters must already reside on the same device + (ensured by :meth:`MeshTransform.to`). + + For distributions that do not implement ``icdf`` (e.g. Poisson), + falls back to ``distribution.sample()`` without generator + reproducibility. + + Parameters + ---------- + distribution : torch.distributions.Distribution + The target distribution to sample from. + shape : tuple[int, ...] + Shape of the sample to draw. + generator : torch.Generator or None + Random generator for reproducibility. When provided, uniform + samples are generated on ``generator.device``. + fallback_device : torch.device or None + Device used for ``torch.rand`` when *generator* is ``None``. + Typically ``self._device`` set by :meth:`MeshTransform.to`. + + Returns + ------- + torch.Tensor + Sampled tensor with the requested *shape*. + """ + if generator is not None: + u = torch.rand(shape, generator=generator, device=generator.device) + else: + u = torch.rand(shape, device=fallback_device) + try: + return distribution.icdf(u) + except NotImplementedError: + warnings.warn( + f"{type(distribution).__name__} does not implement icdf; " + "falling back to .sample() without generator reproducibility.", + stacklevel=3, + ) + return distribution.sample(shape).to(device=u.device) + + +@register() +class RandomScaleMesh(MeshTransform): + r"""Random scale of mesh. Scale factor is sampled per ``__call__``. + + The scale factor is drawn from *distribution* (default + ``Uniform(0.9, 1.1)``). Any ``torch.distributions.Distribution`` + with an ``icdf`` method can be used; see ``DISTRIBUTIONS.md``. + """ + + def __init__( + self, + distribution: torch.distributions.Distribution | None = None, + transform_point_data: bool = False, + transform_cell_data: bool = False, + transform_global_data: bool = False, + ) -> None: + """ + Parameters + ---------- + distribution : torch.distributions.Distribution or None + Distribution from which the scale factor is sampled. + Defaults to ``Uniform(0.9, 1.1)``. + transform_point_data : bool + If ``True``, transform point-data fields under scaling. + transform_cell_data : bool + If ``True``, transform cell-data fields under scaling. + transform_global_data : bool + If ``True``, transform global-data fields under scaling. + """ + super().__init__() + self._distribution = distribution or torch.distributions.Uniform(0.9, 1.1) + self.transform_point_data = transform_point_data + self.transform_cell_data = transform_cell_data + self.transform_global_data = transform_global_data + self._generator: torch.Generator | None = None + + def _sample_factor(self) -> Float[torch.Tensor, ""]: + """Sample a scale factor from ``self._distribution``. + + Returns + ------- + torch.Tensor + Scalar (0-dim) tensor with the sampled factor. + """ + return _sample_distribution( + self._distribution, (1,), self._generator, self._device + ).squeeze(0) + + def __call__(self, mesh: Mesh) -> Mesh: + """Apply a random scale to *mesh*. + + Parameters + ---------- + mesh : Mesh + Input mesh. + + Returns + ------- + Mesh + Scaled mesh. + """ + factor = self._sample_factor() + return mesh.scale( + factor, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply a random scale to every mesh in *domain*. + + A single scale factor is sampled and applied consistently to the + interior and all boundary meshes. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh. + + Returns + ------- + DomainMesh + Scaled domain mesh. + """ + factor = self._sample_factor() + return domain.scale( + factor, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def extra_repr(self) -> str: + return f"distribution={self._distribution}" + + +@register() +class RandomTranslateMesh(MeshTransform): + r"""Random translation of mesh. Offset is sampled per ``__call__``. + + Each spatial axis is sampled independently from *distribution* + (default ``Uniform(-0.1, 0.1)``). Pass a batched distribution to + control each axis separately, e.g. + ``Uniform(tensor([-0.1, -0.2, -0.3]), tensor([0.1, 0.2, 0.3]))``. + """ + + def __init__( + self, + distribution: torch.distributions.Distribution | None = None, + ) -> None: + """ + Parameters + ---------- + distribution : torch.distributions.Distribution or None + Distribution from which the per-axis offsets are sampled. + A scalar distribution produces IID samples per axis; a + batched distribution (``batch_shape == (n_spatial_dims,)``) + allows different parameters per axis. + Defaults to ``Uniform(-0.1, 0.1)``. + """ + super().__init__() + self._distribution = distribution or torch.distributions.Uniform(-0.1, 0.1) + self._generator: torch.Generator | None = None + + def _sample_offset( + self, n_spatial_dims: int + ) -> Float[torch.Tensor, " spatial_dims"]: + """Sample a translation offset from ``self._distribution``. + + Parameters + ---------- + n_spatial_dims : int + Number of spatial dimensions (typically 2 or 3). + + Returns + ------- + torch.Tensor + Offset vector, shape ``(n_spatial_dims,)``. + """ + return _sample_distribution( + self._distribution, (n_spatial_dims,), self._generator, self._device + ) + + def __call__(self, mesh: Mesh) -> Mesh: + """Apply a random translation to *mesh*. + + Parameters + ---------- + mesh : Mesh + Input mesh. + + Returns + ------- + Mesh + Translated mesh. + """ + offset = self._sample_offset(mesh.n_spatial_dims) + return mesh.translate(offset) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply a random translation to every mesh in *domain*. + + A single offset is sampled and applied consistently to the + interior and all boundary meshes. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh. + + Returns + ------- + DomainMesh + Translated domain mesh. + """ + offset = self._sample_offset(domain.interior.n_spatial_dims) + return domain.translate(offset) + + def extra_repr(self) -> str: + return f"distribution={self._distribution}" + + +@register() +class RandomRotateMesh(MeshTransform): + r"""Random rotation of mesh. Axis and angle are sampled per ``__call__``. + + Two modes are supported: + + * ``"axis_aligned"`` (default) -- picks one of the candidate *axes* + uniformly at random and samples an angle from *distribution*. + This limits rotations to the three cardinal planes. + * ``"uniform"`` -- samples a rotation uniformly from SO(3) via random + unit quaternions (3-D meshes only). *axes* and *distribution* are + ignored in this mode. + """ + + def __init__( + self, + axes: list[Literal["x", "y", "z"]] | None = None, + distribution: torch.distributions.Distribution | None = None, + mode: Literal["axis_aligned", "uniform"] = "axis_aligned", + transform_point_data: bool = False, + transform_cell_data: bool = False, + transform_global_data: bool = False, + ) -> None: + """ + Parameters + ---------- + axes : list[{"x", "y", "z"}] or None + Candidate rotation axes. One is chosen uniformly at random + per call. Defaults to ``["x", "y", "z"]``. + Only used when ``mode="axis_aligned"``. + distribution : torch.distributions.Distribution or None + Distribution from which the rotation angle (radians) is + sampled. Defaults to ``Uniform(-pi, pi)``. + Only used when ``mode="axis_aligned"``. + mode : {"axis_aligned", "uniform"} + ``"axis_aligned"`` picks a random cardinal axis and angle + each call. ``"uniform"`` samples a rotation uniformly from + SO(3) via random quaternions (3-D only). + transform_point_data : bool + If ``True``, transform point-data fields under rotation. + transform_cell_data : bool + If ``True``, transform cell-data fields under rotation. + transform_global_data : bool + If ``True``, transform global-data fields under rotation. + """ + super().__init__() + if mode not in ("axis_aligned", "uniform"): + raise ValueError(f"mode must be 'axis_aligned' or 'uniform', got {mode!r}") + self.axes = axes if axes is not None else ["x", "y", "z"] + self._distribution = distribution or torch.distributions.Uniform( + -math.pi, math.pi + ) + self.mode = mode + self.transform_point_data = transform_point_data + self.transform_cell_data = transform_cell_data + self.transform_global_data = transform_global_data + self._generator: torch.Generator | None = None + + # Coefficient matrix mapping outer(q,q).flatten() (16,) -> R.flatten() (9,). + # Derived from the standard unit-quaternion rotation formula using + # w²+x²+y²+z² = 1 to rewrite 1-2(…) terms as sums of squared components. + # ww wx wy wz xw xx xy xz yw yx yy yz zw zx zy zz + self._q2r_map = torch.tensor( + [ + [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1], + [0, 0, 0, -2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1], + [0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0], + [0, 0, -2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0], + [1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1], + ], + dtype=torch.float32, + ) + + # ------------------------------------------------------------------ + # axis-aligned helpers + # ------------------------------------------------------------------ + + def _sample_axis_and_angle(self) -> tuple[str, Float[torch.Tensor, ""]]: + """Sample a random axis and rotation angle. + + The axis index is drawn via ``torch.randint`` with the generator. + The angle is sampled from ``self._distribution`` via ICDF. + + Returns + ------- + axis : str + One of ``"x"``, ``"y"``, ``"z"``. + angle : torch.Tensor + Scalar (0-dim) tensor with the sampled angle in radians. + """ + gen_device = ( + self._generator.device if self._generator is not None else self._device + ) + axis_idx = torch.randint( + len(self.axes), (1,), generator=self._generator, device=gen_device + ) + axis = self.axes[axis_idx] + angle = _sample_distribution( + self._distribution, (1,), self._generator, self._device + ).squeeze(0) + return axis, angle + + # ------------------------------------------------------------------ + # uniform SO(3) helpers + # ------------------------------------------------------------------ + + def _quaternion_to_rotation_matrix( + self, + q: Float[torch.Tensor, "4"], + ) -> Float[torch.Tensor, "3 3"]: + """Convert a unit quaternion to a 3x3 rotation matrix. + + Parameters + ---------- + q : torch.Tensor + Unit quaternion ``(w, x, y, z)``, shape ``(4,)``. + + Returns + ------- + torch.Tensor + Rotation matrix, shape ``(3, 3)``. + """ + # 2 dispatches: outer product + matrix-vector multiply. + return (self._q2r_map.to(q) @ torch.outer(q, q).reshape(16)).reshape(3, 3) + + def _sample_uniform_rotation(self) -> Float[torch.Tensor, "3 3"]: + """Sample a rotation matrix uniformly from SO(3). + + Uses the random unit quaternion method: sample a 4-D isotropic + Gaussian vector, normalize to the unit sphere, and convert to a + rotation matrix. + + Returns + ------- + torch.Tensor + Rotation matrix, shape ``(3, 3)``. + """ + gen_device = ( + self._generator.device if self._generator is not None else self._device + ) + q = torch.randn(4, generator=self._generator, device=gen_device) + q = q / q.norm() + return self._quaternion_to_rotation_matrix(q) + + # ------------------------------------------------------------------ + # __call__ / apply_to_domain + # ------------------------------------------------------------------ + + def __call__(self, mesh: Mesh[..., 3]) -> Mesh[..., 3]: + """Apply a random rotation to *mesh*. + + Parameters + ---------- + mesh : Mesh + Input mesh. + + Returns + ------- + Mesh + Rotated mesh. + """ + if self.mode == "uniform": + if mesh.n_spatial_dims != 3: + raise ValueError( + f"mode='uniform' requires 3-D meshes, " + f"got n_spatial_dims={mesh.n_spatial_dims}" + ) + R = self._sample_uniform_rotation() + return mesh.transform( + R, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + assume_invertible=True, + ) + + axis, angle = self._sample_axis_and_angle() + return mesh.rotate( + angle, + axis=axis, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply a random rotation to every mesh in *domain*. + + A single rotation is sampled and applied consistently to the + interior and all boundary meshes. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh. + + Returns + ------- + DomainMesh + Rotated domain mesh. + """ + if self.mode == "uniform": + if domain.interior.n_spatial_dims != 3: + raise ValueError( + f"mode='uniform' requires 3-D meshes, " + f"got n_spatial_dims={domain.interior.n_spatial_dims}" + ) + R = self._sample_uniform_rotation() + return domain.transform( + R, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + assume_invertible=True, + ) + + axis, angle = self._sample_axis_and_angle() + return domain.rotate( + angle, + axis=axis, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def extra_repr(self) -> str: + if self.mode == "uniform": + return "mode='uniform'" + return f"axes={self.axes}, distribution={self._distribution}" diff --git a/physicsnemo/datapipes/transforms/mesh/base.py b/physicsnemo/datapipes/transforms/mesh/base.py new file mode 100644 index 0000000000..b0533fca93 --- /dev/null +++ b/physicsnemo/datapipes/transforms/mesh/base.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base for mesh transforms: Mesh -> Mesh and TensorDict[str, Mesh] -> TensorDict[str, Mesh]. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from tensordict import TensorDict + +from physicsnemo.mesh import DomainMesh, Mesh + + +def apply_to_tensordict_mesh( + data: TensorDict, + transform: MeshTransform, +) -> TensorDict: + """Apply a Mesh -> Mesh transform to each value in a TensorDict of Mesh. + + Parameters + ---------- + data : TensorDict + TensorDict whose values are Mesh instances. + transform : MeshTransform + Transform instance; called on each mesh. + + Returns + ------- + TensorDict + New TensorDict with transformed meshes (same keys). + """ + out = {k: transform(v) for k, v in data.items()} + return TensorDict(out, batch_size=[]) + + +class MeshTransform(ABC): + r""" + Base for transforms that take a Mesh and return a Mesh. + + Use for single-mesh pipelines. For multi-mesh (TensorDict[str, Mesh]), + apply the same transform to each value or use apply_to_tensordict_mesh. + """ + + def __init__(self) -> None: + self._device: Optional[torch.device] = None + + @abstractmethod + def __call__(self, mesh: Mesh) -> Mesh: + """ + Apply the transform to a mesh. + + Parameters + ---------- + mesh : Mesh + Input mesh. + + Returns + ------- + Mesh + Transformed mesh. + """ + raise NotImplementedError + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply this transform to a DomainMesh. + + Default: broadcasts ``__call__`` to interior and all boundaries + via :meth:`DomainMesh._map_meshes`, leaving domain-level + ``global_data`` unchanged. + + Override in subclasses that need domain-aware behavior (e.g. + transforms that modify ``global_data``, random augmentations + that must sample parameters once, or centering transforms). + + Parameters + ---------- + domain : DomainMesh + Input domain mesh. + + Returns + ------- + DomainMesh + Transformed domain mesh. + """ + return domain._map_meshes(self) + + @property + def stochastic(self) -> bool: + """Whether this transform uses random sampling. + + Returns ``True`` if the instance has a ``_generator`` attribute + (set by stochastic subclasses such as ``RandomScaleMesh``). + Deterministic transforms return ``False``. + """ + return hasattr(self, "_generator") + + def set_generator(self, generator: torch.Generator) -> None: + """Assign a ``torch.Generator`` for reproducible random sampling. + + Only takes effect on stochastic transforms (those that declare + ``self._generator``). Deterministic transforms silently ignore + the call. + + Parameters + ---------- + generator : torch.Generator + Generator to use for all subsequent random draws. + """ + if self.stochastic: + self._generator = generator + + def set_epoch(self, epoch: int) -> None: + """Reseed the generator for a new epoch. + + Reseeds ``self._generator`` with ``initial_seed() + epoch`` so + each epoch produces a different but deterministic random + sequence. No-op for deterministic transforms or when no + generator has been assigned. + + Parameters + ---------- + epoch : int + Current epoch number. + """ + if self.stochastic and self._generator is not None: + self._generator.manual_seed(self._generator.initial_seed() + epoch) + + def to(self, device: torch.device | str) -> MeshTransform: + """Move any internal tensors, generators, and distributions to *device*. + + ``torch.Generator`` objects cannot be moved in-place, so a new + generator is created on *device* and seeded with + :meth:`~torch.Generator.initial_seed` from the original. + + ``torch.distributions.Distribution`` objects are reconstructed + with their parameter tensors moved to *device*, using + ``arg_constraints`` to discover parameter names generically. + + Parameters + ---------- + device : torch.device or str + Target device. + + Returns + ------- + MeshTransform + ``self``, for chaining. + """ + self._device = torch.device(device) if isinstance(device, str) else device + for name, value in self.__dict__.items(): + if isinstance(value, torch.Tensor): + setattr(self, name, value.to(self._device)) + elif isinstance(value, torch.Generator): + new_gen = torch.Generator(device=self._device) + new_gen.manual_seed(value.initial_seed()) + setattr(self, name, new_gen) + elif isinstance(value, torch.distributions.Distribution): + dist_cls = type(value) + kwargs = {} + # Access arg_constraints on the instance (not the class) + # because the base Distribution defines it as a @property. + for param_name in value.arg_constraints: + p = getattr(value, param_name) + kwargs[param_name] = ( + p.to(self._device) if isinstance(p, torch.Tensor) else p + ) + setattr(self, name, dist_cls(**kwargs, validate_args=False)) + return self + + @property + def device(self) -> torch.device | None: + """The device that internal tensors and generators reside on. + + Returns ``None`` if :meth:`to` has not been called yet. + + Returns + ------- + torch.device or None + Current device, or ``None`` if unset. + """ + return self._device + + def extra_repr(self) -> str: + """Return a string of extra information for :meth:`__repr__`. + + Subclasses should override this to include constructor arguments + or other state that is useful for debugging (e.g. + ``"scale=0.1, p=0.5"``). The base implementation returns an + empty string. + + Returns + ------- + str + Extra representation string. + """ + return "" + + def __repr__(self) -> str: + """Return a human-readable string representation of the transform. + + The format is ``ClassName(extra_repr())``, mirroring the + convention used by :class:`torch.nn.Module`. + + Returns + ------- + str + String representation. + """ + return f"{self.__class__.__name__}({self.extra_repr()})" diff --git a/physicsnemo/datapipes/transforms/mesh/transforms.py b/physicsnemo/datapipes/transforms/mesh/transforms.py new file mode 100644 index 0000000000..c1ace7f6a9 --- /dev/null +++ b/physicsnemo/datapipes/transforms/mesh/transforms.py @@ -0,0 +1,833 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Deterministic mesh transforms (Mesh -> Mesh) and terminal conversions. +""" + +from __future__ import annotations + +from typing import Literal + +import torch +from jaxtyping import Float, Int +from tensordict import TensorDict + +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.mesh.base import MeshTransform +from physicsnemo.datapipes.transforms.subsample import poisson_sample_indices_fixed +from physicsnemo.mesh import DomainMesh, Mesh + + +@register() +class ScaleMesh(MeshTransform): + r"""Scale mesh geometry (and optionally point/cell/global data) by a uniform factor.""" + + def __init__( + self, + factor: float | Float[torch.Tensor, ""], + transform_point_data: bool = False, + transform_cell_data: bool = False, + transform_global_data: bool = False, + ) -> None: + super().__init__() + self.factor = factor + self.transform_point_data = transform_point_data + self.transform_cell_data = transform_cell_data + self.transform_global_data = transform_global_data + + def __call__(self, mesh: Mesh) -> Mesh: + return mesh.scale( + self.factor, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply uniform scaling to a :class:`DomainMesh`. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh (interior + boundaries). + + Returns + ------- + DomainMesh + Scaled domain mesh. + """ + return domain.scale( + self.factor, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def extra_repr(self) -> str: + return f"factor={self.factor}" + + +@register() +class TranslateMesh(MeshTransform): + r"""Translate mesh geometry by a vector.""" + + def __init__( + self, vector: Float[torch.Tensor, " spatial_dims"] | list[float] + ) -> None: + super().__init__() + if not isinstance(vector, torch.Tensor): + vector = torch.tensor(vector, dtype=torch.float32) + self.vector = vector + + def __call__(self, mesh: Mesh) -> Mesh: + return mesh.translate(self.vector.to(mesh.points.device)) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply translation to a :class:`DomainMesh`. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh (interior + boundaries). + + Returns + ------- + DomainMesh + Translated domain mesh. + """ + return domain.translate(self.vector.to(domain.interior.points.device)) + + def extra_repr(self) -> str: + return f"vector={self.vector.tolist()}" + + +@register() +class RotateMesh(MeshTransform): + r"""Rotate mesh geometry (and optionally point/cell/global data) about an axis.""" + + def __init__( + self, + angle: float, + axis: Float[torch.Tensor, " spatial_dims"] + | list + | tuple + | Literal["x", "y", "z"] + | None = None, + center: Float[torch.Tensor, " spatial_dims"] | list | tuple | None = None, + transform_point_data: bool = False, + transform_cell_data: bool = False, + transform_global_data: bool = False, + ) -> None: + super().__init__() + self.angle = angle + self.axis = axis + self.center = center + self.transform_point_data = transform_point_data + self.transform_cell_data = transform_cell_data + self.transform_global_data = transform_global_data + + def __call__(self, mesh: Mesh) -> Mesh: + return mesh.rotate( + self.angle, + axis=self.axis, + center=self.center, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Apply rotation to a :class:`DomainMesh`. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh (interior + boundaries). + + Returns + ------- + DomainMesh + Rotated domain mesh. + """ + return domain.rotate( + self.angle, + axis=self.axis, + center=self.center, + transform_point_data=self.transform_point_data, + transform_cell_data=self.transform_cell_data, + transform_global_data=self.transform_global_data, + ) + + def extra_repr(self) -> str: + parts = [f"angle={self.angle}"] + if self.axis is not None: + parts.append(f"axis={self.axis}") + if self.center is not None: + parts.append(f"center={self.center}") + return ", ".join(parts) + + +@register() +class CenterMesh(MeshTransform): + r"""Translate mesh so its center of mass is at the origin.""" + + def __init__(self, use_area_weighting: bool = True) -> None: + super().__init__() + self.use_area_weighting = use_area_weighting + + def _compute_com(self, mesh: Mesh) -> Float[torch.Tensor, " spatial_dims"]: + """Compute center of mass for a single mesh.""" + if self.use_area_weighting and mesh.n_cells > 0: + areas = mesh.cell_areas # (n_cells,) + centroids = mesh.cell_centroids # (n_cells, n_spatial_dims) + total_area = areas.sum() + return (centroids * areas.unsqueeze(-1)).sum(dim=0) / total_area + return mesh.points.mean(dim=0) + + def __call__(self, mesh: Mesh) -> Mesh: + return mesh.translate(-self._compute_com(mesh)) + + def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: + """Translate a :class:`DomainMesh` so its interior center of mass is at the origin. + + The center of mass is computed from the interior mesh and the same + translation is applied to all boundaries to keep them consistent. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh (interior + boundaries). + + Returns + ------- + DomainMesh + Centered domain mesh. + """ + com = self._compute_com(domain.interior) + return domain.translate(-com) + + def extra_repr(self) -> str: + return f"use_area_weighting={self.use_area_weighting}" + + +def _compact_points(mesh: Mesh) -> Mesh: + """Remove unreferenced points and remap cell indices.""" + if mesh.n_cells == 0: + return mesh + referenced = torch.unique(mesh.cells) + if referenced.numel() == mesh.n_points: + return mesh + new_points = mesh.points[referenced] + remap = torch.empty(mesh.n_points, dtype=torch.long, device=mesh.cells.device) + remap[referenced] = torch.arange(referenced.numel(), device=mesh.cells.device) + new_cells = remap[mesh.cells] + new_point_data = ( + mesh.point_data[referenced] if mesh.point_data.keys() else mesh.point_data + ) + return Mesh( + points=new_points, + cells=new_cells, + point_data=new_point_data, + cell_data=mesh.cell_data, + global_data=mesh.global_data, + ) + + +@register() +class SubsampleMesh(MeshTransform): + r"""Subsample a mesh to a fixed number of cells and/or points.""" + + def __init__( + self, + n_cells: int | None = None, + n_points: int | None = None, + compact: bool = True, + ) -> None: + super().__init__() + if n_cells is None and n_points is None: + raise ValueError("At least one of n_cells or n_points must be specified.") + self.n_cells = n_cells + self.n_points = n_points + self.compact = compact + self._generator: torch.Generator | None = None + + def _random_indices( + self, total: int, k: int, device: torch.device + ) -> Int[torch.Tensor, " k"]: + if total <= k: + return torch.arange(total, device=device) + if total > 2**24: + return poisson_sample_indices_fixed( + total, + k, + device=device, + generator=self._generator, + ) + return torch.randperm(total, device=device, generator=self._generator)[:k] + + def __call__(self, mesh: Mesh) -> Mesh: + if self.n_cells is not None and mesh.n_cells > self.n_cells: + indices = self._random_indices( + mesh.n_cells, self.n_cells, mesh.cells.device + ) + mesh = mesh.slice_cells(indices) + if self.compact: + mesh = _compact_points(mesh) + + if self.n_points is not None and mesh.n_points > self.n_points: + indices = self._random_indices( + mesh.n_points, self.n_points, mesh.points.device + ) + mesh = mesh.slice_points(indices) + + return mesh + + def extra_repr(self) -> str: + parts = [] + if self.n_cells is not None: + parts.append(f"n_cells={self.n_cells}") + if self.n_points is not None: + parts.append(f"n_points={self.n_points}") + return ", ".join(parts) + + +def _rename_td_keys(td: TensorDict, mapping: dict[str, str]) -> TensorDict: + """Rename keys in a TensorDict, returning a new TensorDict.""" + out = td.clone() + for old_key, new_key in mapping.items(): + if old_key in out.keys(): + out[new_key] = out.pop(old_key) + return out + + +@register() +class DropMeshFields(MeshTransform): + r"""Remove fields from a Mesh's point_data, cell_data, or global_data. + + Useful for dropping fields that would interfere with downstream + transforms (e.g. removing a scalar ``TimeValue`` from ``global_data`` + before a rotation that expects all global fields to be 3-vectors). + """ + + def __init__( + self, + point_data: list[str] | None = None, + cell_data: list[str] | None = None, + global_data: list[str] | None = None, + ) -> None: + super().__init__() + self._point_data_keys = point_data or [] + self._cell_data_keys = cell_data or [] + self._global_data_keys = global_data or [] + + def __call__(self, mesh: Mesh) -> Mesh: + new_pd = mesh.point_data + if self._point_data_keys: + new_pd = new_pd.clone() + for k in self._point_data_keys: + if k in new_pd.keys(): + del new_pd[k] + + new_cd = mesh.cell_data + if self._cell_data_keys: + new_cd = new_cd.clone() + for k in self._cell_data_keys: + if k in new_cd.keys(): + del new_cd[k] + + new_gd = mesh.global_data + if self._global_data_keys: + new_gd = new_gd.clone() + for k in self._global_data_keys: + if k in new_gd.keys(): + del new_gd[k] + + return Mesh( + points=mesh.points, + cells=mesh.cells, + point_data=new_pd, + cell_data=new_cd, + global_data=new_gd, + ) + + def extra_repr(self) -> str: + parts = [] + if self._point_data_keys: + parts.append(f"point_data={self._point_data_keys}") + if self._cell_data_keys: + parts.append(f"cell_data={self._cell_data_keys}") + if self._global_data_keys: + parts.append(f"global_data={self._global_data_keys}") + return ", ".join(parts) + + +@register() +class RenameMeshFields(MeshTransform): + r"""Rename fields in a Mesh's point_data, cell_data, or global_data. + + Useful for harmonizing field names across datasets that store + the same physical quantity under different keys (e.g. + ``pMeanTrim`` vs ``pressure_average``). + """ + + def __init__( + self, + point_data: dict[str, str] | None = None, + cell_data: dict[str, str] | None = None, + global_data: dict[str, str] | None = None, + ) -> None: + super().__init__() + self._point_data_map = point_data or {} + self._cell_data_map = cell_data or {} + self._global_data_map = global_data or {} + + def __call__(self, mesh: Mesh) -> Mesh: + new_pd = ( + _rename_td_keys(mesh.point_data, self._point_data_map) + if self._point_data_map + else mesh.point_data + ) + new_cd = ( + _rename_td_keys(mesh.cell_data, self._cell_data_map) + if self._cell_data_map + else mesh.cell_data + ) + new_gd = ( + _rename_td_keys(mesh.global_data, self._global_data_map) + if self._global_data_map + else mesh.global_data + ) + return Mesh( + points=mesh.points, + cells=mesh.cells, + point_data=new_pd, + cell_data=new_cd, + global_data=new_gd, + ) + + def extra_repr(self) -> str: + parts = [] + if self._point_data_map: + parts.append(f"point_data={self._point_data_map}") + if self._cell_data_map: + parts.append(f"cell_data={self._cell_data_map}") + if self._global_data_map: + parts.append(f"global_data={self._global_data_map}") + return ", ".join(parts) + + +@register() +class SetGlobalField(MeshTransform): + r"""Inject constant tensor fields into a Mesh's global_data. + + Fields are set on every call, overwriting any existing field with + the same key. Tensors are moved to the mesh's device automatically. + + Typical use: inject a per-dataset inlet velocity vector so that + downstream rotation transforms (with ``transform_global_data=True``) + rotate it consistently with the mesh geometry. + """ + + def __init__( + self, + fields: dict[str, Float[torch.Tensor, " *shape"] | list[float]], + ) -> None: + super().__init__() + self._fields: dict[str, Float[torch.Tensor, " *shape"]] = {} + for k, v in fields.items(): + if not isinstance(v, torch.Tensor): + v = torch.tensor(v, dtype=torch.float32) + self._fields[k] = v + + def __call__(self, mesh: Mesh) -> Mesh: + new_gd = mesh.global_data.clone() + for k, v in self._fields.items(): + new_gd[k] = v.to(device=mesh.points.device, dtype=mesh.points.dtype) + return Mesh( + points=mesh.points, + cells=mesh.cells, + point_data=mesh.point_data, + cell_data=mesh.cell_data, + global_data=new_gd, + ) + + def extra_repr(self) -> str: + shapes = {k: tuple(v.shape) for k, v in self._fields.items()} + return f"fields={shapes}" + + +def _get_mesh_section(mesh: Mesh, section: str) -> TensorDict: + """Look up a Mesh data section by name.""" + if section == "point_data": + return mesh.point_data + if section == "cell_data": + return mesh.cell_data + if section == "global_data": + return mesh.global_data + raise ValueError(f"Unknown mesh section: {section!r}") + + +@register() +class NormalizeMeshFields(MeshTransform): + r"""Standardize mesh data fields with direction-preserving vector support. + + For **scalar** fields: ``(x - mean) / std``. + + For **vector** fields: ``(x - mean_vec) / std_shared`` where + ``mean_vec`` is a per-component mean and ``std_shared`` is a single + scalar applied uniformly to all components. This preserves relative + component magnitudes (and therefore vector direction) while bringing + the overall field scale to O(1). + + Statistics may come from two sources (checked in order): + + 1. **stats_file** — path to a ``.pt`` file mapping field names to + dicts with keys ``type``, ``mean``, ``std``. + 2. **fields** — inline dict supplied directly in YAML. + + Example YAML (inline):: + + - _target_: ${dp:NormalizeMeshFields} + section: point_data + fields: + pressure: {type: scalar, mean: -0.15, std: 0.45} + wss: {type: vector, mean: [0.003, 0.0, 0.0], std: 0.005} + + Example YAML (from .pt file):: + + - _target_: ${dp:NormalizeMeshFields} + section: point_data + stats_file: /path/to/norm_stats.pt + """ + + def __init__( + self, + section: str = "point_data", + fields: dict[str, dict] | None = None, + stats_file: str | None = None, + eps: float = 1e-8, + ) -> None: + super().__init__() + self._section = section + self._eps = eps + + if stats_file is not None: + self._stats: dict[str, dict[str, Float[torch.Tensor, " *shape"] | str]] = ( + torch.load(stats_file, weights_only=True) + ) + elif fields is not None: + self._stats = {} + for name, cfg in fields.items(): + self._stats[name] = { + "type": cfg["type"], + "mean": torch.as_tensor(cfg["mean"], dtype=torch.float32), + "std": torch.as_tensor(cfg["std"], dtype=torch.float32), + } + else: + raise ValueError("Provide one of 'stats_file' or 'fields'") + + def __call__(self, mesh: Mesh) -> Mesh: + td = _get_mesh_section(mesh, self._section) + new_td = td.clone() + + for field_name, stats in self._stats.items(): + if field_name not in new_td.keys(): + continue + val = new_td[field_name].float() + mean = stats["mean"].to(dtype=val.dtype, device=val.device) + std = stats["std"].to(dtype=val.dtype, device=val.device) + new_td[field_name] = (val - mean) / (std + self._eps) + + kwargs: dict = { + "points": mesh.points, + "cells": mesh.cells, + "point_data": mesh.point_data, + "cell_data": mesh.cell_data, + "global_data": mesh.global_data, + } + kwargs[self._section] = new_td + return Mesh(**kwargs) + + def inverse_tensor( + self, + tensor: Float[torch.Tensor, "*batch channels"], + target_config: dict[str, str], + n_spatial_dims: int = 3, + ) -> Float[torch.Tensor, "*batch channels"]: + """Un-normalize a concatenated output tensor back to physical units. + + Fields present in ``target_config`` but absent from the stored + normalization stats are passed through unchanged (their channels + are skipped). This allows partial normalization (e.g. only WSS) + without requiring every field to have stats. + + Parameters + ---------- + tensor : Tensor + Shape ``(*, C)`` where channels are ordered according to + *target_config*. + target_config : dict[str, str] + Ordered mapping of ``{field_name: field_type}`` matching the + channel layout, e.g. ``{"pressure": "scalar", "wss": "vector"}``. + n_spatial_dims : int, optional + Dimensionality of vector fields. Default is 3. + + Returns + ------- + Tensor + Same shape, with each normalized field's channels un-normalized. + """ + out = tensor.clone() + idx = 0 + for name, ftype in target_config.items(): + dim = 1 if ftype == "scalar" else n_spatial_dims + if name in self._stats: + stats = self._stats[name] + mean = stats["mean"].to(dtype=tensor.dtype, device=tensor.device) + std = stats["std"].to(dtype=tensor.dtype, device=tensor.device) + out[..., idx : idx + dim] = ( + out[..., idx : idx + dim] * (std + self._eps) + mean + ) + idx += dim + return out + + @property + def stats(self) -> dict: + """Normalization statistics dict (for serialization).""" + return self._stats + + def extra_repr(self) -> str: + parts = [] + for name, s in self._stats.items(): + parts.append(f"{name}({s['type']}): mean={s['mean']}, std={s['std']}") + return f"section={self._section}, " + ", ".join(parts) + + +@register() +class ComputeSurfaceNormals(MeshTransform): + r"""Compute surface normal vectors and store them in point_data or cell_data. + + Uses the :class:`~physicsnemo.mesh.Mesh` built-in normal computation + (cross product for triangles in 3D, angle-area weighted averaging for + vertex normals). + + Place this transform **before** :class:`SubsampleMesh` so that the + normals are subsampled along with the other fields. + + Parameters + ---------- + store_as : {"cell_data", "point_data"} + Where to store the computed normals. ``"cell_data"`` stores one + normal per cell (the face normal). ``"point_data"`` stores one + normal per vertex (angle-area weighted average of adjacent face + normals). Both modes require the mesh to have cells. + field_name : str + Key under which to store the normals. Default ``"normals"``. + """ + + def __init__( + self, + store_as: Literal["cell_data", "point_data"] = "cell_data", + field_name: str = "normals", + ) -> None: + super().__init__() + if store_as not in ("cell_data", "point_data"): + raise ValueError( + f"store_as must be 'cell_data' or 'point_data', got {store_as!r}" + ) + self.store_as = store_as + self.field_name = field_name + + def __call__(self, mesh: Mesh) -> Mesh: + if self.store_as == "cell_data": + normals = mesh.cell_normals + new_cd = mesh.cell_data.clone() + new_cd[self.field_name] = normals + return Mesh( + points=mesh.points, + cells=mesh.cells, + point_data=mesh.point_data, + cell_data=new_cd, + global_data=mesh.global_data, + ) + else: + normals = mesh.point_normals + new_pd = mesh.point_data.clone() + new_pd[self.field_name] = normals + return Mesh( + points=mesh.points, + cells=mesh.cells, + point_data=new_pd, + cell_data=mesh.cell_data, + global_data=mesh.global_data, + ) + + def extra_repr(self) -> str: + return f"store_as={self.store_as!r}, field_name={self.field_name!r}" + + +def _mesh_to_tensordict(mesh: Mesh) -> TensorDict: + """Convert a single Mesh into a flat TensorDict (no cache, no tensorclass).""" + out: dict = { + "points": mesh.points, + "cells": mesh.cells, + } + if mesh.point_data.keys(): + out["point_data"] = mesh.point_data.clone() + if mesh.cell_data.keys(): + out["cell_data"] = mesh.cell_data.clone() + if mesh.global_data.keys(): + out["global_data"] = mesh.global_data.clone() + return TensorDict(out, batch_size=[]) + + +@register() +class MeshToTensorDict(MeshTransform): + r"""Convert a Mesh or DomainMesh into a plain TensorDict. + + This is a terminal transform -- place it last in the transform chain. + After conversion the data is no longer a Mesh and cannot be passed to + other MeshTransform instances. + + For a single :class:`Mesh` the output layout is:: + + TensorDict({ + "points": (N_p, D_s), + "cells": (N_c, D_m+1), + "point_data": TensorDict({field: tensor, ...}), + "cell_data": TensorDict({field: tensor, ...}), + "global_data": TensorDict({field: tensor, ...}), + }) + + For a :class:`DomainMesh` the output layout is:: + + TensorDict({ + "interior": TensorDict({points, cells, ...}), + "boundaries": TensorDict({ + "wall": TensorDict({points, cells, ...}), + ... + }), + "global_data": TensorDict({field: tensor, ...}), + }) + """ + + def __call__(self, mesh: Mesh) -> TensorDict: # type: ignore[override] + return _mesh_to_tensordict(mesh) + + def apply_to_domain(self, domain: DomainMesh) -> TensorDict: # type: ignore[override] + """Convert a :class:`DomainMesh` into a nested :class:`TensorDict`. + + The output contains an ``"interior"`` key with the interior mesh + converted via :func:`_mesh_to_tensordict`, an optional + ``"boundaries"`` sub-dict keyed by boundary name, and an optional + ``"global_data"`` entry. + + Parameters + ---------- + domain : DomainMesh + Input domain mesh (interior + boundaries). + + Returns + ------- + TensorDict + Nested TensorDict representation of the domain. + """ + out: dict = { + "interior": _mesh_to_tensordict(domain.interior), + } + if domain.n_boundaries > 0: + out["boundaries"] = TensorDict( + { + name: _mesh_to_tensordict(domain.boundaries[name]) + for name in domain.boundary_names + }, + batch_size=[], + ) + if domain.global_data.keys(): + out["global_data"] = domain.global_data.clone() + return TensorDict(out, batch_size=[]) + + +def _resolve_td_path(td: TensorDict, dotted_key: str) -> Float[torch.Tensor, " *shape"]: + """Resolve a dot-separated key path into a tensor from a TensorDict.""" + parts = dotted_key.split(".") + current = td + for part in parts: + current = current[part] + return current + + +@register() +class ComputeCellCentroids(MeshTransform): + r"""Compute cell centroids from points and cells in a TensorDict. + + Placed after :class:`MeshToTensorDict`, this adds a ``cell_centroids`` + key of shape :math:`(N_c, D_s)` computed as the mean of each cell's + vertex positions. Requires ``points`` and ``cells`` to be present. + """ + + def __call__(self, td: TensorDict) -> TensorDict: # type: ignore[override] + points = td["points"] + cells = td["cells"] + centroids = points[cells].mean(dim=1) + td = td.clone() + td["cell_centroids"] = centroids + return td + + +@register() +class RestructureTensorDict(MeshTransform): + r"""Reorganize a flat TensorDict into named groups. + + Placed after :class:`MeshToTensorDict`, this transform picks fields + from the flat layout and assembles them into a structured dict + (e.g. separate ``input`` and ``output`` groups for model training). + + Each group is defined as ``{dest_key: source_path}`` where + ``source_path`` uses dots for nesting (e.g. ``point_data.pressure``). + + Example YAML:: + + - _target_: ${dp:RestructureTensorDict} + groups: + input: + points: points + inlet_velocity: global_data.inlet_velocity + output: + pressure: point_data.pressure + wss: point_data.wss + """ + + def __init__(self, groups: dict[str, dict[str, str]]) -> None: + super().__init__() + self._groups = groups + + def __call__(self, td: TensorDict) -> TensorDict: # type: ignore[override] + out: dict = {} + for group_name, mapping in self._groups.items(): + group: dict = {} + for dest_key, source_path in mapping.items(): + group[dest_key] = _resolve_td_path(td, source_path) + out[group_name] = TensorDict(group, batch_size=[]) + return TensorDict(out, batch_size=[]) + + def extra_repr(self) -> str: + lines = [] + for group, mapping in self._groups.items(): + sources = ", ".join(f"{k}<-{v}" for k, v in mapping.items()) + lines.append(f"{group}: {{{sources}}}") + return "; ".join(lines) diff --git a/physicsnemo/datapipes/transforms/subsample.py b/physicsnemo/datapipes/transforms/subsample.py index 51d446fbeb..035cf36b2f 100644 --- a/physicsnemo/datapipes/transforms/subsample.py +++ b/physicsnemo/datapipes/transforms/subsample.py @@ -32,7 +32,12 @@ from physicsnemo.datapipes.transforms.base import Transform -def poisson_sample_indices_fixed(N: int, k: int, device=None) -> torch.Tensor: +def poisson_sample_indices_fixed( + N: int, + k: int, + device=None, + generator: torch.Generator | None = None, +) -> torch.Tensor: """ Near-uniform sampler of indices for very large arrays. @@ -52,6 +57,8 @@ def poisson_sample_indices_fixed(N: int, k: int, device=None) -> torch.Tensor: Number of indices to sample. device : torch.device, optional Device for the output tensor. + generator : torch.Generator, optional + Random generator for reproducibility. Returns ------- @@ -65,7 +72,7 @@ def poisson_sample_indices_fixed(N: int, k: int, device=None) -> torch.Tensor: torch.Size([10000]) """ # Draw exponential gaps off of random initializations - gaps = torch.rand(k, device=device).exponential_() + gaps = torch.rand(k, device=device, generator=generator).exponential_() summed = gaps.sum() @@ -88,6 +95,7 @@ def shuffle_array( points: torch.Tensor, n_points: int, weights: Optional[torch.Tensor] = None, + generator: torch.Generator | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Sample points with or without weights. @@ -101,6 +109,8 @@ def shuffle_array( weights : torch.Tensor, optional Optional weights for sampling, shape :math:`(N,)`. If None, uses uniform sampling. + generator : torch.Generator, optional + Random generator for reproducibility. Returns ------- @@ -119,15 +129,19 @@ def shuffle_array( if weights is not None: # Weighted sampling - indices = torch.multinomial(weights, n_points, replacement=False) + indices = torch.multinomial( + weights, n_points, replacement=False, generator=generator + ) else: # Uniform sampling if N > 2**24: # Use Poisson sampling for very large arrays - indices = poisson_sample_indices_fixed(N, n_points, device=device) + indices = poisson_sample_indices_fixed( + N, n_points, device=device, generator=generator + ) else: # Use standard multinomial for smaller arrays - indices = torch.randperm(N, device=device)[:n_points] + indices = torch.randperm(N, device=device, generator=generator)[:n_points] sampled_points = points[indices] return sampled_points, indices @@ -236,6 +250,7 @@ def __init__( self.n_points = n_points self.algorithm = algorithm self.weights_key = weights_key + self._generator: torch.Generator | None = None def __call__(self, data: TensorDict) -> TensorDict: """ @@ -300,12 +315,26 @@ def __call__(self, data: TensorDict) -> TensorDict: device = first_tensor.device if weights is not None: # Weighted sampling - _, indices = shuffle_array(first_tensor, self.n_points, weights=weights) + _, indices = shuffle_array( + first_tensor, + self.n_points, + weights=weights, + generator=self._generator, + ) elif self.algorithm == "poisson_fixed" and N > 2**24: - indices = poisson_sample_indices_fixed(N, self.n_points, device=device) + indices = poisson_sample_indices_fixed( + N, + self.n_points, + device=device, + generator=self._generator, + ) else: # Use uniform sampling - indices = torch.randperm(N, device=device)[: self.n_points] + indices = torch.randperm( + N, + device=device, + generator=self._generator, + )[: self.n_points] # Apply indices to all keys updates = {} diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index 254d690e31..e87eed41b2 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -69,6 +69,11 @@ class GALE(PhysicsAttentionIrregularMesh): Whether to use Transolver++ features. Default is False. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. Forward ------- @@ -121,9 +126,17 @@ def __init__( plus: bool = False, context_dim: int = 0, concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", ) -> None: super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) + if state_mixing_mode not in ("weighted", "concat_project"): + raise ValueError( + f"Invalid state_mixing_mode: {state_mixing_mode!r}. " + f"Expected 'weighted' or 'concat_project'." + ) + self.state_mixing_mode = state_mixing_mode + linear_layer = te.Linear if self.use_te else nn.Linear # Cross-attention projection layers for context integration @@ -131,9 +144,17 @@ def __init__( self.cross_k = linear_layer(context_dim, dim_head) self.cross_v = linear_layer(context_dim, dim_head) - # Learnable mixing weight between self and cross attention - # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing - self.state_mixing = nn.Parameter(torch.tensor(0.0)) + # Mixing layers for blending self-attention and cross-attention + if state_mixing_mode == "weighted": + # Learnable mixing weight between self and cross attention + # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + else: + # Concatenate self and cross attention and project back to dim_head + self.concat_project = nn.Sequential( + linear_layer(2 * dim_head, dim_head), + nn.GELU(), + ) # Replace inherited out_dropout with ConcreteDropout when enabled if concrete_dropout: @@ -277,12 +298,18 @@ def forward( for _slice_token in slice_tokens ] - # Blend self-attention and cross-attention with learnable mixing weight - mixing_weight = torch.sigmoid(self.state_mixing) - out_slice_token = [ - mixing_weight * sst + (1 - mixing_weight) * cst - for sst, cst in zip(self_slice_token, cross_slice_token) - ] + # Blend self-attention and cross-attention + if self.state_mixing_mode == "weighted": + mixing_weight = torch.sigmoid(self.state_mixing) + out_slice_token = [ + mixing_weight * sst + (1 - mixing_weight) * cst + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + out_slice_token = [ + self.concat_project(torch.cat([sst, cst], dim=-1)) + for sst, cst in zip(self_slice_token, cross_slice_token) + ] else: # Use only self-attention when no context is provided out_slice_token = self_slice_token @@ -330,6 +357,11 @@ class GALE_block(nn.Module): attention_type : str, optional attention_type is used to choose the attention type (GALE or GALE_FA). Default is ``"GALE"``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. Forward ------- @@ -384,6 +416,7 @@ def __init__( context_dim: int = 0, attention_type: str = "GALE", concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", ) -> None: super().__init__() @@ -414,6 +447,7 @@ def __init__( plus=plus, context_dim=context_dim, concrete_dropout=concrete_dropout, + state_mixing_mode=state_mixing_mode, ) case 'GALE_FA': self.Attn = GALE_FA( @@ -425,6 +459,7 @@ def __init__( use_te=use_te, context_dim=context_dim, concrete_dropout=concrete_dropout, + state_mixing_mode=state_mixing_mode, ) case _: raise ValueError( diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py index ebdebd4e68..eb9f8a3478 100644 --- a/physicsnemo/experimental/models/geotransolver/gale_fa.py +++ b/physicsnemo/experimental/models/geotransolver/gale_fa.py @@ -68,6 +68,11 @@ class GALE_FA(nn.Module): concrete_dropout : bool, optional Whether to use learned concrete dropout instead of standard dropout. Default is ``False``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. Forward ------- @@ -119,6 +124,7 @@ def __init__( use_te: bool = True, context_dim: int = 0, concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", ): if use_te: raise ValueError( @@ -126,6 +132,12 @@ def __init__( "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." ) super().__init__() + if state_mixing_mode not in ("weighted", "concat_project"): + raise ValueError( + f"Invalid state_mixing_mode: {state_mixing_mode!r}. " + f"Expected 'weighted' or 'concat_project'." + ) + self.state_mixing_mode = state_mixing_mode self.use_te = use_te self.heads = heads self.dim_head = dim_head @@ -150,8 +162,16 @@ def __init__( self.cross_k = linear_layer(context_dim, dim_head) self.cross_v = linear_layer(context_dim, dim_head) - # Learnable mixing weight between self and cross attention - self.state_mixing = nn.Parameter(torch.tensor(0.0)) + # Mixing layers for blending self-attention and cross-attention + if state_mixing_mode == "weighted": + # Learnable mixing weight between self and cross attention + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + else: + # Concatenate self and cross attention and project back to dim_head + self.concat_project = nn.Sequential( + linear_layer(2 * dim_head, dim_head), + nn.GELU(), + ) # te attention if self.use_te: @@ -249,9 +269,12 @@ def forward( else: cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=self.scale) for _q in q] - # Apply learnable mixing: - mixing_weight = torch.sigmoid(self.state_mixing) - outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] + # Blend self-attention and cross-attention + if self.state_mixing_mode == "weighted": + mixing_weight = torch.sigmoid(self.state_mixing) + outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] + else: + outputs = [self.concat_project(torch.cat([_ys, _yc], dim=-1)) for _ys, _yc in zip(self_attention, cross_attention)] else: outputs = self_attention diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 6460960d08..9fe2e87fb2 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -207,6 +207,11 @@ class GeoTransolver(Module): attention_type : str, optional attention_type is used to choose the attention type (GALE or GALE_FA). Default is ``"GALE"``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs in GALE layers. + ``"weighted"`` uses a learnable sigmoid-gated weighted sum. + ``"concat_project"`` concatenates the two along the head dimension and + projects back with a linear layer. Default is ``"weighted"``. Forward ------- @@ -337,6 +342,7 @@ def __init__( n_hidden_local: int = 32, attention_type: str = "GALE", concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", ) -> None: super().__init__(meta=GeoTransolverMetaData()) self.__name__ = "GeoTransolver" @@ -429,6 +435,7 @@ def __init__( context_dim=context_dim, attention_type=attention_type, concrete_dropout=concrete_dropout, + state_mixing_mode=state_mixing_mode, ) for layer_idx in range(n_layers) ] diff --git a/physicsnemo/mesh/domain_mesh.py b/physicsnemo/mesh/domain_mesh.py index e477cbcf98..6b7298a4dd 100644 --- a/physicsnemo/mesh/domain_mesh.py +++ b/physicsnemo/mesh/domain_mesh.py @@ -814,7 +814,7 @@ def n_boundaries(self) -> int: int The number of entries in ``boundaries``. """ - return len(self.boundaries) + return len(list(self.boundaries.keys())) ### Methods diff --git a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py index 6698be028d..a1e89c74e1 100644 --- a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py @@ -424,33 +424,38 @@ def radius_search_impl_fake( indices = torch.empty( queries.shape[0], max_points, dtype=torch.int32, device=queries.device ) - if max_points is not None: - num_neighbors = torch.empty( - queries.shape[0], dtype=torch.int32, device=queries.device - ) - else: - num_neighbors = torch.empty(0, dtype=torch.int32, device=queries.device) + num_neighbors = torch.empty( + queries.shape[0], dtype=torch.int32, device=queries.device + ) + # Dtype must match the real op, which returns `distances.to(points.dtype)` + # and `points.to(points.dtype)`. Hard-coding fp32 here causes Inductor to + # emit kernels with the wrong strides/byte-counts under bf16/fp16 and + # triggers cudaErrorIllegalAddress downstream. if return_dists: distances = torch.empty( queries.shape[0], max_points, - dtype=torch.float32, + dtype=points.dtype, device=queries.device, ) else: - distances = torch.empty(0, dtype=torch.float32, device=queries.device) + distances = torch.empty(0, dtype=points.dtype, device=queries.device) if return_points: out_points = torch.empty( queries.shape[0], max_points, 3, - dtype=torch.float32, + dtype=points.dtype, device=queries.device, ) else: - out_points = torch.empty(0, 3, dtype=torch.float32, device=queries.device) + # Real op returns rank-3 (0, max_points, 3); keep the rank consistent + # so downstream shape/stride assumptions in compiled graphs hold. + out_points = torch.empty( + 0, max_points, 3, dtype=points.dtype, device=queries.device + ) return indices, out_points, distances, num_neighbors @@ -601,6 +606,7 @@ def apply_grad_to_points( @apply_grad_to_points.register_fake def apply_grad_to_points_fake( indexes: torch.Tensor, + num_neighbors: torch.Tensor, grad_points_out: torch.Tensor, points_shape: List[int], max_points: int | None = None, @@ -610,6 +616,8 @@ def apply_grad_to_points_fake( Args: indexes (torch.Tensor): The indices mapping output points to input points. + num_neighbors (torch.Tensor): The per-query neighbor counts (only used when + ``max_points`` is not None, but always present to match the real op signature). grad_points_out (torch.Tensor): The gradient of the output points. points_shape (torch.Size): The shape of the input points tensor. @@ -636,6 +644,23 @@ def radius_search( return_dists: bool = False, return_points: bool = False, ): + """ + Perform a radius search between points and queries. + + Args: + points (torch.Tensor): The input points tensor. + queries (torch.Tensor): The query points tensor. + radius (float): The search radius. + max_points (int | None): The maximum number of neighbors per query, or + None for unlimited. + return_dists (bool): Whether to return distances between query and + neighbor points. + return_points (bool): Whether to return the neighbor points themselves. + + Returns: + The formatted radius search results, whose contents depend on + ``return_dists`` and ``return_points``. + """ indices, points_out, distances, _ = radius_search_impl( points, queries, radius, max_points, return_dists, return_points ) diff --git a/test/datapipes/core/test_dataset.py b/test/datapipes/core/test_dataset.py index 79286dc480..2010893dea 100644 --- a/test/datapipes/core/test_dataset.py +++ b/test/datapipes/core/test_dataset.py @@ -277,26 +277,10 @@ def test_prefetch_single(self, numpy_data_dir): # Prefetch index 0 dataset.prefetch(0) - # Should have 1 prefetch in flight (may complete quickly) - assert dataset.prefetch_count >= 0 - # Get should use prefetched result data, metadata = dataset[0] assert "positions" in data - def test_prefetch_batch(self, numpy_data_dir): - """Test prefetching multiple samples.""" - reader = dp.NumpyReader(numpy_data_dir) - dataset = dp.Dataset(reader) - - # Prefetch multiple indices - dataset.prefetch_batch([0, 1, 2, 3]) - - # Get samples - for i in range(4): - data, metadata = dataset[i] - assert metadata["index"] == i - def test_prefetch_non_prefetched_index(self, numpy_data_dir): """Test getting a non-prefetched index loads synchronously.""" reader = dp.NumpyReader(numpy_data_dir) @@ -316,10 +300,10 @@ def test_prefetch_skips_if_already_in_flight(self, numpy_data_dir): # Prefetch same index twice dataset.prefetch(0) - initial_count = dataset.prefetch_count + initial_count = len(dataset._prefetch_futures) dataset.prefetch(0) # Should be a no-op - assert dataset.prefetch_count == initial_count + assert len(dataset._prefetch_futures) == initial_count # Still should be able to get the data data, metadata = dataset[0] @@ -356,7 +340,7 @@ def test_prefetch_then_getitem_workflow(self, numpy_data_dir): assert metadata["index"] == i # Prefetch count should be 0 after retrieving all - assert dataset.prefetch_count == 0 + assert len(dataset._prefetch_futures) == 0 # ============================================================================ @@ -381,39 +365,6 @@ def test_prefetch_with_stream(self, numpy_data_dir): torch.cuda.synchronize() - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_prefetch_batch_with_streams(self, numpy_data_dir): - """Test prefetch_batch with multiple CUDA streams.""" - reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) - dataset = dp.Dataset(reader, device="cuda:0") - - streams = [torch.cuda.Stream() for _ in range(4)] - dataset.prefetch_batch([0, 1, 2, 3], streams=streams) - - for i in range(4): - data, metadata = dataset[i] - assert data["positions"].device.type == "cuda" - - torch.cuda.synchronize() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_prefetch_batch_with_stream_cycling(self, numpy_data_dir): - """Test prefetch_batch cycles through streams correctly.""" - reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) - dataset = dp.Dataset(reader, device="cuda:0") - - # Use fewer streams than indices to test cycling - streams = [torch.cuda.Stream() for _ in range(2)] - dataset.prefetch_batch([0, 1, 2, 3, 4], streams=streams) - - # All samples should be retrievable - for i in range(5): - data, metadata = dataset[i] - assert metadata["index"] == i - assert data["positions"].device.type == "cuda" - - torch.cuda.synchronize() - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_prefetch_with_stream_and_transforms(self, numpy_data_dir): """Test prefetching with CUDA stream and transforms.""" @@ -474,11 +425,11 @@ def test_prefetch_cancel_all(self, numpy_data_dir): reader = dp.NumpyReader(numpy_data_dir) dataset = dp.Dataset(reader) - dataset.prefetch_batch([0, 1, 2, 3]) + for i in range(4): + dataset.prefetch(i) dataset.cancel_prefetch() - # Prefetch count should be 0 after cancel - assert dataset.prefetch_count == 0 + assert len(dataset._prefetch_futures) == 0 def test_prefetch_cancel_specific(self, numpy_data_dir): """Test canceling a specific prefetch.""" @@ -525,11 +476,11 @@ def test_close_stops_prefetch(self, numpy_data_dir): reader = dp.NumpyReader(numpy_data_dir) dataset = dp.Dataset(reader) - dataset.prefetch_batch([0, 1, 2, 3]) + for i in range(4): + dataset.prefetch(i) dataset.close() - # Should not raise, prefetch should be stopped - assert dataset.prefetch_count == 0 + assert len(dataset._prefetch_futures) == 0 def test_close_shuts_down_executor(self, numpy_data_dir): """Test that close shuts down the executor.""" @@ -566,7 +517,7 @@ def test_context_manager_cleans_up(self, numpy_data_dir): # After context exit, executor should be shut down assert dataset._executor is None - assert dataset.prefetch_count == 0 + assert len(dataset._prefetch_futures) == 0 # ============================================================================ @@ -679,7 +630,8 @@ def test_full_gpu_pipeline(self, numpy_data_dir): # Prefetch with streams streams = [torch.cuda.Stream() for _ in range(2)] - dataset.prefetch_batch([0, 1, 2, 3], streams=streams) + for i in range(4): + dataset.prefetch(i, stream=streams[i % len(streams)]) # Retrieve results for i in range(4): diff --git a/test/datapipes/core/test_multi_dataset.py b/test/datapipes/core/test_multi_dataset.py index 1c25c39b9f..29a418a6ba 100644 --- a/test/datapipes/core/test_multi_dataset.py +++ b/test/datapipes/core/test_multi_dataset.py @@ -194,30 +194,6 @@ def test_cancel_prefetch_invalid_index_no_op(self, numpy_data_dir): data0, _ = multi[0] assert "positions" in data0 - def test_prefetch_batch(self, numpy_data_dir): - """prefetch_batch delegates to sub-datasets by global index.""" - ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) - - multi.prefetch_batch([0, 1, 10, 11]) - for idx in [0, 1, 10, 11]: - data, meta = multi[idx] - assert meta[DATASET_INDEX_METADATA_KEY] == (0 if idx < 10 else 1) - assert "positions" in data - - def test_prefetch_count(self, numpy_data_dir): - """prefetch_count is sum of sub-dataset prefetch counts.""" - ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - ds_b = dp.Dataset(dp.NumpyReader(numpy_data_dir)) - multi = dp.MultiDataset(ds_a, ds_b, output_strict=True) - - assert multi.prefetch_count == 0 - multi.prefetch_batch([0, 1, 2, 3]) - assert multi.prefetch_count >= 0 # may complete quickly - multi.cancel_prefetch() - assert multi.prefetch_count == 0 - def test_close_closes_all(self, numpy_data_dir): """close() closes all sub-datasets.""" ds_a = dp.Dataset(dp.NumpyReader(numpy_data_dir)) diff --git a/test/datapipes/readers/test_mesh_readers.py b/test/datapipes/readers/test_mesh_readers.py new file mode 100644 index 0000000000..d69cea02b4 --- /dev/null +++ b/test/datapipes/readers/test_mesh_readers.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MeshReader, DomainMeshReader, and DomainMesh transform integration.""" + +import pytest +import torch + +from physicsnemo.datapipes.mesh_dataset import MeshDataset +from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader +from physicsnemo.datapipes.transforms.mesh import ( + CenterMesh, + RandomScaleMesh, + ScaleMesh, + apply_to_tensordict_mesh, +) +from physicsnemo.mesh import DomainMesh, Mesh +from physicsnemo.mesh.primitives.basic import ( + single_triangle_3d, + two_triangles_2d, +) + + +class TestMeshReader: + """Tests for MeshReader (single-mesh).""" + + def test_len_and_getitem(self, tmp_path): + mesh = two_triangles_2d.load() + mesh.save(tmp_path / "a.pt") + mesh.save(tmp_path / "b.pt") + reader = MeshReader(tmp_path, pattern="*.pt") + assert len(reader) == 2 + m, meta = reader[0] + assert isinstance(m, Mesh) + assert m.n_points == mesh.n_points + assert "source_path" in meta + assert "index" in meta + assert meta["index"] == 0 + + def test_negative_index(self, tmp_path): + mesh = two_triangles_2d.load() + mesh.save(tmp_path / "single.pt") + reader = MeshReader(tmp_path, pattern="*.pt") + m1, _ = reader[0] + m2, _ = reader[-1] + assert m1.n_points == m2.n_points + + def test_iter(self, tmp_path): + mesh = two_triangles_2d.load() + for i in range(3): + mesh.save(tmp_path / f"m{i}.pt") + reader = MeshReader(tmp_path, pattern="*.pt") + samples = list(reader) + assert len(samples) == 3 + for m, meta in samples: + assert isinstance(m, Mesh) + assert isinstance(meta, dict) + + +class TestDomainMeshReader: + """Tests for DomainMeshReader (DomainMesh per sample).""" + + def _make_domain_mesh(self): + """Create a simple DomainMesh for testing.""" + interior = Mesh(points=torch.randn(10, 3)) + wall = single_triangle_3d.load() + inlet = single_triangle_3d.load() + return DomainMesh( + interior=interior, + boundaries={"wall": wall, "inlet": inlet}, + global_data={"Re": torch.tensor(1e6)}, + ) + + def test_len_and_getitem(self, tmp_path): + dm = self._make_domain_mesh() + dm.save(tmp_path / "sample_a.pt") + dm.save(tmp_path / "sample_b.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + assert len(reader) == 2 + loaded, meta = reader[0] + assert isinstance(loaded, DomainMesh) + assert loaded.interior.n_points == dm.interior.n_points + assert "source_path" in meta + assert "index" in meta + assert meta["index"] == 0 + + def test_boundary_names_in_metadata(self, tmp_path): + dm = self._make_domain_mesh() + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + _, meta = reader[0] + assert sorted(meta["boundary_names"]) == ["inlet", "wall"] + + def test_no_boundaries(self, tmp_path): + dm = DomainMesh(interior=Mesh(points=torch.randn(5, 3))) + dm.save(tmp_path / "bare.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + loaded, meta = reader[0] + assert loaded.n_boundaries == 0 + assert meta["boundary_names"] == [] + + def test_iter(self, tmp_path): + dm = self._make_domain_mesh() + for i in range(3): + dm.save(tmp_path / f"dm{i}.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + samples = list(reader) + assert len(samples) == 3 + for loaded, meta in samples: + assert isinstance(loaded, DomainMesh) + assert isinstance(meta, dict) + + def test_global_data_preserved(self, tmp_path): + dm = self._make_domain_mesh() + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + loaded, _ = reader[0] + assert "Re" in loaded.global_data.keys() + + +class TestMeshDataset: + """Tests for MeshDataset with mesh transforms.""" + + def test_single_mesh_with_transform(self, tmp_path): + mesh = two_triangles_2d.load() + mesh.save(tmp_path / "m.pt") + reader = MeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset(reader, transforms=[ScaleMesh(2.0)]) + m, meta = ds[0] + assert isinstance(m, Mesh) + assert m.n_points == mesh.n_points + + def test_domain_mesh_with_transform(self, tmp_path): + interior = Mesh(points=torch.randn(10, 3)) + wall = single_triangle_3d.load() + dm = DomainMesh( + interior=interior, + boundaries={"wall": wall}, + ) + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset(reader, transforms=[ScaleMesh(0.5)]) + loaded, meta = ds[0] + assert isinstance(loaded, DomainMesh) + assert loaded.interior.n_points == interior.n_points + assert loaded.n_boundaries == 1 + + def test_domain_mesh_transform_applies_to_all(self, tmp_path): + interior = Mesh( + points=torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]), + ) + wall = Mesh( + points=torch.tensor([[1.0, 0.0, 0.0], [3.0, 0.0, 0.0]]), + ) + dm = DomainMesh(interior=interior, boundaries={"wall": wall}) + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset(reader, transforms=[ScaleMesh(2.0)]) + loaded, _ = ds[0] + assert torch.allclose( + loaded.interior.points, + torch.tensor([[0.0, 0.0, 0.0], [4.0, 0.0, 0.0]]), + ) + assert torch.allclose( + loaded.boundaries["wall"].points, + torch.tensor([[2.0, 0.0, 0.0], [6.0, 0.0, 0.0]]), + ) + + +class TestDomainMeshTransforms: + """Tests for DomainMesh-aware transform behavior via apply_to_domain.""" + + def test_scale_transforms_domain_global_data(self, tmp_path): + """ScaleMesh with transform_global_data=True should scale domain global_data.""" + dm = DomainMesh( + interior=Mesh( + points=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), + ), + global_data={"velocity": torch.tensor([1.0, 0.0, 0.0])}, + ) + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset( + reader, + transforms=[ScaleMesh(2.0, transform_global_data=True)], + ) + loaded, _ = ds[0] + assert torch.allclose( + loaded.global_data["velocity"], + torch.tensor([2.0, 0.0, 0.0]), + ) + + def test_scale_preserves_domain_global_data_by_default(self, tmp_path): + """ScaleMesh without transform_global_data leaves domain global_data unchanged.""" + dm = DomainMesh( + interior=Mesh( + points=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), + ), + global_data={"velocity": torch.tensor([1.0, 0.0, 0.0])}, + ) + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset(reader, transforms=[ScaleMesh(2.0)]) + loaded, _ = ds[0] + assert torch.allclose( + loaded.global_data["velocity"], + torch.tensor([1.0, 0.0, 0.0]), + ) + + def test_random_scale_consistent_across_meshes(self, tmp_path): + """RandomScaleMesh should apply the same factor to interior and boundaries.""" + interior = Mesh( + points=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), + ) + wall = Mesh( + points=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), + ) + dm = DomainMesh(interior=interior, boundaries={"wall": wall}) + dm.save(tmp_path / "dm.pt") + + aug = RandomScaleMesh( + distribution=torch.distributions.Uniform(0.5, 2.0), + ) + aug.set_generator(torch.Generator().manual_seed(42)) + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset( + reader, + transforms=[aug], + ) + loaded, _ = ds[0] + + interior_factor = loaded.interior.points[1, 0].item() + wall_factor = loaded.boundaries["wall"].points[1, 0].item() + assert interior_factor == pytest.approx(wall_factor) + + def test_center_mesh_uses_interior_com(self, tmp_path): + """CenterMesh should center by interior COM, not per-mesh COM.""" + interior = Mesh( + points=torch.tensor( + [ + [2.0, 0.0, 0.0], + [4.0, 0.0, 0.0], + ] + ), + ) + wall = Mesh( + points=torch.tensor( + [ + [10.0, 0.0, 0.0], + [12.0, 0.0, 0.0], + ] + ), + ) + dm = DomainMesh(interior=interior, boundaries={"wall": wall}) + dm.save(tmp_path / "dm.pt") + reader = DomainMeshReader(tmp_path, pattern="*.pt") + ds = MeshDataset( + reader, + transforms=[CenterMesh(use_area_weighting=False)], + ) + loaded, _ = ds[0] + + interior_com = loaded.interior.points.mean(dim=0) + assert torch.allclose(interior_com, torch.zeros(3), atol=1e-6) + + expected_wall = torch.tensor( + [ + [10.0 - 3.0, 0.0, 0.0], + [12.0 - 3.0, 0.0, 0.0], + ] + ) + assert torch.allclose(loaded.boundaries["wall"].points, expected_wall) + + +class TestApplyToTensorDictMesh: + """Tests for apply_to_tensordict_mesh helper (standalone utility).""" + + def test_scale_each(self): + from tensordict import TensorDict + + mesh = two_triangles_2d.load() + td = TensorDict({"x": mesh, "y": mesh.clone()}, batch_size=[]) + out = apply_to_tensordict_mesh(td, ScaleMesh(3.0)) + assert out["x"].n_points == mesh.n_points + assert "x" in out + assert "y" in out diff --git a/test/datapipes/transforms/test_mesh_augmentations.py b/test/datapipes/transforms/test_mesh_augmentations.py new file mode 100644 index 0000000000..6f270ba7d7 --- /dev/null +++ b/test/datapipes/transforms/test_mesh_augmentations.py @@ -0,0 +1,809 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distribution-parametrized mesh augmentations.""" + +import math +import warnings + +import pytest +import torch +import torch.distributions as D + +from physicsnemo.datapipes.transforms.mesh.augmentations import ( + RandomRotateMesh, + RandomScaleMesh, + RandomTranslateMesh, + _sample_distribution, +) +from physicsnemo.mesh import DomainMesh, Mesh + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _simple_mesh_3d() -> Mesh: + """A minimal 3-D mesh (single triangle).""" + points = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=torch.float32, + ) + cells = torch.tensor([[0, 1, 2]], dtype=torch.int64) + return Mesh(points=points, cells=cells) + + +def _simple_domain_3d() -> DomainMesh: + """A minimal 3-D DomainMesh with interior + one boundary.""" + interior = Mesh( + points=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=torch.float32, + ), + cells=torch.tensor([[0, 1, 2]], dtype=torch.int64), + ) + wall = Mesh( + points=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=torch.float32, + ), + cells=torch.tensor([[0, 1, 2]], dtype=torch.int64), + ) + return DomainMesh(interior=interior, boundaries={"wall": wall}) + + +def _seed(aug, seed: int): + """Assign a seeded generator to an augmentation transform.""" + aug.set_generator(torch.Generator().manual_seed(seed)) + return aug + + +# --------------------------------------------------------------------------- +# _sample_distribution +# --------------------------------------------------------------------------- + + +class TestSampleDistribution: + """Tests for the _sample_distribution() helper function.""" + + def test_uniform_statistics(self): + """Samples from Uniform(2, 5) should have mean ~3.5 and lie in [2, 5].""" + gen = torch.Generator().manual_seed(0) + dist = D.Uniform(2.0, 5.0) + samples = torch.stack( + [_sample_distribution(dist, (1,), gen).squeeze(0) for _ in range(5000)] + ) + assert samples.min() >= 2.0 + assert samples.max() <= 5.0 + assert samples.mean().item() == pytest.approx(3.5, abs=0.1) + + def test_normal_statistics(self): + """Samples from Normal(10, 0.5) should cluster near 10.""" + gen = torch.Generator().manual_seed(0) + dist = D.Normal(10.0, 0.5) + samples = torch.stack( + [_sample_distribution(dist, (1,), gen).squeeze(0) for _ in range(5000)] + ) + assert samples.mean().item() == pytest.approx(10.0, abs=0.05) + assert samples.std().item() == pytest.approx(0.5, abs=0.05) + + def test_cauchy_median(self): + """Cauchy(3, 1) should have median ~3 (mean is undefined).""" + gen = torch.Generator().manual_seed(0) + dist = D.Cauchy(3.0, 1.0) + samples = torch.stack( + [_sample_distribution(dist, (1,), gen).squeeze(0) for _ in range(5000)] + ) + assert samples.median().item() == pytest.approx(3.0, abs=0.1) + + def test_laplace_statistics(self): + """Laplace(0, 0.5) should have mean ~0 and scale ~0.5.""" + gen = torch.Generator().manual_seed(0) + dist = D.Laplace(0.0, 0.5) + samples = torch.stack( + [_sample_distribution(dist, (1,), gen).squeeze(0) for _ in range(5000)] + ) + assert samples.mean().item() == pytest.approx(0.0, abs=0.05) + # Mean absolute deviation of Laplace(0, b) is b + assert samples.abs().mean().item() == pytest.approx(0.5, abs=0.05) + + def test_reproducibility_with_generator(self): + """Same seed should produce identical samples.""" + dist = D.Normal(0.0, 1.0) + + gen1 = torch.Generator().manual_seed(123) + s1 = _sample_distribution(dist, (10,), gen1) + + gen2 = torch.Generator().manual_seed(123) + s2 = _sample_distribution(dist, (10,), gen2) + + assert torch.allclose(s1, s2) + + def test_different_seeds_differ(self): + """Different seeds should produce different samples.""" + dist = D.Normal(0.0, 1.0) + + gen1 = torch.Generator().manual_seed(0) + s1 = _sample_distribution(dist, (10,), gen1) + + gen2 = torch.Generator().manual_seed(999) + s2 = _sample_distribution(dist, (10,), gen2) + + assert not torch.allclose(s1, s2) + + def test_multidimensional_shape(self): + """Should return the requested shape.""" + gen = torch.Generator().manual_seed(0) + dist = D.Normal(0.0, 1.0) + s = _sample_distribution(dist, (3, 4), gen) + assert s.shape == (3, 4) + + def test_fallback_warning_for_poisson(self): + """Poisson has no icdf; should fall back with a warning.""" + gen = torch.Generator().manual_seed(0) + dist = D.Poisson(3.0) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + s = _sample_distribution(dist, (100,), gen) + assert len(w) == 1 + assert "icdf" in str(w[0].message).lower() + assert s.shape == (100,) + + def test_fallback_device(self): + """When generator is None, samples should land on fallback_device.""" + dist = D.Normal(0.0, 1.0) + s = _sample_distribution(dist, (5,), generator=None, fallback_device="cpu") + assert s.device == torch.device("cpu") + assert s.shape == (5,) + + def test_batched_distribution(self): + """A batched distribution should produce per-element samples.""" + gen = torch.Generator().manual_seed(0) + low = torch.tensor([-1.0, -2.0, -3.0]) + high = torch.tensor([1.0, 2.0, 3.0]) + dist = D.Uniform(low, high) + samples = torch.stack( + [_sample_distribution(dist, (3,), gen) for _ in range(2000)] + ) + # Each column should stay in its respective range + assert samples[:, 0].min() >= -1.0 + assert samples[:, 0].max() <= 1.0 + assert samples[:, 1].min() >= -2.0 + assert samples[:, 1].max() <= 2.0 + assert samples[:, 2].min() >= -3.0 + assert samples[:, 2].max() <= 3.0 + + +# --------------------------------------------------------------------------- +# MeshTransform.to() distribution handling +# --------------------------------------------------------------------------- + + +class TestDistributionDeviceTransfer: + """Tests for MeshTransform.to() moving distribution parameters.""" + + def test_to_cpu_preserves_distribution(self): + """to('cpu') should produce a working distribution on CPU.""" + aug = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.05)), 0) + aug.to("cpu") + factor = aug._sample_factor() + assert factor.device == torch.device("cpu") + assert factor.item() == pytest.approx(1.0, abs=0.3) + + def test_to_preserves_distribution_type(self): + """to() should keep the same distribution class.""" + aug = RandomScaleMesh(distribution=D.Laplace(0.0, 1.0)) + aug.to("cpu") + assert isinstance(aug._distribution, D.Laplace) + + def test_to_moves_batched_distribution(self): + """to() should move batched distribution params.""" + dist = D.Uniform( + torch.tensor([-1.0, -2.0, -3.0]), + torch.tensor([1.0, 2.0, 3.0]), + ) + aug = RandomTranslateMesh(distribution=dist) + aug.to("cpu") + assert aug._distribution.low.device == torch.device("cpu") + assert aug._distribution.high.device == torch.device("cpu") + + +# --------------------------------------------------------------------------- +# RandomScaleMesh +# --------------------------------------------------------------------------- + + +class TestRandomScaleMesh: + """Tests for RandomScaleMesh with distribution-based sampling.""" + + def test_default_distribution(self): + """Default distribution should be Uniform(0.9, 1.1).""" + aug = RandomScaleMesh() + assert isinstance(aug._distribution, D.Uniform) + + def test_normal_distribution_clusters_near_center(self): + """Normal(1.0, 0.05) should produce scale factors near 1.0.""" + aug = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.05)), 42) + mesh = _simple_mesh_3d() + factors = [] + for _ in range(500): + scaled = aug(mesh) + # point (1,0,0) scaled by factor -> (factor, 0, 0) + factors.append(scaled.points[1, 0].item()) + factors_t = torch.tensor(factors) + assert factors_t.mean().item() == pytest.approx(1.0, abs=0.02) + assert factors_t.std().item() == pytest.approx(0.05, abs=0.02) + + def test_lognormal_always_positive(self): + """LogNormal should always produce positive scale factors.""" + aug = _seed(RandomScaleMesh(distribution=D.LogNormal(0.0, 0.1)), 0) + mesh = _simple_mesh_3d() + for _ in range(100): + scaled = aug(mesh) + factor = scaled.points[1, 0].item() + assert factor > 0.0 + + def test_reproducibility(self): + """Same seed should produce identical results.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.5)), 7) + r1 = aug1(mesh) + + aug2 = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.5)), 7) + r2 = aug2(mesh) + + assert torch.allclose(r1.points, r2.points) + + def test_apply_to_domain_consistent(self): + """apply_to_domain should use the same factor for interior and boundary.""" + domain = _simple_domain_3d() + aug = _seed(RandomScaleMesh(distribution=D.Uniform(0.5, 2.0)), 0) + scaled = aug.apply_to_domain(domain) + # Interior and wall started with the same points, so after scaling + # by the same factor they should still match. + assert torch.allclose(scaled.interior.points, scaled.boundaries["wall"].points) + + def test_sequence_reproducibility(self): + """Same seed should produce identical results over multiple calls.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.5)), 7) + seq1 = [aug1(mesh).points.clone() for _ in range(10)] + + aug2 = _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.5)), 7) + seq2 = [aug2(mesh).points.clone() for _ in range(10)] + + for s1, s2 in zip(seq1, seq2): + assert torch.allclose(s1, s2) + + def test_apply_to_domain_reproducibility(self): + """Same seed should produce identical domain results.""" + domain = _simple_domain_3d() + + aug1 = _seed(RandomScaleMesh(distribution=D.Uniform(0.5, 2.0)), 42) + d1 = aug1.apply_to_domain(domain) + + aug2 = _seed(RandomScaleMesh(distribution=D.Uniform(0.5, 2.0)), 42) + d2 = aug2.apply_to_domain(domain) + + assert torch.allclose(d1.interior.points, d2.interior.points) + assert torch.allclose( + d1.boundaries["wall"].points, d2.boundaries["wall"].points + ) + + def test_extra_repr(self): + """extra_repr should mention the distribution.""" + aug = RandomScaleMesh(distribution=D.Normal(1.0, 0.05)) + assert "Normal" in aug.extra_repr() + + +# --------------------------------------------------------------------------- +# RandomTranslateMesh +# --------------------------------------------------------------------------- + + +class TestRandomTranslateMesh: + """Tests for RandomTranslateMesh with distribution-based sampling.""" + + def test_default_distribution(self): + """Default distribution should be Uniform(-0.1, 0.1).""" + aug = RandomTranslateMesh() + assert isinstance(aug._distribution, D.Uniform) + + def test_laplace_distribution(self): + """Laplace(0, 0.02) should produce offsets concentrated near zero.""" + aug = _seed(RandomTranslateMesh(distribution=D.Laplace(0.0, 0.02)), 0) + mesh = _simple_mesh_3d() + offsets = [] + for _ in range(500): + translated = aug(mesh) + # The first point starts at (0,0,0), so its position = offset + offsets.append(translated.points[0].clone()) + offsets_t = torch.stack(offsets) + assert offsets_t.mean(dim=0).abs().max().item() < 0.01 + assert offsets_t.abs().mean().item() == pytest.approx(0.02, abs=0.005) + + def test_batched_per_axis_distribution(self): + """Batched Uniform should sample different ranges per axis.""" + dist = D.Uniform( + torch.tensor([-1.0, -2.0, -3.0]), + torch.tensor([1.0, 2.0, 3.0]), + ) + aug = _seed(RandomTranslateMesh(distribution=dist), 0) + mesh = _simple_mesh_3d() + offsets = [] + for _ in range(1000): + translated = aug(mesh) + offsets.append(translated.points[0].clone()) + offsets_t = torch.stack(offsets) + # Axis 0: range [-1, 1] + assert offsets_t[:, 0].min() >= -1.0 + assert offsets_t[:, 0].max() <= 1.0 + # Axis 1: range [-2, 2] + assert offsets_t[:, 1].min() >= -2.0 + assert offsets_t[:, 1].max() <= 2.0 + # Axis 2: range [-3, 3] + assert offsets_t[:, 2].min() >= -3.0 + assert offsets_t[:, 2].max() <= 3.0 + + def test_reproducibility(self): + """Same seed should produce identical translations.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomTranslateMesh(distribution=D.Normal(0.0, 1.0)), 99) + r1 = aug1(mesh) + + aug2 = _seed(RandomTranslateMesh(distribution=D.Normal(0.0, 1.0)), 99) + r2 = aug2(mesh) + + assert torch.allclose(r1.points, r2.points) + + def test_sequence_reproducibility(self): + """Same seed should produce identical results over multiple calls.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomTranslateMesh(distribution=D.Normal(0.0, 1.0)), 99) + seq1 = [aug1(mesh).points.clone() for _ in range(10)] + + aug2 = _seed(RandomTranslateMesh(distribution=D.Normal(0.0, 1.0)), 99) + seq2 = [aug2(mesh).points.clone() for _ in range(10)] + + for s1, s2 in zip(seq1, seq2): + assert torch.allclose(s1, s2) + + def test_apply_to_domain_consistent(self): + """apply_to_domain should use the same offset for all meshes.""" + domain = _simple_domain_3d() + aug = _seed(RandomTranslateMesh(distribution=D.Uniform(-1.0, 1.0)), 0) + translated = aug.apply_to_domain(domain) + assert torch.allclose( + translated.interior.points, translated.boundaries["wall"].points + ) + + def test_apply_to_domain_reproducibility(self): + """Same seed should produce identical domain results.""" + domain = _simple_domain_3d() + + aug1 = _seed(RandomTranslateMesh(distribution=D.Laplace(0.0, 0.5)), 77) + d1 = aug1.apply_to_domain(domain) + + aug2 = _seed(RandomTranslateMesh(distribution=D.Laplace(0.0, 0.5)), 77) + d2 = aug2.apply_to_domain(domain) + + assert torch.allclose(d1.interior.points, d2.interior.points) + assert torch.allclose( + d1.boundaries["wall"].points, d2.boundaries["wall"].points + ) + + +# --------------------------------------------------------------------------- +# RandomRotateMesh +# --------------------------------------------------------------------------- + + +class TestRandomRotateMesh: + """Tests for RandomRotateMesh with distribution-based sampling.""" + + def test_default_distribution(self): + """Default distribution should be Uniform(-pi, pi).""" + aug = RandomRotateMesh() + assert isinstance(aug._distribution, D.Uniform) + + def test_small_angle_gaussian(self): + """Normal(0, 0.1) should produce small rotations near identity.""" + aug = _seed(RandomRotateMesh(distribution=D.Normal(0.0, 0.1)), 0) + mesh = _simple_mesh_3d() + original_points = mesh.points.clone() + displacements = [] + for _ in range(200): + rotated = aug(mesh) + disp = (rotated.points - original_points).norm(dim=-1).max() + displacements.append(disp.item()) + # Small angles -> small displacements (max point is at distance 1 from origin) + avg_disp = sum(displacements) / len(displacements) + assert avg_disp < 0.2 + + def test_axis_restriction(self): + """Rotation about z-axis only should not change z coordinates.""" + aug = _seed( + RandomRotateMesh( + axes=["z"], + distribution=D.Uniform(-math.pi, math.pi), + ), + 0, + ) + mesh = _simple_mesh_3d() + for _ in range(20): + rotated = aug(mesh) + assert torch.allclose(rotated.points[:, 2], mesh.points[:, 2], atol=1e-6) + + def test_uniform_mode_ignores_distribution(self): + """mode='uniform' should work regardless of distribution parameter.""" + aug = _seed( + RandomRotateMesh(mode="uniform", distribution=D.Normal(0.0, 0.01)), + 0, + ) + mesh = _simple_mesh_3d() + rotated = aug(mesh) + # Should produce a valid rotation (points should have same norms) + orig_norms = mesh.points.norm(dim=-1) + rot_norms = rotated.points.norm(dim=-1) + assert torch.allclose(orig_norms, rot_norms, atol=1e-5) + + def test_uniform_mode_orthogonal_matrix(self): + """mode='uniform' should produce orthogonal rotation matrices (det=+1).""" + aug = _seed(RandomRotateMesh(mode="uniform"), 42) + # Sample several rotation matrices and check orthogonality + for _ in range(20): + R = aug._sample_uniform_rotation() + assert R.shape == (3, 3) + assert torch.allclose(R @ R.T, torch.eye(3), atol=1e-5) + assert torch.det(R).item() == pytest.approx(1.0, abs=1e-5) + + def test_reproducibility(self): + """Same seed should produce identical rotations.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomRotateMesh(distribution=D.Normal(0.0, 1.0)), 55) + r1 = aug1(mesh) + + aug2 = _seed(RandomRotateMesh(distribution=D.Normal(0.0, 1.0)), 55) + r2 = aug2(mesh) + + assert torch.allclose(r1.points, r2.points) + + def test_apply_to_domain_consistent(self): + """apply_to_domain should use the same rotation for all meshes.""" + domain = _simple_domain_3d() + aug = _seed( + RandomRotateMesh(distribution=D.Uniform(-math.pi, math.pi)), + 0, + ) + rotated = aug.apply_to_domain(domain) + assert torch.allclose( + rotated.interior.points, + rotated.boundaries["wall"].points, + atol=1e-6, + ) + + def test_invalid_mode_raises(self): + """Invalid mode should raise ValueError.""" + with pytest.raises(ValueError, match="mode must be"): + RandomRotateMesh(mode="bogus") + + def test_uniform_mode_3d_only(self): + """mode='uniform' should reject non-3D meshes.""" + aug = RandomRotateMesh(mode="uniform") + mesh_2d = Mesh( + points=torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]), + cells=torch.tensor([[0, 1, 2]]), + ) + with pytest.raises(ValueError, match="3-D meshes"): + aug(mesh_2d) + + def test_extra_repr_axis_aligned(self): + """extra_repr should mention axes and distribution.""" + aug = RandomRotateMesh(axes=["z"], distribution=D.Normal(0.0, 0.1)) + r = aug.extra_repr() + assert "z" in r + assert "Normal" in r + + def test_sequence_reproducibility(self): + """Same seed should produce identical results over multiple calls.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomRotateMesh(distribution=D.Normal(0.0, 1.0)), 55) + seq1 = [aug1(mesh).points.clone() for _ in range(10)] + + aug2 = _seed(RandomRotateMesh(distribution=D.Normal(0.0, 1.0)), 55) + seq2 = [aug2(mesh).points.clone() for _ in range(10)] + + for s1, s2 in zip(seq1, seq2): + assert torch.allclose(s1, s2) + + def test_uniform_mode_reproducibility(self): + """Same seed should produce identical uniform SO(3) rotations.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomRotateMesh(mode="uniform"), 12) + r1 = aug1(mesh) + + aug2 = _seed(RandomRotateMesh(mode="uniform"), 12) + r2 = aug2(mesh) + + assert torch.allclose(r1.points, r2.points) + + def test_uniform_mode_sequence_reproducibility(self): + """Same seed should produce identical uniform rotation sequences.""" + mesh = _simple_mesh_3d() + + aug1 = _seed(RandomRotateMesh(mode="uniform"), 12) + seq1 = [aug1(mesh).points.clone() for _ in range(10)] + + aug2 = _seed(RandomRotateMesh(mode="uniform"), 12) + seq2 = [aug2(mesh).points.clone() for _ in range(10)] + + for s1, s2 in zip(seq1, seq2): + assert torch.allclose(s1, s2) + + def test_apply_to_domain_reproducibility(self): + """Same seed should produce identical domain rotations.""" + domain = _simple_domain_3d() + + aug1 = _seed( + RandomRotateMesh(distribution=D.Uniform(-math.pi, math.pi)), + 0, + ) + d1 = aug1.apply_to_domain(domain) + + aug2 = _seed( + RandomRotateMesh(distribution=D.Uniform(-math.pi, math.pi)), + 0, + ) + d2 = aug2.apply_to_domain(domain) + + assert torch.allclose(d1.interior.points, d2.interior.points, atol=1e-6) + assert torch.allclose( + d1.boundaries["wall"].points, + d2.boundaries["wall"].points, + atol=1e-6, + ) + + def test_extra_repr_uniform(self): + """extra_repr for uniform mode should mention 'uniform'.""" + aug = RandomRotateMesh(mode="uniform") + assert "uniform" in aug.extra_repr() + + +# --------------------------------------------------------------------------- +# Composed pipeline reproducibility +# --------------------------------------------------------------------------- + + +class TestPipelineReproducibility: + """Tests that a composed augmentation pipeline is reproducible end-to-end.""" + + def test_scale_translate_rotate_pipeline(self): + """Chaining scale -> translate -> rotate with matched seeds is reproducible.""" + mesh = _simple_mesh_3d() + + def _build_pipeline(seed): + return [ + _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.1)), seed), + _seed(RandomTranslateMesh(distribution=D.Laplace(0.0, 0.05)), seed + 1), + _seed(RandomRotateMesh(distribution=D.Normal(0.0, 0.3)), seed + 2), + ] + + pipeline1 = _build_pipeline(42) + pipeline2 = _build_pipeline(42) + + # Apply each pipeline for several iterations + for _ in range(5): + m1 = mesh + for aug in pipeline1: + m1 = aug(m1) + m2 = mesh + for aug in pipeline2: + m2 = aug(m2) + assert torch.allclose(m1.points, m2.points) + + def test_pipeline_different_seeds_differ(self): + """Different seeds should produce different pipeline results.""" + mesh = _simple_mesh_3d() + + def _build_pipeline(seed): + return [ + _seed(RandomScaleMesh(distribution=D.Normal(1.0, 0.1)), seed), + _seed(RandomTranslateMesh(distribution=D.Laplace(0.0, 0.05)), seed + 1), + _seed(RandomRotateMesh(distribution=D.Normal(0.0, 0.3)), seed + 2), + ] + + pipeline1 = _build_pipeline(0) + pipeline2 = _build_pipeline(999) + + m1 = mesh + for aug in pipeline1: + m1 = aug(m1) + m2 = mesh + for aug in pipeline2: + m2 = aug(m2) + assert not torch.allclose(m1.points, m2.points) + + def test_domain_pipeline_reproducibility(self): + """Composed pipeline should be reproducible when applied to DomainMesh.""" + domain = _simple_domain_3d() + + def _build_pipeline(seed): + return [ + _seed(RandomScaleMesh(distribution=D.Uniform(0.8, 1.2)), seed), + _seed(RandomTranslateMesh(distribution=D.Normal(0.0, 0.1)), seed + 1), + _seed( + RandomRotateMesh(distribution=D.Uniform(-math.pi, math.pi)), + seed + 2, + ), + ] + + pipeline1 = _build_pipeline(7) + pipeline2 = _build_pipeline(7) + + d1 = domain + for aug in pipeline1: + d1 = aug.apply_to_domain(d1) + d2 = domain + for aug in pipeline2: + d2 = aug.apply_to_domain(d2) + + assert torch.allclose(d1.interior.points, d2.interior.points, atol=1e-6) + assert torch.allclose( + d1.boundaries["wall"].points, + d2.boundaries["wall"].points, + atol=1e-6, + ) + + +# --------------------------------------------------------------------------- +# DataLoader-driven pipeline reproducibility +# --------------------------------------------------------------------------- + + +class TestDataLoaderDrivenReproducibility: + """Tests that the DataLoader seed drives the full pipeline reproducibly.""" + + def test_mesh_dataset_set_generator_distributes(self, tmp_path): + """set_generator should give independent generators to reader + transforms.""" + from physicsnemo.datapipes.mesh_dataset import MeshDataset + from physicsnemo.datapipes.readers.mesh import MeshReader + + mesh = _simple_mesh_3d() + mesh.save(tmp_path / "a.pt") + mesh.save(tmp_path / "b.pt") + + reader = MeshReader(tmp_path, pattern="*.pt") + transforms = [ + RandomScaleMesh(distribution=D.Uniform(0.5, 2.0)), + RandomTranslateMesh(distribution=D.Uniform(-0.5, 0.5)), + ] + ds = MeshDataset(reader, transforms=transforms) + + master = torch.Generator().manual_seed(42) + ds.set_generator(master) + + # Reader and both transforms should have received generators + assert reader._subsample_generator is not None + assert transforms[0]._generator is not None + assert transforms[1]._generator is not None + + # Generators should have different seeds (independent forks) + seeds = { + reader._subsample_generator.initial_seed(), + transforms[0]._generator.initial_seed(), + transforms[1]._generator.initial_seed(), + } + assert len(seeds) == 3 + + def test_dataloader_seed_produces_identical_sequences(self, tmp_path): + """Two MeshDatasets seeded identically produce identical transform results.""" + from physicsnemo.datapipes.mesh_dataset import MeshDataset + from physicsnemo.datapipes.readers.mesh import MeshReader + + mesh = _simple_mesh_3d() + for i in range(4): + mesh.save(tmp_path / f"s{i}.pt") + + def _build(seed): + reader = MeshReader(tmp_path, pattern="*.pt") + transforms = [RandomScaleMesh(distribution=D.Uniform(0.5, 2.0))] + ds = MeshDataset(reader, transforms=transforms) + gen = torch.Generator().manual_seed(seed) + ds.set_generator(gen) + return ds + + ds1 = _build(123) + ds2 = _build(123) + + for i in range(len(ds1)): + m1, _ = ds1[i] + m2, _ = ds2[i] + assert torch.allclose(m1.points, m2.points) + + def test_different_seeds_produce_different_results(self, tmp_path): + """Two MeshDatasets with different seeds produce different results.""" + from physicsnemo.datapipes.mesh_dataset import MeshDataset + from physicsnemo.datapipes.readers.mesh import MeshReader + + mesh = _simple_mesh_3d() + mesh.save(tmp_path / "s0.pt") + + def _build(seed): + reader = MeshReader(tmp_path, pattern="*.pt") + transforms = [RandomScaleMesh(distribution=D.Uniform(0.5, 2.0))] + ds = MeshDataset(reader, transforms=transforms) + gen = torch.Generator().manual_seed(seed) + ds.set_generator(gen) + return ds + + ds1 = _build(0) + ds2 = _build(999) + + m1, _ = ds1[0] + m2, _ = ds2[0] + assert not torch.allclose(m1.points, m2.points) + + def test_set_epoch_changes_randomness(self, tmp_path): + """set_epoch reseeds transforms so different epochs differ.""" + from physicsnemo.datapipes.mesh_dataset import MeshDataset + from physicsnemo.datapipes.readers.mesh import MeshReader + + mesh = _simple_mesh_3d() + mesh.save(tmp_path / "s0.pt") + + reader = MeshReader(tmp_path, pattern="*.pt") + transforms = [RandomScaleMesh(distribution=D.Uniform(0.5, 2.0))] + ds = MeshDataset(reader, transforms=transforms) + gen = torch.Generator().manual_seed(42) + ds.set_generator(gen) + + ds.set_epoch(0) + m0, _ = ds[0] + + ds.set_epoch(1) + m1, _ = ds[0] + + # Different epochs should produce different scale factors + assert not torch.allclose(m0.points, m1.points) + + def test_set_epoch_is_deterministic(self, tmp_path): + """Resetting to the same epoch reproduces the same result.""" + from physicsnemo.datapipes.mesh_dataset import MeshDataset + from physicsnemo.datapipes.readers.mesh import MeshReader + + mesh = _simple_mesh_3d() + mesh.save(tmp_path / "s0.pt") + + def _run_epoch(seed, epoch): + reader = MeshReader(tmp_path, pattern="*.pt") + transforms = [RandomScaleMesh(distribution=D.Uniform(0.5, 2.0))] + ds = MeshDataset(reader, transforms=transforms) + gen = torch.Generator().manual_seed(seed) + ds.set_generator(gen) + ds.set_epoch(epoch) + m, _ = ds[0] + return m.points.clone() + + pts_a = _run_epoch(42, 5) + pts_b = _run_epoch(42, 5) + assert torch.allclose(pts_a, pts_b)