diff --git a/CHANGELOG.md b/CHANGELOG.md index 090fb6a4a1..578f86af45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Examples: Added Neural Operator Factory for reservoir simulation in + `examples/reservoir_simulation/neural_operator_factory/`. Config-driven + framework supporting FNO, U-FNO, Conv-FNO, FNO4D, DeepONet (7 variants), + and TNO architectures on 2D and 3D spatial datasets. Includes physics-informed + losses (derivative regularization, mass conservation), three-stage + autoregressive training (teacher forcing, pushforward, rollout), per-sample + domain masking, multi-GPU DDP, and reproducible examples for U-FNO, U-DeepONet, + Fourier-MIONet, and TNO papers on the CO2 sequestration and Norne field + datasets. - Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`) - Adds GLOBE AirFRANS example case (`examples/cfd/external_aerodynamics/globe/airfrans`) - Adds automatic support for `FSDP` and/or `ShardTensor` models in checkpoint save/load diff --git a/examples/reservoir_simulation/neural_operator_factory/.gitignore b/examples/reservoir_simulation/neural_operator_factory/.gitignore new file mode 100644 index 0000000000..7e5c216b3a --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/.gitignore @@ -0,0 +1,30 @@ +# Temporary directories +logs/ +outputs/ +mlruns/ +visualizations/ +__pycache__/ +checkpoints/ + +# NFS lock files +.nfs* + +# Cluster-specific files +*.sbatch + +# Temporary scripts +check_gpu_memory.sh +launch.log + +# Python cache +*.pyc +*.pyo +*.pyd +.Python + +# Jupyter +.ipynb_checkpoints/ + +# IDE +.vscode/ +.idea/ diff --git a/examples/reservoir_simulation/neural_operator_factory/README.md b/examples/reservoir_simulation/neural_operator_factory/README.md new file mode 100644 index 0000000000..9717c55e18 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/README.md @@ -0,0 +1,173 @@ +# Neural Operator Factory + +A config-driven framework for training neural operator surrogates +for reservoir simulation, built on +[PhysicsNeMo](https://github.com/NVIDIA/physicsnemo). +Switch between 165 model architectures, 6 training regimes, and +physics-informed loss combinations — all from YAML, zero code changes. + +## Why NOF + +Pick a model, a training strategy, and a loss function. +The framework handles everything else: multi-GPU distribution, +autoregressive rollout, inactive-cell masking, checkpointing, +and metric tracking. + +**165 architectures** from composable building blocks: + +| Family | How it works | Count | +|--------|-------------|-------| +| **xFNO** | Fourier base +/- UNet +/- Conv | 5 | +| **DeepONet** | 1 branch (8 types) + trunk | 16 | +| **MIONet** | 2 branches + trunk | 16 | +| **TNO** | 2 independent branches + trunk | 128 | + +Spatial branches are assembled from **three independent layer types** +(Fourier, UNet, Conv) that can be freely combined — a Fourier-UNet +branch, a Conv-only branch, or a triple Fourier-UNet-Conv hybrid +are all one config change apart. + +**6 training regimes**, all model-agnostic: + +| Regime | Description | +|--------|-------------| +| Full-mapping | Entire trajectory in one forward pass | +| AR: teacher forcing | Ground-truth input at each window | +| AR: TF + rollout (detached) | Free-running with per-window gradients | +| AR: TF + rollout (live) | Full-trajectory backprop through the chain | +| AR: TF + pushforward + rollout | Curriculum unroll bridging TF and rollout | + +**Physics-informed losses** compose on top: + +- **Data**: MSE, L1, relative L2, Huber (combinable with weights) +- **Derivative regularization**: central differences on any + spatial axis, using actual cell widths from the grid +- **Mass conservation**: volume-weighted spatial integration + constraint at each timestep + +Together this gives **~46,500** unique configurations from YAML alone. + +## What You Get for Free + +Every configuration — from a basic FNO to a triple-hybrid TNO — +automatically inherits: + +- **Multi-GPU DDP** with proper gradient sync, distributed + sampling, and rank-0-only I/O +- **DDP-safe autoregressive rollout** (`no_sync` + manual + AllReduce for multi-forward AR steps) +- **Automatic mask detection** (ACTNUM, non-zero fallback, + per-sample) propagated through all losses and metrics +- **Dimension-agnostic pipeline** — the same code handles 2D and + 3D spatial data; losses, AR utilities, and masking adapt + automatically +- **Lazy module initialization** — dummy forward pass materializes + all layers before DDP wrapping, with correct branch2 handling + for TNO and MIONet +- **Self-describing checkpoints** — architecture config saved + alongside weights for one-line model reconstruction +- **Optimizer and scheduler resume** — seamless training + continuation from any checkpoint +- **BatchNorm freeze** during live-gradient rollout to prevent + autograd graph invalidation +- **Mixed precision** (AMP) for any model via a single flag (works with selected models) +- **Configurable validation metrics** (RMSE, MAE, MRE, MPE, + relative L2) with automatic per-sample masking + +## Quick Start + +```bash +# Train a TNO on Norne (8-GPU DDP via SLURM) +sbatch examples/pi_norne/train.sbatch pressure_training_config + +# Train a U-FNO on CO2 saturation +sbatch examples/ufno_co2/train.sbatch U-FNO saturation_training_config + +# Evaluate +sbatch examples/tno_co2/eval.sbatch saturation +``` + +All commands run from the `neural_operator_factory/` directory. + +## Reproduced Papers + +Each example ships with configs that reproduce published results: + +| Example | Paper | Architecture | Dataset | +|---------|-------|-------------|---------| +| [pi_norne](examples/pi_norne/) | — | Fourier-DeepONet + feedback | Norne field (3D) | +| [ufno_co2](examples/ufno_co2/) | Wen et al. 2022 | FNO, Conv-FNO, U-FNO | CO2 sequestration | +| [udeeponet_co2](examples/udeeponet_co2/) | Diab & Al Kobaisi 2024 | U-DeepONet | CO2 sequestration | +| [fourier_mionet_co2](examples/fourier_mionet_co2/) | Jiang et al. 2024 | MIONet, Fourier-MIONet | CO2 sequestration | +| [tno_co2](examples/tno_co2/) | Diab & Al Kobaisi 2025 | TNO | CO2 sequestration | + +## Dataset Format + +Input and output tensors in `.pt` format: + +- **2D spatial**: input `(N, H, W, T, C)`, output `(N, H, W, T)` +- **3D spatial**: input `(N, X, Y, Z, T, C)`, output `(N, X, Y, Z, T)` + +The last input channels must follow the NOF grid convention: + +| 2D (last 3 channels) | 3D (last 4 channels) | Used by | +|-----------------------|----------------------|---------| +| grid\_x (W widths) | grid\_x (X widths) | Derivative loss | +| grid\_y (H widths) | grid\_y (Y widths) | Derivative loss | +| — | grid\_z (Z widths) | Derivative loss | +| grid\_t (time) | grid\_t (time) | DeepONet trunk | + +Inactive cells must be zero in all channels. The framework +auto-detects binary ACTNUM masks, falls back to non-zero +pattern detection, and supports per-sample masks when reservoir +geometry varies across realizations. + +The CO2 sequestration dataset used by the included examples is +publicly available at: + + +## Project Structure + +```text +neural_operator_factory/ +├── models/ xfno.py, xdeeponet.py, unet.py, physicsnemo_unet.py +├── data/ dataloader.py, validation.py, scalar_utils.py +├── training/ train.py, losses.py, physics_losses.py, ar_utils.py, metrics.py +├── utils/ checkpoint.py, padding.py, co2_normalization.py +├── conf/ model_config.yaml, training_config.yaml +├── examples/ pi_norne/, ufno_co2/, udeeponet_co2/, fourier_mionet_co2/, tno_co2/ +└── tests/ 375 unit tests +``` + +## References + +1. Wen, G. et al. (2022). "U-FNO — An enhanced Fourier neural + operator-based deep-learning model for multiphase flow." + *Advances in Water Resources*, 163, 104180. +2. Diab, W. & Al Kobaisi, M. (2024). "U-DeepONet: U-Net + enhanced deep operator network for geologic carbon + sequestration." *Scientific Reports*, 14, 21298. +3. Jiang, Z. et al. (2024). "Fourier-MIONet: Fourier-enhanced + multiple-input neural operators for multiphase modeling of + geological carbon sequestration." + *Reliability Eng. & System Safety*, 251, 110392. +4. Diab, W. & Al Kobaisi, M. (2025). "Temporal neural operator + for modeling time-dependent physical phenomena." + *Scientific Reports*, 15. +5. Li, Z. et al. (2021). "Fourier Neural Operator for Parametric + Partial Differential Equations." *ICLR 2021*. +6. Lu, L. et al. (2021). "Learning nonlinear operators via + DeepONet." *Nature Machine Intelligence*, 3, 218-229. +7. Jin, P., Meng, S. & Lu, L. (2022). "MIONet: Learning + multiple-input operators via tensor product." + *SIAM J. Scientific Computing*, 44(6), A3490-A3514. +8. Zhu, M. et al. (2023). "Fourier-DeepONet: Fourier-enhanced + deep operator networks for full waveform inversion." + *arXiv:2305.17289*. +9. Chandra, A. et al. (2025). "Neural operators for + accelerating scientific simulations and design." + *arXiv:2503.11031*. + +## License + +Apache License 2.0. See [LICENSE](../../../LICENSE.txt). diff --git a/examples/reservoir_simulation/neural_operator_factory/conf/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/conf/model_config.yaml new file mode 100644 index 0000000000..c93fc8f971 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/conf/model_config.yaml @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Model Architecture Configuration +# ============================================================================= +# This config is used with training_config.yaml (Norne 4D TNO by default). +# For examples with self-contained configs, see examples/ufno_co2/. +# +# The loss section requires input channels to follow the NOF convention: +# 3d: last 3 channels = [grid_x, grid_y, grid_t] +# 4d: last 4 channels = [grid_x, grid_y, grid_z, grid_t] +# See README.md for full dataset requirements. + +arch: + # --------------------------------------------------------------------------- + # Step 1: Problem Dimensions + # --------------------------------------------------------------------------- + # '3d' = 2D spatial + time, input shape (B, H, W, T, C) - e.g., CO2 dataset + # '4d' = 3D spatial + time, input shape (B, X, Y, Z, T, C) - e.g., Norne dataset + # + # IMPORTANT: This must match your dataset! The data loader will validate + # that the loaded data matches this dimension setting. + # Configure dataset files in training_config.yaml -> data section. + dimensions: 4d + + # --------------------------------------------------------------------------- + # Step 2: Model Architecture + # --------------------------------------------------------------------------- + # 'xfno' = Fourier Neural Operator variants + # 'xdeeponet' = DeepONet variants + model: xdeeponet + + + + # =========================================================================== + # xFNO Configuration + # =========================================================================== + # Active when: model = xfno + # + # Resulting model class: + # - 3d: UFNO/UFNONet (SpectralConv3d + optional UNet/Conv) + # - 4d: FNO4D/FNO4DNet (SpectralConv4d, pure FNO only) + + xfno: + out_channels: 1 + width: 36 + padding: 8 + + # ── Encoder (lifting) ────────────────────────────────────────────── + # Projects raw input channels to the latent width. + # type: mlp — Linear or FullyConnected (channels-last, then permute) + # type: conv — Conv3dFCLayer pointwise convolution (channels-first) + # 4D only supports mlp lifting (ConvNdFCLayer used internally). + lifting_type: mlp # 'mlp' or 'conv' (3d only) + lifting_layers: 1 # Number of layers in lifting network + lifting_width: 36 # Hidden width factor for multi-layer lifting + + # ── Processing Layers ────────────────────────────────────────────── + # Fourier layers: SpectralConv + 1x1 Conv (always present) + # UNet layers: Hybrid Fourier + UNet skip (3d only, >0 = U-FNO) + # Conv layers: Hybrid Fourier + Conv skip (3d only, >0 = Conv-FNO) + num_fno_layers: 4 + num_unet_layers: 0 # 3d only — >0 enables U-FNO + num_conv_layers: 0 # 3d only — >0 enables Conv-FNO + activation_fn: relu # Activation for all processing layers + + # Fourier modes per dimension + modes1: 10 # H or X dimension + modes2: 10 # W or Y dimension + modes3: 10 # T dimension (3d) or Z dimension (4d) + modes4: 6 # T dimension (4d only, ignored for 3d) + + # UNet / Conv settings (3d only) + unet_type: physicsnemo # 'custom' or 'physicsnemo' + unet_kernel_size: 3 + unet_dropout: 0.0 + conv_kernel_size: 3 + + # ── Decoder ──────────────────────────────────────────────────────── + # Projects latent width back to output channels. + # type: mlp — FullyConnected (channels-last) + # type: conv — Conv3dFCLayer pointwise convolution (channels-first) + decoder_type: mlp # 'mlp' or 'conv' (3d only) + decoder_layers: 1 # Number of hidden layers in decoder + decoder_width: 128 # Hidden layer width in decoder + decoder_activation_fn: relu # Activation for decoder hidden layers (last layer is linear) + + # ── 4d only options ──────────────────────────────────────────────── + coord_features: true # Add (x, y, z, t) coordinate features + + + + + # =========================================================================== + # xDeepONet Configuration + # =========================================================================== + # Active when: model = xdeeponet + # + # Resulting model class: + # - 3d: DeepONet/DeepONetWrapper (2D spatial branch) + # - 4d: DeepONet3D/DeepONet3DWrapper (3D spatial branch) + # + # Variants: + # - deeponet: Basic DeepONet (MLP branch only) + # - u_deeponet: UNet-enhanced spatial branch + # - fourier_deeponet: Fourier layers in spatial branch + # - conv_deeponet: Conv layers in spatial branch + # - hybrid_deeponet: Fourier + UNet + Conv combination + # - mionet: Two-branch architecture + # - fourier_mionet: Two-branch with Fourier layers + # - tno: Temporal Neural Operator (branch2 = previous solution, AR only) + xdeeponet: + variant: tno + width: 128 + padding: 8 + + # Branch 1: Primary spatial encoder + # + # encoder: How to lift raw input channels to the latent width. + # type: linear — single LazyLinear(in_channels → width). Default for spatial processing. + # type: mlp — multi-layer MLP(in_channels → hidden → width). Use alone for pure MLP branch + # (set all layers counts to 0), or as a learned encoder before Fourier/UNet layers. + # type: conv — multi-layer pointwise conv encoder (Linear layers with activations). + # Equivalent to 1x1 convolutions at each spatial point. Use for deeper + # lifting before Fourier/UNet layers. + # + # layers: Processing layers applied after the encoder. + # Fourier, UNet, and Conv layers can be combined. When num_fourier_layers > 0, UNet/Conv layers + # become hybrid: each gets a parallel Fourier component added to its output. + # Set all to 0 for a pure MLP branch (encoder.type must be 'mlp'). + # + branch1: + encoder: + type: linear # 'linear' or 'mlp' + hidden_width: 64 # mlp-only: hidden layer width + num_layers: 2 # mlp-only: number of layers + activation_fn: tanh # mlp-only: activation function + layers: + num_fourier_layers: 1 # Number of pure Fourier (spectral conv + 1x1 conv) layers + num_unet_layers: 1 # Number of UNet layers (hybrid with Fourier when fourier > 0) + num_conv_layers: 0 # Number of Conv layers (hybrid with Fourier when fourier > 0) + modes1: 10 # Fourier modes in H/X direction + modes2: 10 # Fourier modes in W/Y direction + modes3: 10 # Fourier modes in Z direction (4D only) + kernel_size: 3 # Kernel size for UNet/Conv layers + dropout: 0.0 # Dropout for UNet (custom impl only) + unet_impl: physicsnemo # 'custom' or 'physicsnemo' + activation_fn: tanh # Activation for processing layers + # Adaptive pool to fixed resolution before processing, upsample back after. + # Decouples model complexity from grid size. null = native resolution. + # Set to [H, W] (3d) or [X, Y, Z] (4d). + internal_resolution: null + + # Branch 2: Secondary encoder (for mionet/fourier_mionet/tno) + # Same structure as branch1. For TNO, receives the previous solution state. + branch2: + encoder: + type: linear + hidden_width: 128 + num_layers: 1 + activation_fn: tanh + layers: + num_fourier_layers: 1 + num_unet_layers: 1 + num_conv_layers: 0 + modes1: 10 + modes2: 10 + modes3: 10 + kernel_size: 3 + dropout: 0.0 + unet_impl: physicsnemo + activation_fn: tanh + + # Trunk: Temporal/coordinate encoder + trunk: + input_type: time # 'time' (t only) or 'grid' (x,y,t for 3d / x,y,z,t for 4d) + hidden_width: 128 + num_layers: 8 + activation_fn: tanh # activation for hidden layers (sin, tanh, relu, gelu, etc.) + # output_activation: whether to apply activation_fn to the trunk's output layer. + # true (default) = activated output; used by DeepONet, U-DeepONet, MIONet. + # false = linear output; used by TNO — avoids squashing the Hadamard + # product's dynamic range, allowing the trunk to produce + # arbitrary magnitudes. + output_activation: true + + # Decoder: Output projection + # 'mlp': per-timestep trunk query → MLP → 1. Standard DeepONet decoding. + # The trunk is queried at each target timestep separately. + # 'conv': per-timestep trunk query → conv decoder → 1. + # 'temporal_projection': single trunk query → Hadamard product → MLP → + # linear head (width → K timesteps). The trunk is queried once and a + # learned linear projection produces all K output timesteps directly. + # Faster inference for autoregressive bundling. Requires + # set_output_window(K) at runtime (handled automatically by train.py). + decoder_type: mlp + decoder_width: 128 + decoder_layers: 2 + decoder_activation_fn: relu + + +# ============================================================================= +# Loss Function Configuration +# ============================================================================= +# When masking is enabled (training_config.yaml -> mask_enabled: true), +# only active cells contribute to the loss. For relative_l2, active +# cells are selected per-sample so norms are not diluted by zeros. +# Derivative and physics losses use a union mask across the batch. +loss: + # Data losses: mse, l1, relative_l2, huber, usage exmaple: [relative_l2, l1] + # Multiple types can be combined with weights: [w1, w2, ...] + types: [relative_l2] + weights: [1.0] + + # Spatial derivative regularization + # Usage: dims: [dx] or [dx, dy] (2D) or [dx, dy, dz] (3D) + derivative: + enabled: true + weight: 0.5 + dims: [dx, dy, dz] + metric: null # null = inherit first data loss type + + # Physics-informed losses + physics: + mass_conservation: + enabled: true + weight: 0.5 + use_cell_volumes: true + metric: null # null = inherit first data loss; or mse, l1, relative_l2, huber diff --git a/examples/reservoir_simulation/neural_operator_factory/conf/training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/conf/training_config.yaml new file mode 100644 index 0000000000..601664f525 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/conf/training_config.yaml @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Neural Operator Factory — Training Configuration +# +# This is the base config for the Norne 4D TNO experiment. +# For CO2 examples, use the self-contained configs under examples/ufno_co2/. +# +# Input channel ordering convention (last channels): +# 3d: [..., grid_x, grid_y, grid_t] channels [-3, -2, -1] +# 4d: [..., grid_x, grid_y, grid_z, grid_t] channels [-4, -3, -2, -1] +# Inactive cells must be zero in all input channels and the output. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + # --------------------------------------------------------------------------- + # Dataset Configuration + # --------------------------------------------------------------------------- + # The dataset files should match the dimensions specified in arch.dimensions + # - 3d: expects (N, H, W, T, C) input, (N, H, W, T) output + # - 4d: expects (N, X, Y, Z, T, C) input, (N, X, Y, Z, T) output + + data_path: /data/norne + + # Variable name — used for checkpoint naming and metric selection, not file resolution. + # Choices: 'pressure', 'saturation' + variable: saturation + + # Input/output file patterns ({mode} is replaced with train/val/test). + # Available Norne dataset files: + # Inputs: norne_{mode}_a.pt (or symlink norne_{mode}_input.pt) + # Outputs: norne_{mode}_pressure.pt (pressure, single channel) + # norne_{mode}_swat.pt (water saturation, single channel) + # norne_{mode}_sgas.pt (gas saturation, single channel) + # norne_{mode}_u.pt (all 3 channels — requires multi-output support) + # For CO2 dataset: set input_file/output_file to null and use variable instead + # (resolves to dP_{mode}_a.pt / dP_{mode}_u.pt or sg_* automatically) + input_file: norne_{mode}_a.pt + output_file: norne_{mode}_swat.pt + + # z-score normalize inputs and outputs at load time. + # Set false if the .pt files are already pre-normalized. + # The dataloader computes mean/std from the training set and shares with val/test. + normalize: false + + # Number of DataLoader worker processes per GPU for parallel data loading. + # Higher values speed up I/O-bound training but use more CPU memory. + # Recommended: 2-4 per GPU. Set to 0 for debugging (loads in main process). + num_workers: 2 + + # --------------------------------------------------------------------------- + # Validation Metric Configuration (dataset-agnostic) + # --------------------------------------------------------------------------- + # val_metric: Which metric to compute during validation. + # Choices: 'rmse', 'mae', 'mre', 'mpe', 'relative_l2' + # Default: 'rmse' (works for any dataset without assumptions). + # CO2 examples: 'mre' for pressure, 'mpe' for saturation. + val_metric: rmse + + # mask_enabled: Exclude inactive cells from loss and metrics. + # Auto-detects the best mask channel from inputs (priority order): + # 1) ACTNUM channel (binary {0,1}, static across time) + # 2) Non-zero channel fallback (static zero pattern across time) + # 3) No mask (all cells active) + # If the mask varies across samples, it is constructed per-batch. + mask_enabled: true + + # mask_channel: Explicit input channel index for the active-cell mask. + # null = auto-detect (recommended). Set to an integer to override. + # Example: 5 for Norne ACTNUM, 0 for CO2 permeability. + mask_channel: null + + # num_timesteps: Truncate the time axis to the first N timesteps for + # train and val splits. The test split always keeps ALL timesteps, + # enabling temporal extrapolation evaluation. + # null (default) = use all timesteps in the .pt files. + # Example: 16 for CO2 (train on 1.8 years, test extrapolation to 30 years). + num_timesteps: 50 + + +training: + batch_size: 2 + epochs: 100 # used by full_mapping only; ignored if regime is autoregressive below + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + # Path to a checkpoint to resume training from. null = start fresh. + resume_from_checkpoint: null + + # --------------------------------------------------------------------------- + # Training Regime + # --------------------------------------------------------------------------- + # full_mapping: Predict entire trajectory in one forward pass (default). + # autoregressive: Predict K timesteps from L context timesteps. + # Three-stage training: teacher forcing -> pushforward -> rollout + # Works with any model type and any spatial dimensionality: 3d (CO2) and 4d (Norne) + regime: autoregressive + + autoregressive: + # L: number of context timesteps fed to the model as input. + input_window: 1 + + # K: number of timesteps the model predicts per forward pass. + # DeepONet/TNO (mlp decoder): K trunk queries per forward pass. + # DeepONet/TNO (temporal_projection decoder): single trunk query, + # linear head projects to K outputs (faster, set_output_window(K)). + # xFNO: requires L + K >= 2 * time_modes for proper spectral resolution. + output_window: 3 + + # ----------------------------------------------------------------------- + # Stage 1 -- Teacher Forcing + # ----------------------------------------------------------------------- + # Model always sees ground-truth input. Teaches the underlying physics + # from clean signals. Loss is computed per-window with gradient + # accumulation (memory efficient). + teacher_forcing_epochs: 5 + + # ----------------------------------------------------------------------- + # Stage 2 -- Pushforward (optional, set to 0 to skip) + # ----------------------------------------------------------------------- + # Short unrolled chains with live gradients (uses live_rollout_step). + # Gradually increases chain length via curriculum (1 → max_unroll) to + # bridge the gap between teacher forcing and free-running rollout. + pushforward_epochs: 0 + + # Maximum number of pushforward unroll steps. + # Curriculum linearly ramps from 1 to this value over pushforward_epochs. + max_unroll: 5 + + # ----------------------------------------------------------------------- + # Stage 3 -- Rollout + # ----------------------------------------------------------------------- + # Free-running autoregressive sweep where the model receives its own + # predictions as input, training it to self-correct approximation errors. + rollout_epochs: 175 + + # rollout_mode: Controls how gradients are handled during rollout. + # + # 'detached' (default): + # Per-window gradient accumulation. Predictions are detached between + # windows, so each window trains independently. Memory efficient + # (one compute graph at a time). Good default for most models. + # + # 'live_gradients': + # Full-trajectory rollout with live gradients through the entire chain. + # Predictions are collected, then a single loss is computed on the + # concatenated trajectory. Gradients flow from the final loss back + # through all intermediate predictions, providing stronger gradient + # signal for temporal coherence. More memory intensive. Matches the + # original TNO paper training procedure. + rollout_mode: detached + + # Noise standard deviation injected into branch2 / feedback during training. + # Positive values add Gaussian noise for robustness; 0 disables. + # Set to 0.01 to 0.05 during rollout training to regularize against error accumulation. + noise_std: 0.0 + + # Feedback channel: append previous prediction as an extra input channel. + # With TNO, the feedback goes through branch2 (t-branch) automatically. + # Set to true to enable for other models in AR training (appends to branch1). + use_feedback_channel: false + + # Factor to reset learning rate at each stage transition. + # 1.0 = no reset; < 1.0 = scale LR up relative to current schedule. + lr_reset_factor: 1.0 + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: nof_training + +seed: 42 + +compute: + # cudnn.benchmark: auto-tune convolution algorithms for the input + # size. Faster training but non-deterministic. + benchmark: true + + # cudnn.deterministic: force deterministic algorithms for exact + # reproducibility. Slower than benchmark. When true, benchmark + # is overridden to false. + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/data/__init__.py b/examples/reservoir_simulation/neural_operator_factory/data/__init__.py new file mode 100644 index 0000000000..75ccc5d617 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/__init__.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data loading, validation, and preprocessing utilities.""" + +from data.dataloader import ( + ReservoirDataset, + collate_fn, + create_dataloaders, + get_dataset_info, +) +from data.file_resolution import resolve_data_files +from data.gpu_prefetch import GPUPrefetcher +from data.mask_detection import MaskResult, detect_mask +from data.normalization import NormStats, compute_norm_stats, normalize_sample +from data.scalar_utils import ( + create_mionet_collate_fn, + detect_scalar_channels, + log_scalar_detection_results, + verify_scalar_consistency, +) +from data.validation import ( + detect_dimensions, + get_dimension_info, + print_validation_summary, + validate_batch_dimensions, + validate_sample_dimensions, +) + +__all__ = [ + # Dataloader + "ReservoirDataset", + "collate_fn", + "create_dataloaders", + "get_dataset_info", + # File resolution + "resolve_data_files", + # Mask detection + "detect_mask", + "MaskResult", + # Normalization + "compute_norm_stats", + "normalize_sample", + "NormStats", + # GPU prefetch + "GPUPrefetcher", + # Validation + "detect_dimensions", + "validate_batch_dimensions", + "validate_sample_dimensions", + "print_validation_summary", + "get_dimension_info", + # Scalar utils + "detect_scalar_channels", + "verify_scalar_consistency", + "create_mionet_collate_fn", + "log_scalar_detection_results", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/data/dataloader.py b/examples/reservoir_simulation/neural_operator_factory/data/dataloader.py new file mode 100644 index 0000000000..16dad22c90 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/dataloader.py @@ -0,0 +1,543 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified dataset loaders for reservoir simulation neural operators. + +Supports both 3D (2D spatial + time) and 4D (3D spatial + time) datasets: +- 3D: Input (N, H, W, T, C), Output (N, H, W, T) - e.g., CO2 sequestration +- 4D: Input (N, X, Y, Z, T, C), Output (N, X, Y, Z, T) - e.g., Norne field +""" + +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.utils.data import Dataset + +from data.file_resolution import resolve_data_files +from data.mask_detection import MaskResult, detect_mask +from data.normalization import ( + NormStats, + compute_norm_stats, + identity_norm_stats, + normalize_sample, +) + + +def _log_message(msg: str, rank_zero_only: bool = True): + """Print message, optionally only on rank 0 in distributed mode.""" + try: + from physicsnemo.distributed import DistributedManager + + dist = DistributedManager() + if not rank_zero_only or dist.rank == 0: + print(msg) + except Exception: + print(msg) + + +def _load_tensor(path: Path) -> torch.Tensor: + """Load a ``.pt`` tensor into CPU memory. + + Uses a standard bulk read which is optimal when the full tensor + will be scanned (e.g. for normalization statistics). + """ + return torch.load(path, map_location="cpu") + + +class ReservoirDataset(Dataset): + """Unified dataset for reservoir simulation modeling. + + Automatically detects and handles both 3D and 4D data: + - 3D: (N, H, W, T, C) input, (N, H, W, T) output + - 4D: (N, X, Y, Z, T, C) input, (N, X, Y, Z, T) output + + Parameters + ---------- + data_path : Union[str, Path] + Path to the data directory. + mode : str + Dataset split: ``'train'``, ``'val'``, or ``'test'``. + input_file : str, optional + Input filename pattern (supports ``{mode}`` placeholder). + output_file : str, optional + Output filename pattern (supports ``{mode}`` placeholder). + variable : str, optional + ``'pressure'`` or ``'saturation'`` for CO2 naming convention. + normalize : bool + Z-score normalize using training-set statistics (default ``True``). + expected_dimensions : str, optional + ``'3d'`` or ``'4d'``. Raises on mismatch with loaded data. + use_mask : bool + Enable inactive-cell mask detection (default ``False``). + mask_channel : int, optional + Explicit mask channel index (overrides auto-detection). + num_timesteps : int, optional + Truncate the time axis to the first *N* steps (train/val only). + """ + + def __init__( + self, + data_path: Union[str, Path], + mode: str = "train", + input_file: Optional[str] = None, + output_file: Optional[str] = None, + variable: Optional[str] = None, + normalize: bool = True, + expected_dimensions: Optional[str] = None, + use_mask: bool = False, + mask_channel: Optional[int] = None, + num_timesteps: Optional[int] = None, + ): + super().__init__() + + self.data_path = Path(data_path) + self.mode = mode.lower() + self.normalize = normalize + self.variable = variable + self.expected_dimensions = ( + expected_dimensions.lower() if expected_dimensions else None + ) + self.use_mask = use_mask + self._config_mask_channel = mask_channel + self._num_timesteps = num_timesteps + + if self.mode not in ("train", "val", "test"): + raise ValueError(f"Mode must be 'train', 'val', or 'test', got {mode}") + + # --- File resolution (delegated) --- + self.input_file, self.output_file = resolve_data_files( + self.data_path, self.mode, input_file, output_file, variable + ) + + # --- Load data --- + self._load_data() + + if self._num_timesteps is not None: + T = self._num_timesteps + self.input_data = self.input_data[..., :T, :] + self.output_data = self.output_data[..., :T] + _log_message(f" Truncated to {T} timesteps") + + # --- Dimension detection --- + self._detect_dimensions() + + # --- Mask detection (delegated) --- + self.mask_channel: Optional[int] = None + self.mask_per_sample: bool = False + self.static_mask: Optional[torch.Tensor] = None + if self.use_mask: + self._apply_mask_detection() + + # --- Normalization (delegated) --- + self._norm_stats: Optional[NormStats] = None + if self.normalize: + self._init_normalization() + + # ------------------------------------------------------------------ + # Data loading + # ------------------------------------------------------------------ + + def _load_data(self): + """Load input/output tensors from disk.""" + if not self.input_file.exists(): + raise FileNotFoundError(f"Input file not found: {self.input_file}") + if not self.output_file.exists(): + raise FileNotFoundError(f"Output file not found: {self.output_file}") + + _log_message( + f"Loading {self.mode} data: " + f"{self.input_file.name} -> {self.output_file.name}" + ) + + self.input_data = _load_tensor(self.input_file) + self.output_data = _load_tensor(self.output_file) + + _log_message( + f" Loaded {len(self.input_data)} samples | " + f"Input: {tuple(self.input_data.shape)} | " + f"Output: {tuple(self.output_data.shape)}" + ) + + # ------------------------------------------------------------------ + # Dimension detection + # ------------------------------------------------------------------ + + def _detect_dimensions(self): + """Detect 3D vs 4D from tensor shapes and validate against config.""" + input_ndim = self.input_data.dim() + output_ndim = self.output_data.dim() + + if input_ndim == 5 and output_ndim == 4: + self.dimensions = "3d" + self.spatial_dims = 2 + self.dim_names = ("H", "W", "T") + elif input_ndim == 6 and output_ndim == 5: + self.dimensions = "4d" + self.spatial_dims = 3 + self.dim_names = ("X", "Y", "Z", "T") + else: + raise ValueError( + f"Unsupported data dimensions!\n" + f" Input: {input_ndim}D {tuple(self.input_data.shape)}\n" + f" Output: {output_ndim}D {tuple(self.output_data.shape)}\n" + f"Expected:\n" + f" 3D: Input (N, H, W, T, C), Output (N, H, W, T)\n" + f" 4D: Input (N, X, Y, Z, T, C), Output (N, X, Y, Z, T)" + ) + + if ( + self.expected_dimensions is not None + and self.dimensions != self.expected_dimensions + ): + raise ValueError( + f"Dimension mismatch!\n" + f" Config expects: {self.expected_dimensions}\n" + f" Data has: {self.dimensions}\n" + f" Input shape: {tuple(self.input_data.shape)}\n" + f" Please update arch.dimensions in config to " + f"'{self.dimensions}' " + f"or use a dataset with {self.expected_dimensions} data." + ) + + self.num_samples = self.input_data.shape[0] + self.spatial_shape = tuple(self.input_data.shape[1:-2]) + self.time_steps = self.input_data.shape[-2] + self.num_channels = self.input_data.shape[-1] + + _log_message( + f" Detected: {self.dimensions.upper()} | " + f"Spatial: {self.spatial_shape} | " + f"T: {self.time_steps} | C: {self.num_channels}" + ) + + # ------------------------------------------------------------------ + # Mask detection (delegates to data.mask_detection) + # ------------------------------------------------------------------ + + def _apply_mask_detection(self): + """Run mask detection and store results on self.""" + result: MaskResult = detect_mask( + self.input_data, self.output_data, self._config_mask_channel + ) + self.mask_channel = result.channel + self.mask_per_sample = result.per_sample + self.static_mask = result.static_mask + + if result.method == "none": + _log_message(" Mask: none (all cells active)") + else: + pct = 100 * result.n_active / result.n_total if result.n_total else 0 + ps = " (per-sample)" if result.per_sample else "" + _log_message( + f" Mask [{result.method} ch {result.channel}]: " + f"{result.n_active}/{result.n_total} active " + f"({pct:.1f}%){ps}" + ) + + def get_static_mask(self): + """Return static spatial mask or None.""" + return self.static_mask + + # ------------------------------------------------------------------ + # Normalization (delegates to data.normalization) + # ------------------------------------------------------------------ + + def _init_normalization(self): + """Compute or prepare normalization statistics.""" + if self.mode == "train": + self._norm_stats = compute_norm_stats(self.input_data, self.output_data) + _log_message( + f" Normalization: Output " + f"mean={self._norm_stats.output_mean.item():.4f}, " + f"std={self._norm_stats.output_std.item():.4f}" + ) + else: + self._norm_stats = identity_norm_stats( + self.input_data.dim(), self.num_channels + ) + + # Backward-compatible properties so existing code that reads + # ds.input_mean / ds.input_std / ds.output_mean / ds.output_std + # continues to work. + + @property + def input_mean(self): + """Input channel means (broadcastable).""" + return self._norm_stats.input_mean if self._norm_stats else None + + @input_mean.setter + def input_mean(self, value): + """Set input channel means.""" + if self._norm_stats is None: + self._norm_stats = NormStats(value, value, value, value) + self._norm_stats.input_mean = value + + @property + def input_std(self): + """Input channel standard deviations (broadcastable).""" + return self._norm_stats.input_std if self._norm_stats else None + + @input_std.setter + def input_std(self, value): + """Set input channel standard deviations.""" + if self._norm_stats is not None: + self._norm_stats.input_std = value + + @property + def output_mean(self): + """Scalar output mean.""" + return self._norm_stats.output_mean if self._norm_stats else None + + @output_mean.setter + def output_mean(self, value): + """Set scalar output mean.""" + if self._norm_stats is not None: + self._norm_stats.output_mean = value + + @property + def output_std(self): + """Scalar output standard deviation.""" + return self._norm_stats.output_std if self._norm_stats else None + + @output_std.setter + def output_std(self, value): + """Set scalar output standard deviation.""" + if self._norm_stats is not None: + self._norm_stats.output_std = value + + def set_normalization( + self, + input_mean: torch.Tensor, + input_std: torch.Tensor, + output_mean: torch.Tensor, + output_std: torch.Tensor, + ): + """Set normalization parameters from an external source.""" + self._norm_stats = NormStats(input_mean, input_std, output_mean, output_std) + + def get_normalization_stats(self) -> Tuple[torch.Tensor, ...]: + """Return ``(input_mean, input_std, output_mean, output_std)``.""" + if self._norm_stats is None: + raise RuntimeError("Normalization not initialized") + return self._norm_stats.as_tuple() + + # ------------------------------------------------------------------ + # Dataset interface + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Return a single ``(input, output)`` sample. + + Returns + ------- + Tuple[Tensor, Tensor] + 3D: ``(H, W, T, C)``, ``(H, W, T)`` + 4D: ``(X, Y, Z, T, C)``, ``(X, Y, Z, T)`` + """ + inp = self.input_data[idx] + out = self.output_data[idx] + + if self.normalize and self._norm_stats is not None: + inp, out = normalize_sample(inp, out, self._norm_stats) + + return inp, out + + +# ===================================================================== +# Collate +# ===================================================================== + + +def collate_fn(batch): + """Stack samples along the batch dimension (3D and 4D agnostic).""" + inputs = torch.stack([item[0] for item in batch], dim=0) + targets = torch.stack([item[1] for item in batch], dim=0) + return inputs, targets + + +# ===================================================================== +# Dataloader factory +# ===================================================================== + + +def create_dataloaders( + data_path: Union[str, Path], + batch_size: int = 4, + normalize: bool = True, + num_workers: int = 4, + device: Union[str, torch.device] = "cuda", + input_file: Optional[str] = None, + output_file: Optional[str] = None, + variable: Optional[str] = None, + expected_dimensions: Optional[str] = None, + use_mask: bool = False, + mask_channel: Optional[int] = None, + num_timesteps: Optional[int] = None, +) -> Tuple[torch.utils.data.DataLoader, ...]: + """Create train, validation, and test dataloaders. + + Parameters + ---------- + data_path : Union[str, Path] + Path to the data directory. + batch_size : int + Batch size per GPU (default 4). + normalize : bool + Z-score normalize (default ``True``). + num_workers : int + DataLoader worker processes (default 4). + device : Union[str, torch.device] + Target device for ``pin_memory`` (default ``"cuda"``). + input_file, output_file : str, optional + Filename patterns with ``{mode}`` placeholder. + variable : str, optional + ``'pressure'`` or ``'saturation'`` for CO2 convention. + expected_dimensions : str, optional + ``'3d'`` or ``'4d'``; raises on mismatch. + use_mask : bool + Enable mask detection (default ``False``). + mask_channel : int, optional + Explicit mask channel (overrides auto-detect). + num_timesteps : int, optional + Truncate train/val time axis; test keeps all. + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader] + ``(train_loader, val_loader, test_loader)`` + """ + from torch.utils.data import DataLoader + + try: + from physicsnemo.distributed import DistributedManager + + dist = DistributedManager() + is_distributed = dist.world_size > 1 + except Exception: + is_distributed = False + + dataset_kwargs: dict = { + "data_path": data_path, + "input_file": input_file, + "output_file": output_file, + "variable": variable, + "normalize": normalize, + "expected_dimensions": expected_dimensions, + "use_mask": use_mask, + "mask_channel": mask_channel, + } + + train_dataset = ReservoirDataset( + mode="train", num_timesteps=num_timesteps, **dataset_kwargs + ) + val_dataset = ReservoirDataset( + mode="val", num_timesteps=num_timesteps, **dataset_kwargs + ) + test_dataset = ReservoirDataset(mode="test", **dataset_kwargs) + + if normalize: + norm_stats = train_dataset.get_normalization_stats() + + if is_distributed: + import torch.distributed as dist_torch + + gpu_stats = [] + for stat in norm_stats: + s = stat.cuda() + dist_torch.broadcast(s, src=0) + gpu_stats.append(s.cpu()) + norm_stats = tuple(gpu_stats) + + val_dataset.set_normalization(*norm_stats) + test_dataset.set_normalization(*norm_stats) + + use_pin_memory = (isinstance(device, torch.device) and device.type == "cuda") or ( + isinstance(device, str) and device == "cuda" + ) + + train_sampler = val_sampler = test_sampler = None + if is_distributed: + from torch.utils.data.distributed import DistributedSampler + + train_sampler = DistributedSampler(train_dataset, shuffle=True, drop_last=False) + val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=False) + test_sampler = DistributedSampler(test_dataset, shuffle=False, drop_last=False) + + loader_kwargs: dict = { + "batch_size": batch_size, + "num_workers": num_workers, + "pin_memory": use_pin_memory, + "persistent_workers": num_workers > 0, + "collate_fn": collate_fn, + } + + train_loader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + **loader_kwargs, + ) + val_loader = DataLoader( + val_dataset, shuffle=False, sampler=val_sampler, **loader_kwargs + ) + test_loader = DataLoader( + test_dataset, shuffle=False, sampler=test_sampler, **loader_kwargs + ) + + _log_message( + f"Created dataloaders: {train_dataset.dimensions.upper()} data | " + f"Train: {len(train_dataset)}, " + f"Val: {len(val_dataset)}, " + f"Test: {len(test_dataset)}" + ) + + return train_loader, val_loader, test_loader + + +# ===================================================================== +# Utility +# ===================================================================== + + +def get_dataset_info(data_path: Union[str, Path], **kwargs) -> Dict: + """Quick dataset introspection without full loading overhead. + + Returns + ------- + dict + Keys: dimensions, spatial_shape, time_steps, num_channels, num_samples. + """ + ds = ReservoirDataset(data_path, mode="train", normalize=False, **kwargs) + return { + "dimensions": ds.dimensions, + "spatial_shape": ds.spatial_shape, + "time_steps": ds.time_steps, + "num_channels": ds.num_channels, + "num_samples": { + "train": len(ds), + "val": len( + ReservoirDataset(data_path, mode="val", normalize=False, **kwargs) + ), + "test": len( + ReservoirDataset(data_path, mode="test", normalize=False, **kwargs) + ), + }, + } diff --git a/examples/reservoir_simulation/neural_operator_factory/data/file_resolution.py b/examples/reservoir_simulation/neural_operator_factory/data/file_resolution.py new file mode 100644 index 0000000000..4608f93be1 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/file_resolution.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resolve input/output file paths for reservoir simulation datasets.""" + +from pathlib import Path +from typing import Optional, Tuple + + +def resolve_data_files( + data_path: Path, + mode: str, + input_file: Optional[str] = None, + output_file: Optional[str] = None, + variable: Optional[str] = None, +) -> Tuple[Path, Path]: + """Resolve input and output file paths with flexible naming. + + Supports three resolution strategies in priority order: + + 1. **Explicit files** — ``input_file`` and ``output_file`` are both + provided (may contain a ``{mode}`` placeholder). + 2. **Variable-based** — ``variable`` maps to CO2 naming convention + (``dP_*`` for pressure, ``sg_*`` for saturation). + 3. **Auto-detect** — scans the directory for common patterns. + + Parameters + ---------- + data_path : Path + Root data directory. + mode : str + Dataset split: ``'train'``, ``'val'``, or ``'test'``. + input_file : str, optional + Input filename or pattern with ``{mode}`` placeholder. + output_file : str, optional + Output filename or pattern with ``{mode}`` placeholder. + variable : str, optional + ``'pressure'`` or ``'saturation'`` for CO2 convention. + + Returns + ------- + Tuple[Path, Path] + ``(input_path, output_path)`` + + Raises + ------ + FileNotFoundError + If auto-detection fails (no known pattern matches). + ValueError + If *variable* is not a recognised name. + """ + if input_file is not None and output_file is not None: + return ( + data_path / input_file.format(mode=mode), + data_path / output_file.format(mode=mode), + ) + + if variable is not None: + var_map = {"pressure": "dP", "saturation": "sg", "dP": "dP", "sg": "sg"} + if variable.lower() not in var_map: + raise ValueError( + f"Variable must be 'pressure' or 'saturation', got {variable}" + ) + prefix = var_map[variable.lower()] + return ( + data_path / f"{prefix}_{mode}_a.pt", + data_path / f"{prefix}_{mode}_u.pt", + ) + + return _auto_detect_files(data_path, mode) + + +def _auto_detect_files(data_path: Path, mode: str) -> Tuple[Path, Path]: + """Scan *data_path* for known filename patterns.""" + patterns = [ + (f"{mode}_input.pt", f"{mode}_output.pt"), + (f"input_{mode}.pt", f"output_{mode}.pt"), + (f"{mode}_x.pt", f"{mode}_y.pt"), + (f"x_{mode}.pt", f"y_{mode}.pt"), + (f"dP_{mode}_a.pt", f"dP_{mode}_u.pt"), + (f"sg_{mode}_a.pt", f"sg_{mode}_u.pt"), + ] + for inp_name, out_name in patterns: + inp_path = data_path / inp_name + out_path = data_path / out_name + if inp_path.exists() and out_path.exists(): + return inp_path, out_path + + pt_files = sorted(f.name for f in data_path.glob("*.pt")) + raise FileNotFoundError( + f"Could not auto-detect data files in {data_path}\n" + f"Available .pt files: {pt_files}\n" + f"Please specify input_file and output_file explicitly." + ) diff --git a/examples/reservoir_simulation/neural_operator_factory/data/gpu_prefetch.py b/examples/reservoir_simulation/neural_operator_factory/data/gpu_prefetch.py new file mode 100644 index 0000000000..d68b461974 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/gpu_prefetch.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU prefetch wrapper for PyTorch DataLoaders. + +Overlaps host-to-device transfers with GPU compute by using a +dedicated CUDA stream. Batches arrive on-device; the training +loop no longer needs explicit ``.to(device)`` calls. + +Usage +----- +>>> loader = DataLoader(dataset, batch_size=4) +>>> prefetched = GPUPrefetcher(loader, device="cuda:0") +>>> for inputs, targets in prefetched: +... # inputs and targets are already on GPU +... pred = model(inputs) +""" + +from typing import Iterator, Tuple, Union + +import torch +from torch import Tensor +from torch.utils.data import DataLoader + + +class GPUPrefetcher: + """Prefetches the next batch to GPU while the current batch trains. + + Parameters + ---------- + loader : DataLoader + Source dataloader (CPU-side). + device : str or torch.device + Target GPU device. + """ + + def __init__( + self, + loader: DataLoader, + device: Union[str, torch.device] = "cuda", + ): + self.loader = loader + self.device = torch.device(device) + self.stream = torch.cuda.Stream(device=self.device) + + def __iter__(self) -> Iterator[Tuple[Tensor, ...]]: + """Yield batches that are already on *device*.""" + it = iter(self.loader) + + try: + batch = next(it) + except StopIteration: + return + + batch = self._transfer(batch) + + for next_batch in it: + with torch.cuda.stream(self.stream): + next_batch = self._transfer(next_batch) + + yield batch + torch.cuda.current_stream(self.device).wait_stream(self.stream) + batch = next_batch + + yield batch + + def __len__(self) -> int: + """Number of batches (delegated to the wrapped loader).""" + return len(self.loader) + + # ------------------------------------------------------------------ + # Forwarded attributes so the prefetcher is a drop-in for DataLoader + # ------------------------------------------------------------------ + + @property + def dataset(self): + """Underlying dataset.""" + return self.loader.dataset + + @property + def sampler(self): + """Underlying sampler.""" + return self.loader.sampler + + def _transfer(self, batch: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + return tuple(t.to(self.device, non_blocking=True) for t in batch) diff --git a/examples/reservoir_simulation/neural_operator_factory/data/mask_detection.py b/examples/reservoir_simulation/neural_operator_factory/data/mask_detection.py new file mode 100644 index 0000000000..4d1f2def6b --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/mask_detection.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Automatic inactive-cell mask detection for reservoir grids. + +Detection priority: + +1. **Explicit channel** — caller supplies a channel index. +2. **ACTNUM** — binary {0, 1} channel, static across time, whose zeros + coincide with output-inactive cells. +3. **Non-zero fallback** — any channel with a static zero pattern + matching the output's inactive cells. +4. **No mask** — all cells treated as active. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor + + +@dataclass +class MaskResult: + """Outcome of mask detection. + + Attributes + ---------- + channel : int or None + Input channel index used as the mask source. + per_sample : bool + ``True`` when the mask varies across samples (built per-batch). + static_mask : Tensor or None + ``(*spatial)`` boolean mask when the mask is the same for all + samples; ``None`` when *per_sample* is ``True`` or no mask. + method : str + Human-readable label: ``'config'``, ``'actnum'``, + ``'nonzero'``, or ``'none'``. + n_active : int + Number of active cells in sample 0 (for logging). + n_total : int + Total number of spatial cells (for logging). + """ + + channel: Optional[int] + per_sample: bool + static_mask: Optional[Tensor] + method: str + n_active: int + n_total: int + + +def detect_mask( + input_data: Tensor, + output_data: Tensor, + config_channel: Optional[int] = None, +) -> MaskResult: + """Detect the best inactive-cell mask for a reservoir dataset. + + Parameters + ---------- + input_data : Tensor + Full input tensor ``(N, *spatial, T, C)``. + output_data : Tensor + Full output tensor ``(N, *spatial, T)``. + config_channel : int or None + Explicit mask channel from config (highest priority). + + Returns + ------- + MaskResult + """ + if config_channel is not None: + return _from_explicit_channel(input_data, config_channel) + + result = _find_actnum_channel(input_data, output_data) + if result is not None: + return result + + result = _find_nonzero_channel(input_data, output_data) + if result is not None: + return result + + return MaskResult( + channel=None, + per_sample=False, + static_mask=None, + method="none", + n_active=0, + n_total=0, + ) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +_N_CROSS_CHECK = 3 + + +def _output_inactive(output_data: Tensor) -> Tensor: + """Boolean mask of cells that are zero across all timesteps in sample 0.""" + return output_data[0].abs().sum(dim=-1) == 0 + + +def _check_consistency(input_data: Tensor, reference: Tensor, channel: int) -> bool: + """True if the mask is identical across the first few samples.""" + n = min(input_data.shape[0], _N_CROSS_CHECK) + return all( + torch.equal(reference, input_data[si][..., 0, channel]) for si in range(1, n) + ) + + +def _build_result( + input_data: Tensor, + channel: int, + consistent: bool, + method: str, +) -> MaskResult: + mask_s0 = input_data[0][..., 0, channel] != 0 + static = mask_s0 if consistent else None + return MaskResult( + channel=channel, + per_sample=not consistent, + static_mask=static, + method=method, + n_active=int(mask_s0.sum().item()), + n_total=int(mask_s0.numel()), + ) + + +def _from_explicit_channel(input_data: Tensor, channel: int) -> MaskResult: + col = input_data[0][..., 0, channel] + n = min(input_data.shape[0], _N_CROSS_CHECK) + consistent = all( + torch.equal(col != 0, input_data[si][..., 0, channel] != 0) + for si in range(1, n) + ) + return _build_result(input_data, channel, consistent, "config") + + +def _find_actnum_channel( + input_data: Tensor, output_data: Tensor +) -> Optional[MaskResult]: + """Binary {0,1} channel, static across time, zeros subset of inactive.""" + s0 = input_data[0] + out_inactive = _output_inactive(output_data) + candidates = [] + for ch in range(s0.shape[-1]): + col = s0[..., 0, ch] + vals = col.unique() + if not (vals.numel() <= 2 and all(v in (0.0, 1.0) for v in vals.tolist())): + continue + if not torch.equal(s0[..., 0, ch], s0[..., -1, ch]): + continue + zeros = col == 0 + if (zeros & ~out_inactive).any(): + continue + consistent = _check_consistency(input_data, col, ch) + candidates.append((ch, zeros.sum().item(), consistent)) + + if not candidates: + return None + best_ch, _, best_consistent = max(candidates, key=lambda x: x[1]) + return _build_result(input_data, best_ch, best_consistent, "actnum") + + +def _find_nonzero_channel( + input_data: Tensor, output_data: Tensor +) -> Optional[MaskResult]: + """Channel with a static zero pattern matching output-inactive cells.""" + s0 = input_data[0] + out_inactive = _output_inactive(output_data) + candidates = [] + for ch in range(s0.shape[-1]): + zeros_t0 = s0[..., 0, ch] == 0 + n_zeros = zeros_t0.sum().item() + if n_zeros == 0 or n_zeros == zeros_t0.numel(): + continue + if not torch.equal(zeros_t0, s0[..., -1, ch] == 0): + continue + if (zeros_t0 & ~out_inactive).any(): + continue + n = min(input_data.shape[0], _N_CROSS_CHECK) + consistent = all( + torch.equal(zeros_t0, input_data[si][..., 0, ch] == 0) for si in range(1, n) + ) + candidates.append((ch, n_zeros, consistent)) + + if not candidates: + return None + best_ch, _, best_consistent = max(candidates, key=lambda x: x[1]) + return _build_result(input_data, best_ch, best_consistent, "nonzero") diff --git a/examples/reservoir_simulation/neural_operator_factory/data/normalization.py b/examples/reservoir_simulation/neural_operator_factory/data/normalization.py new file mode 100644 index 0000000000..7fc0e94148 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/normalization.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Z-score normalization utilities for reservoir simulation tensors.""" + +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch import Tensor + + +@dataclass +class NormStats: + """Per-channel z-score statistics. + + Attributes + ---------- + input_mean, input_std : Tensor + Shape ``(1, *[1]*spatial, 1, C)`` — broadcastable to input samples. + output_mean, output_std : Tensor + Scalar tensors for the output variable. + """ + + input_mean: Tensor + input_std: Tensor + output_mean: Tensor + output_std: Tensor + + def as_tuple(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Unpack for ``set_normalization(*stats.as_tuple())``.""" + return self.input_mean, self.input_std, self.output_mean, self.output_std + + +def compute_norm_stats(input_data: Tensor, output_data: Tensor) -> NormStats: + """Compute z-score statistics from a training split. + + Parameters + ---------- + input_data : Tensor + ``(N, *spatial, T, C)`` — full training input tensor. + output_data : Tensor + ``(N, *spatial, T)`` — full training output tensor. + + Returns + ------- + NormStats + """ + reduce_dims = tuple(range(input_data.dim() - 1)) + + input_mean = input_data.mean(dim=reduce_dims, keepdim=True) + input_std = input_data.std(dim=reduce_dims, keepdim=True) + output_mean = output_data.mean() + output_std = output_data.std() + + input_std = torch.where(input_std > 1e-6, input_std, torch.ones_like(input_std)) + if output_std < 1e-6: + output_std = torch.tensor(1.0) + + return NormStats(input_mean, input_std, output_mean, output_std) + + +def identity_norm_stats(input_ndim: int, num_channels: int) -> NormStats: + """Identity (no-op) statistics for val/test before sharing. + + Parameters + ---------- + input_ndim : int + Dimensionality of a full input tensor (5 for 3D, 6 for 4D). + num_channels : int + Number of input channels (last dimension). + """ + shape = [1] * (input_ndim - 1) + [num_channels] + return NormStats( + input_mean=torch.zeros(shape), + input_std=torch.ones(shape), + output_mean=torch.tensor(0.0), + output_std=torch.tensor(1.0), + ) + + +def normalize_sample( + input_sample: Tensor, + output_sample: Tensor, + stats: NormStats, +) -> Tuple[Tensor, Tensor]: + """Apply z-score normalization to a single sample. + + Parameters + ---------- + input_sample : Tensor + ``(*spatial, T, C)`` — one sample without batch dim. + output_sample : Tensor + ``(*spatial, T)`` — one sample without batch dim. + stats : NormStats + + Returns + ------- + Tuple[Tensor, Tensor] + Normalized ``(input, output)``. + """ + mean = stats.input_mean.squeeze(0).to(input_sample.device) + std = stats.input_std.squeeze(0).to(input_sample.device) + o_mean = stats.output_mean.to(output_sample.device) + o_std = stats.output_std.to(output_sample.device) + + return (input_sample - mean) / std, (output_sample - o_mean) / o_std diff --git a/examples/reservoir_simulation/neural_operator_factory/data/scalar_utils.py b/examples/reservoir_simulation/neural_operator_factory/data/scalar_utils.py new file mode 100644 index 0000000000..fcfb808900 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/scalar_utils.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Scalar channel detection utilities for MIONet. + +This module provides utilities to automatically detect which input channels +contain constant (scalar) values that should be processed separately by +the MIONet's second branch. +""" + +from functools import partial +from typing import Dict, List, Optional, Tuple + +import torch + + +def detect_scalar_channels( + sample_input: torch.Tensor, + threshold: float = 1e-6, + num_samples_to_check: int = 1, +) -> Dict: + """ + Detect which channels contain constant (scalar) values. + + A channel is considered "scalar" if it contains the same value at every + spatial and temporal location. This is detected using multiple checks: + 1. Standard deviation < threshold + 2. Min value equals max value + 3. All values are close to the first value + + Args: + sample_input: Single sample tensor of shape (H, W, T, C) or batch (B, H, W, T, C) + threshold: Tolerance for detecting constants (default: 1e-6) + num_samples_to_check: Number of samples to verify consistency (if batch provided) + + Returns: + Dictionary containing: + - 'scalar_indices': List of channel indices that are constant + - 'spatial_indices': List of channel indices that vary spatially + - 'scalar_values': Tensor of shape (C_scalar,) with the constant value for each scalar channel + - 'num_scalar_channels': Number of scalar channels detected + - 'num_spatial_channels': Number of spatial channels detected + + Example: + >>> sample = dataset[0][0] # Get first input sample + >>> result = detect_scalar_channels(sample) + >>> print(f"Scalar channels: {result['scalar_indices']}") + >>> print(f"Spatial channels: {result['spatial_indices']}") + """ + # Handle batch dimension if present + if sample_input.dim() == 5: + # (B, H, W, T, C) -> use first sample + sample_input = sample_input[0] + + if sample_input.dim() != 4: + raise ValueError( + f"Expected input shape (H, W, T, C) or (B, H, W, T, C), " + f"got shape {sample_input.shape}" + ) + + num_channels = sample_input.shape[-1] + scalar_indices = [] + spatial_indices = [] + scalar_values = [] + + for c in range(num_channels): + channel_data = sample_input[..., c] # (H, W, T) + + # Flatten for easier analysis + flat_data = channel_data.flatten() + first_value = flat_data[0] + + # Multiple checks for robustness: + # 1. Standard deviation check + std_check = channel_data.std().item() < threshold + + # 2. Min equals max check + min_val = channel_data.min().item() + max_val = channel_data.max().item() + minmax_check = abs(max_val - min_val) < threshold + + # 3. All values close to first value + allclose_check = torch.allclose( + channel_data, + torch.full_like(channel_data, first_value.item()), + atol=threshold, + rtol=0, + ) + + # 4. Unique values check (should be 1 for scalar) + unique_check = len(torch.unique(flat_data)) == 1 + + # Channel is scalar if ALL checks pass + is_constant = std_check and minmax_check and allclose_check and unique_check + + if is_constant: + scalar_indices.append(c) + scalar_values.append(first_value.item()) + else: + spatial_indices.append(c) + + # Convert scalar values to tensor + scalar_values_tensor = torch.tensor(scalar_values, dtype=sample_input.dtype) + + return { + "scalar_indices": scalar_indices, + "spatial_indices": spatial_indices, + "scalar_values": scalar_values_tensor, + "num_scalar_channels": len(scalar_indices), + "num_spatial_channels": len(spatial_indices), + } + + +def verify_scalar_consistency( + dataset, + scalar_indices: List[int], + num_samples: int = 10, + threshold: float = 1e-6, +) -> Tuple[bool, Optional[str]]: + """ + Verify that detected scalar channels are consistent across multiple samples. + + This checks that the same channels are scalar across different samples + in the dataset (though the scalar values may differ per sample). + + Args: + dataset: PyTorch dataset to check + scalar_indices: List of channel indices detected as scalar + num_samples: Number of samples to check + threshold: Tolerance for scalar detection + + Returns: + Tuple of (is_consistent, error_message) + - is_consistent: True if all samples have same scalar channels + - error_message: None if consistent, otherwise describes the inconsistency + """ + num_to_check = min(num_samples, len(dataset)) + + for i in range(num_to_check): + sample_input, _ = dataset[i] + result = detect_scalar_channels(sample_input, threshold=threshold) + + if set(result["scalar_indices"]) != set(scalar_indices): + return False, ( + f"Inconsistent scalar channels at sample {i}. " + f"Expected {scalar_indices}, got {result['scalar_indices']}" + ) + + return True, None + + +def create_mionet_collate_fn( + scalar_indices: List[int], + spatial_indices: List[int], +): + """ + Create a custom collate function for MIONet that separates scalar and spatial inputs. + + Args: + scalar_indices: List of channel indices that are scalar + spatial_indices: List of channel indices that are spatial + + Returns: + Collate function for use with DataLoader + + Example: + >>> collate_fn = create_mionet_collate_fn([3, 5, 7], [0, 1, 2, 4, 6, 8, 9, 10, 11]) + >>> dataloader = DataLoader(dataset, collate_fn=collate_fn, ...) + """ + + def mionet_collate_fn(batch, scalar_idx, spatial_idx): + """ + Custom collate that separates scalar and spatial inputs. + + Returns: + Tuple of (spatial_inputs, scalar_inputs, targets) + - spatial_inputs: (B, H, W, T, C_spatial) - channels that vary spatially + - scalar_inputs: (B, C_scalar) - scalar values for each sample + - targets: (B, H, W, T) - target outputs + """ + inputs, targets = zip(*batch) + inputs = torch.stack(inputs) # (B, H, W, T, C) + targets = torch.stack(targets) # (B, H, W, T) + + # Separate spatial and scalar channels + spatial_inputs = inputs[..., spatial_idx] # (B, H, W, T, C_spatial) + + # For scalar channels, extract the constant value (take from position [0,0,0]) + # since it's the same everywhere in the spatial/temporal domain + scalar_inputs = inputs[:, 0, 0, 0, scalar_idx] # (B, C_scalar) + + return spatial_inputs, scalar_inputs, targets + + # Return partial function with indices bound + return partial( + mionet_collate_fn, scalar_idx=scalar_indices, spatial_idx=spatial_indices + ) + + +def log_scalar_detection_results( + result: Dict, + logger=None, + channel_names: Optional[List[str]] = None, +): + """ + Log the results of scalar channel detection. + + Args: + result: Dictionary from detect_scalar_channels() + logger: Optional logger object (uses print if None) + channel_names: Optional list of channel names for better logging + """ + log_fn = logger.info if logger else print + success_fn = logger.success if logger and hasattr(logger, "success") else log_fn + + log_fn("=" * 60) + log_fn("SCALAR CHANNEL DETECTION RESULTS") + log_fn("=" * 60) + + # Scalar channels + log_fn(f"Scalar channels ({result['num_scalar_channels']}):") + for i, idx in enumerate(result["scalar_indices"]): + name = channel_names[idx] if channel_names else f"Channel {idx}" + value = result["scalar_values"][i].item() + log_fn(f" [{idx}] {name}: value = {value:.6f}") + + # Spatial channels + log_fn(f"Spatial channels ({result['num_spatial_channels']}):") + for idx in result["spatial_indices"]: + name = channel_names[idx] if channel_names else f"Channel {idx}" + log_fn(f" [{idx}] {name}") + + log_fn("=" * 60) + success_fn( + f"✅ Detection complete: {result['num_scalar_channels']} scalar, " + f"{result['num_spatial_channels']} spatial channels" + ) diff --git a/examples/reservoir_simulation/neural_operator_factory/data/validation.py b/examples/reservoir_simulation/neural_operator_factory/data/validation.py new file mode 100644 index 0000000000..e0dd782077 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/data/validation.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data validation utilities for reservoir simulation models. + +Supports both 3D (2D spatial + time) and 4D (3D spatial + time) datasets: +- 3D: Input (B, H, W, T, C), Output (B, H, W, T) +- 4D: Input (B, X, Y, Z, T, C), Output (B, X, Y, Z, T) +""" + +from typing import Dict, Optional, Tuple + +import torch + + +def detect_dimensions(inputs: torch.Tensor) -> str: + """ + Detect data dimensions from input tensor. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor (batched) + + Returns + ------- + str + '3d' for (B, H, W, T, C) or '4d' for (B, X, Y, Z, T, C) + """ + ndim = inputs.dim() + if ndim == 5: + return "3d" + elif ndim == 6: + return "4d" + else: + raise ValueError( + f"Cannot detect dimensions from {ndim}D tensor. " + f"Expected 5D (3d) or 6D (4d), got shape {tuple(inputs.shape)}" + ) + + +def validate_batch_dimensions( + inputs: torch.Tensor, targets: torch.Tensor, variable: str = "unknown" +) -> Dict: + """ + Validate dimensions of a batch of training/evaluation data. + + Automatically detects 3D or 4D data format. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor + - 3D: (batch, H, W, T, channels) + - 4D: (batch, X, Y, Z, T, channels) + targets : torch.Tensor + Target tensor + - 3D: (batch, H, W, T) + - 4D: (batch, X, Y, Z, T) + variable : str + Name of the variable being predicted + + Returns + ------- + Dict + Dictionary with keys: + - dimensions: '3d' or '4d' + - batch_size: int + - spatial_shape: tuple + - time_steps: int + - num_channels: int + + Raises + ------ + ValueError + If data dimensions are invalid or inconsistent + """ + input_ndim = inputs.dim() + output_ndim = targets.dim() + + # Determine expected format based on input dimensions + if input_ndim == 5: + # 3D format: (B, H, W, T, C) + dimensions = "3d" + expected_output_ndim = 4 + dim_names = "H, W, T" + spatial_slice = slice(1, 3) + time_idx = 3 + channel_idx = 4 + + elif input_ndim == 6: + # 4D format: (B, X, Y, Z, T, C) + dimensions = "4d" + expected_output_ndim = 5 + dim_names = "X, Y, Z, T" + spatial_slice = slice(1, 4) + time_idx = 4 + channel_idx = 5 + + else: + raise ValueError( + f"❌ Invalid input shape! Expected 5D (B, H, W, T, C) for 3D data " + f"or 6D (B, X, Y, Z, T, C) for 4D data, " + f"got {input_ndim}D tensor with shape {tuple(inputs.shape)}." + ) + + # Validate output dimensions + if output_ndim != expected_output_ndim: + raise ValueError( + f"❌ Invalid target shape for {dimensions.upper()} data! " + f"Expected {expected_output_ndim}D tensor (B, {dim_names}), " + f"got {output_ndim}D tensor with shape {tuple(targets.shape)}." + ) + + # Extract dimensions + batch_size = inputs.shape[0] + spatial_shape = tuple(inputs.shape[spatial_slice]) + time_steps = inputs.shape[time_idx] + num_channels = inputs.shape[channel_idx] + + # Validate batch size match + if targets.shape[0] != batch_size: + raise ValueError( + f"❌ Batch size mismatch! Input: {batch_size}, Target: {targets.shape[0]}" + ) + + # Validate spatial+temporal dimensions match + input_spatiotemporal = tuple( + inputs.shape[1:-1] + ) # All dims except batch and channels + target_spatiotemporal = tuple(targets.shape[1:]) # All dims except batch + + if input_spatiotemporal != target_spatiotemporal: + raise ValueError( + f"❌ Spatial/temporal dimension mismatch! " + f"Input: {input_spatiotemporal}, Target: {target_spatiotemporal}" + ) + + # Sanity checks + if any(d < 1 for d in spatial_shape): + raise ValueError(f"❌ Invalid spatial dimensions: {spatial_shape}") + if time_steps < 1: + raise ValueError(f"❌ Invalid time steps: {time_steps}") + if num_channels < 1: + raise ValueError(f"❌ Invalid number of channels: {num_channels}") + + return { + "dimensions": dimensions, + "batch_size": batch_size, + "spatial_shape": spatial_shape, + "time_steps": time_steps, + "num_channels": num_channels, + } + + +def validate_sample_dimensions( + input_sample: torch.Tensor, target_sample: torch.Tensor, variable: str = "unknown" +) -> Dict: + """ + Validate dimensions of a single sample (unbatched). + + Parameters + ---------- + input_sample : torch.Tensor + Input tensor (single sample) + - 3D: (H, W, T, channels) + - 4D: (X, Y, Z, T, channels) + target_sample : torch.Tensor + Target tensor (single sample) + - 3D: (H, W, T) + - 4D: (X, Y, Z, T) + variable : str + Name of the variable being predicted + + Returns + ------- + Dict + Dictionary with keys: dimensions, spatial_shape, time_steps, num_channels + """ + input_ndim = input_sample.dim() + output_ndim = target_sample.dim() + + if input_ndim == 4: + # 3D: (H, W, T, C) + dimensions = "3d" + expected_output_ndim = 3 + dim_names = "H, W, T" + spatial_slice = slice(0, 2) + time_idx = 2 + channel_idx = 3 + + elif input_ndim == 5: + # 4D: (X, Y, Z, T, C) + dimensions = "4d" + expected_output_ndim = 4 + dim_names = "X, Y, Z, T" + spatial_slice = slice(0, 3) + time_idx = 3 + channel_idx = 4 + + else: + raise ValueError( + f"❌ Invalid input shape! Expected 4D (H, W, T, C) for 3D data " + f"or 5D (X, Y, Z, T, C) for 4D data, " + f"got {input_ndim}D tensor with shape {tuple(input_sample.shape)}." + ) + + if output_ndim != expected_output_ndim: + raise ValueError( + f"❌ Invalid target shape for {dimensions.upper()} data! " + f"Expected {expected_output_ndim}D tensor ({dim_names}), " + f"got {output_ndim}D tensor with shape {tuple(target_sample.shape)}." + ) + + spatial_shape = tuple(input_sample.shape[spatial_slice]) + time_steps = input_sample.shape[time_idx] + num_channels = input_sample.shape[channel_idx] + + # Validate spatial+temporal match + input_spatiotemporal = tuple(input_sample.shape[:-1]) + target_spatiotemporal = tuple(target_sample.shape) + + if input_spatiotemporal != target_spatiotemporal: + raise ValueError( + f"❌ Spatial/temporal dimension mismatch! " + f"Input: {input_spatiotemporal}, Target: {target_spatiotemporal}" + ) + + return { + "dimensions": dimensions, + "spatial_shape": spatial_shape, + "time_steps": time_steps, + "num_channels": num_channels, + } + + +def print_validation_summary( + input_shape: Tuple, + target_shape: Tuple, + variable: str, + is_batch: bool = True, + logger: Optional[object] = None, +): + """ + Print a formatted summary of validation results. + + Automatically detects 3D or 4D format from shapes. + + Parameters + ---------- + input_shape : Tuple + Shape of input tensor + target_shape : Tuple + Shape of target tensor + variable : str + Name of the variable being predicted + is_batch : bool + Whether shapes include batch dimension + logger : object, optional + Logger object with .success() and .info() methods + """ + log_func = logger.success if logger else print + info_func = logger.info if logger else print + + # Detect dimensions + if is_batch: + is_4d = len(input_shape) == 6 + batch_size = input_shape[0] + spatial_start = 1 + else: + is_4d = len(input_shape) == 5 + batch_size = None + spatial_start = 0 + + if is_4d: + dim_label = "4D" + dim_names = ("X", "Y", "Z", "T") + spatial_end = spatial_start + 3 + else: + dim_label = "3D" + dim_names = ("H", "W", "T") + spatial_end = spatial_start + 2 + + spatial_shape = input_shape[spatial_start:spatial_end] + time_steps = input_shape[spatial_end] + num_channels = input_shape[-1] + + log_func(f"✅ Data validation passed! ({dim_label})") + + if is_batch: + spatial_str = " × ".join( + f"{dim_names[i]}={spatial_shape[i]}" for i in range(len(spatial_shape)) + ) + info_func( + f" Input shape: {input_shape} → (batch, {', '.join(dim_names)}, channels)" + ) + info_func(f" Target shape: {target_shape} → (batch, {', '.join(dim_names)})") + info_func(f" Batch size: {batch_size}") + else: + info_func(f" Input shape: {input_shape} → ({', '.join(dim_names)}, channels)") + info_func(f" Target shape: {target_shape} → ({', '.join(dim_names)})") + + spatial_str = " × ".join(str(s) for s in spatial_shape) + info_func(f" Spatial dimensions: {spatial_str} ({' × '.join(dim_names[:-1])})") + info_func(f" Time steps: {time_steps}") + info_func(f" Input channels: {num_channels}") + info_func(f" Variable: {variable}") + + +def get_dimension_info(tensor: torch.Tensor, is_batch: bool = True) -> Dict: + """ + Extract dimension information from a tensor. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor + is_batch : bool + Whether tensor includes batch dimension + + Returns + ------- + Dict with dimension information + """ + ndim = tensor.dim() + + if is_batch: + if ndim == 5: + return { + "dimensions": "3d", + "batch_size": tensor.shape[0], + "spatial_shape": tuple(tensor.shape[1:3]), + "time_steps": tensor.shape[3], + "num_channels": tensor.shape[4], + } + elif ndim == 6: + return { + "dimensions": "4d", + "batch_size": tensor.shape[0], + "spatial_shape": tuple(tensor.shape[1:4]), + "time_steps": tensor.shape[4], + "num_channels": tensor.shape[5], + } + elif ndim == 4: + # Output tensor (3D) + return { + "dimensions": "3d", + "batch_size": tensor.shape[0], + "spatial_shape": tuple(tensor.shape[1:3]), + "time_steps": tensor.shape[3], + "num_channels": 1, + } + elif ndim == 5: + # Output tensor (4D) + return { + "dimensions": "4d", + "batch_size": tensor.shape[0], + "spatial_shape": tuple(tensor.shape[1:4]), + "time_steps": tensor.shape[4], + "num_channels": 1, + } + else: + if ndim == 4: + return { + "dimensions": "3d", + "spatial_shape": tuple(tensor.shape[0:2]), + "time_steps": tensor.shape[2], + "num_channels": tensor.shape[3], + } + elif ndim == 5: + return { + "dimensions": "4d", + "spatial_shape": tuple(tensor.shape[0:3]), + "time_steps": tensor.shape[3], + "num_channels": tensor.shape[4], + } + + raise ValueError(f"Cannot extract dimension info from {ndim}D tensor") diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/README.md new file mode 100644 index 0000000000..840502d51f --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/README.md @@ -0,0 +1,28 @@ +# Neural Operator Factory — Examples + +Each subdirectory is a self-contained example with its own Hydra config files +and SLURM batch scripts. All examples use the shared training script at +`training/train.py` and per-example evaluation scripts. + +## Running an Example + +All jobs are submitted from the `neural_operator_factory/` directory: + +```bash +cd examples/reservoir_simulation/neural_operator_factory + +# Training (8-GPU via SLURM) +sbatch examples/ufno_co2/train.sbatch saturation_ufno + +# Evaluation (single GPU via SLURM) +sbatch examples/ufno_co2/eval.sbatch saturation +``` + +## Available Examples + +| Example | Description | Dataset | Architectures | +|---------|-------------|---------|---------------| +| [ufno_co2](ufno_co2/) | Reproduce U-FNO paper (Wen et al. 2022) | CO2 sequestration (3D) | FNO, Conv-FNO, U-FNO | +| [udeeponet_co2](udeeponet_co2/) | Reproduce U-DeepONet paper (Diab & Al Kobaisi 2024) | CO2 sequestration (3D) | U-DeepONet | +| [fourier_mionet_co2](fourier_mionet_co2/) | Reproduce Fourier-MIONet paper (Jiang et al. 2024) | CO2 sequestration (3D) | MIONet, MIONet-FNN, Fourier-MIONet | +| [tno_co2](tno_co2/) | Reproduce TNO paper (Diab & Al Kobaisi 2025) | CO2 sequestration (3D) | TNO (autoregressive) | diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/README.md new file mode 100644 index 0000000000..1bbd1c2fcf --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/README.md @@ -0,0 +1,133 @@ +# Fourier-MIONet for CO2 Sequestration + +Reproduces the main results from: + +> **Fourier-MIONet: Fourier-enhanced multiple-input neural** +> **operators for multiphase modeling of geological carbon** +> **sequestration** +> Jiang, Z., Zhu, M. & Lu, L. (2024). +> *Reliability Engineering & System Safety*, 251, 110392. +> [arXiv:2303.04778](https://arxiv.org/abs/2303.04778) + +## Overview + +Three MIONet-based architectures on the CO2 dataset: + +| Model | Branch1 | Branch2 | Decoder | Paper R² (sat) | +|-------|---------|---------|---------|---------------| +| Vanilla MIONet | Spatial (conv) | MLP (scalars) | Direct | 0.948 | +| MIONet-FNN | Spatial (conv) | MLP (scalars) | Deep FNN | 0.971 | +| Fourier-MIONet | Spatial (Fourier+UNet) | MLP (scalars) | MLP | 0.985 | + +All models use a sinusoidal trunk for time encoding and +full-mapping training (all 24 timesteps at once). + +## Dataset + +Same CO2 sequestration dataset as the U-FNO and +U-DeepONet examples: + +- **Grid**: 96 x 200 x 24 time steps +- **Input channels**: 12 +- **Samples**: 4,500 train / 500 val / 500 test + +The dataset is publicly available at: + + +## Usage + +All commands from `neural_operator_factory/`: + +### Training + +```bash +# Fourier-MIONet saturation (flagship) +sbatch examples/fourier_mionet_co2/train.sbatch \ + fourier_mionet saturation_training_config + +# Vanilla MIONet pressure +sbatch examples/fourier_mionet_co2/train.sbatch \ + vanilla_mionet pressure_training_config + +# MIONet-FNN saturation +sbatch examples/fourier_mionet_co2/train.sbatch \ + mionet_fnn saturation_training_config + +# All 6 experiments +for arch in vanilla_mionet mionet_fnn fourier_mionet; do + for var in saturation pressure; do + sbatch examples/fourier_mionet_co2/train.sbatch \ + $arch ${var}_training_config + done +done +``` + +### Evaluation + +```bash +sbatch examples/fourier_mionet_co2/eval.sbatch saturation +sbatch examples/fourier_mionet_co2/eval.sbatch pressure +``` + +## Results + +### Paper Reference (Table 6) + +| Model | Sat R² | Pres R² | +|-------|--------|---------| +| Vanilla MIONet | 0.948 | 0.961 | +| MIONet-FNN | 0.971 | 0.979 | +| Fourier-MIONet | 0.985 | 0.986 | + +### NOF Reproduction (500 test samples, 8× GPU) + +| Model | Sat MPE | Sat R² | Pres MRE | Pres R² | +|-------|---------|--------|----------|---------| +| Vanilla MIONet | — | — | — | — | +| MIONet-FNN | 0.1454 | 0.705 | 0.0614 | 0.592 | +| **Fourier-MIONet** | **0.0227** | **0.990** | **0.0082** | **0.987** | + +Fourier-MIONet exceeds the paper's R² targets. +Baselines underperform due to minimal branch architecture +(single conv layer) — they demonstrate the ranking +Fourier-MIONet >> MIONet-FNN >> Vanilla MIONet, +consistent with the paper. + +## Files + +```text +fourier_mionet_co2/ +├── README.md +├── train.sbatch +├── eval.sbatch +├── evaluate_pressure.py +├── evaluate_saturation.py +└── conf/ + ├── vanilla_mionet/ + │ ├── model_config.yaml + │ ├── saturation_training_config.yaml + │ └── pressure_training_config.yaml + ├── mionet_fnn/ + │ ├── model_config.yaml + │ ├── saturation_training_config.yaml + │ └── pressure_training_config.yaml + └── fourier_mionet/ + ├── model_config.yaml + ├── saturation_training_config.yaml + └── pressure_training_config.yaml +``` + +## Reference + +```bibtex +@article{jiang2024fourier, + title={{Fourier-MIONet}: Fourier-enhanced multiple-input + neural operators for multiphase modeling of geological + carbon sequestration}, + author={Jiang, Zhongyi and Zhu, Min and Lu, Lu}, + journal={Reliability Engineering \& System Safety}, + volume={251}, + pages={110392}, + year={2024} +} +``` diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/model_config.yaml new file mode 100644 index 0000000000..e4c7c13d9f --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/model_config.yaml @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Fourier-MIONet — Model Configuration (CO2 Sequestration) +# ============================================================================= +# MIONet with Fourier + U-Net enhanced spatial branch. +# Branch1 uses Fourier layers + U-Net layers for spatial encoding, +# branch2 uses MLP for scalar inputs. +# Width = 36, matching the paper's ~3.7M parameters. +# Paper reference R2: sat=0.985, pres=0.986 +# ============================================================================= + +arch: + dimensions: 3d + model: xdeeponet + + xdeeponet: + variant: fourier_mionet + width: 36 + padding: 8 + + branch1: + encoder: + type: linear + activation_fn: relu + layers: + num_fourier_layers: 3 + num_unet_layers: 3 + num_conv_layers: 0 + modes1: 12 + modes2: 12 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: relu + internal_resolution: null + + branch2: + encoder: + type: mlp + hidden_width: 36 + num_layers: 3 + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 0 + num_conv_layers: 0 + + trunk: + input_type: time + hidden_width: 36 + num_layers: 6 + activation_fn: sin + output_activation: true + + decoder_type: mlp + decoder_width: 128 + decoder_layers: 1 + decoder_activation_fn: relu + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/pressure_training_config.yaml new file mode 100644 index 0000000000..156e32a199 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/pressure_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_fourier_mionet_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/saturation_training_config.yaml new file mode 100644 index 0000000000..adb9e1b912 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/fourier_mionet/saturation_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_fourier_mionet_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/model_config.yaml new file mode 100644 index 0000000000..0eb7b63d40 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/model_config.yaml @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# MIONet-FNN — Model Configuration (CO2 Sequestration) +# ============================================================================= +# Same as vanilla MIONet but with a deeper FNN decoder +# (merger net) after branch-trunk product. +# Paper reference R2: sat=0.971, pres=0.979 +# ============================================================================= + +arch: + dimensions: 3d + model: xdeeponet + + xdeeponet: + variant: mionet + width: 36 + padding: 8 + + branch1: + encoder: + type: linear + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 0 + num_conv_layers: 1 + modes1: 12 + modes2: 12 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: relu + internal_resolution: null + + branch2: + encoder: + type: mlp + hidden_width: 36 + num_layers: 3 + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 0 + num_conv_layers: 0 + + trunk: + input_type: time + hidden_width: 36 + num_layers: 6 + activation_fn: sin + output_activation: true + + decoder_type: mlp + decoder_width: 128 + decoder_layers: 2 + decoder_activation_fn: relu + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/pressure_training_config.yaml new file mode 100644 index 0000000000..d8ef4ecd27 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/pressure_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_mionet_fnn_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/saturation_training_config.yaml new file mode 100644 index 0000000000..baa803a4b5 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/mionet_fnn/saturation_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_mionet_fnn_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/model_config.yaml new file mode 100644 index 0000000000..fc21f70c31 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/model_config.yaml @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Vanilla MIONet — Model Configuration (CO2 Sequestration) +# ============================================================================= +# Two-branch MIONet: branch1 = spatial (minimal), branch2 = MLP +# for scalar inputs, trunk = sinusoidal FNN. No Fourier layers. +# Paper reference R2: sat=0.948, pres=0.961 +# ============================================================================= + +arch: + dimensions: 3d + model: xdeeponet + + xdeeponet: + variant: mionet + width: 36 + padding: 8 + + branch1: + encoder: + type: linear + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 0 + num_conv_layers: 1 + modes1: 12 + modes2: 12 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: relu + internal_resolution: null + + branch2: + encoder: + type: mlp + hidden_width: 36 + num_layers: 3 + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 0 + num_conv_layers: 0 + + trunk: + input_type: time + hidden_width: 36 + num_layers: 6 + activation_fn: sin + output_activation: true + + decoder_type: mlp + decoder_width: 36 + decoder_layers: 0 + decoder_activation_fn: relu + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/pressure_training_config.yaml new file mode 100644 index 0000000000..b3476ed7b4 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/pressure_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_vanilla_mionet_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/saturation_training_config.yaml new file mode 100644 index 0000000000..04cbc29cc3 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/conf/vanilla_mionet/saturation_training_config.yaml @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/fourier_mionet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: fmionet_co2_vanilla_mionet_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_pressure.py b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_pressure.py new file mode 100644 index 0000000000..4194800c40 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_pressure.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 pressure buildup evaluation — matches the validated U-FNO evaluation algorithm. + +dnorm_dP is applied to prediction only. MRE and R² are computed with +pred in physical bar and target in raw (normalized) space, matching the +original paper evaluation. R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def dnorm_dP(dP): + """Denormalize pressure predictions from normalized to physical units.""" + dP = dP * 18.772821433027488 + dP = dP + 4.172939172019009 + return dP + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def mean_relative_error(y_pred, y_true): + """Compute mean relative error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true) / (y_true.max() - y_true.min())) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 pressure model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "pressure" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mre_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + _has_b2 = model_config.get("variant", "") in ("mionet", "fourier_mionet", "tno") + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + fwd = {"x_branch2": x_batch[:, 0, 0, 0, :]} if _has_b2 else {} + pred_batch = model(x_batch, **fwd) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = dnorm_dP(pred_batch.cpu().numpy()) + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mre = mean_relative_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mre_list.append(np.mean(mre)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mre = np.mean(mre_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Pressure Buildup — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MRE: {overall_mre:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + vfwd = {"x_branch2": x[:, 0, 0, 0, :]} if _has_b2 else {} + pred = model(x, **vfwd) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = dnorm_dP(pred.squeeze(0).cpu().numpy()) + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True dP, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred dP (bar), t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"pressure_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_saturation.py b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_saturation.py new file mode 100644 index 0000000000..e6fa278a79 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/fourier_mionet_co2/evaluate_saturation.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 gas saturation evaluation — matches the validated U-FNO evaluation algorithm. + +No denormalization. MPE computed within the plume region. +R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def mean_plume_error(y_pred, y_true): + """Compute mean plume error for saturation predictions.""" + mask = (y_pred != 0) & (y_true != 0) + y_pred_masked = y_pred[mask] + y_true_masked = y_true[mask] + if len(y_pred_masked) == 0: + return 0.0 + return np.mean(np.abs(y_pred_masked - y_true_masked)) + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 saturation model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "saturation" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mpe_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + _has_b2 = model_config.get("variant", "") in ("mionet", "fourier_mionet", "tno") + + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + fwd = {"x_branch2": x_batch[:, 0, 0, 0, :]} if _has_b2 else {} + pred_batch = model(x_batch, **fwd) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = pred_batch.cpu().numpy() + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mpe = mean_plume_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mpe_list.append(np.mean(mpe)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mpe = np.mean(mpe_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Gas Saturation — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MPE: {overall_mpe:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + vfwd = {"x_branch2": x[:, 0, 0, 0, :]} if _has_b2 else {} + pred = model(x, **vfwd) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = pred.squeeze(0).cpu().numpy() + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"saturation_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/README.md new file mode 100644 index 0000000000..28c15a44e3 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/README.md @@ -0,0 +1,183 @@ +# Physics-Informed Fourier-DeepONet for Norne Field Simulation + +Physics-informed neural operator surrogate for the Norne reservoir +simulation dataset using a 4D Fourier-DeepONet with derivative +regularization, mass conservation losses, and autoregressive feedback. + +## Overview + +This example trains a Fourier-DeepONet on the Norne field dataset — a +real-world 3D reservoir model based on the publicly available +[Norne Field](https://github.com/OPM/opm-data/tree/master/norne) dataset. +Norne requires volumetric 4D operators +(3D spatial + time) and handles complex geological features including +numerous faults with Non-Neighbor Connections (NNCs), pinch-outs, and +39% inactive cells. + +A Design of Experiment study identified fault transmissibility and +KVKH multipliers as key uncertain parameters, which were varied using +Latin Hypercube Sampling (LHS) to generate 500 realizations. The +primary LHS variable is PERMZ (vertical permeability), controlling the +Kv/Kh ratio. All simulations were generated using the open-source +[OPM](https://opm-project.org/) reservoir simulator. + +### Pressure Architecture (Fourier-DeepONet) + +| Component | Configuration | +|-----------|--------------| +| Branch1 | Linear encoder → 6 Fourier layers, gelu, modes 10×10×6 | +| Trunk | 12-layer tanh FNN, time input, linear output | +| Decoder | Temporal projection (width → K=1) | +| Feedback | Previous prediction appended as extra input channel | +| Width | 64 | +| Parameters | 118M | + +### Saturation Architecture (Fourier-DeepONet — SWAT / SGAS) + +| Component | Configuration | +|-----------|--------------| +| Branch1 | Linear encoder → 6 Fourier layers, gelu, modes 10×10×6 | +| Trunk | 12-layer tanh FNN, time input, linear output | +| Decoder | Temporal projection (width → K=1) | +| Feedback | Previous prediction appended as extra input channel | +| Width | 64 | +| Parameters | 118M | + +### Losses + +| Variable | Data Loss | Derivative | Mass Conservation | +|----------|----------|-----------|------------------| +| Pressure | Relative L2 (w=1.0) | dx, dy, dz (w=0.5) | Disabled | +| SWAT | L1 (w=1.0) | dx, dy, dz (w=0.5) | Enabled (w=0.5) | +| SGAS | L1 (w=1.0) | dx, dy, dz (w=0.5) | Enabled (w=0.5) | + +### Training Configuration + +| Setting | Pressure | Saturations (SWAT / SGAS) | +|---------|----------|--------------------------| +| Regime | AR: 10 TF + 90 rollout | AR: 10 TF + 90 rollout | +| Rollout mode | `detached` | `detached` | +| L / K | 3 / 1 | 3 / 1 | +| Feedback channel | enabled | enabled | +| Batch size | 1 per GPU × 8 GPUs | 1 per GPU × 8 GPUs | +| Optimizer | Adam, lr=1e-3, wd=1e-4 | Adam, lr=1e-3, wd=1e-4 | +| Scheduler | StepLR(10, 0.85) | StepLR(10, 0.85) | +| Masking | ACTNUM ch 5 (39.2%) | ACTNUM auto-detect | +| Normalize | true | false | + +## Dataset + +Norne field reservoir simulation (500 LHS realizations, OPM simulator): + +- **Grid**: 46 × 112 × 22 (113,344 total cells, 44,431 active — 39.2%) +- **Time steps**: 65 (0 to 3,260 days, ~9 years of operation) +- **Wells**: 36 (producers and injectors, multi-layer horizontal completions) +- **LHS variable**: PERMZ (vertical permeability / Kv/Kh ratio) +- **Input channels**: 11 + - Static: PERMX (log10), PERMZ (log10), PORO, PORV, NTG, ACTNUM + - Coordinates: grid_x, grid_y, grid_z (normalized) + - Dynamic: grid_t (normalized), WCID (+1 injector / -1 producer / 0 none) +- **Output variables** (separate models): + - `pressure` — cell pressure (bar) + - `swat` — water saturation (fraction) + - `sgas` — gas saturation (fraction) +- **Samples**: 400 train / 50 val / 50 test + +### Comparison with CO2 Dataset + +| Aspect | CO2 | Norne | +|--------|-----|-------| +| Spatial dims | 2D (96 × 200) | 3D (46 × 112 × 22) | +| Total cells | 19,200 | 113,344 | +| Active cells | variable (~53%) | 44,431 (39.2%) | +| Timesteps | 24 | 65 | +| Samples | 5,500 | 500 | +| Input channels | 12 | 11 | +| Simulator | Custom | OPM | + +## Usage + +All commands from `neural_operator_factory/`: + +### Training + +```bash +# Pressure +sbatch examples/pi_norne/train.sbatch pressure_training_config + +# Water saturation (SWAT) +sbatch examples/pi_norne/train.sbatch swat_training_config + +# Gas saturation (SGAS) +sbatch examples/pi_norne/train.sbatch sgas_training_config +``` + +### Evaluation + +```bash +# Pressure (normalize + feedback) +NORMALIZE=1 FEEDBACK=1 sbatch examples/pi_norne/eval.sbatch pressure + +# Saturations (feedback, no normalization) +FEEDBACK=1 sbatch examples/pi_norne/eval.sbatch swat +FEEDBACK=1 sbatch examples/pi_norne/eval.sbatch sgas + +# With explicit checkpoint +NORMALIZE=1 FEEDBACK=1 \ + CHECKPOINT=checkpoints/best_model_pressure_deeponet3d_fourier_deeponet_linear.pth \ + sbatch examples/pi_norne/eval.sbatch pressure +``` + +## Results + +### Pressure (Fourier-DeepONet) + +| Metric | Value | +|--------|-------| +| MAE | 0.92 bar | +| RMSE | 1.54 bar | +| Relative L2 | 0.54% | +| R² | 0.999 | +| Parameters | 118M | +| Training time | 3 hr 24 min (8× H100, 100 epochs) | + +### Water Saturation (SWAT) + +| Metric | Value | +|--------|-------| +| MAE | 2.91e-3 | +| RMSE | 7.89e-3 | +| Relative L2 | 1.05% | +| R² | 0.9996 | +| Parameters | 118M | +| Training time | 3 hr 17 min (8× H100, 100 epochs) | + +### Gas Saturation (SGAS) + +| Metric | Value | +|--------|-------| +| MAE | 1.09e-2 | +| RMSE | 4.17e-2 | +| Relative L2 | 14.1% | +| R² | 0.977 | +| Parameters | 118M | +| Training time | 3 hr 15 min (8× H100, 100 epochs) | + +## Files + +```text +pi_norne/ +├── README.md +├── train.sbatch # 8 GPU, 4 hour time limit +├── eval.sbatch # 1 GPU evaluation +└── conf/ + ├── pressure_model_config.yaml # Pressure architecture + loss (no mass conservation) + ├── swat_model_config.yaml # SWAT architecture + loss (mass conservation enabled) + ├── sgas_model_config.yaml # SGAS architecture + loss (mass conservation enabled) + ├── pressure_training_config.yaml + ├── swat_training_config.yaml + └── sgas_training_config.yaml + +Each variable has its own model config. All three currently use the same +Fourier-DeepONet architecture with per-variable loss configuration. +``` diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_model_config.yaml new file mode 100644 index 0000000000..e38d33ff18 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_model_config.yaml @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# PI-Norne — Pressure Fourier-DeepONet Model Configuration +# ============================================================================= +# 4D Fourier-DeepONet for pressure prediction on the Norne field. +# Branch: 6 Fourier layers (pure spectral processing) +# Trunk: 12-layer deep MLP for temporal encoding +# Feedback channel appends previous prediction as extra input. +# +# Norne grid: 46 x 112 x 22, 44,431 active cells (39.2%), 65 timesteps +# Input channels: 11 + 1 feedback = 12 during AR training +# Output: pressure (bar) +# ============================================================================= + +arch: + dimensions: 4d + model: xdeeponet + + xdeeponet: + variant: fourier_deeponet + width: 64 + padding: 8 + + # Branch1: 6 pure Fourier layers for spatial encoding + branch1: + encoder: + type: linear + activation_fn: gelu + layers: + num_fourier_layers: 6 + num_unet_layers: 0 + num_conv_layers: 0 + modes1: 10 + modes2: 10 + modes3: 6 + kernel_size: 3 + dropout: 0.0 + activation_fn: gelu + + # Trunk: 12-layer deep MLP for time coordinates + trunk: + input_type: time + hidden_width: 128 + num_layers: 12 + activation_fn: tanh + output_activation: false + + # Decoder: temporal projection for single-step output + decoder_type: temporal_projection + decoder_width: 64 + decoder_layers: 1 + decoder_activation_fn: relu + +# ============================================================================= +# Pressure Loss Configuration +# ============================================================================= +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx, dy, dz] + metric: null + + physics: + mass_conservation: + enabled: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_training_config.yaml new file mode 100644 index 0000000000..e1abc4f722 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/pressure_training_config.yaml @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# PI-Norne — Pressure Fourier-DeepONet (Norne Field) + +defaults: + - pressure_model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/pi_norne/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/norne + variable: pressure + input_file: norne_{mode}_a.pt + output_file: norne_{mode}_pressure.pt + normalize: true + num_workers: 4 + val_metric: rmse + mask_enabled: true + mask_channel: 5 + denormalize_fn: null + num_timesteps: null + +training: + batch_size: 1 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: autoregressive + + autoregressive: + input_window: 3 + output_window: 1 + teacher_forcing_epochs: 10 + pushforward_epochs: 0 + rollout_epochs: 90 + max_unroll: 5 + noise_std: 0.0 + use_feedback_channel: true + lr_reset_factor: 1.0 + rollout_mode: detached + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: pi_norne_pressure_fdon + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_model_config.yaml new file mode 100644 index 0000000000..d6891f4a84 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_model_config.yaml @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# PI-Norne — Gas Saturation (SGAS) Fourier-DeepONet Model Configuration +# ============================================================================= +# 4D Fourier-DeepONet for gas saturation prediction on the Norne field. +# Branch: 6 Fourier layers (pure spectral processing) +# Trunk: 12-layer deep MLP for temporal encoding +# Feedback channel appends previous prediction as extra input. +# +# Norne grid: 46 x 112 x 22, 44,431 active cells (39.2%), 65 timesteps +# Input channels: 11 + 1 feedback = 12 during AR training +# Output: gas saturation (fraction, [0, 1]) +# ============================================================================= + +arch: + dimensions: 4d + model: xdeeponet + + xdeeponet: + variant: fourier_deeponet + width: 64 + padding: 8 + + # Branch1: 6 pure Fourier layers for spatial encoding + branch1: + encoder: + type: linear + activation_fn: gelu + layers: + num_fourier_layers: 6 + num_unet_layers: 0 + num_conv_layers: 0 + modes1: 10 + modes2: 10 + modes3: 6 + kernel_size: 3 + dropout: 0.0 + activation_fn: gelu + + # Trunk: 12-layer deep MLP for time coordinates + trunk: + input_type: time + hidden_width: 128 + num_layers: 12 + activation_fn: tanh + output_activation: false + + # Decoder: temporal projection for single-step output + decoder_type: temporal_projection + decoder_width: 64 + decoder_layers: 1 + decoder_activation_fn: relu + +# ============================================================================= +# SGAS Loss Configuration +# ============================================================================= +loss: + types: [l1] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx, dy, dz] + metric: null + + physics: + mass_conservation: + enabled: true + weight: 0.5 + use_cell_volumes: true + metric: null diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_training_config.yaml new file mode 100644 index 0000000000..50a7d38875 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/sgas_training_config.yaml @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# PI-Norne — Gas Saturation (Norne Field) + +defaults: + - sgas_model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/pi_norne/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/norne + variable: sgas + input_file: norne_{mode}_a.pt + output_file: norne_{mode}_sgas.pt + normalize: false + num_workers: 4 + val_metric: rmse + mask_enabled: true + mask_channel: null + denormalize_fn: null + num_timesteps: null + +training: + batch_size: 1 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: autoregressive + + autoregressive: + input_window: 3 + output_window: 1 + teacher_forcing_epochs: 10 + pushforward_epochs: 0 + rollout_epochs: 90 + max_unroll: 5 + noise_std: 0.0 + use_feedback_channel: true + lr_reset_factor: 1.0 + rollout_mode: detached + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: pi_norne_sgas + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_model_config.yaml new file mode 100644 index 0000000000..bcb94bec0c --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_model_config.yaml @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# PI-Norne — Water Saturation (SWAT) Fourier-DeepONet Model Configuration +# ============================================================================= +# 4D Fourier-DeepONet for water saturation prediction on the Norne field. +# Branch: 6 Fourier layers (pure spectral processing) +# Trunk: 12-layer deep MLP for temporal encoding +# Feedback channel appends previous prediction as extra input. +# +# Norne grid: 46 x 112 x 22, 44,431 active cells (39.2%), 65 timesteps +# Input channels: 11 + 1 feedback = 12 during AR training +# Output: water saturation (fraction, [0, 1]) +# ============================================================================= + +arch: + dimensions: 4d + model: xdeeponet + + xdeeponet: + variant: fourier_deeponet + width: 64 + padding: 8 + + # Branch1: 6 pure Fourier layers for spatial encoding + branch1: + encoder: + type: linear + activation_fn: gelu + layers: + num_fourier_layers: 6 + num_unet_layers: 0 + num_conv_layers: 0 + modes1: 10 + modes2: 10 + modes3: 6 + kernel_size: 3 + dropout: 0.0 + activation_fn: gelu + + # Trunk: 12-layer deep MLP for time coordinates + trunk: + input_type: time + hidden_width: 128 + num_layers: 12 + activation_fn: tanh + output_activation: false + + # Decoder: temporal projection for single-step output + decoder_type: temporal_projection + decoder_width: 64 + decoder_layers: 1 + decoder_activation_fn: relu + +# ============================================================================= +# SWAT Loss Configuration +# ============================================================================= +loss: + types: [l1] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx, dy, dz] + metric: null + + physics: + mass_conservation: + enabled: true + weight: 0.5 + use_cell_volumes: true + metric: null diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_training_config.yaml new file mode 100644 index 0000000000..885999222b --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/conf/swat_training_config.yaml @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# PI-Norne — Water Saturation (Norne Field) + +defaults: + - swat_model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/pi_norne/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/norne + variable: swat + input_file: norne_{mode}_a.pt + output_file: norne_{mode}_swat.pt + normalize: false + num_workers: 4 + val_metric: rmse + mask_enabled: true + mask_channel: null + denormalize_fn: null + num_timesteps: null + +training: + batch_size: 1 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: autoregressive + + autoregressive: + input_window: 3 + output_window: 1 + teacher_forcing_epochs: 10 + pushforward_epochs: 0 + rollout_epochs: 90 + max_unroll: 5 + noise_std: 0.0 + use_feedback_channel: true + lr_reset_factor: 1.0 + rollout_mode: detached + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: pi_norne_swat + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/evaluate_norne.py b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/evaluate_norne.py new file mode 100644 index 0000000000..892a1ff8da --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/pi_norne/evaluate_norne.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluate a trained model on the Norne field test dataset. + +Supports two evaluation modes: + full_mapping: Feed all timesteps at once (single forward pass). + autoregressive: Roll out step-by-step (L context -> K predicted), + matching the XMGN autoregressive inference protocol. + +Automatically detects TNO variant and feedback-channel usage from the +saved checkpoint config so you don't need to specify them manually. + +Usage: + # Pressure evaluation (auto-detect model variant from checkpoint) + python scripts/evaluate_norne.py --variable pressure +or sbatch eval_norne.sbatch pressure + + # Saturation (SWAT) + python scripts/evaluate_norne.py --variable swat +or sbatch eval_norne.sbatch swat + + # Custom checkpoint + mode + python scripts/evaluate_norne.py --checkpoint path/to/model.pth --mode autoregressive --L 1 --K 3 + CHECKPOINT=checkpoints/best_model_swat_deeponet3d_tno_spatial.pth sbatch eval_norne.sbatch swat + + NORMALIZE=0 sbatch eval_norne.sbatch pressure +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import argparse + +import numpy as np +import torch +from data.dataloader import ReservoirDataset +from training.ar_utils import ar_validate_full_rollout +from training.metrics import ( + compute_r2_score, + compute_relative_l2_error, +) +from utils.checkpoint import build_model_from_config + + +def load_model(model_config, device): + """Reconstruct model from saved config.""" + model, _ = build_model_from_config(model_config, device=device) + return model + + +def print_metrics( + all_predictions, all_targets, variable, num_timesteps, spatial_mask=None +): + """Compute and print all metrics in XMGN-compatible format.""" + if spatial_mask is not None: + m = spatial_mask + for _ in range(all_predictions.ndim - m.ndim): + m = ( + m[np.newaxis] + if m.ndim < all_predictions.ndim - 1 + else m[..., np.newaxis] + ) + m = np.broadcast_to(m, all_predictions.shape) + pf, gf = all_predictions[m], all_targets[m] + else: + pf, gf = all_predictions.ravel(), all_targets.ravel() + overall_mae = np.mean(np.abs(pf - gf)) + overall_mse = np.mean((pf - gf) ** 2) + overall_rmse = np.sqrt(overall_mse) + overall_rel_l2 = compute_relative_l2_error(pf, gf) + overall_r2 = compute_r2_score(pf, gf) + + print("=" * 70) + print("EVALUATION RESULTS") + print("=" * 70) + print(f"Test samples processed: {all_predictions.shape[0]}") + print(f"Overall MAE: {overall_mae:.6e}") + print(f"Overall MSE: {overall_mse:.6e}") + print(f"Overall RMSE: {overall_rmse:.6e}") + print(f"Relative L2 Error: {overall_rel_l2:.6e}") + print(f"R2 Score: {overall_r2:.6f}") + print() + + print("Per-Variable Metrics:") + print("-" * 70) + print( + f" {variable:>12s} | MAE: {overall_mae:>12.6e} | RMSE: {overall_rmse:>12.6e}" + ) + print() + + print("Per-Timestep Metrics (averaged over all samples):") + print(f"{'t':>4s} | {'MAE':>12s} | {'MSE':>12s} | {'RMSE':>12s}") + print("-" * 50) + for t in range(num_timesteps): + pred_t = all_predictions[..., t] + gt_t = all_targets[..., t] + if spatial_mask is not None: + mt = np.broadcast_to(spatial_mask[np.newaxis], pred_t.shape) + pt, gt = pred_t[mt], gt_t[mt] + else: + pt, gt = pred_t.ravel(), gt_t.ravel() + t_mae = np.mean(np.abs(pt - gt)) + t_mse = np.mean((pt - gt) ** 2) + t_rmse = np.sqrt(t_mse) + print(f"{t:4d} | {t_mae:12.6e} | {t_mse:12.6e} | {t_rmse:12.6e}") + + print() + print("Per-Sample Summary (first 10 and last 5):") + print( + f"{'sample':>6s} | {'MAE':>12s} | {'RMSE':>12s} | {'RelL2':>12s} | {'R2':>8s}" + ) + print("-" * 60) + n = all_predictions.shape[0] + show_idx = list(range(min(10, n))) + list(range(max(10, n - 5), n)) + for i in show_idx: + pi, gi = all_predictions[i], all_targets[i] + if spatial_mask is not None: + ms = np.broadcast_to(spatial_mask[..., np.newaxis], pi.shape) + pi_f, gi_f = pi[ms], gi[ms] + else: + pi_f, gi_f = pi.ravel(), gi.ravel() + s_mae = np.mean(np.abs(pi_f - gi_f)) + s_rmse = np.sqrt(np.mean((pi_f - gi_f) ** 2)) + s_rel_l2 = compute_relative_l2_error(pi_f, gi_f) + s_r2 = compute_r2_score(pi_f, gi_f) + print( + f"{i:6d} | {s_mae:12.6e} | {s_rmse:12.6e} | {s_rel_l2:12.6e} | {s_r2:8.4f}" + ) + + print() + print("=" * 70) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate model on Norne test set (XMGN-compatible metrics)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/norne", + ) + parser.add_argument( + "--input_file", + type=str, + default="norne_{mode}_a.pt", + help="Input file pattern ({mode} replaced with train/val/test)", + ) + parser.add_argument( + "--output_file", + type=str, + default=None, + help="Output file pattern. If not set, inferred from --variable: " + "pressure->norne_{mode}_pressure.pt, swat->norne_{mode}_swat.pt, sgas->norne_{mode}_sgas.pt", + ) + parser.add_argument( + "--variable", + type=str, + default="pressure", + choices=["pressure", "swat", "sgas"], + help="Variable to evaluate (default: pressure)", + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument( + "--tno", action="store_true", help="TNO mode: feed predictions back as branch2" + ) + parser.add_argument( + "--mask", + action="store_true", + help="Auto-detect ACTNUM and evaluate on active cells only", + ) + parser.add_argument( + "--feedback", + action="store_true", + help="Enable feedback channel (append previous prediction as extra input). " + "Auto-detected from checkpoint if feedback_channel is saved.", + ) + parser.add_argument( + "--mode", + type=str, + default="full_mapping", + choices=["full_mapping", "autoregressive"], + help="full_mapping: single forward pass. autoregressive: AR rollout.", + ) + parser.add_argument( + "--L", type=int, default=1, help="AR input window (context timesteps)" + ) + parser.add_argument( + "--K", + type=int, + default=3, + help="AR output window (predicted timesteps per step)", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="Load data normalized (for models trained with normalize=true). " + "Metrics are reported on denormalized (physical) values.", + ) + args = parser.parse_args() + + # Infer output file from variable if not explicitly set + VARIABLE_FILE_MAP = { + "pressure": "norne_{mode}_pressure.pt", + "swat": "norne_{mode}_swat.pt", + "sgas": "norne_{mode}_sgas.pt", + } + if args.output_file is None: + args.output_file = VARIABLE_FILE_MAP.get( + args.variable, f"norne_{{mode}}_{args.variable}.pt" + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # -- Checkpoint -- + if args.checkpoint: + ckpt_path = Path(args.checkpoint) + else: + ckpt_dir = Path("checkpoints") + candidates = sorted( + ckpt_dir.glob("best_model_*.pth"), key=lambda p: p.stat().st_mtime + ) + if not candidates: + raise FileNotFoundError("No checkpoints found. Specify --checkpoint.") + ckpt_path = candidates[-1] + + checkpoint = torch.load(ckpt_path, map_location=device) + model_config = checkpoint["model_config"] + + # Auto-detect TNO variant and feedback channel from checkpoint + is_tno = model_config.get("variant", "") == "tno" + feedback_channel = model_config.get("feedback_channel", None) + if args.tno: + is_tno = True + if args.feedback: + feedback_channel = 1 + + print("=" * 70) + print("NEURAL OPERATOR FACTORY - NORNE EVALUATION") + print("=" * 70) + print(f"Checkpoint: {ckpt_path}") + print(f"Model: {model_config['model_arch_name']}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6e}") + print(f"Variable: {args.variable}") + print(f"Mode: {args.mode}") + if args.mode == "autoregressive": + print(f" L={args.L}, K={args.K}") + total_steps = "unknown (determined from data)" + print(f" AR steps to cover trajectory: ~{total_steps}") + print() + + # -- Data -- + # When normalize=True, we need train stats to denormalize predictions + # Resolve {mode} patterns for file names + input_pattern = args.input_file + output_pattern = args.output_file + + output_mean, output_std = 0.0, 1.0 + if args.normalize: + train_dataset = ReservoirDataset( + data_path=args.data_path, + mode="train", + input_file=input_pattern, + output_file=output_pattern, + normalize=True, + ) + norm_stats = train_dataset.get_normalization_stats() + output_mean = norm_stats[2].item() + output_std = norm_stats[3].item() + print( + f"Normalization: output_mean={output_mean:.4f}, output_std={output_std:.4f}" + ) + del train_dataset + + test_dataset = ReservoirDataset( + data_path=args.data_path, + mode="test", + input_file=input_pattern, + output_file=output_pattern, + normalize=args.normalize, + ) + if args.normalize: + # Re-load train stats for the test dataset normalization + tmp_train = ReservoirDataset( + data_path=args.data_path, + mode="train", + input_file=input_pattern, + output_file=output_pattern, + normalize=True, + ) + test_dataset.set_normalization(*tmp_train.get_normalization_stats()) + del tmp_train + sample_x, sample_y = test_dataset[0] + num_timesteps = sample_y.shape[-1] + + spatial_mask = None + if args.mask: + mask_ds = ReservoirDataset( + data_path=args.data_path, + mode="test", + input_file=args.input_file, + output_file=args.output_file, + normalize=False, + use_mask=True, + ) + spatial_mask = mask_ds.get_static_mask() + if spatial_mask is not None: + sm = spatial_mask.numpy() + print( + f"Mask: {sm.sum()}/{sm.size} active ({100 * sm.mean():.1f}%)" + ) + else: + print("Mask: no ACTNUM detected") + del mask_ds + + print(f"Test samples: {len(test_dataset)}") + print(f"Input shape: {tuple(sample_x.shape)}") + print(f"Output shape: {tuple(sample_y.shape)}") + print(f"Timesteps: {num_timesteps}") + if args.mode == "autoregressive": + actual_ar_steps = (num_timesteps - args.L) // args.K + print( + f"AR steps: {actual_ar_steps} (from {num_timesteps} timesteps, L={args.L}, K={args.K})" + ) + print() + + # -- Model -- + model = load_model(model_config, device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + num_params = sum(p.numel() for p in model.parameters()) + print(f"Parameters: {num_params:,}") + print() + + # -- Inference -- + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + all_predictions = [] + all_targets = [] + + if args.mode == "autoregressive": + print(f"Running AUTOREGRESSIVE inference (L={args.L}, K={args.K})...") + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + y_batch_dev = y_batch.to(device) + + pred_batch = ar_validate_full_rollout( + model, + x_batch, + y_batch_dev, + L=args.L, + K=args.K, + is_tno=is_tno, + feedback_channel=feedback_channel, + ) + + all_predictions.append(pred_batch.cpu().numpy()) + all_targets.append(y_batch.numpy()) + + if (batch_idx + 1) % 10 == 0: + print( + f" {(batch_idx + 1) * args.batch_size}/{len(test_dataset)} samples" + ) + else: + print("Running FULL-MAPPING inference...") + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + pred_batch = model(x_batch).cpu().numpy() + + all_predictions.append(pred_batch) + all_targets.append(y_batch.numpy()) + + if (batch_idx + 1) % 10 == 0: + print( + f" {(batch_idx + 1) * args.batch_size}/{len(test_dataset)} samples" + ) + + all_predictions = np.concatenate(all_predictions, axis=0) + all_targets = np.concatenate(all_targets, axis=0) + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.\n") + + # Denormalize to physical units for fair metric comparison + if args.normalize: + all_predictions = all_predictions * output_std + output_mean + all_targets = all_targets * output_std + output_mean + print( + f"Denormalized to physical units (mean={output_mean:.4f}, std={output_std:.4f})" + ) + print( + f" Pred range: [{all_predictions.min():.4f}, {all_predictions.max():.4f}]" + ) + print(f" GT range: [{all_targets.min():.4f}, {all_targets.max():.4f}]") + print() + + # -- Metrics -- + sm_np = spatial_mask.numpy() if spatial_mask is not None else None + print_metrics(all_predictions, all_targets, args.variable, num_timesteps, sm_np) + print("Evaluation complete.") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/README.md new file mode 100644 index 0000000000..3b30f88d3b --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/README.md @@ -0,0 +1,153 @@ +# TNO for CO2 Sequestration + +Reproduces the CO2 sequestration results from the Temporal Neural Operator +(TNO) paper: + +> **Temporal Neural Operator for Modeling Time-Dependent Physical Phenomena** +> Diab, W. & Al Kobaisi, M. (2025). +> [arXiv:2504.20249](https://arxiv.org/abs/2504.20249) + +## Overview + +The Temporal Neural Operator (TNO) extends DeepONet with a temporal branch +(t-branch) that processes the previous solution state, enabling +autoregressive temporal predictions with temporal bundling. The +architecture uses a Hadamard product to combine branch, t-branch, and +trunk outputs, followed by a temporal projection decoder that maps from +the latent space directly to K output timesteps. + +### Architecture + +| Component | Configuration | +|-----------|--------------| +| Branch1 | Lifting + 1 x UNet2D + ReLU (input fields) | +| Branch2 (t-branch) | Lifting + 1 x UNet2D + ReLU (previous solution) | +| Trunk | 14-layer tanh FNN, linear output (time only) | +| Decoder | Temporal projection: MLP + Linear(p → K) | +| Width (p) | 96 (saturation), 128 (pressure) | +| L / K | 1 / 3 (temporal bundling) | + +### Training Configuration + +| Setting | Value | +|---------|-------| +| Regime | Autoregressive: 90 TF + 40 rollout epochs | +| Rollout mode | `live_gradients` (full-trajectory loss) | +| Loss | Relative L2 | +| Optimizer | Adam, lr=6e-4, weight_decay=1e-4 | +| Scheduler | StepLR(step_size=2, gamma=0.92) | +| Batch size | 4 per GPU × 8 GPUs | +| Training timesteps | First 16 of 24 (up to 1.8 years) | + +### Temporal Extrapolation + +The model is trained on the first 16 timesteps (up to 1.8 years) and +tested on all 24 timesteps. Timesteps 17-24 (2.6 to 30 years) are +temporal extrapolation — the model must predict dynamics it never saw +during training, while also generalizing to 500 unseen geological +realizations. + +## Dataset + +Same CO2 sequestration dataset as other NOF examples (Wen et al. 2022): + +- **Grid**: 96 × 200 (variable thickness per realization) +- **Time steps**: 24 (1 day to 30 years, logarithmic spacing) +- **Input channels**: 12 (4 spatial fields + 5 scalars + grid coordinates) +- **Samples**: 4,500 train / 500 val / 500 test + +The dataset is publicly available at: + + +## Usage + +All commands from `neural_operator_factory/`: + +### Training + +```bash +# Gas saturation (p=96, ~2.7M params) +sbatch examples/tno_co2/train.sbatch saturation_training_config + +# Pressure buildup (p=128, ~7.7M params) +sbatch examples/tno_co2/train.sbatch pressure_training_config +``` + +### Evaluation + +```bash +sbatch examples/tno_co2/eval.sbatch saturation +sbatch examples/tno_co2/eval.sbatch pressure +``` + +Note: the pressure evaluation script applies `r_dnorm_dP` to convert the +test targets from physical bar units to normalized space before feeding +to the AR rollout, and `dnorm_dP` to convert predictions back to bar for +metric computation. + +## Results + +### Test Set (500 samples, 24 timesteps including 8 extrapolation) + +| Variable | Metric | NOF | Paper | +|----------|--------|-----|-------| +| **Saturation** | MPE | 4.89% | ~3-5% | +| | MAE | 0.0079 | ~0.005-0.01 | +| | R² | 0.958 | ~0.96 | +| **Pressure** | MRE | 1.27% | ~1-2% | +| | MAE | 1.11 bar | ~1-2 bar | +| | R² | 0.985 | ~0.98 | + +### Validation (16 timesteps, generalization only) + +| Variable | Best Val Loss | Best Metric | Epoch | +|----------|--------------|-------------|-------| +| Saturation | 0.0798 | MPE = 1.81% | 130 | +| Pressure | 0.0577 | MRE = 0.65% | 130 | + +## Key Implementation Details + +- **Temporal projection decoder**: The trunk is queried once; a linear + head projects from width to K=3 output timesteps directly (paper Eq. 10). + Faster than per-timestep trunk queries. + +- **Trunk output activation**: Set to `false` (linear output) to avoid + squashing the Hadamard product's dynamic range. Other DeepONet variants + use activated trunk output (`true`). + +- **Live-gradient rollout**: During the rollout training stage, predictions + are collected with live gradients through the entire chain, and a single + loss is computed on the concatenated trajectory. DDP gradient sync is + handled via `model.no_sync()` + manual AllReduce. + +- **BatchNorm freeze**: During live-gradient rollout, BatchNorm layers are + set to eval mode to prevent inplace running-stat updates from + invalidating the autograd graph across chained forward passes. + +## Files + +```text +tno_co2/ +├── README.md +├── train.sbatch +├── eval.sbatch +├── evaluate_pressure.py # r_dnorm_dP + dnorm_dP for test data +├── evaluate_saturation.py +└── conf/ + ├── model_config.yaml # TNO architecture + loss + ├── saturation_training_config.yaml # p=96, 90 TF + 40 rollout + └── pressure_training_config.yaml # p=128, 90 TF + 40 rollout +``` + +## Reference + +```bibtex +@article{diab2025tno, + title={Temporal Neural Operator for Modeling Time-Dependent + Physical Phenomena}, + author={Diab, Waleed and Al Kobaisi, Mohammed}, + journal={arXiv preprint arXiv:2504.20249}, + year={2025}, + url={https://arxiv.org/abs/2504.20249} +} +``` diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/model_config.yaml new file mode 100644 index 0000000000..b2f12071db --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/model_config.yaml @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# TNO — Model Configuration (CO2 Sequestration) +# ============================================================================= +# Temporal Neural Operator: branch1 (input fields) + branch2/t-branch +# (previous solution state) + trunk (time), all with U-Net processors. +# Autoregressive with temporal bundling (L=1, K=3). +# Temporal projection decoder: G maps latent p → K output timesteps directly. +# +# Reference: Diab, W. & Al Kobaisi, M. (2025). "Temporal Neural Operator +# for Modeling Time-Dependent Physical Phenomena." +# arXiv:2504.20249. https://arxiv.org/abs/2504.20249 +# +# Default width p=96 (saturation). +# Pressure overrides to p=128 in pressure_training_config.yaml. +# ============================================================================= + +arch: + dimensions: 3d + model: xdeeponet + + xdeeponet: + variant: tno + width: 96 + padding: 8 + + # Branch1: processes input fields (4 spatial + 5 scalar channels) + # Lifting → U-Net → ReLU output + branch1: + encoder: + type: linear + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 1 + num_conv_layers: 0 + modes1: 12 + modes2: 12 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: relu + internal_resolution: null + + # Branch2 (t-branch): processes previous solution state + # Lifting → U-Net → ReLU output + branch2: + encoder: + type: linear + activation_fn: relu + layers: + num_fourier_layers: 0 + num_unet_layers: 1 + num_conv_layers: 0 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: relu + + # Trunk: tanh FNN, narrow (40) but deep (14 layers), linear output. + # Original code: TimeNN(in_width=L, hidden_width=40, output_dim=width, num_layers=14) + trunk: + input_type: time + hidden_width: 40 + num_layers: 14 + activation_fn: tanh + output_activation: false + + # Decoder: ReLU → Linear → Linear(→K) + # Original code: fc_2(relu) → fc_3(linear) → fc_4(width→K) + # decoder_layers=1 creates: FCLayer(relu) + FCLayer(identity) + temporal_head + decoder_type: temporal_projection + decoder_width: 96 + decoder_layers: 1 + decoder_activation_fn: relu + +# Relative L2 loss (matches original TNO code: LpLoss) +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: false + weight: 0.0 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/pressure_training_config.yaml new file mode 100644 index 0000000000..426362876b --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/pressure_training_config.yaml @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TNO — Pressure Buildup (CO2), width p=128 +# Original code: width=128 (paper Appendix F states p=160, code uses 128) + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/tno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +# Override width from 96 to 128 for pressure +arch: + xdeeponet: + width: 128 + branch2: + encoder: + hidden_width: 128 + decoder_width: 128 + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + num_timesteps: 16 + +training: + batch_size: 4 + initial_lr: 0.0006 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: autoregressive + + autoregressive: + input_window: 1 + output_window: 3 + teacher_forcing_epochs: 90 + pushforward_epochs: 0 + rollout_epochs: 40 + max_unroll: 5 + noise_std: 0.0 + use_feedback_channel: false + lr_reset_factor: 1.0 + rollout_mode: live_gradients + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 2 + gamma: 0.92 + +logging: + use_mlflow: false + experiment_name: tno_co2_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/saturation_training_config.yaml new file mode 100644 index 0000000000..3996ef89e4 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/conf/saturation_training_config.yaml @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TNO — Gas Saturation (CO2), width p=96 + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/tno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: false + mask_channel: 0 + denormalize_fn: null + num_timesteps: 16 + +training: + batch_size: 4 + initial_lr: 0.0006 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: autoregressive + + autoregressive: + input_window: 1 + output_window: 3 + teacher_forcing_epochs: 90 + pushforward_epochs: 0 + rollout_epochs: 40 + max_unroll: 5 + noise_std: 0.0 + use_feedback_channel: false + lr_reset_factor: 1.0 + rollout_mode: live_gradients + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 2 + gamma: 0.92 + +logging: + use_mlflow: false + experiment_name: tno_co2_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_pressure.py b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_pressure.py new file mode 100644 index 0000000000..e9b3b0faaa --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_pressure.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 pressure buildup evaluation — matches the validated U-FNO evaluation algorithm. + +dnorm_dP is applied to prediction only. MRE and R² are computed with +pred in physical bar and target in raw (normalized) space, matching the +original paper evaluation. R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def dnorm_dP(dP): + """Denormalize pressure predictions from normalized to physical units.""" + dP = dP * 18.772821433027488 + dP = dP + 4.172939172019009 + return dP + + +def r_dnorm_dP(dP): + """Reverse-denormalize pressure targets from physical to normalized units.""" + return (dP - 4.172939172019009) / 18.772821433027488 + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def mean_relative_error(y_pred, y_true): + """Compute mean relative error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true) / (y_true.max() - y_true.min())) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 pressure model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "pressure" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mre_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + from training.ar_utils import ar_validate_full_rollout + + _is_tno = model_config.get("variant", "") == "tno" + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + if _is_tno: + y_norm = r_dnorm_dP(y_batch).to(device) + pred_batch = ar_validate_full_rollout( + model, x_batch, y_norm, L=1, K=3, is_tno=True + ) + else: + _has_b2 = model_config.get("variant", "") in ( + "mionet", + "fourier_mionet", + ) + fwd = {"x_branch2": x_batch[:, 0, 0, 0, :]} if _has_b2 else {} + pred_batch = model(x_batch, **fwd) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = dnorm_dP(pred_batch.cpu().numpy()) + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mre = mean_relative_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mre_list.append(np.mean(mre)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mre = np.mean(mre_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Pressure Buildup — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MRE: {overall_mre:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + if _is_tno: + y_norm = r_dnorm_dP(y).unsqueeze(0).to(device) + pred = ar_validate_full_rollout(model, x, y_norm, L=1, K=3, is_tno=True) + else: + pred = model(x) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = dnorm_dP(pred.squeeze(0).cpu().numpy()) + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True dP, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred dP (bar), t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"pressure_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_saturation.py b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_saturation.py new file mode 100644 index 0000000000..df518c2396 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/tno_co2/evaluate_saturation.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 gas saturation evaluation — matches the validated U-FNO evaluation algorithm. + +No denormalization. MPE computed within the plume region. +R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def mean_plume_error(y_pred, y_true): + """Compute mean plume error for saturation predictions.""" + mask = (y_pred != 0) & (y_true != 0) + y_pred_masked = y_pred[mask] + y_true_masked = y_true[mask] + if len(y_pred_masked) == 0: + return 0.0 + return np.mean(np.abs(y_pred_masked - y_true_masked)) + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 saturation model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "saturation" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mpe_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + from training.ar_utils import ar_validate_full_rollout + + _is_tno = model_config.get("variant", "") == "tno" + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + if _is_tno: + y_dev = y_batch.to(device) + pred_batch = ar_validate_full_rollout( + model, x_batch, y_dev, L=1, K=3, is_tno=True + ) + else: + _has_b2 = model_config.get("variant", "") in ( + "mionet", + "fourier_mionet", + ) + fwd = {"x_branch2": x_batch[:, 0, 0, 0, :]} if _has_b2 else {} + pred_batch = model(x_batch, **fwd) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = pred_batch.cpu().numpy() + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mpe = mean_plume_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mpe_list.append(np.mean(mpe)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mpe = np.mean(mpe_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Gas Saturation — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MPE: {overall_mpe:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + if _is_tno: + y_dev = y.unsqueeze(0).to(device) + pred = ar_validate_full_rollout( + model, x, y_dev, L=1, K=3, is_tno=True + ).squeeze(0) + else: + vfwd = {} + pred = model(x, **vfwd).squeeze(0) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = pred.squeeze(0).cpu().numpy() + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"saturation_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/README.md new file mode 100644 index 0000000000..f919fd65c3 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/README.md @@ -0,0 +1,109 @@ +# U-DeepONet for CO2 Sequestration + +Reproduces the main results from: + +> **U-DeepONet: U-Net enhanced deep operator network for** +> **geologic carbon sequestration** +> Diab, W. & Al Kobaisi, M. (2024). +> *Scientific Reports*, 14, 21298. +> [doi:10.1038/s41598-024-72393-0](https://doi.org/10.1038/s41598-024-72393-0) + +## Overview + +The U-DeepONet uses a DeepONet architecture with 3 U-Net +blocks in the branch network and a sinusoidal trunk for time +encoding. No Fourier layers are used, making it significantly +faster to train than U-FNO while matching or exceeding accuracy. + +| Setting | Saturation | Pressure | +|---------|-----------|----------| +| Width (f) | 64 | 96 | +| Branch | 3 × UNet2D + ReLU | Same | +| Trunk | 10-layer sin FNN | Same | +| Decoder | FC (f → 1) | Same | +| Learning rate | 0.0007 | 0.0006 | + +## Dataset + +Same CO2 sequestration dataset as the U-FNO example: + +- **Grid**: 96 × 200 × 24 time steps +- **Input channels**: 12 +- **Samples**: 4,500 train / 500 val / 500 test + +The dataset is publicly available at: + + +## Usage + +All commands from `neural_operator_factory/`: + +### Training + +```bash +# Gas saturation (width=64) +sbatch examples/udeeponet_co2/train.sbatch saturation_training_config + +# Pressure buildup (width=96) +sbatch examples/udeeponet_co2/train.sbatch pressure_training_config +``` + +### Evaluation + +```bash +sbatch examples/udeeponet_co2/eval.sbatch saturation +sbatch examples/udeeponet_co2/eval.sbatch pressure +``` + +## Results + +### Paper Reference (Table 4) + +| Variable | MPE/MRE | MAE | R² | +|----------|---------|-----|-----| +| Saturation | 0.0158 | 0.0146 | 0.985 | +| Pressure | 0.0072 | 0.64 | 0.994 | + +### NOF Reproduction (500 test samples, 8× GPU) + +| Variable | NOF MPE/MRE | Paper | NOF R² | Paper R² | +|----------|------------|-------|--------|----------| +| Saturation | 0.0195 | 0.0158 | 0.991 | 0.985 | +| Pressure | 0.0082 | 0.0072 | 0.992 | 0.994 | + +### Training Time (8× GPU DDP) + +| Variable | Epochs | Time | +|----------|--------|------| +| Saturation | 100 | ~14 min | +| Pressure | 140 | ~19 min | + +## Files + +```text +udeeponet_co2/ +├── README.md +├── train.sbatch +├── eval.sbatch +├── evaluate_pressure.py +├── evaluate_saturation.py +└── conf/ + ├── model_config.yaml # U-DeepONet (width=64) + ├── saturation_training_config.yaml # lr=0.0007 + └── pressure_training_config.yaml # lr=0.0006, width=96 +``` + +## Reference + +```bibtex +@article{diab2024udeeponet, + title={{U-DeepONet}: {U-Net} enhanced deep operator + network for geologic carbon sequestration}, + author={Diab, Waleed and Al Kobaisi, Mohammed}, + journal={Scientific Reports}, + volume={14}, + pages={21298}, + year={2024}, + publisher={Nature Publishing Group} +} +``` diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/model_config.yaml new file mode 100644 index 0000000000..ed0fb1ef3f --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/model_config.yaml @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-DeepONet — Model Configuration (CO2 Sequestration) +# ============================================================================= +# U-Net enhanced DeepONet: 3 U-Net blocks in the branch, sinusoidal +# trunk for time encoding. No Fourier layers. Full-mapping regime +# (predict entire trajectory in a single forward pass). +# +# Default width p=64 (saturation). +# Pressure overrides to p=96 in pressure_training_config.yaml. +# +# Reference: Diab & Al Kobaisi (2024), Sci. Rep. 14, 21298. +# https://doi.org/10.1038/s41598-024-72393-0 +# ============================================================================= + +arch: + dimensions: 3d + model: xdeeponet + + xdeeponet: + variant: u_deeponet + width: 64 + padding: 8 + + # Branch: 3 U-Net blocks in series (no Fourier layers) + branch1: + encoder: + type: linear + activation_fn: sin + layers: + num_fourier_layers: 0 + num_unet_layers: 3 + num_conv_layers: 0 + modes1: 12 + modes2: 12 + kernel_size: 3 + dropout: 0.0 + unet_impl: custom + activation_fn: sin + internal_resolution: null + + # Trunk: 10-layer sinusoidal FNN processing time only + # output_activation: true (default) — sin applied to output layer. + # U-DeepONet uses activated trunk output, unlike TNO which uses linear. + trunk: + input_type: time + hidden_width: 64 + num_layers: 10 + activation_fn: sin + output_activation: true + + # Decoder: shallow MLP projection (per-timestep trunk query → 1) + decoder_type: mlp + decoder_width: 64 + decoder_layers: 1 + decoder_activation_fn: relu + +# Relative L2 loss + spatial derivative regularization (dx) +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/pressure_training_config.yaml new file mode 100644 index 0000000000..2e293cc0c4 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/pressure_training_config.yaml @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-DeepONet — Pressure Buildup (CO2 Sequestration) +# ============================================================================= +# Width = 96 (overrides model_config.yaml default of 64). +# Expected test-set results (paper Table 4b): +# MRE = 0.0072, MAE = 0.64, R² = 0.994 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/udeeponet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +# Override width from 64 to 96, and UNet activation from sin to relu +arch: + xdeeponet: + width: 96 + branch1: + encoder: + activation_fn: relu + layers: + activation_fn: relu + trunk: + hidden_width: 96 + decoder_width: 96 + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.0006 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: udeeponet_co2_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/saturation_training_config.yaml new file mode 100644 index 0000000000..c1f540aa68 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/conf/saturation_training_config.yaml @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-DeepONet — Gas Saturation (CO2 Sequestration) +# ============================================================================= +# Width = 64 (from model_config.yaml, no override needed). +# Expected test-set results (paper Table 4a): +# MPE = 0.0158, MAE = 0.0146, R² = 0.985 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/udeeponet_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.0007 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: udeeponet_co2_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_pressure.py b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_pressure.py new file mode 100644 index 0000000000..a7e217e355 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_pressure.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 pressure buildup evaluation — matches the validated U-FNO evaluation algorithm. + +dnorm_dP is applied to prediction only. MRE and R² are computed with +pred in physical bar and target in raw (normalized) space, matching the +original paper evaluation. R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def dnorm_dP(dP): + """Denormalize pressure predictions from normalized to physical units.""" + dP = dP * 18.772821433027488 + dP = dP + 4.172939172019009 + return dP + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def mean_relative_error(y_pred, y_true): + """Compute mean relative error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true) / (y_true.max() - y_true.min())) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 pressure model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "pressure" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mre_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + pred_batch = model(x_batch) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = dnorm_dP(pred_batch.cpu().numpy()) + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mre = mean_relative_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mre_list.append(np.mean(mre)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mre = np.mean(mre_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Pressure Buildup — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MRE: {overall_mre:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + pred = model(x) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = dnorm_dP(pred.squeeze(0).cpu().numpy()) + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True dP, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred dP (bar), t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"pressure_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_saturation.py b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_saturation.py new file mode 100644 index 0000000000..355c45b8ae --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/udeeponet_co2/evaluate_saturation.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 gas saturation evaluation — matches the validated U-FNO evaluation algorithm. + +No denormalization. MPE computed within the plume region. +R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def mean_plume_error(y_pred, y_true): + """Compute mean plume error for saturation predictions.""" + mask = (y_pred != 0) & (y_true != 0) + y_pred_masked = y_pred[mask] + y_true_masked = y_true[mask] + if len(y_pred_masked) == 0: + return 0.0 + return np.mean(np.abs(y_pred_masked - y_true_masked)) + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 saturation model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "saturation" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mpe_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + pred_batch = model(x_batch) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = pred_batch.cpu().numpy() + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mpe = mean_plume_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mpe_list.append(np.mean(mpe)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mpe = np.mean(mpe_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Gas Saturation — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MPE: {overall_mpe:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + pred = model(x) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = pred.squeeze(0).cpu().numpy() + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"saturation_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/README.md b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/README.md new file mode 100644 index 0000000000..9b74deed56 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/README.md @@ -0,0 +1,172 @@ +# U-FNO for CO2 Sequestration + +Reproduces the main results from: + +> **U-FNO — An enhanced Fourier neural operator-based** +> **deep-learning model for multiphase flow** +> Wen, G., Li, Z., Azizzadenesheli, K., Anandkumar, A., & Benson, S. M. (2022). +> *Advances in Water Resources*, 163, 104180. +> [arXiv:2109.03697](https://arxiv.org/abs/2109.03697) + +## Overview + +This example trains three FNO-based architectures on the CO2 sequestration dataset +for both **gas saturation** and **pressure buildup** prediction: + +| Architecture | FNO Layers | U-Net Layers | Conv Layers | ~Parameters | +|--------------|-----------|-------------|-------------|-------------| +| **FNO** | 6 | 0 | 0 | 31M | +| **Conv-FNO** | 3 | 0 | 3 | 31M | +| **U-FNO** | 3 | 3 | 0 | 33M | + +All models use width=36, Fourier modes (10, 10, 10), and a 2-layer MLP decoder (36→128→1). + +## Dataset + +- **Grid**: 96 (vertical) × 200 (radial) × 24 (time steps) +- **Input channels**: 12 (permeability, porosity, injection config, scalar params, grid widths) +- **Output**: Gas saturation `sg` or pressure buildup `dP` (separate models) +- **Samples**: 4,500 train / 500 val / 500 test + +The dataset is publicly available at: + + +## Usage + +All commands are run from the `neural_operator_factory/` directory. + +### Training (SLURM) + +```bash +# U-FNO on gas saturation (flagship result) +sbatch examples/ufno_co2/train.sbatch U-FNO saturation_training_config + +# U-FNO on pressure buildup +sbatch examples/ufno_co2/train.sbatch U-FNO pressure_training_config + +# FNO baseline on saturation +sbatch examples/ufno_co2/train.sbatch FNO saturation_training_config + +# FNO on pressure +sbatch examples/ufno_co2/train.sbatch FNO pressure_training_config + +# Conv-FNO on saturation +sbatch examples/ufno_co2/train.sbatch Conv-FNO saturation_training_config + +# Conv-FNO on pressure +sbatch examples/ufno_co2/train.sbatch Conv-FNO pressure_training_config + +# Run all 6 experiments +for arch in FNO Conv-FNO U-FNO; do + for var in saturation_training_config pressure_training_config; do + sbatch examples/ufno_co2/train.sbatch $arch $var + done +done +``` + +### Evaluation (SLURM) + +```bash +# Evaluate best gas saturation model (auto-detects checkpoint) +sbatch examples/ufno_co2/eval.sbatch saturation + +# Evaluate best pressure buildup model +sbatch examples/ufno_co2/eval.sbatch pressure + +# Explicit checkpoint +CHECKPOINT=checkpoints/best_model_saturation_ufno_custom.pth \ + sbatch examples/ufno_co2/eval.sbatch saturation +``` + +## Results + +### Paper Reference (Table I.14) + +#### Gas Saturation — Test Set + +| Model | MPE (mean) | MPE (std) | R² plume | +|-------|-----------|----------|----------| +| FNO | 0.0276 | 0.0160 | 0.961 | +| Conv-FNO | 0.0224 | 0.0125 | 0.970 | +| **U-FNO** | **0.0161** | **0.0105** | **0.981** | + +#### Pressure Buildup — Test Set + +| Model | MRE (mean) | MRE (std) | R² | +|-------|-----------|----------|-----| +| FNO | 0.0082 | 0.0052 | 0.989 | +| Conv-FNO | 0.0078 | 0.0048 | 0.990 | +| **U-FNO** | **0.0068** | **0.0045** | **0.992** | + +### NOF Reproduction (this example, 500 test samples) + +#### Gas Saturation + +| Model | NOF MPE | Paper MPE | NOF R² | Paper R² | +|-------|---------|-----------|--------|----------| +| FNO | 0.0303 | 0.0276 | 0.984 | 0.961 | +| Conv-FNO | 0.0234 | 0.0224 | 0.988 | 0.970 | +| **U-FNO** | **0.0182** | **0.0161** | **0.993** | **0.981** | + +#### Pressure Buildup + +| Model | NOF MRE | Paper MRE | NOF R² | Paper R² | +|-------|---------|-----------|--------|----------| +| FNO | 0.0089 | 0.0082 | 0.976 | 0.989 | +| Conv-FNO | 0.0087 | 0.0078 | 0.984 | 0.990 | +| **U-FNO** | **0.0068** | **0.0068** | **0.991** | **0.992** | + +### Training Time (8× GPU DDP) + +| Model | Saturation (100 ep) | Pressure (140 ep) | +|-------|--------------------|--------------------| +| FNO | ~39 min | ~54 min | +| Conv-FNO | ~55 min | ~78 min | +| U-FNO | ~72 min | ~102 min | + +## Loss Function + +Matches the paper's Equation 12: relative L2 loss + +radial derivative regularization. +Active cell masking is applied +(cells outside the reservoir are zero-padded). + +## Files + +```text +ufno_co2/ +├── README.md +├── train.sbatch # SLURM training (8 GPU) +├── eval.sbatch # SLURM evaluation (1 GPU) +├── evaluate_pressure.py # Pressure eval with dnorm_dP +├── evaluate_saturation.py # Saturation eval with MPE +└── conf/ + ├── FNO/ # Pure FNO (6 Fourier layers) + │ ├── model_config.yaml + │ ├── saturation_training_config.yaml + │ └── pressure_training_config.yaml + ├── Conv-FNO/ # Conv-FNO (3 Fourier + 3 Conv) + │ ├── model_config.yaml + │ ├── saturation_training_config.yaml + │ └── pressure_training_config.yaml + └── U-FNO/ # U-FNO (3 Fourier + 3 U-Net) + ├── model_config.yaml + ├── saturation_training_config.yaml + └── pressure_training_config.yaml +``` + +## Reference + +```bibtex +@article{wen2022ufno, + title={{U-FNO--An enhanced Fourier neural operator-based + deep-learning model for multiphase flow}}, + author={Wen, Gege and Li, Zongyi and Azizzadenesheli, + Kamyar and Anandkumar, Anima and Benson, Sally M}, + journal={Advances in Water Resources}, + volume={163}, + pages={104180}, + year={2022}, + publisher={Elsevier} +} +``` diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/model_config.yaml new file mode 100644 index 0000000000..210681b228 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/model_config.yaml @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Conv-FNO — Model Configuration (CO2 Sequestration) +# ============================================================================= +# 3 Fourier layers + 3 Conv-enhanced layers (Conv3d skip connections). +# Reference: Wen et al. (2022), arXiv:2109.03697, Table G.12 — ~31M params. +# ============================================================================= + +arch: + dimensions: 3d + model: xfno + + xfno: + out_channels: 1 + width: 36 + padding: 8 + activation_fn: relu + + modes1: 10 + modes2: 10 + modes3: 10 + + num_fno_layers: 3 + num_unet_layers: 0 + num_conv_layers: 3 + + lifting_type: mlp + lifting_layers: 1 + lifting_width: 36 + + decoder_type: mlp + decoder_layers: 1 + decoder_width: 128 + decoder_activation_fn: relu + + unet_type: custom + unet_kernel_size: 3 + unet_dropout: 0.0 + conv_kernel_size: 3 + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/pressure_training_config.yaml new file mode 100644 index 0000000000..fecf350259 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/pressure_training_config.yaml @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Conv-FNO — Pressure Buildup (CO2 Sequestration) +# ============================================================================= +# Expected test-set results (paper Table I.14b): +# MRE = 0.0078 ± 0.0048, R² = 0.990 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_convfno_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/saturation_training_config.yaml new file mode 100644 index 0000000000..a2c3a27e64 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/Conv-FNO/saturation_training_config.yaml @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Conv-FNO — Gas Saturation (CO2 Sequestration) +# ============================================================================= +# Expected test-set results (paper Table I.14a): +# MPE = 0.0224 ± 0.0125, R²_plume = 0.970 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_convfno_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/model_config.yaml new file mode 100644 index 0000000000..802497285d --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/model_config.yaml @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# FNO — Model Configuration (CO2 Sequestration) +# ============================================================================= +# Pure FNO: 6 Fourier layers, no U-Net or Conv skip connections. +# Reference: Wen et al. (2022), arXiv:2109.03697, Table F.11 — ~31M params. +# ============================================================================= + +arch: + dimensions: 3d + model: xfno + + xfno: + out_channels: 1 + width: 36 + padding: 8 + activation_fn: relu + + modes1: 10 + modes2: 10 + modes3: 10 + + num_fno_layers: 6 + num_unet_layers: 0 + num_conv_layers: 0 + + lifting_type: mlp + lifting_layers: 1 + lifting_width: 36 + + decoder_type: mlp + decoder_layers: 1 + decoder_width: 128 + decoder_activation_fn: relu + + unet_type: custom + unet_kernel_size: 3 + unet_dropout: 0.0 + conv_kernel_size: 3 + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/pressure_training_config.yaml new file mode 100644 index 0000000000..2de163c725 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/pressure_training_config.yaml @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# FNO — Pressure Buildup (CO2 Sequestration) +# ============================================================================= +# Expected test-set results (paper Table I.14b): +# MRE = 0.0082 ± 0.0052, R² = 0.989 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_fno_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/saturation_training_config.yaml new file mode 100644 index 0000000000..7d35558342 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/FNO/saturation_training_config.yaml @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# FNO — Gas Saturation (CO2 Sequestration) +# ============================================================================= +# Expected test-set results (paper Table I.14a): +# MPE = 0.0276 ± 0.0160, R²_plume = 0.961 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_fno_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/model_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/model_config.yaml new file mode 100644 index 0000000000..dba5fe0a67 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/model_config.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-FNO — Model Configuration (CO2 Sequestration) +# ============================================================================= +# 3 Fourier layers + 3 U-Fourier layers (U-Net skip connections). +# Flagship architecture from the paper. +# Reference: Wen et al. (2022), arXiv:2109.03697, Table H.13 — ~33M params. +# ============================================================================= + +arch: + dimensions: 3d + model: xfno + + xfno: + out_channels: 1 + width: 36 + padding: 8 + activation_fn: relu + + modes1: 10 + modes2: 10 + modes3: 10 + + num_fno_layers: 3 + num_unet_layers: 3 + num_conv_layers: 0 + + lifting_type: mlp + lifting_layers: 1 + lifting_width: 36 + + decoder_type: mlp + decoder_layers: 2 + decoder_width: 128 + decoder_activation_fn: relu + + unet_type: custom + unet_kernel_size: 3 + unet_dropout: 0.0 + conv_kernel_size: 3 + +loss: + types: [relative_l2] + weights: [1.0] + + derivative: + enabled: true + weight: 0.5 + dims: [dx] + metric: null + + physics: + mass_conservation: + enabled: false + weight: 0.0 + use_cell_volumes: false diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/pressure_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/pressure_training_config.yaml new file mode 100644 index 0000000000..9fd6030820 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/pressure_training_config.yaml @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-FNO — Pressure Buildup (CO2 Sequestration) +# ============================================================================= +# Flagship architecture from the paper. +# Expected test-set results (paper Table I.14b): +# MRE = 0.0068 ± 0.0045, R² = 0.992 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: pressure + input_file: null + output_file: null + normalize: false + num_workers: 1 + val_metric: mre + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 140 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_ufno_pressure + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/saturation_training_config.yaml b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/saturation_training_config.yaml new file mode 100644 index 0000000000..a25a91d04a --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/conf/U-FNO/saturation_training_config.yaml @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# U-FNO — Gas Saturation (CO2 Sequestration) +# ============================================================================= +# Flagship architecture from the paper. +# Expected test-set results (paper Table I.14a): +# MPE = 0.0161 ± 0.0105, R²_plume = 0.981 +# ============================================================================= + +defaults: + - model_config + - _self_ + +hydra: + job: + chdir: False + run: + dir: ./outputs/ufno_co2/${now:%Y-%m-%d}/${now:%H-%M-%S} + +data: + data_path: /data/co2 + variable: saturation + input_file: null + output_file: null + normalize: false + num_workers: 4 + val_metric: mpe + mask_enabled: true + mask_channel: 0 + denormalize_fn: null + +training: + batch_size: 4 + epochs: 100 + initial_lr: 0.001 + checkpoint_dir: ./checkpoints + validate_freq: 5 + use_amp: false + resume_from_checkpoint: null + regime: full_mapping + +optimizer: + weight_decay: 0.0001 + betas: [0.9, 0.999] + eps: 1.0e-8 + +scheduler: + type: step + step_size: 10 + gamma: 0.85 + +logging: + use_mlflow: false + experiment_name: ufno_co2_ufno_saturation + +seed: 42 + +compute: + benchmark: true + deterministic: true diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_pressure.py b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_pressure.py new file mode 100644 index 0000000000..a7e217e355 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_pressure.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 pressure buildup evaluation — matches the validated U-FNO evaluation algorithm. + +dnorm_dP is applied to prediction only. MRE and R² are computed with +pred in physical bar and target in raw (normalized) space, matching the +original paper evaluation. R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def dnorm_dP(dP): + """Denormalize pressure predictions from normalized to physical units.""" + dP = dP * 18.772821433027488 + dP = dP + 4.172939172019009 + return dP + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def mean_relative_error(y_pred, y_true): + """Compute mean relative error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true) / (y_true.max() - y_true.min())) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 pressure model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "pressure" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mre_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + pred_batch = model(x_batch) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = dnorm_dP(pred_batch.cpu().numpy()) + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mre = mean_relative_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mre_list.append(np.mean(mre)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mre = np.mean(mre_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Pressure Buildup — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MRE: {overall_mre:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + pred = model(x) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = dnorm_dP(pred.squeeze(0).cpu().numpy()) + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True dP, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred dP (bar), t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"pressure_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_saturation.py b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_saturation.py new file mode 100644 index 0000000000..355c45b8ae --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/examples/ufno_co2/evaluate_saturation.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2 gas saturation evaluation — matches the validated U-FNO evaluation algorithm. + +No denormalization. MPE computed within the plume region. +R² is computed globally across all samples. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) + +import matplotlib +import numpy as np +import torch + +matplotlib.use("Agg") +import argparse + +import matplotlib.pyplot as plt +from data.dataloader import ReservoirDataset +from data.validation import print_validation_summary, validate_sample_dimensions +from utils.checkpoint import build_model_from_config + + +def mean_plume_error(y_pred, y_true): + """Compute mean plume error for saturation predictions.""" + mask = (y_pred != 0) & (y_true != 0) + y_pred_masked = y_pred[mask] + y_true_masked = y_true[mask] + if len(y_pred_masked) == 0: + return 0.0 + return np.mean(np.abs(y_pred_masked - y_true_masked)) + + +def mean_absolute_error(y_pred, y_true): + """Compute mean absolute error between predictions and targets.""" + return np.mean(np.abs(y_pred - y_true)) + + +def main(): + """Run evaluation on the test set and print metrics.""" + parser = argparse.ArgumentParser( + description="Evaluate CO2 saturation model (U-FNO example)" + ) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument( + "--data_path", + type=str, + default="/data/co2", + ) + parser.add_argument("--batch_size", type=int, default=6) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + variable = "saturation" + data_path = Path(args.data_path) + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint) + else: + checkpoint_dir = Path("checkpoints") + checkpoints = list(checkpoint_dir.glob(f"best_model_{variable}_*.pth")) + if not checkpoints: + raise FileNotFoundError( + f"No checkpoints found for {variable}. Specify --checkpoint" + ) + checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime) + print(f"Auto-detected checkpoint: {checkpoint_path}") + + print(f"\nLoading test dataset for {variable}...") + test_dataset = ReservoirDataset( + data_path=data_path, mode="test", variable=variable, normalize=False + ) + print(f"Test dataset size: {len(test_dataset)}") + + sample_input, sample_target = test_dataset[0] + validate_sample_dimensions(sample_input, sample_target, variable) + print_validation_summary( + tuple(sample_input.shape), + tuple(sample_target.shape), + variable, + is_batch=False, + ) + + print(f"\nLoading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + model_config = checkpoint["model_config"] + print(f"Model: {model_config.get('model_arch_name', 'unknown')}") + print(f"Epoch: {checkpoint['epoch']}") + print(f"Val loss: {checkpoint['val_loss']:.6f}") + + model, model_arch_name = build_model_from_config(model_config, device=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + print(f"Model loaded: {model_arch_name}") + + print("\n" + "=" * 80) + print(f"Evaluating on {len(test_dataset)} test samples...") + print("=" * 80) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + mpe_list = [] + mae_list = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + for batch_idx, (x_batch, y_batch) in enumerate(test_loader): + x_batch = x_batch.to(device) + pred_batch = model(x_batch) + + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + pred_np = pred_batch.cpu().numpy() + + for rr in range(x_np.shape[0]): + mask = x_np[rr, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + y_plot = y_np[rr][mask].reshape((thickness, 200, 24, -1)) + pred_plot = pred_np[rr][mask].reshape((thickness, 200, 24, -1)) + + mpe = mean_plume_error(pred_plot, y_plot) + mae = mean_absolute_error(pred_plot, y_plot) + + mpe_list.append(np.mean(mpe)) + mae_list.append(mae) + y_true_all.append(y_plot) + y_pred_all.append(pred_plot) + + if (batch_idx + 1) % 10 == 0: + n_done = (batch_idx + 1) * args.batch_size + print(f" {n_done}/{len(test_dataset)} samples...") + + print(f" {len(test_dataset)}/{len(test_dataset)} samples - done.") + + overall_mpe = np.mean(mpe_list) + overall_mae = np.mean(mae_list) + + y_true_all = np.concatenate(y_true_all, axis=0) + y_pred_all = np.concatenate(y_pred_all, axis=0) + ss_res = np.sum((y_true_all - y_pred_all) ** 2) + ss_tot = np.sum((y_true_all - y_true_all.mean()) ** 2) + r2_score = 1 - (ss_res / ss_tot) + + print("\n" + "=" * 80) + print(f"Gas Saturation — Test Set ({len(test_dataset)} samples)") + print("=" * 80) + print(f"MPE: {overall_mpe:.4f}") + print(f"MAE: {overall_mae:.4f}") + print(f"R2: {r2_score:.4f}") + print("=" * 80) + + # Visualization (sample 0) + print("\nCreating visualization for sample 0...") + x, y = test_dataset[0] + x = x.unsqueeze(0).to(device) + with torch.no_grad(): + pred = model(x) + + x_plot = x.cpu().numpy() + y_plot = y.numpy() + pred_plot = pred.squeeze(0).cpu().numpy() + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + times = np.cumsum(np.power(1.421245, range(24))) + time_labels = [ + f"{int(t)} d" if t < 365 else f"{round(int(t) / 365, 1)} y" for t in times + ] + + def pcolor(data): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], + Y[:thickness, :], + np.flipud(data), + shading="auto", + ) + + t_lst = [14, 20, 23] + plt.figure(figsize=(15, 12)) + for j, t in enumerate(t_lst): + plt.subplot(3, 3, j + 1) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"True Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 4) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"Pred Sg, t={time_labels[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + plt.subplot(3, 3, j + 7) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"Error, t={time_labels[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + output_dir = Path("visualizations") + output_dir.mkdir(exist_ok=True) + out_file = output_dir / f"saturation_{model_arch_name}.png" + plt.savefig(out_file, dpi=150, bbox_inches="tight") + print(f"Saved: {out_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/models/__init__.py b/examples/reservoir_simulation/neural_operator_factory/models/__init__.py new file mode 100644 index 0000000000..e23dff0534 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/models/__init__.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Neural operator model architectures. + +Available models: +- FNO variants: UFNO, UFNONet, FNO4D, FNO4DNet +- DeepONet variants: DeepONet, DeepONet3D, DeepONetWrapper, DeepONet3DWrapper +- U-Net: UNet2D, UNet3D, PhysicsNemoUNet2D, PhysicsNemoUNet3D, StandaloneUNet +""" + +from models.physicsnemo_unet import ( + PhysicsNemoUNet2D, + PhysicsNemoUNet3D, + StandaloneUNet, +) +from models.unet import UNet2D, UNet3D +from models.xdeeponet import ( + DeepONet, + DeepONet3D, + DeepONet3DWrapper, + DeepONetWrapper, + MLPBranch, + SpatialBranch, + SpatialBranch3D, + TrunkNet, +) +from models.xfno import FNO4D, UFNO, FNO4DNet, UFNONet + +__all__ = [ + # FNO + "UFNO", + "UFNONet", + "FNO4D", + "FNO4DNet", + # DeepONet + "TrunkNet", + "MLPBranch", + "SpatialBranch", + "SpatialBranch3D", + "DeepONet", + "DeepONet3D", + "DeepONetWrapper", + "DeepONet3DWrapper", + # U-Net + "UNet2D", + "UNet3D", + "PhysicsNemoUNet2D", + "PhysicsNemoUNet3D", + "StandaloneUNet", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/models/physicsnemo_unet.py b/examples/reservoir_simulation/neural_operator_factory/models/physicsnemo_unet.py new file mode 100644 index 0000000000..dc16030cc0 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/models/physicsnemo_unet.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PhysicsNemo U-Net wrappers for 2D and 3D data. + +This module provides wrappers around PhysicsNemo's 3D U-Net for: +- 2D spatial data (H × W) - used in DeepONet branch network +- 3D spatiotemporal data (H × W × T) - used in U-FNO and standalone models + +PhysicsNemo's UNet is natively 3D, so the 2D wrapper adds/removes a dummy +temporal dimension to enable 2D processing. +""" + +from typing import List, Optional + +import torch.nn as nn +from torch import Tensor + +from physicsnemo.models.unet import UNet as PhysicsNemoUNet + +# ============================================================================== +# PhysicsNemo UNet Wrappers (2D and 3D) +# ============================================================================== + + +class PhysicsNemoUNet2D(nn.Module): + """Wrapper to use PhysicsNemo's 3D UNet for 2D spatial data. + + This wrapper adds a dummy temporal dimension (T=1) to use PhysicsNemo's + 3D UNet architecture for 2D spatial processing (H × W). + + Use this for models that process spatial slices independently, + such as DeepONet's branch network. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + kernel_size : int + Convolution kernel size + model_depth : int + Depth of the U-Net (number of downsampling levels) + feature_map_channels : list + Number of channels at each depth level + **kwargs + Additional arguments passed to PhysicsNemo's UNet + + Example + ------- + >>> unet = PhysicsNemoUNet2D(in_channels=64, out_channels=64) + >>> x = torch.randn(4, 64, 104, 200) # (B, C, H, W) + >>> y = unet(x) # (B, 64, H, W) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: Optional[List[int]] = None, + **kwargs, + ): + super().__init__() + + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + + self.unet = PhysicsNemoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kwargs.get("stride", 1), + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=kwargs.get("num_conv_blocks", 1), + conv_activation=kwargs.get("conv_activation", "leaky_relu"), + conv_transpose_activation=kwargs.get( + "conv_transpose_activation", "leaky_relu" + ), + padding=kwargs.get("padding", kernel_size // 2), + padding_mode=kwargs.get("padding_mode", "zeros"), + pooling_type=kwargs.get("pooling_type", "MaxPool3d"), + pool_size=kwargs.get("pool_size", 2), + normalization=kwargs.get("normalization", "batchnorm"), + normalization_args=kwargs.get("normalization_args", None), + use_attn_gate=kwargs.get("use_attn_gate", False), + gradient_checkpointing=kwargs.get("gradient_checkpointing", False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Parameters + ---------- + x : Tensor + Input of shape (B, C, H, W) + + Returns + ------- + Tensor + Output of shape (B, C, H, W) + """ + # Add dummy temporal dimension: (B, C, H, W) -> (B, C, H, W, 1) + x = x.unsqueeze(-1) + + # Forward through 3D UNet + x = self.unet(x) + + # Remove temporal dimension: (B, C, H, W, 1) -> (B, C, H, W) + x = x.squeeze(-1) + + return x + + +class PhysicsNemoUNet3D(nn.Module): + """Wrapper for PhysicsNemo's 3D UNet for spatiotemporal data. + + This is a thin wrapper around PhysicsNemo's native 3D UNet for + processing spatiotemporal data (H × W × T). + + Use this for models that process full spatiotemporal volumes, + such as U-FNO or standalone U-Net models. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + kernel_size : int + Convolution kernel size + model_depth : int + Depth of the U-Net (number of downsampling levels) + feature_map_channels : list + Number of channels at each depth level + **kwargs + Additional arguments passed to PhysicsNemo's UNet + + Example + ------- + >>> unet = PhysicsNemoUNet3D(in_channels=12, out_channels=1) + >>> x = torch.randn(4, 12, 104, 200, 24) # (B, C, H, W, T) + >>> y = unet(x) # (B, 1, H, W, T) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: Optional[List[int]] = None, + **kwargs, + ): + super().__init__() + + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + + self.unet = PhysicsNemoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kwargs.get("stride", 1), + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=kwargs.get("num_conv_blocks", 1), + conv_activation=kwargs.get("conv_activation", "leaky_relu"), + conv_transpose_activation=kwargs.get( + "conv_transpose_activation", "leaky_relu" + ), + padding=kwargs.get("padding", kernel_size // 2), + padding_mode=kwargs.get("padding_mode", "zeros"), + pooling_type=kwargs.get("pooling_type", "MaxPool3d"), + pool_size=kwargs.get("pool_size", 2), + normalization=kwargs.get("normalization", "batchnorm"), + normalization_args=kwargs.get("normalization_args", None), + use_attn_gate=kwargs.get("use_attn_gate", False), + gradient_checkpointing=kwargs.get("gradient_checkpointing", False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Parameters + ---------- + x : Tensor + Input of shape (B, C, H, W, T) + + Returns + ------- + Tensor + Output of shape (B, C, H, W, T) + """ + return self.unet(x) + + +# ============================================================================== +# Standalone U-Net for baseline comparisons +# ============================================================================== + + +class StandaloneUNet(nn.Module): + """Standalone U-Net wrapper for spatiotemporal prediction (no Fourier layers). + + This provides a pure U-Net architecture without any spectral convolution layers, + useful for comparison and baseline experiments. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + unet_type : str + Type of UNet: "custom" (UNet3D) or "physicsnemo" (PhysicsNemo's UNet) + **unet_kwargs + Additional arguments passed to the UNet constructor + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + unet_type: str = "custom", + **unet_kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.unet_type = unet_type.lower() + + # Build U-Net + if self.unet_type == "custom": + raise ValueError( + "Custom UNet3D is only designed for use within U-FNO (where input_channels == output_channels). " + "For standalone U-Net models, please use unet_type='physicsnemo' instead." + ) + elif self.unet_type == "physicsnemo": + self.unet = PhysicsNemoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=unet_kwargs.get("kernel_size", 3), + stride=unet_kwargs.get("stride", 1), + model_depth=unet_kwargs.get("model_depth", 3), + feature_map_channels=unet_kwargs.get( + "feature_map_channels", [36, 36, 36] + ), + num_conv_blocks=unet_kwargs.get("num_conv_blocks", 1), + conv_activation=unet_kwargs.get("conv_activation", "relu"), + conv_transpose_activation=unet_kwargs.get( + "conv_transpose_activation", "relu" + ), + padding=unet_kwargs.get("padding", 1), + padding_mode=unet_kwargs.get("padding_mode", "zeros"), + pooling_type=unet_kwargs.get("pooling_type", "MaxPool3d"), + pool_size=unet_kwargs.get("pool_size", 2), + normalization=unet_kwargs.get("normalization", "batchnorm"), + normalization_args=unet_kwargs.get("normalization_args", None), + use_attn_gate=unet_kwargs.get("use_attn_gate", False), + gradient_checkpointing=unet_kwargs.get("gradient_checkpointing", False), + ) + else: + raise ValueError( + f"Unknown unet_type: {self.unet_type}. Use 'custom' or 'physicsnemo'" + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch, height, width, time, channels) + + Returns + ------- + Tensor + Output tensor of shape (batch, height, width, time) + """ + # Input: (B, H, W, T, C) + # Permute to (B, C, H, W, T) for 3D convolution + x = x.permute(0, 4, 1, 2, 3) + + # U-Net forward pass + x = self.unet(x) + + # Permute back: (B, out_channels, H, W, T) -> (B, H, W, T, out_channels) + x = x.permute(0, 2, 3, 4, 1) + + # Squeeze out channel dimension: (B, H, W, T, 1) -> (B, H, W, T) + return x.squeeze(-1) + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/examples/reservoir_simulation/neural_operator_factory/models/unet.py b/examples/reservoir_simulation/neural_operator_factory/models/unet.py new file mode 100644 index 0000000000..beb0b8ecd5 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/models/unet.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +U-Net Modules for spatial and spatiotemporal data. + +This module implements U-Net architectures for multi-scale feature extraction: +- UNet2D: For 2D spatial data (Height × Width) +- UNet3D: For 3D spatiotemporal data (Height × Width × Time) + +These can be used as standalone architectures or as components in hybrid models +like U-FNO and U-DeepONet. + +Architecture: +- 3 downsampling layers (conv with stride=2) +- 3 upsampling layers (transposed conv) +- Skip connections at each level +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class UNet2D(nn.Module): + """2D U-Net for spatial data (H × W).""" + + def __init__( + self, + input_channels: int, + output_channels: int, + kernel_size: int = 3, + dropout_rate: float = 0.0, + ): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + # Encoder + self.conv1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv2 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv2_1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + self.conv3 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv3_1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + + # Decoder + self.deconv2 = self._deconv_block(input_channels, output_channels) + self.deconv1 = self._deconv_block(input_channels * 2, output_channels) + self.deconv0 = self._deconv_block(input_channels * 2, output_channels) + + # Output + self.output_layer = self._output_block( + input_channels * 2, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + + def _conv_block( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + dropout_rate: float, + ) -> nn.Module: + """2D convolutional block.""" + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(0.1), + nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), + ) + + def _deconv_block(self, in_channels: int, out_channels: int) -> nn.Module: + """2D transposed convolutional block.""" + return nn.Sequential( + nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ), + nn.LeakyReLU(0.1), + ) + + def _output_block( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + dropout_rate: float, + ) -> nn.Module: + """Output layer.""" + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. Input: (batch, channels, H, W)""" + # Validate dimensions (must be divisible by 8) + h, w = x.shape[2], x.shape[3] + if h % 8 != 0 or w % 8 != 0: + raise ValueError( + f"Input dimensions ({h}, {w}) must be divisible by 8. Got shape: {x.shape}" + ) + + # Encoder + out_conv1 = self.conv1(x) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + + # Decoder with skip connections + out_deconv2 = self.deconv2(out_conv3) + if out_deconv2.shape[2:] != out_conv2.shape[2:]: + out_deconv2 = F.interpolate( + out_deconv2, + size=out_conv2.shape[2:], + mode="bilinear", + align_corners=False, + ) + concat2 = torch.cat((out_conv2, out_deconv2), dim=1) + + out_deconv1 = self.deconv1(concat2) + if out_deconv1.shape[2:] != out_conv1.shape[2:]: + out_deconv1 = F.interpolate( + out_deconv1, + size=out_conv1.shape[2:], + mode="bilinear", + align_corners=False, + ) + concat1 = torch.cat((out_conv1, out_deconv1), dim=1) + + out_deconv0 = self.deconv0(concat1) + if out_deconv0.shape[2:] != x.shape[2:]: + out_deconv0 = F.interpolate( + out_deconv0, size=x.shape[2:], mode="bilinear", align_corners=False + ) + concat0 = torch.cat((x, out_deconv0), dim=1) + + out = self.output_layer(concat0) + + return out + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class UNet3D(nn.Module): + """3D U-Net for spatiotemporal data (H × W × T).""" + + def __init__( + self, + input_channels: int, + output_channels: int, + kernel_size: int = 3, + dropout_rate: float = 0.0, + ): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + # Encoder + self.conv1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv2 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv2_1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + self.conv3 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=2, + dropout_rate=dropout_rate, + ) + self.conv3_1 = self._conv_block( + input_channels, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + + # Decoder + self.deconv2 = self._deconv_block(input_channels, output_channels) + self.deconv1 = self._deconv_block(input_channels * 2, output_channels) + self.deconv0 = self._deconv_block(input_channels * 2, output_channels) + + # Output + self.output_layer = self._output_block( + input_channels * 2, + output_channels, + kernel_size=kernel_size, + stride=1, + dropout_rate=dropout_rate, + ) + + def _conv_block( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + dropout_rate: float, + ) -> nn.Module: + """3D convolutional block.""" + return nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + bias=False, + ), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(0.1), + nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), + ) + + def _deconv_block(self, in_channels: int, out_channels: int) -> nn.Module: + """3D transposed convolutional block.""" + return nn.Sequential( + nn.ConvTranspose3d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ), + nn.LeakyReLU(0.1), + ) + + def _output_block( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + dropout_rate: float, + ) -> nn.Module: + """Output layer.""" + return nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. Input: (batch, channels, H, W, T)""" + # Validate dimensions (must be divisible by 8) + h, w, t = x.shape[2], x.shape[3], x.shape[4] + if h % 8 != 0 or w % 8 != 0 or t % 8 != 0: + raise ValueError( + f"Input dimensions ({h}, {w}, {t}) must be divisible by 8. Got shape: {x.shape}" + ) + + # Encoder + out_conv1 = self.conv1(x) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + + # Decoder with skip connections + out_deconv2 = self.deconv2(out_conv3) + if out_deconv2.shape[2:] != out_conv2.shape[2:]: + out_deconv2 = F.interpolate( + out_deconv2, + size=out_conv2.shape[2:], + mode="trilinear", + align_corners=False, + ) + concat2 = torch.cat((out_conv2, out_deconv2), dim=1) + + out_deconv1 = self.deconv1(concat2) + if out_deconv1.shape[2:] != out_conv1.shape[2:]: + out_deconv1 = F.interpolate( + out_deconv1, + size=out_conv1.shape[2:], + mode="trilinear", + align_corners=False, + ) + concat1 = torch.cat((out_conv1, out_deconv1), dim=1) + + out_deconv0 = self.deconv0(concat1) + if out_deconv0.shape[2:] != x.shape[2:]: + out_deconv0 = F.interpolate( + out_deconv0, size=x.shape[2:], mode="trilinear", align_corners=False + ) + concat0 = torch.cat((x, out_deconv0), dim=1) + + out = self.output_layer(concat0) + + return out + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/examples/reservoir_simulation/neural_operator_factory/models/xdeeponet.py b/examples/reservoir_simulation/neural_operator_factory/models/xdeeponet.py new file mode 100644 index 0000000000..ea0a68a631 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/models/xdeeponet.py @@ -0,0 +1,1264 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepONet Variants for 2D and 3D spatial problems. + +Supported variants: + - deeponet: Basic DeepONet + - u_deeponet: DeepONet with U-Net branch + - fourier_deeponet: DeepONet with Fourier layers + - conv_deeponet: DeepONet with convolutional layers + - hybrid_deeponet: Combination of Fourier + U-Net + Conv + - mionet: Multiple-input operator network (2 branches) + - fourier_mionet: MIONet with Fourier layers +""" + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from physicsnemo.models.module import Module +from torch import Tensor +from utils.padding import compute_right_pad_to_multiple, pad_spatial_right + +from models.physicsnemo_unet import PhysicsNemoUNet2D, PhysicsNemoUNet3D +from models.unet import UNet2D, UNet3D +from physicsnemo.models.layers import ( + Conv2dFCLayer, + Conv3dFCLayer, + SpectralConv2d, + SpectralConv3d, + get_activation, +) +from physicsnemo.models.mlp import FullyConnected + +# ============================================================================= +# Branch Config Normalization +# ============================================================================= + + +def _normalize_branch_config(config: dict) -> dict: + """Normalize branch config to the nested encoder/layers format. + + Supports two formats: + + **New format** (nested):: + + branch1: + encoder: + type: linear # "linear" (LazyLinear lift) or "mlp" + hidden_width: 64 # MLP-only settings + num_layers: 2 + activation_fn: tanh + layers: + num_fourier_layers: 3 + num_unet_layers: 1 + num_conv_layers: 0 + modes1: 12 + ... + internal_resolution: null + + **Old format** (flat, auto-converted for backward compat):: + + branch1: + encoder: spatial # or "mlp" + num_fourier_layers: 3 + hidden_width: 64 + ... + + Returns a dict in the new format. + """ + if "encoder" not in config: + return config + + enc = config["encoder"] + + if not isinstance(enc, str): + return config + + enc_type_str = str(enc).lower() + cfg = dict(config) + cfg.pop("encoder") + + encoder_keys = {"hidden_width", "num_layers"} + layer_keys = { + "num_fourier_layers", + "num_unet_layers", + "num_conv_layers", + "modes1", + "modes2", + "modes3", + "kernel_size", + "dropout", + "unet_impl", + } + + activation = cfg.pop("activation_fn", "sin") + internal_res = cfg.pop("internal_resolution", None) + in_channels = cfg.pop("in_channels", None) + + encoder_dict = { + "type": "mlp" if enc_type_str == "mlp" else "linear", + "activation_fn": activation, + } + for k in encoder_keys: + if k in cfg: + encoder_dict[k] = cfg.pop(k) + + layers_dict = {"activation_fn": activation} + for k in layer_keys: + if k in cfg: + layers_dict[k] = cfg.pop(k) + + result = {"encoder": encoder_dict, "layers": layers_dict} + if internal_res is not None: + result["internal_resolution"] = internal_res + if in_channels is not None: + result["in_channels"] = in_channels + + return result + + +def _build_conv_encoder(width: int, enc_config: dict) -> nn.Module: + """Build a multi-layer pointwise encoder to replace the default LazyLinear lift. + + Operates in channels-last format ``(B, *spatial, C)`` — matching the + SpatialBranch lift interface. Each layer is a ``Linear`` with activation, + equivalent to a 1x1 convolution applied independently at every spatial point. + + Parameters + ---------- + width : int + Output width (latent dimension). + enc_config : dict + Encoder config with optional ``num_layers``, ``hidden_width``, + ``activation_fn``. + """ + num_layers = enc_config.get("num_layers", 1) + activation_fn = enc_config.get("activation_fn", "relu") + act = get_activation(activation_fn) + + if num_layers <= 1: + return nn.LazyLinear(width) + + hidden_width = enc_config.get("hidden_width", width // 2) + layers_list = [nn.LazyLinear(hidden_width), act] + for _ in range(num_layers - 2): + layers_list.extend([nn.Linear(hidden_width, hidden_width), act]) + layers_list.append(nn.Linear(hidden_width, width)) + return nn.Sequential(*layers_list) + + +# ============================================================================= +# Shared Components +# ============================================================================= + + +class TrunkNet(nn.Module): + """ + MLP trunk network for encoding query coordinates. + + Input: (T, in_features) - T query points with in_features dimensions + Output: (T, out_features) - encoded representations + + Parameters + ---------- + output_activation : bool + If True (default), apply activation to the output layer. + If False, the output layer is linear (no activation). + The original TNO paper uses False (linear trunk output) + to avoid squashing the Hadamard product's dynamic range. + """ + + def __init__( + self, + in_features: int = 1, + out_features: int = 64, + hidden_width: int = 128, + num_layers: int = 6, + activation_fn: str = "sin", + output_activation: bool = True, + ): + super().__init__() + + self._output_activation = output_activation + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + self.layers = nn.ModuleList() + self.layers.append(self._make_linear(in_features, hidden_width)) + for _ in range(num_layers - 1): + self.layers.append(self._make_linear(hidden_width, hidden_width)) + + self.output_layer = self._make_linear(hidden_width, out_features) + + def _make_linear(self, in_dim: int, out_dim: int) -> nn.Linear: + layer = nn.Linear(in_dim, out_dim) + init.xavier_normal_(layer.weight) + init.zeros_(layer.bias) + return layer + + def forward(self, x: Tensor) -> Tensor: + for layer in self.layers: + x = self.activation_fn(layer(x)) + x = self.output_layer(x) + if self._output_activation: + x = self.activation_fn(x) + return x + + +class MLPBranch(nn.Module): + """ + MLP branch network for scalar/vector inputs. + + Input: (B, in_features) - batch of scalar inputs + Output: (B, out_features) - encoded representations + + Note: in_features is auto-discovered using LazyLinear on first forward pass. + """ + + def __init__( + self, + out_features: int, + hidden_width: int = 64, + num_layers: int = 3, + activation_fn: str = "relu", + ): + super().__init__() + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + self.layers = nn.ModuleList() + # First layer uses LazyLinear to auto-discover input size + self.layers.append(nn.LazyLinear(hidden_width)) + for _ in range(num_layers - 2): + self.layers.append(self._make_linear(hidden_width, hidden_width)) + + self.output_layer = self._make_linear(hidden_width, out_features) + + def _make_linear(self, in_dim: int, out_dim: int) -> nn.Linear: + layer = nn.Linear(in_dim, out_dim) + init.xavier_normal_(layer.weight) + init.zeros_(layer.bias) + return layer + + def forward(self, x: Tensor) -> Tensor: + for layer in self.layers: + x = self.activation_fn(layer(x)) + return self.activation_fn(self.output_layer(x)) + + +# ============================================================================= +# 2D Components +# ============================================================================= + + +class SpatialBranch(nn.Module): + """ + 2D spatial branch network with Fourier, U-Net, and/or Conv layers. + + Input: (B, H, W, C) - batch of 2D spatial fields + Output: (B, H, W, width) - encoded spatial representations + + When *internal_resolution* is set, the feature maps are pooled to + that fixed size before processing and upsampled back afterwards. + This enables resolution-agnostic training and inference. + """ + + def __init__( + self, + in_channels: int, + width: int, + num_fourier_layers: int = 0, + num_unet_layers: int = 0, + num_conv_layers: int = 0, + modes1: int = 12, + modes2: int = 12, + kernel_size: int = 3, + dropout: float = 0.0, + unet_impl: str = "custom", + activation_fn: str = "gelu", + internal_resolution: Optional[list] = None, + ): + super().__init__() + + self.num_fourier_layers = num_fourier_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.use_fourier_base = num_fourier_layers > 0 + self.internal_resolution = ( + tuple(internal_resolution) if internal_resolution else None + ) + + total_layers = num_fourier_layers + num_unet_layers + num_conv_layers + if total_layers == 0: + raise ValueError("SpatialBranch requires at least one layer type") + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + if self.internal_resolution is not None: + self.adaptive_pool = nn.AdaptiveAvgPool2d(self.internal_resolution) + + # Lifting layer + self.lift = nn.LazyLinear(width) + + # Spectral convolutions (Fourier layers) + num_fourier_components = ( + total_layers if self.use_fourier_base else num_fourier_layers + ) + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + for _ in range(num_fourier_components): + self.spectral_convs.append(SpectralConv2d(width, width, modes1, modes2)) + self.conv_1x1s.append(nn.Conv2d(width, width, kernel_size=1)) + + # U-Net modules + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + if unet_impl == "custom": + self.unet_modules.append(UNet2D(width, width, kernel_size, dropout)) + else: + self.unet_modules.append(PhysicsNemoUNet2D(width, width, kernel_size)) + + # Convolutional modules + self.conv_modules = nn.ModuleList() + padding = (kernel_size - 1) // 2 + for _ in range(num_conv_layers): + self.conv_modules.append( + nn.Sequential( + nn.Conv2d( + width, + width, + kernel_size=kernel_size, + padding=padding, + bias=False, + ), + nn.BatchNorm2d(width), + ) + ) + + def forward(self, x: Tensor) -> Tensor: + # Lift to width dimension + x = self.lift(x) + x = x.permute(0, 3, 1, 2) # (B, H, W, width) -> (B, width, H, W) + + # Adaptive pool to internal resolution if configured + original_size = x.shape[2:] + if self.internal_resolution is not None: + x = self.adaptive_pool(x) + + # Fourier layers + for i in range(self.num_fourier_layers): + x = self.activation_fn(self.spectral_convs[i](x) + self.conv_1x1s[i](x)) + + # Hybrid or standalone layers + if self.use_fourier_base: + for i in range(self.num_unet_layers): + j = self.num_fourier_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.unet_modules[i](x) + ) + for i in range(self.num_conv_layers): + j = self.num_fourier_layers + self.num_unet_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.conv_modules[i](x) + ) + else: + for unet in self.unet_modules: + x = self.activation_fn(unet(x)) + for conv in self.conv_modules: + x = self.activation_fn(conv(x)) + + # Upsample back to original resolution + if self.internal_resolution is not None and x.shape[2:] != original_size: + x = F.interpolate( + x, size=original_size, mode="bilinear", align_corners=True + ) + + return x.permute(0, 2, 3, 1) # (B, width, H, W) -> (B, H, W, width) + + +class DeepONet(Module): + """ + 2D DeepONet for operator learning. + + Input: + - x_branch1: (B, H, W, C) for spatial or (B, in_features) for MLP + - x_time: (T,) or (T, in_features) query coordinates + - x_branch2: optional second branch input for MIONet + Output: (B, H, W, T) for spatial or (B, T) for MLP + """ + + VALID_VARIANTS = [ + "deeponet", + "u_deeponet", + "fourier_deeponet", + "conv_deeponet", + "hybrid_deeponet", + "mionet", + "fourier_mionet", + "tno", + ] + + def __init__( + self, + variant: str = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: str = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + ): + super().__init__() + + self.variant = variant.lower() + self.width = width + self.decoder_type = decoder_type.lower() + self.decoder_activation_fn = decoder_activation_fn + + if self.variant not in self.VALID_VARIANTS: + raise ValueError( + f"Unknown variant: {variant}. Valid: {self.VALID_VARIANTS}" + ) + + branch1_config = branch1_config or {} + trunk_config = trunk_config or {} + + # Build networks + self.branch1 = self._build_branch(branch1_config, width) + + self.has_branch2 = branch2_config is not None + if self.has_branch2: + self.branch2 = self._build_branch(branch2_config, width) + + self.trunk = TrunkNet( + in_features=trunk_config.get("in_features", 1), + out_features=width, + hidden_width=trunk_config.get("hidden_width", 128), + num_layers=trunk_config.get("num_layers", 6), + activation_fn=trunk_config.get("activation_fn", "sin"), + output_activation=trunk_config.get("output_activation", True), + ) + + if decoder_type == "temporal_projection": + self._temporal_projection = True + self.decoder = self._build_decoder( + width, + width, + decoder_layers, + decoder_width, + "mlp", + decoder_activation_fn, + ) + self.temporal_head = None + else: + self._temporal_projection = False + self.decoder = self._build_decoder( + width, + 1, + decoder_layers, + decoder_width, + decoder_type, + decoder_activation_fn, + ) + + def set_output_window(self, K: int): + """Set the temporal projection head for K output timesteps.""" + if self._temporal_projection: + device = next(self.parameters()).device + self.temporal_head = nn.Linear(self.width, K).to(device) + + def _build_branch(self, config: dict, width: int) -> nn.Module: + config = _normalize_branch_config(config) + enc = config.get("encoder", {}) + layers = config.get("layers", {}) + + enc_type = enc.get("type", "linear") + enc_activation = enc.get("activation_fn", "sin") + + has_layers = ( + layers.get("num_fourier_layers", 0) + + layers.get("num_unet_layers", 0) + + layers.get("num_conv_layers", 0) + ) > 0 + + if enc_type == "mlp" and not has_layers: + return MLPBranch( + out_features=width, + hidden_width=enc.get("hidden_width", 64), + num_layers=enc.get("num_layers", 3), + activation_fn=enc_activation, + ) + + layer_activation = layers.get("activation_fn", enc_activation) + branch = SpatialBranch( + in_channels=config.get("in_channels", 12), + width=width, + num_fourier_layers=layers.get("num_fourier_layers", 0), + num_unet_layers=layers.get("num_unet_layers", 0), + num_conv_layers=layers.get("num_conv_layers", 0), + modes1=layers.get("modes1", 12), + modes2=layers.get("modes2", 12), + kernel_size=layers.get("kernel_size", 3), + dropout=layers.get("dropout", 0.0), + unet_impl=layers.get("unet_impl", "custom"), + activation_fn=layer_activation, + internal_resolution=config.get("internal_resolution", None), + ) + if enc_type == "conv": + branch.lift = _build_conv_encoder(width, enc) + return branch + + def _build_decoder( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + return FullyConnected( + width, hidden_width, out_channels, num_layers, activation_fn + ) + + elif decoder_type == "conv": + if num_layers == 0: + return Conv2dFCLayer(width, out_channels) + + layers = [] + in_ch = width + for _ in range(num_layers): + layers.extend( + [Conv2dFCLayer(in_ch, hidden_width), get_activation(activation_fn)] + ) + in_ch = hidden_width + layers.append(Conv2dFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers) + + else: + raise ValueError(f"Unknown decoder_type: {decoder_type}") + + def forward( + self, x_branch1: Tensor, x_time: Tensor, x_branch2: Tensor = None + ) -> Tensor: + if x_time.dim() == 1: + x_time = x_time.unsqueeze(-1) + + b1_out = self.branch1(x_branch1) + + if self.has_branch2: + if x_branch2 is None: + raise ValueError("x_branch2 required for mionet/tno variants") + b2_out = self.branch2(x_branch2) + + trunk_out = self.trunk(x_time) + + # Combine branch and trunk + if b1_out.dim() == 4: # Spatial branch + if self._temporal_projection: + # Single trunk query → (B, H, W, width) combined + trunk_single = trunk_out[0:1] # (1, width) + trunk_exp = trunk_single.unsqueeze(1).unsqueeze(2) # (1, 1, 1, w) + combined = b1_out * trunk_exp + if self.has_branch2: + if b2_out.dim() == 4: + combined = combined * b2_out + else: + combined = combined * b2_out.unsqueeze(1).unsqueeze(2) + combined = self.decoder(combined) # (B, H, W, width) + if self.temporal_head is not None: + combined = self.temporal_head(combined) # (B, H, W, K) + return combined + + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + if self.has_branch2: + if b2_out.dim() == 4: + b2_out = b2_out.unsqueeze(1) + else: + b2_out = b2_out.unsqueeze(1).unsqueeze(2).unsqueeze(3) + combined = b1_out * b2_out * trunk_out + else: + combined = b1_out * trunk_out + + if self.decoder_type == "mlp": + return self.decoder(combined).squeeze(-1).permute(0, 2, 3, 1) + + B, T, H, W, C = combined.shape + combined = combined.permute(0, 1, 4, 2, 3).reshape(B * T, C, H, W) + return self.decoder(combined).reshape(B, T, H, W).permute(0, 2, 3, 1) + + else: # MLP branch + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0) + + if self.has_branch2: + combined = b1_out * b2_out.unsqueeze(1) * trunk_out + else: + combined = b1_out * trunk_out + + return self.decoder(combined).squeeze(-1) + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class DeepONetWrapper(nn.Module): + """ + 2D DeepONet wrapper with automatic padding and input extraction. + + Input: (B, H, W, T, C) - batch of spatiotemporal fields + Output: (B, H, W, T) - predicted output field + """ + + def __init__( + self, + padding: int = 8, + variant: str = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: str = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + ): + super().__init__() + + self.padding = ((padding + 7) // 8) * 8 if padding % 8 != 0 else padding + self.variant = variant + + trunk_config = trunk_config or {} + self.trunk_input = trunk_config.get("input_type", "time").lower() + + if self.trunk_input not in ["time", "grid"]: + raise ValueError("trunk input_type must be 'time' or 'grid'") + + if self.trunk_input == "grid": + trunk_config["in_features"] = 3 # (x, y, t) + else: + trunk_config["in_features"] = trunk_config.get("in_features", 1) + + self.model = DeepONet( + variant=variant, + width=width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=decoder_type, + decoder_width=decoder_width, + decoder_layers=decoder_layers, + decoder_activation_fn=decoder_activation_fn, + ) + self._temporal_projection = self.model._temporal_projection + + def set_output_window(self, K: int): + """Delegate to the inner DeepONet model.""" + self.model.set_output_window(K) + + def forward( + self, + x: Tensor, + x_branch2: Tensor = None, + target_times: Tensor = None, + ) -> Tensor: + """Forward pass. + + Parameters + ---------- + x : Tensor + Input ``(B, H, W, T_in, C)``. + x_branch2 : Tensor, optional + Secondary branch input (MIONet variants). + target_times : Tensor, optional + Explicit trunk query coordinates ``(K,)`` or ``(K, 1)``. + When provided the trunk evaluates at these K points instead of + extracting time values from ``x``. This enables autoregressive + temporal bundling where K != T_in. + + Returns + ------- + Tensor ``(B, H, W, T_out)`` where T_out = K if target_times given, + else T_in. + """ + H, W = x.shape[1], x.shape[2] + + pad_h, pad_w = compute_right_pad_to_multiple( + (H, W), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, spatial_ndim=2, right_pad=(pad_h, pad_w), mode="replicate" + ) + + if x_branch2 is not None and x_branch2.dim() > 2: + x_branch2 = pad_spatial_right( + x_branch2, + spatial_ndim=2, + right_pad=(pad_h, pad_w), + mode="replicate", + ) + + x_spatial = x.permute(0, 4, 1, 2, 3)[..., 0].permute(0, 2, 3, 1) + + if target_times is not None: + if self.trunk_input == "grid": + t_vals = ( + target_times + if target_times.dim() == 1 + else target_times.squeeze(-1) + ) + spatial = x[0, 0, 0, 0, -3:-1] # (2,) = grid_x, grid_y + spatial_exp = spatial.unsqueeze(0).expand(t_vals.shape[0], -1) + x_trunk = torch.cat( + [spatial_exp, t_vals.unsqueeze(-1)], dim=-1 + ) # (K, 3) + else: + x_trunk = ( + target_times + if target_times.dim() == 2 + else target_times.unsqueeze(-1) + ) + elif self.trunk_input == "grid": + x_trunk = x[0, 0, 0, :, -3:] + else: + x_trunk = x[0, 0, 0, :, -1].unsqueeze(-1) + + return self.model(x_spatial, x_trunk, x_branch2)[:, :H, :W, :] + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return self.model.count_params() + + +# ============================================================================= +# 3D Components +# ============================================================================= + + +class SpatialBranch3D(nn.Module): + """ + 3D spatial branch network with Fourier, U-Net, and/or Conv layers. + + Input: (B, X, Y, Z, C) - batch of 3D spatial fields + Output: (B, X, Y, Z, width) - encoded spatial representations + """ + + def __init__( + self, + in_channels: int, + width: int, + num_fourier_layers: int = 0, + num_unet_layers: int = 0, + num_conv_layers: int = 0, + modes1: int = 10, + modes2: int = 10, + modes3: int = 8, + kernel_size: int = 3, + dropout: float = 0.0, + unet_impl: str = "custom", + activation_fn: str = "gelu", + internal_resolution: Optional[list] = None, + ): + super().__init__() + + self.num_fourier_layers = num_fourier_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.use_fourier_base = num_fourier_layers > 0 + self.internal_resolution = ( + tuple(internal_resolution) if internal_resolution else None + ) + + total_layers = num_fourier_layers + num_unet_layers + num_conv_layers + if total_layers == 0: + raise ValueError("SpatialBranch3D requires at least one layer type") + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + if self.internal_resolution is not None: + self.adaptive_pool = nn.AdaptiveAvgPool3d(self.internal_resolution) + + # Lifting layer + self.lift = nn.LazyLinear(width) + + # Spectral convolutions (Fourier layers) + num_fourier_components = ( + total_layers if self.use_fourier_base else num_fourier_layers + ) + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + for _ in range(num_fourier_components): + self.spectral_convs.append( + SpectralConv3d(width, width, modes1, modes2, modes3) + ) + self.conv_1x1s.append(nn.Conv3d(width, width, kernel_size=1)) + + # U-Net modules + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + if unet_impl == "custom": + self.unet_modules.append(UNet3D(width, width, kernel_size, dropout)) + else: + self.unet_modules.append(PhysicsNemoUNet3D(width, width, kernel_size)) + + # Convolutional modules + self.conv_modules = nn.ModuleList() + padding = (kernel_size - 1) // 2 + for _ in range(num_conv_layers): + self.conv_modules.append( + nn.Sequential( + nn.Conv3d( + width, + width, + kernel_size=kernel_size, + padding=padding, + bias=False, + ), + nn.BatchNorm3d(width), + ) + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.lift(x) + x = x.permute(0, 4, 1, 2, 3) + + original_size = x.shape[2:] + if self.internal_resolution is not None: + x = self.adaptive_pool(x) + + for i in range(self.num_fourier_layers): + x = self.activation_fn(self.spectral_convs[i](x) + self.conv_1x1s[i](x)) + + if self.use_fourier_base: + for i in range(self.num_unet_layers): + j = self.num_fourier_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.unet_modules[i](x) + ) + for i in range(self.num_conv_layers): + j = self.num_fourier_layers + self.num_unet_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.conv_modules[i](x) + ) + else: + for unet in self.unet_modules: + x = self.activation_fn(unet(x)) + for conv in self.conv_modules: + x = self.activation_fn(conv(x)) + + if self.internal_resolution is not None and x.shape[2:] != original_size: + x = F.interpolate( + x, size=original_size, mode="trilinear", align_corners=True + ) + + return x.permute(0, 2, 3, 4, 1) + + +class DeepONet3D(Module): + """ + 3D DeepONet for operator learning on volumetric data. + + Input: + - x_branch1: (B, X, Y, Z, C) for spatial or (B, in_features) for MLP + - x_time: (T,) or (T, in_features) query coordinates + - x_branch2: optional second branch input for MIONet + Output: (B, X, Y, Z, T) for spatial or (B, T) for MLP + """ + + VALID_VARIANTS = [ + "deeponet", + "u_deeponet", + "fourier_deeponet", + "conv_deeponet", + "hybrid_deeponet", + "mionet", + "fourier_mionet", + "tno", + ] + + def __init__( + self, + variant: str = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: str = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + ): + super().__init__() + + self.variant = variant.lower() + self.width = width + self.decoder_type = decoder_type.lower() + self.decoder_activation_fn = decoder_activation_fn + + if self.variant not in self.VALID_VARIANTS: + raise ValueError( + f"Unknown variant: {variant}. Valid: {self.VALID_VARIANTS}" + ) + + branch1_config = branch1_config or {} + trunk_config = trunk_config or {} + + # Build networks + self.branch1 = self._build_branch(branch1_config, width) + + self.has_branch2 = branch2_config is not None + if self.has_branch2: + self.branch2 = self._build_branch(branch2_config, width) + + self.trunk = TrunkNet( + in_features=trunk_config.get("in_features", 1), + out_features=width, + hidden_width=trunk_config.get("hidden_width", 128), + num_layers=trunk_config.get("num_layers", 6), + activation_fn=trunk_config.get("activation_fn", "sin"), + output_activation=trunk_config.get("output_activation", True), + ) + + if decoder_type == "temporal_projection": + self._temporal_projection = True + self.decoder = self._build_decoder( + width, + width, + decoder_layers, + decoder_width, + "mlp", + decoder_activation_fn, + ) + self.temporal_head = None + else: + self._temporal_projection = False + self.decoder = self._build_decoder( + width, + 1, + decoder_layers, + decoder_width, + decoder_type, + decoder_activation_fn, + ) + + def set_output_window(self, K: int): + """Set the temporal projection head for K output timesteps.""" + if self._temporal_projection: + device = next(self.parameters()).device + self.temporal_head = nn.Linear(self.width, K).to(device) + + def _build_branch(self, config: dict, width: int) -> nn.Module: + config = _normalize_branch_config(config) + enc = config.get("encoder", {}) + layers = config.get("layers", {}) + + enc_type = enc.get("type", "linear") + enc_activation = enc.get("activation_fn", "sin") + + has_layers = ( + layers.get("num_fourier_layers", 0) + + layers.get("num_unet_layers", 0) + + layers.get("num_conv_layers", 0) + ) > 0 + + if enc_type == "mlp" and not has_layers: + return MLPBranch( + out_features=width, + hidden_width=enc.get("hidden_width", 64), + num_layers=enc.get("num_layers", 3), + activation_fn=enc_activation, + ) + + layer_activation = layers.get("activation_fn", enc_activation) + branch = SpatialBranch3D( + in_channels=config.get("in_channels", 11), + width=width, + num_fourier_layers=layers.get("num_fourier_layers", 0), + num_unet_layers=layers.get("num_unet_layers", 0), + num_conv_layers=layers.get("num_conv_layers", 0), + modes1=layers.get("modes1", 10), + modes2=layers.get("modes2", 10), + modes3=layers.get("modes3", 8), + kernel_size=layers.get("kernel_size", 3), + dropout=layers.get("dropout", 0.0), + unet_impl=layers.get("unet_impl", "custom"), + activation_fn=layer_activation, + internal_resolution=config.get("internal_resolution", None), + ) + if enc_type == "conv": + branch.lift = _build_conv_encoder(width, enc) + return branch + + def _build_decoder( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + return FullyConnected( + width, hidden_width, out_channels, num_layers, activation_fn + ) + + elif decoder_type == "conv": + if num_layers == 0: + return Conv3dFCLayer(width, out_channels) + + layers = [] + in_ch = width + for _ in range(num_layers): + layers.extend( + [Conv3dFCLayer(in_ch, hidden_width), get_activation(activation_fn)] + ) + in_ch = hidden_width + layers.append(Conv3dFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers) + + else: + raise ValueError(f"Unknown decoder_type: {decoder_type}") + + def forward( + self, x_branch1: Tensor, x_time: Tensor, x_branch2: Tensor = None + ) -> Tensor: + if x_time.dim() == 1: + x_time = x_time.unsqueeze(-1) + + b1_out = self.branch1(x_branch1) + + if self.has_branch2: + if x_branch2 is None: + raise ValueError("x_branch2 required for mionet/tno variants") + b2_out = self.branch2(x_branch2) + + trunk_out = self.trunk(x_time) + + # Combine branch and trunk + if b1_out.dim() == 5: # Spatial branch + if self._temporal_projection: + trunk_single = trunk_out[0:1] + trunk_exp = trunk_single.unsqueeze(1).unsqueeze(2).unsqueeze(3) + combined = b1_out * trunk_exp + if self.has_branch2: + if b2_out.dim() == 5: + combined = combined * b2_out + else: + combined = combined * b2_out.unsqueeze(1).unsqueeze( + 2 + ).unsqueeze(3) + combined = self.decoder(combined) + if self.temporal_head is not None: + combined = self.temporal_head(combined) + return combined + + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4) + + if self.has_branch2: + if b2_out.dim() == 5: + b2_out = b2_out.unsqueeze(1) + else: + b2_out = b2_out.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4) + combined = b1_out * b2_out * trunk_out + else: + combined = b1_out * trunk_out + + if self.decoder_type == "mlp": + return self.decoder(combined).squeeze(-1).permute(0, 2, 3, 4, 1) + + B, T, X, Y, Z, C = combined.shape + combined = combined.permute(0, 1, 5, 2, 3, 4).reshape(B * T, C, X, Y, Z) + return self.decoder(combined).reshape(B, T, X, Y, Z).permute(0, 2, 3, 4, 1) + + else: # MLP branch + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0) + + if self.has_branch2: + combined = b1_out * b2_out.unsqueeze(1) * trunk_out + else: + combined = b1_out * trunk_out + + return self.decoder(combined).squeeze(-1) + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class DeepONet3DWrapper(nn.Module): + """ + 3D DeepONet wrapper with automatic padding and input extraction. + + Input: (B, X, Y, Z, T, C) - batch of 4D spatiotemporal fields + Output: (B, X, Y, Z, T) - predicted output field + """ + + def __init__( + self, + padding: int = 8, + variant: str = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: str = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + ): + super().__init__() + + self.padding = ((padding + 7) // 8) * 8 if padding % 8 != 0 else padding + self.variant = variant + + trunk_config = trunk_config or {} + self.trunk_input = trunk_config.get("input_type", "time").lower() + + if self.trunk_input not in ["time", "grid"]: + raise ValueError("trunk input_type must be 'time' or 'grid'") + + if self.trunk_input == "grid": + trunk_config["in_features"] = 4 # (x, y, z, t) + else: + trunk_config["in_features"] = trunk_config.get("in_features", 1) + + self.model = DeepONet3D( + variant=variant, + width=width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=decoder_type, + decoder_width=decoder_width, + decoder_layers=decoder_layers, + decoder_activation_fn=decoder_activation_fn, + ) + self._temporal_projection = self.model._temporal_projection + + def set_output_window(self, K: int): + """Delegate to the inner DeepONet3D model.""" + self.model.set_output_window(K) + + def forward( + self, + x: Tensor, + x_branch2: Tensor = None, + target_times: Tensor = None, + ) -> Tensor: + """Forward pass. + + Parameters + ---------- + x : Tensor + Input ``(B, X, Y, Z, T_in, C)``. + x_branch2 : Tensor, optional + Secondary branch input (MIONet variants). + target_times : Tensor, optional + Explicit trunk query coordinates ``(K,)`` or ``(K, 1)``. + When provided the trunk evaluates at these K points instead of + extracting time values from ``x``. This enables autoregressive + temporal bundling where K != T_in. + + Returns + ------- + Tensor ``(B, X, Y, Z, T_out)`` where T_out = K if target_times given, + else T_in. + """ + X, Y, Z = x.shape[1], x.shape[2], x.shape[3] + + pad_x, pad_y, pad_z = compute_right_pad_to_multiple( + (X, Y, Z), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, spatial_ndim=3, right_pad=(pad_x, pad_y, pad_z), mode="replicate" + ) + + if x_branch2 is not None and x_branch2.dim() > 2: + x_branch2 = pad_spatial_right( + x_branch2, + spatial_ndim=3, + right_pad=(pad_x, pad_y, pad_z), + mode="replicate", + ) + + x_spatial = x[:, :, :, :, 0, :] + + if target_times is not None: + if self.trunk_input == "grid": + t_vals = ( + target_times + if target_times.dim() == 1 + else target_times.squeeze(-1) + ) + spatial = x[0, 0, 0, 0, 0, -4:-1] # (3,) = grid_x, grid_y, grid_z + spatial_exp = spatial.unsqueeze(0).expand(t_vals.shape[0], -1) + x_trunk = torch.cat( + [spatial_exp, t_vals.unsqueeze(-1)], dim=-1 + ) # (K, 4) + else: + x_trunk = ( + target_times + if target_times.dim() == 2 + else target_times.unsqueeze(-1) + ) + elif self.trunk_input == "grid": + x_trunk = x[0, 0, 0, 0, :, -4:] + else: + x_trunk = x[0, 0, 0, 0, :, -1].unsqueeze(-1) + + return self.model(x_spatial, x_trunk, x_branch2)[:, :X, :Y, :Z, :] + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return self.model.count_params() diff --git a/examples/reservoir_simulation/neural_operator_factory/models/xfno.py b/examples/reservoir_simulation/neural_operator_factory/models/xfno.py new file mode 100644 index 0000000000..b817a2a76f --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/models/xfno.py @@ -0,0 +1,796 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +U-FNO: Enhanced Fourier Neural Operator with U-Net or Conv skip connections. + +Reference: + Wen, G., Li, Z., Azizzadenesheli, K., Anandkumar, A., & Benson, S. M. (2022). + U-FNO--An enhanced Fourier neural operator-based deep-learning model for multiphase flow. + Advances in Water Resources, 104180. +""" + +import torch +import torch.nn as nn +from physicsnemo.models.module import Module +from torch import Tensor +from utils.padding import ( + compute_right_pad_to_multiple, + compute_right_pad_to_multiple_per_dim, + pad_spatial_right, +) + +from models.physicsnemo_unet import PhysicsNemoUNet3D +from models.unet import UNet3D +from physicsnemo.models.layers import ( + Conv3dFCLayer, + ConvNdFCLayer, + ConvNdKernel1Layer, + SpectralConv3d, + SpectralConv4d, + get_activation, +) +from physicsnemo.models.mlp import FullyConnected + + +class UFNO(Module): + """U-FNO/Conv-FNO: Fourier Neural Operator enhanced with U-Net or Conv modules. + + Architecture consists of: + - Lifting network to project input to latent space + - num_fno_layers standard Fourier layers (Spectral Conv + 1x1 Conv) + - num_unet_layers enhanced layers (Spectral Conv + 1x1 Conv + U-Net) + - num_conv_layers conv-enhanced layers (Spectral Conv + 1x1 Conv + 3D Conv) + - Decoder network to project latent space to output + + Note: + - U-FNO: num_unet_layers > 0, num_conv_layers = 0 + - Conv-FNO: num_unet_layers = 0, num_conv_layers > 0 + - Standard FNO: num_unet_layers = 0, num_conv_layers = 0 + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + width : int + Latent channel dimension + modes1, modes2, modes3 : int + Number of Fourier modes in each dimension + num_fno_layers : int + Number of standard Fourier layers + num_unet_layers : int + Number of U-Net enhanced layers + num_conv_layers : int + Number of Conv enhanced layers (Conv-FNO) + conv_kernel_size : int + Kernel size for Conv layers (Conv-FNO) + unet_kernel_size : int + Kernel size for U-Net convolutions + unet_dropout : float + Dropout rate in U-Net + unet_type : str + Type of UNet: "custom" (your UNet3D) or "physicsnemo" (PhysicsNemo's UNet) + activation_fn : str + Activation function name + lifting_type : str + Type of lifting layer: "mlp" or "conv" + lifting_layers : int + Number of layers in lifting network + lifting_width : int + Hidden width factor for multi-layer lifting + decoder_type : str + Type of decoder: "mlp" or "conv" + decoder_layers : int + Number of hidden layers in decoder + decoder_width : int + Hidden layer size in decoder + decoder_activation_fn : str, optional + Activation for decoder layers (None uses activation_fn, last layer always linear) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + width: int = 36, + modes1: int = 10, + modes2: int = 10, + modes3: int = 10, + num_fno_layers: int = 3, + num_unet_layers: int = 3, + num_conv_layers: int = 0, + conv_kernel_size: int = 3, + unet_kernel_size: int = 3, + unet_dropout: float = 0.0, + unet_type: str = "custom", + activation_fn: str = "relu", + lifting_type: str = "mlp", + lifting_layers: int = 1, + lifting_width: int = 2, + decoder_type: str = "mlp", + decoder_layers: int = 1, + decoder_width: int = 128, + decoder_activation_fn: str = None, # None means use activation_fn + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.width = width + self.modes1 = modes1 + self.modes2 = modes2 + self.modes3 = modes3 + self.num_fno_layers = num_fno_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.total_layers = num_fno_layers + num_unet_layers + num_conv_layers + self.lifting_type = lifting_type.lower() + self.decoder_type = decoder_type.lower() + self.unet_type = unet_type.lower() + self.activation_fn_name = activation_fn + self.decoder_activation_fn_name = ( + decoder_activation_fn if decoder_activation_fn else activation_fn + ) + self.activation_fn = get_activation(activation_fn) + self.conv_kernel_size = conv_kernel_size + + # Build lifting network + self.lift_network = self._build_lifting_network( + in_channels, width, lifting_layers, lifting_width, lifting_type + ) + + # Build Fourier layers + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + + for _ in range(self.total_layers): + self.spectral_convs.append( + SpectralConv3d(self.width, self.width, modes1, modes2, modes3) + ) + self.conv_1x1s.append(ConvNdKernel1Layer(self.width, self.width)) + + # Build U-Net modules + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + if self.unet_type == "custom": + # Use custom UNet3D + self.unet_modules.append( + UNet3D( + self.width, + self.width, + kernel_size=unet_kernel_size, + dropout_rate=unet_dropout, + ) + ) + elif self.unet_type == "physicsnemo": + # Use PhysicsNemo's UNet wrapper (from physicsnemo_unet.py) + self.unet_modules.append( + PhysicsNemoUNet3D( + in_channels=self.width, + out_channels=self.width, + kernel_size=unet_kernel_size, + model_depth=3, # 3 downsampling levels like custom UNet3D + feature_map_channels=[self.width] * 3, + conv_activation="leaky_relu", + normalization="batchnorm", + ) + ) + else: + raise ValueError( + f"Unknown unet_type: {self.unet_type}. Use 'custom' or 'physicsnemo'" + ) + + # Build Conv modules (for Conv-FNO) + self.conv_modules = nn.ModuleList() + for _ in range(num_conv_layers): + # Simple 3D convolution with activation + padding = (conv_kernel_size - 1) // 2 + self.conv_modules.append( + nn.Sequential( + nn.Conv3d( + self.width, + self.width, + kernel_size=conv_kernel_size, + padding=padding, + bias=False, + ), + nn.BatchNorm3d(self.width), + get_activation(self.activation_fn_name), + ) + ) + + # Build decoder network + self.decoder = self._build_decoder_network( + width, + out_channels, + decoder_layers, + decoder_width, + decoder_type, + self.decoder_activation_fn_name, + ) + + def _build_lifting_network( + self, + in_channels: int, + width: int, + num_layers: int, + hidden_width_factor: int, + lift_type: str, + ) -> nn.Module: + """Build lifting network to project input to latent space.""" + if lift_type == "mlp": + if num_layers == 1: + return nn.Linear(in_channels, width) + else: + return FullyConnected( + in_features=in_channels, + layer_size=width // hidden_width_factor, + out_features=width, + num_layers=num_layers, + activation_fn=self.activation_fn_name, + ) + elif lift_type == "conv": + if num_layers == 1: + return Conv3dFCLayer(in_channels, width) + else: + layers_list = [] + hidden_width = width // hidden_width_factor + layers_list.append(Conv3dFCLayer(in_channels, hidden_width)) + layers_list.append(get_activation(self.activation_fn_name)) + for _ in range(num_layers - 2): + layers_list.append(Conv3dFCLayer(hidden_width, hidden_width)) + layers_list.append(get_activation(self.activation_fn_name)) + layers_list.append(Conv3dFCLayer(hidden_width, width)) + return nn.Sequential(*layers_list) + else: + raise ValueError(f"Unknown lifting_type: {lift_type}. Use 'mlp' or 'conv'") + + def _build_decoder_network( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + """Build decoder network to project latent space to output. + + Parameters + ---------- + width : int + Input width from FNO layers + out_channels : int + Output channels + num_layers : int + Number of hidden layers (0 = direct projection) + hidden_width : int + Hidden layer width + decoder_type : str + 'mlp' for fully connected, 'conv' for 1x1 convolutions + activation_fn : str + Activation function name (last layer always linear) + + Returns + ------- + nn.Module + Decoder network + """ + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + else: + return FullyConnected( + in_features=width, + layer_size=hidden_width, + out_features=out_channels, + num_layers=num_layers, + activation_fn=activation_fn, + ) + elif decoder_type == "conv": + if num_layers == 0: + return Conv3dFCLayer(width, out_channels) + else: + layers_list = [] + in_ch = width + for _ in range(num_layers): + layers_list.append(Conv3dFCLayer(in_ch, hidden_width)) + layers_list.append(get_activation(activation_fn)) + in_ch = hidden_width + layers_list.append(Conv3dFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers_list) + else: + raise ValueError( + f"Unknown decoder_type: {decoder_type}. Use 'mlp' or 'conv'" + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through U-FNO.""" + # Lifting + if self.lifting_type == "mlp": + x = self.lift_network(x) + x = x.permute(0, 4, 1, 2, 3) + else: # conv + # Conv lifting: need channel-first input + # (batch, H, W, T, in_channels) -> (batch, in_channels, H, W, T) + x = x.permute(0, 4, 1, 2, 3) + x = self.lift_network(x) + + # Standard Fourier layers + for layer_idx in range(self.num_fno_layers): + x1 = self.spectral_convs[layer_idx](x) + x2 = self.conv_1x1s[layer_idx](x) + x = x1 + x2 + x = self.activation_fn(x) + + # Fourier + U-Net layers + for unet_idx in range(self.num_unet_layers): + layer_idx = self.num_fno_layers + unet_idx + x1 = self.spectral_convs[layer_idx](x) + x2 = self.conv_1x1s[layer_idx](x) + x3 = self.unet_modules[unet_idx](x) + x = x1 + x2 + x3 + x = self.activation_fn(x) + + # Fourier + Conv layers (Conv-FNO) + for conv_idx in range(self.num_conv_layers): + layer_idx = self.num_fno_layers + self.num_unet_layers + conv_idx + x1 = self.spectral_convs[layer_idx](x) + x2 = self.conv_1x1s[layer_idx](x) + x3 = self.conv_modules[conv_idx](x) + x = x1 + x2 + x3 + x = self.activation_fn(x) + + # Decoder: project back to output space + if self.decoder_type == "mlp": + # MLP decoder: need channel-last for pointwise operations + # (batch, width, H, W, T) -> (batch, H, W, T, width) + x = x.permute(0, 2, 3, 4, 1) + x = self.decoder(x) # (batch, H, W, T, out_channels) + else: # conv + # Conv decoder: already in channel-first format + x = self.decoder(x) # (batch, out_channels, H, W, T) + # Convert to channel-last for consistency + x = x.permute(0, 2, 3, 4, 1) # (batch, H, W, T, out_channels) + + return x + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class UFNONet(nn.Module): + """Wrapper for UFNO that handles padding/de-padding.""" + + def __init__( + self, + modes1: int, + modes2: int, + modes3: int, + width: int, + in_channels: int = 12, + out_channels: int = 1, + num_fno_layers: int = 2, + num_unet_layers: int = 2, + padding: int = 8, + **kwargs, + ): + super(UFNONet, self).__init__() + + # Ensure padding is divisible by 8 for U-Net compatibility + if padding % 8 != 0: + self.padding = ((padding + 7) // 8) * 8 + print( + f"Warning: Padding adjusted from {padding} to {self.padding} for U-Net compatibility" + ) + else: + self.padding = padding + + self.time_modes = modes3 + + self.ufno = UFNO( + modes1=modes1, + modes2=modes2, + modes3=modes3, + width=width, + in_channels=in_channels, + out_channels=out_channels, + num_fno_layers=num_fno_layers, + num_unet_layers=num_unet_layers, + **kwargs, + ) + + def forward( + self, + x: Tensor, + target_times: Tensor = None, + ) -> Tensor: + """Forward pass with padding/de-padding. + + Parameters + ---------- + x : Tensor + Input ``(B, H, W, T_in, C)``. + target_times : Tensor, optional + Explicit target time coordinates ``(K,)`` or ``(K, 1)``. + When provided and K != T_in, the time axis is padded so the + FNO operates on at least L+K timesteps, and the output is + cropped to the last K timesteps. + + Returns + ------- + Tensor ``(B, H, W, T_out)`` where T_out = K if target_times given, + else T_in. + """ + h, w, t_in = x.shape[1], x.shape[2], x.shape[3] + + K = target_times.shape[0] if target_times is not None else None + + if K is not None and K != t_in: + desired_t = t_in + K + min_t = max(desired_t, 2 * self.time_modes) + extra = min_t - t_in + x = pad_spatial_right( + x, + spatial_ndim=3, + right_pad=(0, 0, extra), + mode="replicate", + ) + t_padded = x.shape[3] + else: + K = None + t_padded = t_in + + pad_h, pad_w, pad_t = compute_right_pad_to_multiple( + (h, w, t_padded), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, spatial_ndim=3, right_pad=(pad_h, pad_w, pad_t), mode="replicate" + ) + + x = self.ufno(x) + + if K is not None: + x = x[:, :h, :w, t_in : t_in + K, :] + else: + x = x[:, :h, :w, :t_in, :] + + return x.squeeze(-1) + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return self.ufno.count_params() + + +# ============================================================================= +# 4D FNO CLASSES (3D spatial + time) +# ============================================================================= +# Note: U-Net and Conv skip connections are NOT available for 4D problems +# because PyTorch does not provide native nn.Conv4d. These classes use only +# officially supported PhysicsNemo layers: SpectralConv4d, ConvNdKernel1Layer, +# and ConvNdFCLayer. +# ============================================================================= + + +class FNO4D(Module): + """4D Fourier Neural Operator for volumetric (3D space + time) problems. + + Input: (B, X, Y, Z, T, C) + Output: (B, X, Y, Z, T, out_channels) + + Architecture: + - Lifting network (ConvNdFCLayer) + - num_fno_layers Fourier layers (SpectralConv4d + ConvNdKernel1Layer) + - Decoder network (ConvNdFCLayer) + + Note: Only pure FNO mode is supported for 4D. U-Net and Conv skip + connections are not available because PyTorch has no native Conv4d. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + width : int + Latent channel dimension + modes1, modes2, modes3, modes4 : int + Number of Fourier modes in each dimension (X, Y, Z, T) + num_fno_layers : int + Number of Fourier layers + activation_fn : str + Activation function name + lifting_layers : int + Number of layers in lifting network + decoder_layers : int + Number of hidden layers in decoder + decoder_width : int + Hidden layer size in decoder + coord_features : bool + Whether to add coordinate features (x, y, z, t) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + width: int = 32, + modes1: int = 8, + modes2: int = 8, + modes3: int = 6, + modes4: int = 6, + num_fno_layers: int = 4, + activation_fn: str = "gelu", + lifting_layers: int = 2, + decoder_layers: int = 1, + decoder_width: int = 128, + coord_features: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.width = width + self.modes1 = modes1 + self.modes2 = modes2 + self.modes3 = modes3 + self.modes4 = modes4 + self.num_fno_layers = num_fno_layers + self.coord_features = coord_features + self.activation_fn_name = activation_fn + self.activation_fn = get_activation(activation_fn) + + # Coordinate features add 4 channels (x, y, z, t) + lift_in_channels = in_channels + 4 if coord_features else in_channels + + # Lifting network using ConvNdFCLayer (supports arbitrary dimensions) + self.lift_network = self._build_lifting_network( + lift_in_channels, width, lifting_layers + ) + + # Fourier layers: SpectralConv4d + ConvNdKernel1Layer + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + + for _ in range(num_fno_layers): + self.spectral_convs.append( + SpectralConv4d(self.width, self.width, modes1, modes2, modes3, modes4) + ) + self.conv_1x1s.append(ConvNdKernel1Layer(self.width, self.width)) + + # Decoder network using ConvNdFCLayer + self.decoder = self._build_decoder_network( + width, out_channels, decoder_layers, decoder_width + ) + + def _build_lifting_network( + self, in_channels: int, width: int, num_layers: int + ) -> nn.Module: + """Build lifting network using ConvNdFCLayer.""" + if num_layers == 1: + return ConvNdFCLayer(in_channels, width) + else: + layers_list = [] + hidden_width = width // 2 + layers_list.append(ConvNdFCLayer(in_channels, hidden_width)) + layers_list.append(get_activation(self.activation_fn_name)) + for _ in range(num_layers - 2): + layers_list.append(ConvNdFCLayer(hidden_width, hidden_width)) + layers_list.append(get_activation(self.activation_fn_name)) + layers_list.append(ConvNdFCLayer(hidden_width, width)) + return nn.Sequential(*layers_list) + + def _build_decoder_network( + self, width: int, out_channels: int, num_layers: int, hidden_width: int + ) -> nn.Module: + """Build decoder network using ConvNdFCLayer.""" + if num_layers == 0: + return ConvNdFCLayer(width, out_channels) + else: + layers_list = [] + in_ch = width + for _ in range(num_layers): + layers_list.append(ConvNdFCLayer(in_ch, hidden_width)) + layers_list.append(get_activation(self.activation_fn_name)) + in_ch = hidden_width + layers_list.append(ConvNdFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers_list) + + def _create_meshgrid(self, shape: list, device: torch.device) -> Tensor: + """Create 4D coordinate meshgrid (x, y, z, t) normalized to [0, 1].""" + bsize = shape[0] + size_x, size_y, size_z, size_t = shape[2], shape[3], shape[4], shape[5] + + grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device) + grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device) + grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device) + grid_t = torch.linspace(0, 1, size_t, dtype=torch.float32, device=device) + + grid_x, grid_y, grid_z, grid_t = torch.meshgrid( + grid_x, grid_y, grid_z, grid_t, indexing="ij" + ) + + grid_x = grid_x.unsqueeze(0).unsqueeze(0).expand(bsize, 1, -1, -1, -1, -1) + grid_y = grid_y.unsqueeze(0).unsqueeze(0).expand(bsize, 1, -1, -1, -1, -1) + grid_z = grid_z.unsqueeze(0).unsqueeze(0).expand(bsize, 1, -1, -1, -1, -1) + grid_t = grid_t.unsqueeze(0).unsqueeze(0).expand(bsize, 1, -1, -1, -1, -1) + + return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through FNO4D. + + Input: (B, X, Y, Z, T, C) + Output: (B, X, Y, Z, T, out_channels) + """ + # Convert to channel-first: (B, C, X, Y, Z, T) + x = x.permute(0, 5, 1, 2, 3, 4) + + # Add coordinate features + if self.coord_features: + coord_feat = self._create_meshgrid(list(x.shape), x.device) + x = torch.cat((x, coord_feat), dim=1) + + # Lifting + x = self.lift_network(x) + + # Fourier layers + for layer_idx in range(self.num_fno_layers): + x1 = self.spectral_convs[layer_idx](x) + x2 = self.conv_1x1s[layer_idx](x) + if layer_idx < self.num_fno_layers - 1: + x = self.activation_fn(x1 + x2) + else: + x = x1 + x2 + + # Decoder + x = self.decoder(x) + + # Convert to channel-last: (B, X, Y, Z, T, out_channels) + x = x.permute(0, 2, 3, 4, 5, 1) + return x + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class FNO4DNet(nn.Module): + """Wrapper for FNO4D that handles padding/de-padding. + + Input: (B, X, Y, Z, T, C) + Output: (B, X, Y, Z, T) + + Parameters + ---------- + modes1, modes2, modes3, modes4 : int + Number of Fourier modes in each dimension + width : int + Latent channel dimension + in_channels : int + Number of input channels + out_channels : int + Number of output channels + num_fno_layers : int + Number of Fourier layers + padding : int or list + Padding for each dimension (X, Y, Z, T) + **kwargs + Additional arguments passed to FNO4D + """ + + def __init__( + self, + modes1: int, + modes2: int, + modes3: int, + modes4: int, + width: int, + in_channels: int = 11, + out_channels: int = 1, + num_fno_layers: int = 4, + padding: int = 8, + **kwargs, + ): + super(FNO4DNet, self).__init__() + + # Store padding for each dimension (X, Y, Z, T) + if isinstance(padding, int): + self.padding = [padding, padding, padding, padding] + else: + self.padding = list(padding) + [0] * (4 - len(padding)) + self.padding = self.padding[:4] + + self.time_modes = modes4 + + self.fno4d = FNO4D( + modes1=modes1, + modes2=modes2, + modes3=modes3, + modes4=modes4, + width=width, + in_channels=in_channels, + out_channels=out_channels, + num_fno_layers=num_fno_layers, + **kwargs, + ) + + def forward( + self, + x: Tensor, + target_times: Tensor = None, + ) -> Tensor: + """Forward pass with padding/de-padding. + + Parameters + ---------- + x : Tensor + Input ``(B, X, Y, Z, T_in, C)``. + target_times : Tensor, optional + Explicit target time coordinates ``(K,)`` or ``(K, 1)``. + When provided and K != T_in, the time axis is padded so the + FNO operates on at least L+K timesteps, and the output is + cropped to the last K timesteps. + + Returns + ------- + Tensor ``(B, X, Y, Z, T_out)`` where T_out = K if target_times given, + else T_in. + """ + x0, y0, z0, t_in = x.shape[1], x.shape[2], x.shape[3], x.shape[4] + + K = target_times.shape[0] if target_times is not None else None + + if K is not None and K != t_in: + desired_t = t_in + K + min_t = max(desired_t, 2 * self.time_modes) + extra = min_t - t_in + x = pad_spatial_right( + x, + spatial_ndim=4, + right_pad=(0, 0, 0, extra), + mode="replicate", + ) + t_padded = x.shape[4] + else: + K = None + t_padded = t_in + + pad_x, pad_y, pad_z, pad_t = compute_right_pad_to_multiple_per_dim( + (x0, y0, z0, t_padded), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, + spatial_ndim=4, + right_pad=(pad_x, pad_y, pad_z, pad_t), + mode="replicate", + ) + + x = self.fno4d(x) + + if K is not None: + x = x[:, :x0, :y0, :z0, t_in : t_in + K, :] + else: + x = x[:, :x0, :y0, :z0, :t_in, :] + + return x.squeeze(-1) + + def count_params(self) -> int: + """Count total number of trainable parameters.""" + return self.fno4d.count_params() diff --git a/examples/reservoir_simulation/neural_operator_factory/requirements.txt b/examples/reservoir_simulation/neural_operator_factory/requirements.txt new file mode 100644 index 0000000000..b702446e02 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/requirements.txt @@ -0,0 +1,26 @@ +# Requirements for Neural Operator Factory +# Install with: pip install -r requirements.txt + +# Core dependencies +torch>=2.0.0 +numpy>=1.21.0 +hydra-core>=1.3.0 +omegaconf>=2.3.0 + +# PhysicsNeMo - install from source +# git clone https://github.com/NVIDIA/physicsnemo.git +# cd physicsnemo && pip install -e . +# Provides: DistributedManager, Module, launch/logging utils, +# SpectralConv2d/3d, FullyConnected, and GNN layers. + +# Optional: Experiment tracking (used when logging.use_mlflow: true) +# mlflow>=2.0.0 + +# Optional: Evaluation visualization (used by evaluate_*.py scripts) +# matplotlib>=3.5.0 + +# Testing +# pytest>=7.0.0 + +# Note: CUDA/cuDNN should be installed separately based on your system. +# This code is tested with CUDA 12.2 and H100 GPUs. diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/__init__.py b/examples/reservoir_simulation/neural_operator_factory/tests/__init__.py new file mode 100644 index 0000000000..36813458d8 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the neural operator factory.""" diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/conftest.py b/examples/reservoir_simulation/neural_operator_factory/tests/conftest.py new file mode 100644 index 0000000000..c0f2a97d85 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/conftest.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared pytest fixtures for neural operator factory tests.""" + +import sys +from pathlib import Path + +import pytest +import torch + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +@pytest.fixture(scope="session") +def device(): + """Return available device for the test session.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + +@pytest.fixture +def random_seed(): + """Set random seed for reproducibility.""" + torch.manual_seed(42) + return 42 + + +@pytest.fixture +def sample_input_tensor(device): + """Create a sample input tensor for reservoir simulation models.""" + # Shape: (batch, H, W, T, channels) + # Using smaller dimensions for faster tests + return torch.randn(2, 32, 64, 16, 12).to(device) + + +@pytest.fixture +def sample_target_tensor(device): + """Create a sample target tensor.""" + # Shape: (batch, H, W, T) + return torch.randn(2, 32, 64, 16).to(device) + + +@pytest.fixture +def sample_inputs_with_grid(device): + """Create sample inputs with grid coordinates for loss functions.""" + B, H, W, T, C = 2, 32, 64, 16, 12 + inputs = torch.randn(B, H, W, T, C).to(device) + + # Set grid_x channel (channel -3) with increasing values + grid_x = torch.linspace(0, 100, W).to(device) + inputs[..., -3] = grid_x.view(1, 1, W, 1).expand(B, H, W, T) + + # Set grid_y channel (channel -2) with increasing values + grid_y = torch.linspace(0, 50, H).to(device) + inputs[..., -2] = grid_y.view(1, H, 1, 1).expand(B, H, W, T) + + # Set time channel (channel -1) with increasing values + grid_t = torch.linspace(0, 30, T).to(device) + inputs[..., -1] = grid_t.view(1, 1, 1, T).expand(B, H, W, T) + + return inputs + + +def pytest_configure(config): + """Configure pytest with custom markers.""" + config.addinivalue_line("markers", "slow: mark test as slow to run") + config.addinivalue_line("markers", "gpu: mark test as requiring GPU") + + +def pytest_collection_modifyitems(config, items): + """Skip GPU tests if no GPU is available.""" + if not torch.cuda.is_available(): + skip_gpu = pytest.mark.skip(reason="CUDA not available") + for item in items: + if "gpu" in item.keywords: + item.add_marker(skip_gpu) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_ar_utils.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_ar_utils.py new file mode 100644 index 0000000000..3ba4fadb4c --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_ar_utils.py @@ -0,0 +1,1255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Comprehensive unit tests for autoregressive training utilities.""" + +import sys +from pathlib import Path + +import torch +import torch.nn as nn + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from training.ar_utils import ( + _build_branch2, + _iter_windows, + _model_accepts_target_times, + _model_accepts_x_branch2, + _time_axis_input, + _time_axis_target, + add_noise, + ar_validate_full_rollout, + compute_unroll_steps, + extract_target_times, + get_training_stage, + inject_feedback_channel, + live_rollout_step, + rollout_step, + slice_input_window, + slice_target_window, + teacher_forcing_step, +) + +# --------------------------------------------------------------------------- +# Dummy models +# --------------------------------------------------------------------------- + + +class DummyModel3D(nn.Module): + """Returns zeros; output T = target_times length if given, else T_in.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x, target_times=None): + B, H, W, T_in, C = x.shape + T_out = target_times.shape[0] if target_times is not None else T_in + return ( + torch.zeros(B, H, W, T_out, device=x.device) + self.linear.weight.sum() * 0 + ) + + +class DummyModel4D(nn.Module): + """Returns zeros; output T = target_times length if given, else T_in.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x, target_times=None): + B, X, Y, Z, T_in, C = x.shape + T_out = target_times.shape[0] if target_times is not None else T_in + return ( + torch.zeros(B, X, Y, Z, T_out, device=x.device) + + self.linear.weight.sum() * 0 + ) + + +class DummyModelNoTargetTimes(nn.Module): + """Model that does NOT accept target_times (e.g. FNO).""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + B, H, W, T_in, C = x.shape + return ( + torch.zeros(B, H, W, T_in, device=x.device) + self.linear.weight.sum() * 0 + ) + + +class DummyTNOModel3D(nn.Module): + """TNO-style 3D model: accepts x_branch2.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x, target_times=None, x_branch2=None): + B, H, W, T_in, C = x.shape + T_out = target_times.shape[0] if target_times is not None else T_in + return ( + torch.zeros(B, H, W, T_out, device=x.device) + self.linear.weight.sum() * 0 + ) + + +class DummyTNOModel4D(nn.Module): + """TNO-style 4D model: accepts x_branch2.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x, target_times=None, x_branch2=None): + B, X, Y, Z, T_in, C = x.shape + T_out = target_times.shape[0] if target_times is not None else T_in + return ( + torch.zeros(B, X, Y, Z, T_out, device=x.device) + + self.linear.weight.sum() * 0 + ) + + +class DummyFeedbackModel3D(nn.Module): + """Asserts C == base_channels + 1 (feedback channel present).""" + + def __init__(self, base_channels=5): + super().__init__() + self.base_channels = base_channels + self.linear = nn.Linear(1, 1) + + def forward(self, x, target_times=None): + B, H, W, T_in, C = x.shape + assert C == self.base_channels + 1, ( + f"Expected {self.base_channels + 1} channels, got {C}" + ) + T_out = target_times.shape[0] if target_times is not None else T_in + return ( + torch.zeros(B, H, W, T_out, device=x.device) + self.linear.weight.sum() * 0 + ) + + +class DummyIdentityModel3D(nn.Module): + """Returns scale * ones; useful for gradient flow tests.""" + + def __init__(self): + super().__init__() + self.scale = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x, target_times=None): + B, H, W, T_in, C = x.shape + T_out = target_times.shape[0] if target_times is not None else T_in + return self.scale * torch.ones(B, H, W, T_out, device=x.device) + + +def dummy_loss(pred, target, inputs, spatial_mask=None): + """Simple MSE loss for testing.""" + return torch.mean((pred - target) ** 2) + + +# --------------------------------------------------------------------------- +# Tests: _time_axis_input, _time_axis_target +# --------------------------------------------------------------------------- + + +class TestTimeAxisHelpers: + """Tests for _time_axis_input and _time_axis_target.""" + + def test_input_3d(self): + """Verify time axis is 3 for 5D (3D-spatial) input.""" + assert _time_axis_input(torch.randn(2, 8, 10, 24, 5)) == 3 + + def test_input_4d(self): + """Verify time axis is 4 for 6D (4D-spatial) input.""" + assert _time_axis_input(torch.randn(2, 8, 10, 6, 24, 5)) == 4 + + def test_target_3d_single_output(self): + """Verify time axis is 3 for 4D single-output target.""" + assert _time_axis_target(torch.randn(2, 8, 10, 24)) == 3 + + def test_target_4d_single_output(self): + """Verify time axis is 4 for 5D single-output target.""" + assert _time_axis_target(torch.randn(2, 8, 10, 6, 24)) == 4 + + def test_input_multi_output(self): + """Multi-output input has same layout -- extra channels, same ndim.""" + assert _time_axis_input(torch.randn(2, 8, 10, 24, 11)) == 3 + + def test_target_multi_output(self): + """5D target treated as 4D-spatial (B, X, Y, Z, T); time axis is last.""" + assert _time_axis_target(torch.randn(2, 8, 10, 6, 24)) == 4 + + +# --------------------------------------------------------------------------- +# Tests: slice_input_window, slice_target_window +# --------------------------------------------------------------------------- + + +class TestSlicing: + """Tests for slice_input_window and slice_target_window.""" + + def test_slice_input_3d(self): + """Verify input window slicing for 3D data.""" + x = torch.randn(2, 8, 10, 24, 5) + w = slice_input_window(x, t0=3, width=4) + assert w.shape == (2, 8, 10, 4, 5) + assert torch.equal(w, x[:, :, :, 3:7, :]) + + def test_slice_input_4d(self): + """Verify input window slicing for 4D data.""" + x = torch.randn(2, 8, 10, 6, 24, 5) + w = slice_input_window(x, t0=5, width=3) + assert w.shape == (2, 8, 10, 6, 3, 5) + assert torch.equal(w, x[:, :, :, :, 5:8, :]) + + def test_slice_target_3d(self): + """Verify target window slicing for 3D data.""" + y = torch.randn(2, 8, 10, 24) + w = slice_target_window(y, t0=10, width=5) + assert w.shape == (2, 8, 10, 5) + assert torch.equal(w, y[:, :, :, 10:15]) + + def test_slice_target_4d(self): + """Verify target window slicing for 4D data.""" + y = torch.randn(2, 8, 10, 6, 24) + w = slice_target_window(y, t0=0, width=3) + assert w.shape == (2, 8, 10, 6, 3) + assert torch.equal(w, y[:, :, :, :, 0:3]) + + def test_slice_input_many_channels(self): + """Slicing works regardless of channel count.""" + x = torch.randn(2, 8, 10, 24, 11) + w = slice_input_window(x, t0=0, width=6) + assert w.shape == (2, 8, 10, 6, 11) + + def test_slice_target_4d_spatial(self): + """4D spatial target (B, X, Y, Z, T): slice along last dim.""" + y = torch.randn(2, 8, 10, 6, 24) + w = slice_target_window(y, t0=5, width=4) + assert w.shape == (2, 8, 10, 6, 4) + assert torch.equal(w, y[:, :, :, :, 5:9]) + + +# --------------------------------------------------------------------------- +# Tests: extract_target_times +# --------------------------------------------------------------------------- + + +class TestExtractTargetTimes: + """Tests for extract_target_times.""" + + def test_3d_extracts_last_channel(self): + """Verify target times extracted from last channel for 3D data.""" + x = torch.randn(2, 4, 6, 16, 5) + times = extract_target_times(x, t_start=5, K=3) + assert times.shape == (3,) + assert torch.equal(times, x[0, 0, 0, 5:8, -1]) + + def test_4d_extracts_last_channel(self): + """Verify target times extracted from last channel for 4D data.""" + x = torch.randn(2, 4, 6, 3, 20, 11) + times = extract_target_times(x, t_start=10, K=5) + assert times.shape == (5,) + assert torch.equal(times, x[0, 0, 0, 0, 10:15, -1]) + + def test_k_equals_1(self): + """Verify extraction works with K=1.""" + x = torch.randn(1, 4, 6, 16, 5) + times = extract_target_times(x, t_start=0, K=1) + assert times.shape == (1,) + + +# --------------------------------------------------------------------------- +# Tests: inject_feedback_channel +# --------------------------------------------------------------------------- + + +class TestInjectFeedbackChannel: + """Tests for inject_feedback_channel.""" + + def test_none_feedback_returns_unchanged(self): + """None feedback returns input window unchanged.""" + x = torch.randn(2, 4, 6, 3, 5) + result = inject_feedback_channel(x, None) + assert result.shape == (2, 4, 6, 3, 5) + assert torch.equal(result, x) + + def test_single_output_feedback(self): + """Single-output feedback (B, *spatial, T) is unsqueezed and concatenated.""" + x = torch.randn(2, 4, 6, 3, 5) + fb = torch.randn(2, 4, 6, 3) + result = inject_feedback_channel(x, fb) + assert result.shape == (2, 4, 6, 3, 6) + assert torch.equal(result[:, :, :, :, :5], x) + assert torch.equal(result[:, :, :, :, 5], fb) + + def test_multi_output_feedback(self): + """Multi-output feedback (B, *spatial, T, C_out) concatenated directly.""" + x = torch.randn(2, 4, 6, 3, 5) + fb = torch.randn(2, 4, 6, 3, 2) + result = inject_feedback_channel(x, fb) + assert result.shape == (2, 4, 6, 3, 7) + + +# --------------------------------------------------------------------------- +# Tests: add_noise +# --------------------------------------------------------------------------- + + +class TestAddNoise: + """Tests for add_noise.""" + + def test_zero_std(self): + """noise_std == 0 returns tensor unchanged.""" + t = torch.randn(3, 4) + result = add_noise(t, 0.0) + assert torch.equal(result, t) + + def test_positive_std(self): + """Positive noise_std adds noise.""" + t = torch.randn(3, 4) + result = add_noise(t, 0.1) + assert not torch.equal(result, t) + assert torch.allclose(result, t, atol=1.0) + + def test_negative_std(self): + """Negative noise_std treated as disabled.""" + t = torch.randn(3, 4) + result = add_noise(t, -0.5) + assert torch.equal(result, t) + + +# --------------------------------------------------------------------------- +# Tests: compute_unroll_steps +# --------------------------------------------------------------------------- + + +class TestComputeUnrollSteps: + """Tests for compute_unroll_steps.""" + + def test_start(self): + """At start_epoch, returns 1.""" + result = compute_unroll_steps( + epoch=10, start_epoch=10, total_epochs=100, max_unroll=10 + ) + assert result == 1 + + def test_end(self): + """At start + total, returns max_unroll.""" + result = compute_unroll_steps( + epoch=110, start_epoch=10, total_epochs=100, max_unroll=10 + ) + assert result == 10 + + def test_midpoint(self): + """Midpoint returns approximately half of max_unroll.""" + result = compute_unroll_steps( + epoch=60, start_epoch=10, total_epochs=100, max_unroll=10 + ) + assert 4 <= result <= 6 + + def test_zero_stage(self): + """total_epochs == 0 returns max_unroll immediately.""" + result = compute_unroll_steps( + epoch=0, start_epoch=0, total_epochs=0, max_unroll=10 + ) + assert result == 10 + + def test_beyond_end(self): + """Epoch past the end clamps to max_unroll.""" + result = compute_unroll_steps( + epoch=500, start_epoch=10, total_epochs=100, max_unroll=10 + ) + assert result == 10 + + def test_curriculum_end_exact(self): + """epoch=121, start=21, total=100 gives exactly max_unroll.""" + result = compute_unroll_steps( + epoch=121, start_epoch=21, total_epochs=100, max_unroll=10 + ) + assert result == 10 + + +# --------------------------------------------------------------------------- +# Tests: _iter_windows +# --------------------------------------------------------------------------- + + +class TestIterWindows: + """Tests for _iter_windows.""" + + def test_non_overlapping(self): + """stride=K produces non-overlapping windows.""" + windows = list(_iter_windows(total_T=16, L=1, K=3, stride=3)) + assert len(windows) == 5 + for t0, ts, ak in windows: + assert ts == t0 + 1 + assert ak == 3 + + def test_stride_1(self): + """stride=1 produces overlapping windows with truncation at the end.""" + windows = list(_iter_windows(total_T=6, L=1, K=3, stride=1)) + assert len(windows) == 5 + assert windows[0] == (0, 1, 3) + assert windows[-1] == (4, 5, 1) + + def test_stride_equals_K(self): + """Explicit stride=K same as non-overlapping.""" + w1 = list(_iter_windows(total_T=16, L=1, K=3, stride=3)) + w2 = list(_iter_windows(total_T=16, L=1, K=3, stride=3)) + assert w1 == w2 + + def test_last_window_truncated(self): + """Remaining timesteps < K produces a truncated window.""" + windows = list(_iter_windows(total_T=10, L=1, K=4, stride=4)) + assert len(windows) == 3 + assert windows[-1][2] == 1 + + +# --------------------------------------------------------------------------- +# Tests: get_training_stage +# --------------------------------------------------------------------------- + + +class TestGetTrainingStage: + """Tests for get_training_stage.""" + + def test_teacher_forcing_stage(self): + """Verify early epoch maps to teacher forcing stage.""" + assert ( + get_training_stage(epoch=5, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "teacher_forcing" + ) + + def test_pushforward_stage(self): + """Verify mid-range epoch maps to pushforward stage.""" + assert ( + get_training_stage(epoch=25, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "pushforward" + ) + + def test_rollout_stage(self): + """Verify late epoch maps to rollout stage.""" + assert ( + get_training_stage(epoch=55, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "rollout" + ) + + def test_no_pushforward(self): + """pf_epochs=0 jumps directly from teacher forcing to rollout.""" + assert ( + get_training_stage(epoch=25, tf_epochs=20, pf_epochs=0, ro_epochs=50) + == "rollout" + ) + + def test_no_rollout(self): + """ro_epochs=0 with large pf_epochs keeps epoch in pushforward.""" + assert ( + get_training_stage(epoch=500, tf_epochs=20, pf_epochs=10000, ro_epochs=0) + == "pushforward" + ) + + def test_tf_pf_boundary(self): + """Exact transition from teacher forcing to pushforward.""" + assert ( + get_training_stage(epoch=19, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "teacher_forcing" + ) + assert ( + get_training_stage(epoch=20, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "pushforward" + ) + + def test_pf_rollout_boundary(self): + """Exact transition from pushforward to rollout.""" + assert ( + get_training_stage(epoch=49, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "pushforward" + ) + assert ( + get_training_stage(epoch=50, tf_epochs=20, pf_epochs=30, ro_epochs=50) + == "rollout" + ) + + +# --------------------------------------------------------------------------- +# Tests: _model_accepts_target_times, _model_accepts_x_branch2 +# --------------------------------------------------------------------------- + + +class TestModelIntrospection: + """Tests for _model_accepts_target_times and _model_accepts_x_branch2.""" + + def test_model_with_target_times(self): + """Verify model with target_times parameter is detected.""" + assert _model_accepts_target_times(DummyModel3D()) is True + + def test_model_without_target_times(self): + """Verify model without target_times parameter returns False.""" + assert _model_accepts_target_times(DummyModelNoTargetTimes()) is False + + def test_tno_accepts_x_branch2(self): + """Verify TNO model with x_branch2 parameter is detected.""" + assert _model_accepts_x_branch2(DummyTNOModel3D()) is True + + def test_standard_model_no_x_branch2(self): + """Verify standard model without x_branch2 returns False.""" + assert _model_accepts_x_branch2(DummyModel3D()) is False + + def test_no_target_times_no_x_branch2(self): + """Verify model with neither target_times nor x_branch2 returns False.""" + assert _model_accepts_x_branch2(DummyModelNoTargetTimes()) is False + + +# --------------------------------------------------------------------------- +# Tests: teacher_forcing_step +# --------------------------------------------------------------------------- + + +class TestTeacherForcing: + """Tests for teacher_forcing_step.""" + + def test_3d_returns_scalar_loss(self): + """Verify teacher forcing returns a scalar loss for 3D data.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step(model, inputs, targets, dummy_loss, L=1, K=3) + assert loss.dim() == 0 + + def test_4d_returns_scalar_loss(self): + """Verify teacher forcing returns a scalar loss for 4D data.""" + model = DummyModel4D() + inputs = torch.randn(1, 4, 6, 3, 16, 5) + targets = torch.randn(1, 4, 6, 3, 16) + loss = teacher_forcing_step(model, inputs, targets, dummy_loss, L=1, K=3) + assert loss.dim() == 0 + + def test_k_equals_1(self): + """Verify teacher forcing works with K=1.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step(model, inputs, targets, dummy_loss, L=1, K=1) + assert loss.dim() == 0 + + def test_stride_1(self): + """stride=1 processes overlapping windows without error.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + stride=1, + ) + assert loss.dim() == 0 + + def test_tno(self): + """Verify teacher forcing runs with TNO model.""" + model = DummyTNOModel3D() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + is_tno=True, + ) + assert loss.dim() == 0 + + def test_noise(self): + """Non-zero noise_std does not crash.""" + model = DummyTNOModel3D() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + is_tno=True, + noise_std=0.1, + ) + assert loss.dim() == 0 + + def test_feedback_channel(self): + """feedback_channel injects an extra channel from GT target.""" + model = DummyFeedbackModel3D(base_channels=5) + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + feedback_channel=0, + ) + assert loss.dim() == 0 + + def test_no_target_times_model(self): + """Model without target_times still works (K must equal L).""" + model = DummyModelNoTargetTimes() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=1, + ) + assert loss.dim() == 0 + + +# --------------------------------------------------------------------------- +# Tests: live_rollout_step (pushforward / live-gradient rollout) +# --------------------------------------------------------------------------- + + +class TestLiveRolloutStep: + """Tests for live_rollout_step.""" + + def test_3d(self): + """Verify live rollout returns a scalar loss for 3D data.""" + model = DummyModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=2, + ) + assert loss.dim() == 0 + + def test_4d(self): + """Verify live rollout returns a scalar loss for 4D data.""" + model = DummyModel4D() + model.train() + inputs = torch.randn(1, 4, 6, 3, 16, 5) + targets = torch.randn(1, 4, 6, 3, 16) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=2, + ) + assert loss.dim() == 0 + + def test_tno_live_gradients(self): + """TNO pushforward runs without error.""" + model = DummyTNOModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=2, + is_tno=True, + ) + assert loss.dim() == 0 + + def test_unroll_1(self): + """max_steps=1 processes one step per group.""" + model = DummyModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=1, + ) + assert loss.dim() == 0 + + def test_large_unroll(self): + """Large max_steps covers all windows in a single group.""" + model = DummyModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=100, + ) + assert loss.dim() == 0 + + def test_feedback_channel(self): + """Feedback channel with pushforward.""" + model = DummyFeedbackModel3D(base_channels=5) + model.train() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=2, + feedback_channel=0, + ) + assert loss.dim() == 0 + + def test_gradient_flow(self): + """live_rollout_step calls backward internally; model params get grads.""" + model = DummyIdentityModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 16, 5) + targets = torch.randn(2, 4, 6, 16) + model.zero_grad() + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + max_steps=3, + ) + assert loss.dim() == 0 + assert not loss.requires_grad + assert model.scale.grad is not None + assert model.scale.grad.abs().item() > 0 + + +# --------------------------------------------------------------------------- +# Tests: rollout_step +# --------------------------------------------------------------------------- + + +class TestRolloutStep: + """Tests for rollout_step.""" + + def test_3d_returns_scalar_loss(self): + """Verify rollout step returns a scalar loss for 3D data.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + use_checkpointing=False, + ) + assert loss.dim() == 0 + + def test_4d_with_checkpointing(self): + """Verify rollout step with gradient checkpointing for 4D data.""" + model = DummyModel4D() + inputs = torch.randn(1, 4, 6, 3, 20, 5) + targets = torch.randn(1, 4, 6, 3, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + use_checkpointing=True, + ) + assert loss.dim() == 0 + + def test_tno(self): + """Verify rollout step runs with TNO model.""" + model = DummyTNOModel3D() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + use_checkpointing=False, + is_tno=True, + ) + assert loss.dim() == 0 + + def test_feedback_channel(self): + """Feedback channel with rollout.""" + model = DummyFeedbackModel3D(base_channels=5) + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + use_checkpointing=False, + feedback_channel=0, + ) + assert loss.dim() == 0 + + def test_noise(self): + """noise_std > 0 does not crash.""" + model = DummyTNOModel3D() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + use_checkpointing=False, + is_tno=True, + noise_std=0.05, + ) + assert loss.dim() == 0 + + def test_stride_1(self): + """stride=1 with rollout processes overlapping windows.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 12, 5) + targets = torch.randn(2, 4, 6, 12) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + stride=1, + use_checkpointing=False, + ) + assert loss.dim() == 0 + + +# --------------------------------------------------------------------------- +# Tests: ar_validate_full_rollout +# --------------------------------------------------------------------------- + + +class TestFullRollout: + """Tests for ar_validate_full_rollout.""" + + def test_3d_output_shape(self): + """Verify full rollout output shape matches target for 3D data.""" + model = DummyModel3D() + inputs = torch.randn(1, 4, 6, 16, 5) + targets = torch.randn(1, 4, 6, 16) + pred = ar_validate_full_rollout(model, inputs, targets, L=1, K=3) + assert pred.shape == targets.shape + + def test_4d_output_shape(self): + """Verify full rollout output shape matches target for 4D data.""" + model = DummyModel4D() + inputs = torch.randn(1, 4, 6, 3, 20, 5) + targets = torch.randn(1, 4, 6, 3, 20) + pred = ar_validate_full_rollout(model, inputs, targets, L=1, K=3) + assert pred.shape == targets.shape + + def test_prefix_matches_gt(self): + """First L timesteps are copied from ground truth.""" + model = DummyModel3D() + inputs = torch.randn(1, 4, 6, 16, 5) + targets = torch.randn(1, 4, 6, 16) + pred = ar_validate_full_rollout(model, inputs, targets, L=1, K=3) + assert torch.equal(pred[:, :, :, :1], targets[:, :, :, :1]) + + def test_tno_3d(self): + """Verify full rollout with TNO model for 3D data.""" + model = DummyTNOModel3D() + inputs = torch.randn(1, 4, 6, 16, 5) + targets = torch.randn(1, 4, 6, 16) + pred = ar_validate_full_rollout( + model, + inputs, + targets, + L=1, + K=3, + is_tno=True, + ) + assert pred.shape == targets.shape + + def test_tno_4d(self): + """Verify full rollout with TNO model for 4D data.""" + model = DummyTNOModel4D() + inputs = torch.randn(1, 4, 6, 3, 16, 5) + targets = torch.randn(1, 4, 6, 3, 16) + pred = ar_validate_full_rollout( + model, + inputs, + targets, + L=1, + K=3, + is_tno=True, + ) + assert pred.shape == targets.shape + + def test_k_equals_1(self): + """Verify full rollout works with K=1.""" + model = DummyModel3D() + inputs = torch.randn(1, 4, 6, 10, 5) + targets = torch.randn(1, 4, 6, 10) + pred = ar_validate_full_rollout(model, inputs, targets, L=1, K=1) + assert pred.shape == targets.shape + + def test_feedback_channel(self): + """Feedback channel in full rollout validation.""" + model = DummyFeedbackModel3D(base_channels=5) + inputs = torch.randn(1, 4, 6, 16, 5) + targets = torch.randn(1, 4, 6, 16) + pred = ar_validate_full_rollout( + model, + inputs, + targets, + L=1, + K=3, + feedback_channel=0, + ) + assert pred.shape == targets.shape + + +# --------------------------------------------------------------------------- +# Tests: _build_branch2 +# --------------------------------------------------------------------------- + + +class TestBuildBranch2: + """Tests for _build_branch2.""" + + def test_non_tno_returns_none(self): + """Verify _build_branch2 returns None for non-TNO models.""" + targets = torch.randn(2, 4, 6, 16) + t_ax = _time_axis_target(targets) + result = _build_branch2(targets, None, 0, 1, t_ax, is_tno=False) + assert result is None + + def test_first_window_uses_gt(self): + """prev_pred=None returns GT slice.""" + targets = torch.randn(2, 4, 6, 16) + t_ax = _time_axis_target(targets) + result = _build_branch2(targets, None, 0, 3, t_ax, is_tno=True) + expected = slice_target_window(targets, 0, 3) + assert torch.equal(result, expected) + + def test_subsequent_uses_prev_pred(self): + """With prev_pred, uses the prediction instead of GT.""" + targets = torch.randn(2, 4, 6, 16) + prev_pred = torch.randn(2, 4, 6, 3) + t_ax = _time_axis_target(targets) + result = _build_branch2(targets, prev_pred, 3, 3, t_ax, is_tno=True) + assert torch.equal(result, prev_pred) + + def test_noise_applied(self): + """noise_std > 0 perturbs the branch2 tensor (prev_pred path only).""" + targets = torch.randn(2, 4, 6, 16) + prev_pred = torch.randn(2, 4, 6, 3) + t_ax = _time_axis_target(targets) + b2_clean = _build_branch2( + targets, + prev_pred, + 3, + 3, + t_ax, + is_tno=True, + noise_std=0.0, + ) + b2_noisy = _build_branch2( + targets, + prev_pred, + 3, + 3, + t_ax, + is_tno=True, + noise_std=0.1, + ) + assert torch.equal(b2_clean, prev_pred) + assert not torch.equal(b2_noisy, prev_pred) + + +# --------------------------------------------------------------------------- +# Tests: L=2 K=4 configurations +# --------------------------------------------------------------------------- + + +class TestL2K4Configs: + """Tests for L=2, K=4 configuration.""" + + def test_teacher_forcing(self): + """Verify teacher forcing works with L=2, K=4 configuration.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=2, + K=4, + ) + assert loss.dim() == 0 + + def test_pushforward(self): + """Verify pushforward works with L=2, K=4 configuration.""" + model = DummyModel3D() + model.train() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = live_rollout_step( + model, + inputs, + targets, + dummy_loss, + L=2, + K=4, + max_steps=3, + ) + assert loss.dim() == 0 + + def test_rollout(self): + """Verify rollout works with L=2, K=4 configuration.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 20, 5) + targets = torch.randn(2, 4, 6, 20) + loss = rollout_step( + model, + inputs, + targets, + dummy_loss, + L=2, + K=4, + use_checkpointing=False, + ) + assert loss.dim() == 0 + + def test_full_rollout(self): + """Verify full rollout works with L=2, K=4 configuration.""" + model = DummyModel3D() + inputs = torch.randn(1, 4, 6, 20, 5) + targets = torch.randn(1, 4, 6, 20) + pred = ar_validate_full_rollout(model, inputs, targets, L=2, K=4) + assert pred.shape == targets.shape + + +# --------------------------------------------------------------------------- +# Tests: edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Edge cases: trajectory too short, single window, empty rollout.""" + + def test_trajectory_too_short(self): + """total_T <= L yields no windows and zero loss.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 1, 5) + targets = torch.randn(2, 4, 6, 1) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + ) + assert loss.item() == 0.0 + + def test_single_window(self): + """total_T = L + K produces exactly one window.""" + model = DummyModel3D() + inputs = torch.randn(2, 4, 6, 4, 5) + targets = torch.randn(2, 4, 6, 4) + loss = teacher_forcing_step( + model, + inputs, + targets, + dummy_loss, + L=1, + K=3, + ) + assert loss.dim() == 0 + + def test_empty_rollout(self): + """Full rollout with too-short trajectory returns zeros_like.""" + model = DummyModel3D() + inputs = torch.randn(1, 4, 6, 1, 5) + targets = torch.randn(1, 4, 6, 1) + pred = ar_validate_full_rollout(model, inputs, targets, L=1, K=3) + assert pred.shape == targets.shape + assert torch.equal(pred, torch.zeros_like(targets)) + + +# =================================================================== +# Feedback noise injection +# =================================================================== + + +class TestFeedbackNoise: + """Tests for noise injection on feedback channel.""" + + def _make_model(self, in_ch=5, spatial_ndim=2): + class DummyModel(torch.nn.Module): + def __init__(self, in_channels): + super().__init__() + self.seen_channels = None + + def forward(self, x, **kwargs): + self.seen_channels = x.shape[-1] + spatial = x.shape[1:-2] if x.dim() == 5 else x.shape[1:-1] + T = kwargs.get("target_times", torch.zeros(1)).shape[0] + return torch.zeros(x.shape[0], *spatial, T) + + return DummyModel(in_ch) + + def test_noise_applied_to_feedback_in_rollout(self): + """Rollout with feedback_channel and noise_std > 0 should produce + different results across runs (noise is stochastic).""" + B, H, W, T, C = 1, 4, 6, 6, 5 + inputs = torch.randn(B, H, W, T, C) + targets = torch.randn(B, H, W, T) + model = self._make_model(C + 1) + + def loss_fn(p, t, i, spatial_mask=None): + """Compute MSE loss for testing.""" + return (p - t).pow(2).mean() + + torch.manual_seed(0) + loss1 = rollout_step( + model, + inputs, + targets, + loss_fn, + L=1, + K=2, + noise_std=0.5, + feedback_channel=1, + ) + torch.manual_seed(1) + loss2 = rollout_step( + model, + inputs, + targets, + loss_fn, + L=1, + K=2, + noise_std=0.5, + feedback_channel=1, + ) + # With different seeds, noise differs so losses should differ + # (unless model output is trivially zero, which it is here, + # but the key test is that no error occurs) + assert not torch.isnan(loss1) + assert not torch.isnan(loss2) + + def test_no_noise_in_validation(self): + """ar_validate_full_rollout should be deterministic (no noise).""" + B, H, W, T, C = 1, 4, 6, 6, 5 + inputs = torch.randn(B, H, W, T, C) + targets = torch.randn(B, H, W, T) + model = self._make_model(C + 1) + + torch.manual_seed(42) + pred1 = ar_validate_full_rollout( + model, + inputs, + targets, + L=1, + K=2, + feedback_channel=1, + ) + torch.manual_seed(99) + pred2 = ar_validate_full_rollout( + model, + inputs, + targets, + L=1, + K=2, + feedback_channel=1, + ) + assert torch.equal(pred1, pred2), "Validation should be deterministic" + + def test_zero_noise_no_effect(self): + """noise_std=0 should be a no-op.""" + B, H, W, T, C = 1, 4, 6, 6, 5 + inputs = torch.randn(B, H, W, T, C) + targets = torch.randn(B, H, W, T) + model = self._make_model(C + 1) + + def loss_fn(p, t, i, spatial_mask=None): + """Compute MSE loss for testing.""" + return (p - t).pow(2).mean() + + torch.manual_seed(42) + loss1 = teacher_forcing_step( + model, + inputs, + targets, + loss_fn, + L=1, + K=2, + noise_std=0.0, + feedback_channel=1, + ) + torch.manual_seed(42) + loss2 = teacher_forcing_step( + model, + inputs, + targets, + loss_fn, + L=1, + K=2, + noise_std=0.0, + feedback_channel=1, + ) + assert torch.isclose(loss1, loss2, atol=1e-6) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_checkpoint.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_checkpoint.py new file mode 100644 index 0000000000..45140f0438 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_checkpoint.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for checkpoint utilities.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from utils.checkpoint import build_model_from_config, load_checkpoint, save_checkpoint + + +class TestBuildModelFromConfig: + """Tests for build_model_from_config.""" + + def test_xdeeponet_u_deeponet_3d(self): + """Verify building a 3D U-DeepONet model from config.""" + cfg = { + "model_type": "xdeeponet", + "dimensions": "4d", + "variant": "u_deeponet", + "width": 32, + "padding": 8, + "branch1_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "trunk_config": { + "input_type": "time", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "tanh", + }, + "decoder_type": "mlp", + "decoder_width": 32, + "decoder_layers": 1, + "decoder_activation_fn": "relu", + } + model, name = build_model_from_config(cfg) + assert model is not None + assert "deeponet3d" in name + + def test_xdeeponet_tno_3d(self): + """Verify building a 3D TNO model from config.""" + cfg = { + "model_type": "xdeeponet", + "dimensions": "4d", + "variant": "tno", + "width": 32, + "padding": 8, + "branch1_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "branch2_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "trunk_config": { + "input_type": "time", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "tanh", + }, + "decoder_type": "mlp", + "decoder_width": 32, + "decoder_layers": 1, + } + model, name = build_model_from_config(cfg) + assert "tno" in name + + def test_xdeeponet_2d(self): + """Verify building a 2D U-DeepONet model from config.""" + cfg = { + "model_type": "xdeeponet", + "dimensions": "3d", + "variant": "u_deeponet", + "width": 32, + "padding": 8, + "branch1_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "trunk_config": { + "input_type": "time", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "tanh", + }, + } + model, name = build_model_from_config(cfg) + assert "deeponet_" in name # 2D has no "3d" in name + + def test_xfno_4d(self): + """Verify building a 4D FNO model from config.""" + cfg = { + "model_type": "xfno", + "dimensions": "4d", + "in_channels": 11, + "out_channels": 1, + "width": 16, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "modes4": 3, + "num_fno_layers": 2, + "padding": 8, + } + model, name = build_model_from_config(cfg) + assert "fno4d" in name + + def test_xfno_3d(self): + """Verify building a 3D FNO model from config.""" + cfg = { + "model_type": "xfno", + "dimensions": "3d", + "in_channels": 12, + "out_channels": 1, + "width": 16, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "num_fno_layers": 2, + "padding": 8, + } + model, name = build_model_from_config(cfg) + assert model is not None + + def test_unknown_type_raises(self): + """Verify unknown model_type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model_type"): + build_model_from_config({"model_type": "invalid"}) + + +class TestSaveLoadCheckpoint: + """Tests for save/load round-trip.""" + + def test_round_trip(self, tmp_path): + """Verify save/load checkpoint round-trip preserves all fields.""" + cfg = { + "model_type": "xdeeponet", + "dimensions": "4d", + "variant": "u_deeponet", + "width": 16, + "padding": 8, + "branch1_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "trunk_config": { + "input_type": "time", + "hidden_width": 16, + "num_layers": 2, + "activation_fn": "tanh", + }, + } + model, _ = build_model_from_config(cfg) + + # Dummy forward to init lazy modules + x = torch.randn(1, 8, 16, 8, 2, 5) + with torch.no_grad(): + model(x) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + + path = tmp_path / "test_ckpt.pth" + save_checkpoint( + path=path, + model=model, + epoch=42, + val_loss=0.05, + metric_key="val_rmse", + metric_value=0.1, + model_config=cfg, + optimizer=optimizer, + scheduler=scheduler, + ) + + ckpt = load_checkpoint(path) + assert ckpt["epoch"] == 42 + assert abs(ckpt["val_loss"] - 0.05) < 1e-8 + assert "model_state_dict" in ckpt + assert "optimizer_state_dict" in ckpt + assert "scheduler_state_dict" in ckpt + assert ckpt["model_config"] == cfg + + def test_rebuild_from_checkpoint(self, tmp_path): + """Save a model, then rebuild from checkpoint config and load weights.""" + cfg = { + "model_type": "xdeeponet", + "dimensions": "4d", + "variant": "u_deeponet", + "width": 16, + "padding": 8, + "branch1_config": { + "encoder": "spatial", + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, + "trunk_config": { + "input_type": "time", + "hidden_width": 16, + "num_layers": 2, + "activation_fn": "tanh", + }, + } + model1, _ = build_model_from_config(cfg) + x = torch.randn(1, 8, 16, 8, 2, 5) + with torch.no_grad(): + model1(x) + + path = tmp_path / "test_ckpt.pth" + save_checkpoint( + path=path, + model=model1, + epoch=10, + val_loss=0.1, + metric_key="val_rmse", + metric_value=0.2, + model_config=cfg, + ) + + # Rebuild from checkpoint + ckpt = load_checkpoint(path) + model2, _ = build_model_from_config(ckpt["model_config"]) + with torch.no_grad(): + model2(x) # init lazy + model2.load_state_dict(ckpt["model_state_dict"]) + + # Verify weights match + for (n1, p1), (n2, p2) in zip( + model1.named_parameters(), model2.named_parameters() + ): + assert n1 == n2 + assert torch.equal(p1, p2), f"Mismatch in {n1}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_data_validation.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_data_validation.py new file mode 100644 index 0000000000..452a81b18b --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_data_validation.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for data validation utilities.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from data.validation import ( + detect_dimensions, + get_dimension_info, + print_validation_summary, + validate_batch_dimensions, + validate_sample_dimensions, +) + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +def batch_3d_input(): + """3D batched input tensor: (B, H, W, T, C)""" + return torch.randn(4, 32, 64, 16, 12) + + +@pytest.fixture +def batch_3d_target(): + """3D batched target tensor: (B, H, W, T)""" + return torch.randn(4, 32, 64, 16) + + +@pytest.fixture +def batch_4d_input(): + """4D batched input tensor: (B, X, Y, Z, T, C)""" + return torch.randn(2, 16, 24, 12, 10, 8) + + +@pytest.fixture +def batch_4d_target(): + """4D batched target tensor: (B, X, Y, Z, T)""" + return torch.randn(2, 16, 24, 12, 10) + + +@pytest.fixture +def sample_3d_input(): + """3D single sample input: (H, W, T, C)""" + return torch.randn(32, 64, 16, 12) + + +@pytest.fixture +def sample_3d_target(): + """3D single sample target: (H, W, T)""" + return torch.randn(32, 64, 16) + + +@pytest.fixture +def sample_4d_input(): + """4D single sample input: (X, Y, Z, T, C)""" + return torch.randn(16, 24, 12, 10, 8) + + +@pytest.fixture +def sample_4d_target(): + """4D single sample target: (X, Y, Z, T)""" + return torch.randn(16, 24, 12, 10) + + +# ============================================================================= +# Test detect_dimensions +# ============================================================================= + + +class TestDetectDimensions: + """Tests for detect_dimensions function.""" + + def test_detect_3d_from_batch(self, batch_3d_input): + """Test detecting 3D from batched input.""" + assert detect_dimensions(batch_3d_input) == "3d" + + def test_detect_4d_from_batch(self, batch_4d_input): + """Test detecting 4D from batched input.""" + assert detect_dimensions(batch_4d_input) == "4d" + + def test_detect_invalid_dimensions(self): + """Test error for unsupported dimensions.""" + # 3D tensor (not 5D or 6D) + invalid_tensor = torch.randn(10, 32, 64) + + with pytest.raises(ValueError) as excinfo: + detect_dimensions(invalid_tensor) + + assert "Cannot detect dimensions" in str(excinfo.value) + + def test_detect_from_7d_tensor(self): + """Test error for too many dimensions.""" + tensor_7d = torch.randn(2, 8, 8, 8, 8, 8, 4) + + with pytest.raises(ValueError): + detect_dimensions(tensor_7d) + + +# ============================================================================= +# Test validate_batch_dimensions +# ============================================================================= + + +class TestValidateBatchDimensions: + """Tests for validate_batch_dimensions function.""" + + def test_validate_3d_batch(self, batch_3d_input, batch_3d_target): + """Test validation of 3D batched data.""" + result = validate_batch_dimensions(batch_3d_input, batch_3d_target, "pressure") + + assert result["dimensions"] == "3d" + assert result["batch_size"] == 4 + assert result["spatial_shape"] == (32, 64) + assert result["time_steps"] == 16 + assert result["num_channels"] == 12 + + def test_validate_4d_batch(self, batch_4d_input, batch_4d_target): + """Test validation of 4D batched data.""" + result = validate_batch_dimensions( + batch_4d_input, batch_4d_target, "saturation" + ) + + assert result["dimensions"] == "4d" + assert result["batch_size"] == 2 + assert result["spatial_shape"] == (16, 24, 12) + assert result["time_steps"] == 10 + assert result["num_channels"] == 8 + + def test_batch_size_mismatch(self, batch_3d_input): + """Test error when batch sizes don't match.""" + wrong_target = torch.randn(8, 32, 64, 16) # Different batch size + + with pytest.raises(ValueError) as excinfo: + validate_batch_dimensions(batch_3d_input, wrong_target) + + assert "Batch size mismatch" in str(excinfo.value) + + def test_spatial_mismatch(self, batch_3d_input): + """Test error when spatial dimensions don't match.""" + wrong_target = torch.randn(4, 48, 64, 16) # Different H + + with pytest.raises(ValueError) as excinfo: + validate_batch_dimensions(batch_3d_input, wrong_target) + + assert "mismatch" in str(excinfo.value).lower() + + def test_wrong_input_ndim(self): + """Test error for wrong input dimensions.""" + wrong_input = torch.randn(4, 32, 64) # 3D instead of 5D/6D + target = torch.randn(4, 32, 64) + + with pytest.raises(ValueError) as excinfo: + validate_batch_dimensions(wrong_input, target) + + assert "Invalid input shape" in str(excinfo.value) + + def test_wrong_target_ndim_for_3d(self, batch_3d_input): + """Test error when target has wrong dimensions for 3D data.""" + wrong_target = torch.randn(4, 32, 64, 16, 1) # 5D instead of 4D + + with pytest.raises(ValueError) as excinfo: + validate_batch_dimensions(batch_3d_input, wrong_target) + + assert "Invalid target shape" in str(excinfo.value) + + +# ============================================================================= +# Test validate_sample_dimensions +# ============================================================================= + + +class TestValidateSampleDimensions: + """Tests for validate_sample_dimensions function.""" + + def test_validate_3d_sample(self, sample_3d_input, sample_3d_target): + """Test validation of 3D single sample.""" + result = validate_sample_dimensions(sample_3d_input, sample_3d_target) + + assert result["dimensions"] == "3d" + assert result["spatial_shape"] == (32, 64) + assert result["time_steps"] == 16 + assert result["num_channels"] == 12 + + def test_validate_4d_sample(self, sample_4d_input, sample_4d_target): + """Test validation of 4D single sample.""" + result = validate_sample_dimensions(sample_4d_input, sample_4d_target) + + assert result["dimensions"] == "4d" + assert result["spatial_shape"] == (16, 24, 12) + assert result["time_steps"] == 10 + assert result["num_channels"] == 8 + + def test_sample_spatial_mismatch(self, sample_3d_input): + """Test error when sample spatial dimensions don't match.""" + wrong_target = torch.randn(48, 64, 16) # Different H + + with pytest.raises(ValueError) as excinfo: + validate_sample_dimensions(sample_3d_input, wrong_target) + + assert "mismatch" in str(excinfo.value).lower() + + def test_wrong_sample_input_ndim(self): + """Test error for wrong sample input dimensions.""" + wrong_input = torch.randn(32, 64, 16) # 3D instead of 4D/5D + target = torch.randn(32, 64, 16) + + with pytest.raises(ValueError) as excinfo: + validate_sample_dimensions(wrong_input, target) + + assert "Invalid input shape" in str(excinfo.value) + + +# ============================================================================= +# Test print_validation_summary +# ============================================================================= + + +class TestPrintValidationSummary: + """Tests for print_validation_summary function.""" + + def test_print_3d_batch_summary(self, capsys): + """Test printing summary for 3D batched data.""" + input_shape = (4, 32, 64, 16, 12) + target_shape = (4, 32, 64, 16) + + print_validation_summary(input_shape, target_shape, "pressure", is_batch=True) + + captured = capsys.readouterr() + assert "validation passed" in captured.out.lower() + assert "3D" in captured.out + assert "32" in captured.out # H + assert "64" in captured.out # W + + def test_print_4d_batch_summary(self, capsys): + """Test printing summary for 4D batched data.""" + input_shape = (2, 16, 24, 12, 10, 8) + target_shape = (2, 16, 24, 12, 10) + + print_validation_summary(input_shape, target_shape, "saturation", is_batch=True) + + captured = capsys.readouterr() + assert "4D" in captured.out + assert "X" in captured.out or "Y" in captured.out or "Z" in captured.out + + def test_print_3d_sample_summary(self, capsys): + """Test printing summary for 3D single sample.""" + input_shape = (32, 64, 16, 12) + target_shape = (32, 64, 16) + + print_validation_summary(input_shape, target_shape, "pressure", is_batch=False) + + captured = capsys.readouterr() + assert "validation passed" in captured.out.lower() + + def test_print_with_logger(self): + """Test printing with custom logger.""" + + class MockLogger: + """Minimal logger stub for testing print_validation_summary.""" + + def __init__(self): + self.messages = [] + + def success(self, msg): + """Record a success-level message.""" + self.messages.append(("success", msg)) + + def info(self, msg): + """Record an info-level message.""" + self.messages.append(("info", msg)) + + logger = MockLogger() + input_shape = (4, 32, 64, 16, 12) + target_shape = (4, 32, 64, 16) + + print_validation_summary( + input_shape, target_shape, "pressure", is_batch=True, logger=logger + ) + + assert len(logger.messages) > 0 + assert any("success" in msg[0] for msg in logger.messages) + + +# ============================================================================= +# Test get_dimension_info +# ============================================================================= + + +class TestGetDimensionInfo: + """Tests for get_dimension_info function.""" + + def test_get_info_3d_batch_input(self, batch_3d_input): + """Test getting info from 3D batched input.""" + info = get_dimension_info(batch_3d_input, is_batch=True) + + assert info["dimensions"] == "3d" + assert info["batch_size"] == 4 + assert info["spatial_shape"] == (32, 64) + assert info["time_steps"] == 16 + assert info["num_channels"] == 12 + + def test_get_info_4d_batch_input(self, batch_4d_input): + """Test getting info from 4D batched input.""" + info = get_dimension_info(batch_4d_input, is_batch=True) + + assert info["dimensions"] == "4d" + assert info["batch_size"] == 2 + assert info["spatial_shape"] == (16, 24, 12) + assert info["time_steps"] == 10 + assert info["num_channels"] == 8 + + def test_get_info_3d_sample(self, sample_3d_input): + """Test getting info from 3D single sample.""" + info = get_dimension_info(sample_3d_input, is_batch=False) + + assert info["dimensions"] == "3d" + assert "batch_size" not in info + assert info["spatial_shape"] == (32, 64) + + def test_get_info_4d_sample(self, sample_4d_input): + """Test getting info from 4D single sample.""" + info = get_dimension_info(sample_4d_input, is_batch=False) + + assert info["dimensions"] == "4d" + assert info["spatial_shape"] == (16, 24, 12) + + def test_get_info_invalid_tensor(self): + """Test error for invalid tensor dimensions.""" + invalid_tensor = torch.randn(10, 32) # 2D tensor + + with pytest.raises(ValueError): + get_dimension_info(invalid_tensor, is_batch=True) + + +# ============================================================================= +# Test Integration with Dataset +# ============================================================================= + + +class TestValidationIntegration: + """Integration tests with actual data tensors.""" + + def test_full_validation_pipeline_3d(self, batch_3d_input, batch_3d_target): + """Test full validation pipeline for 3D data.""" + # Detect dimensions + dims = detect_dimensions(batch_3d_input) + assert dims == "3d" + + # Validate + info = validate_batch_dimensions(batch_3d_input, batch_3d_target) + assert info["dimensions"] == "3d" + + # Get detailed info + detailed = get_dimension_info(batch_3d_input, is_batch=True) + assert detailed["dimensions"] == info["dimensions"] + assert detailed["batch_size"] == info["batch_size"] + + def test_full_validation_pipeline_4d(self, batch_4d_input, batch_4d_target): + """Test full validation pipeline for 4D data.""" + dims = detect_dimensions(batch_4d_input) + assert dims == "4d" + + info = validate_batch_dimensions(batch_4d_input, batch_4d_target) + assert info["dimensions"] == "4d" + + detailed = get_dimension_info(batch_4d_input, is_batch=True) + assert detailed["dimensions"] == "4d" + + def test_validation_preserves_tensor_data(self, batch_3d_input, batch_3d_target): + """Test that validation doesn't modify the tensors.""" + original_input = batch_3d_input.clone() + original_target = batch_3d_target.clone() + + validate_batch_dimensions(batch_3d_input, batch_3d_target) + + assert torch.equal(batch_3d_input, original_input) + assert torch.equal(batch_3d_target, original_target) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_dataset.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_dataset.py new file mode 100644 index 0000000000..2c9eaa4ee8 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_dataset.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the unified ReservoirDataset and data loading utilities.""" + +import sys +import tempfile +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from data.dataloader import ( + ReservoirDataset, + collate_fn, + create_dataloaders, + get_dataset_info, +) + +# ============================================================================= +# Fixtures for Creating Test Data +# ============================================================================= + + +@pytest.fixture +def temp_data_dir(): + """Create a temporary directory for test data.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_3d_data(temp_data_dir): + """Create mock 3D dataset files (CO2 format).""" + # 3D: (N, H, W, T, C) input, (N, H, W, T) output + N, H, W, T, C = 10, 32, 64, 16, 12 + + for mode in ["train", "val", "test"]: + input_data = torch.randn(N, H, W, T, C) + output_data = torch.randn(N, H, W, T) + + torch.save(input_data, temp_data_dir / f"dP_{mode}_a.pt") + torch.save(output_data, temp_data_dir / f"dP_{mode}_u.pt") + + return temp_data_dir, {"N": N, "H": H, "W": W, "T": T, "C": C} + + +@pytest.fixture +def mock_4d_data(temp_data_dir): + """Create mock 4D dataset files (Norne format).""" + # 4D: (N, X, Y, Z, T, C) input, (N, X, Y, Z, T) output + N, X, Y, Z, T, C = 8, 16, 24, 12, 10, 8 + + for mode in ["train", "val", "test"]: + input_data = torch.randn(N, X, Y, Z, T, C) + output_data = torch.randn(N, X, Y, Z, T) + + torch.save(input_data, temp_data_dir / f"norne_{mode}_input.pt") + torch.save(output_data, temp_data_dir / f"norne_{mode}_output.pt") + + return temp_data_dir, {"N": N, "X": X, "Y": Y, "Z": Z, "T": T, "C": C} + + +@pytest.fixture +def mock_generic_data(temp_data_dir): + """Create mock generic dataset files.""" + N, H, W, T, C = 6, 24, 48, 12, 6 + + for mode in ["train", "val", "test"]: + input_data = torch.randn(N, H, W, T, C) + output_data = torch.randn(N, H, W, T) + + torch.save(input_data, temp_data_dir / f"{mode}_input.pt") + torch.save(output_data, temp_data_dir / f"{mode}_output.pt") + + return temp_data_dir, {"N": N, "H": H, "W": W, "T": T, "C": C} + + +# ============================================================================= +# Test ReservoirDataset - 3D Data +# ============================================================================= + + +class TestReservoirDataset3D: + """Tests for ReservoirDataset with 3D data.""" + + def test_load_3d_with_variable(self, mock_3d_data): + """Test loading 3D data using variable name (CO2 convention).""" + data_path, dims = mock_3d_data + + ds = ReservoirDataset(data_path, mode="train", variable="pressure") + + assert len(ds) == dims["N"] + assert ds.dimensions == "3d" + assert ds.spatial_shape == (dims["H"], dims["W"]) + assert ds.time_steps == dims["T"] + assert ds.num_channels == dims["C"] + + def test_getitem_3d(self, mock_3d_data): + """Test __getitem__ returns correct shapes for 3D data.""" + data_path, dims = mock_3d_data + + ds = ReservoirDataset( + data_path, mode="train", variable="pressure", normalize=False + ) + x, y = ds[0] + + # Single sample shapes (no batch dimension) + assert x.shape == (dims["H"], dims["W"], dims["T"], dims["C"]) + assert y.shape == (dims["H"], dims["W"], dims["T"]) + + def test_normalization_3d(self, mock_3d_data): + """Test normalization for 3D data.""" + data_path, dims = mock_3d_data + + ds_norm = ReservoirDataset( + data_path, mode="train", variable="pressure", normalize=True + ) + ds_raw = ReservoirDataset( + data_path, mode="train", variable="pressure", normalize=False + ) + + x_norm, y_norm = ds_norm[0] + x_raw, y_raw = ds_raw[0] + + # Normalized data should be different from raw + assert not torch.allclose(x_norm, x_raw) + + def test_normalization_stats_sharing(self, mock_3d_data): + """Test that normalization stats can be shared across datasets.""" + data_path, _ = mock_3d_data + + train_ds = ReservoirDataset( + data_path, mode="train", variable="pressure", normalize=True + ) + val_ds = ReservoirDataset( + data_path, mode="val", variable="pressure", normalize=True + ) + + # Share normalization from train to val + norm_stats = train_ds.get_normalization_stats() + val_ds.set_normalization(*norm_stats) + + # Check that val_ds now has train's normalization + assert torch.allclose(val_ds.input_mean, train_ds.input_mean) + assert torch.allclose(val_ds.input_std, train_ds.input_std) + + +# ============================================================================= +# Test ReservoirDataset - 4D Data +# ============================================================================= + + +class TestReservoirDataset4D: + """Tests for ReservoirDataset with 4D data.""" + + def test_load_4d_with_explicit_files(self, mock_4d_data): + """Test loading 4D data using explicit file patterns.""" + data_path, dims = mock_4d_data + + ds = ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + ) + + assert len(ds) == dims["N"] + assert ds.dimensions == "4d" + assert ds.spatial_shape == (dims["X"], dims["Y"], dims["Z"]) + assert ds.time_steps == dims["T"] + assert ds.num_channels == dims["C"] + + def test_getitem_4d(self, mock_4d_data): + """Test __getitem__ returns correct shapes for 4D data.""" + data_path, dims = mock_4d_data + + ds = ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + normalize=False, + ) + x, y = ds[0] + + # Single sample shapes (no batch dimension) + assert x.shape == (dims["X"], dims["Y"], dims["Z"], dims["T"], dims["C"]) + assert y.shape == (dims["X"], dims["Y"], dims["Z"], dims["T"]) + + def test_normalization_4d(self, mock_4d_data): + """Test normalization for 4D data.""" + data_path, dims = mock_4d_data + + ds_norm = ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + normalize=True, + ) + + # Check normalization stats have correct shape + # Input mean should be (1, 1, 1, 1, 1, C) for 6D data + assert ds_norm.input_mean.shape[-1] == dims["C"] + + +# ============================================================================= +# Test Dimension Validation +# ============================================================================= + + +class TestDimensionValidation: + """Tests for dimension validation against config.""" + + def test_expected_dimensions_match(self, mock_3d_data): + """Test that matching expected dimensions passes.""" + data_path, _ = mock_3d_data + + # Should not raise - dimensions match + ds = ReservoirDataset( + data_path, mode="train", variable="pressure", expected_dimensions="3d" + ) + assert ds.dimensions == "3d" + + def test_expected_dimensions_mismatch(self, mock_3d_data): + """Test that mismatched expected dimensions raises error.""" + data_path, _ = mock_3d_data + + # Should raise - expecting 4d but data is 3d + with pytest.raises(ValueError) as excinfo: + ReservoirDataset( + data_path, mode="train", variable="pressure", expected_dimensions="4d" + ) + + assert "Dimension mismatch" in str(excinfo.value) + assert "4d" in str(excinfo.value) + assert "3d" in str(excinfo.value) + + def test_expected_dimensions_4d(self, mock_4d_data): + """Test expected dimensions validation for 4D data.""" + data_path, _ = mock_4d_data + + # Correct expectation + ds = ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + expected_dimensions="4d", + ) + assert ds.dimensions == "4d" + + # Wrong expectation + with pytest.raises(ValueError): + ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + expected_dimensions="3d", + ) + + +# ============================================================================= +# Test Auto-Detection +# ============================================================================= + + +class TestAutoDetection: + """Tests for file auto-detection.""" + + def test_auto_detect_generic_files(self, mock_generic_data): + """Test auto-detection of generic file naming pattern.""" + data_path, dims = mock_generic_data + + ds = ReservoirDataset(data_path, mode="train") + + assert len(ds) == dims["N"] + assert ds.dimensions == "3d" + + def test_auto_detect_co2_files(self, mock_3d_data): + """Test auto-detection when variable is specified.""" + data_path, dims = mock_3d_data + + ds = ReservoirDataset(data_path, mode="train", variable="pressure") + assert len(ds) == dims["N"] + + def test_auto_detect_fails_gracefully(self, temp_data_dir): + """Test that auto-detection gives helpful error for unknown files.""" + # Create files with unusual naming + torch.save(torch.randn(5, 10, 10, 5, 3), temp_data_dir / "weird_input.pt") + torch.save(torch.randn(5, 10, 10, 5), temp_data_dir / "weird_output.pt") + + with pytest.raises(FileNotFoundError) as excinfo: + ReservoirDataset(temp_data_dir, mode="train") + + assert "weird_input.pt" in str(excinfo.value) or "weird_output.pt" in str( + excinfo.value + ) + + +# ============================================================================= +# Test Collate Functions +# ============================================================================= + + +class TestCollateFunctions: + """Tests for collate functions.""" + + def test_collate_fn(self, mock_3d_data): + """Test collate function for 3D data.""" + data_path, dims = mock_3d_data + ds = ReservoirDataset( + data_path, mode="train", variable="pressure", normalize=False + ) + + batch = [ds[i] for i in range(3)] + inputs, targets = collate_fn(batch) + + assert inputs.shape == (3, dims["H"], dims["W"], dims["T"], dims["C"]) + assert targets.shape == (3, dims["H"], dims["W"], dims["T"]) + + def test_collate_fn_4d(self, mock_4d_data): + """Test collate function for 4D data.""" + data_path, dims = mock_4d_data + ds = ReservoirDataset( + data_path, + mode="train", + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + normalize=False, + ) + + batch = [ds[i] for i in range(2)] + inputs, targets = collate_fn(batch) + + assert inputs.shape == ( + 2, + dims["X"], + dims["Y"], + dims["Z"], + dims["T"], + dims["C"], + ) + assert targets.shape == (2, dims["X"], dims["Y"], dims["Z"], dims["T"]) + + +# ============================================================================= +# Test Dataloaders +# ============================================================================= + + +class TestDataloaders: + """Tests for create_dataloaders function.""" + + def test_create_dataloaders_3d(self, mock_3d_data): + """Test creating dataloaders for 3D data.""" + data_path, dims = mock_3d_data + + train, val, test = create_dataloaders( + data_path, variable="pressure", batch_size=2, num_workers=0, normalize=False + ) + + assert len(train.dataset) == dims["N"] + assert len(val.dataset) == dims["N"] + assert len(test.dataset) == dims["N"] + + # Check batch dimensions + inputs, targets = next(iter(train)) + assert inputs.shape[0] == 2 # batch size + assert inputs.dim() == 5 # (B, H, W, T, C) + + def test_create_dataloaders_4d(self, mock_4d_data): + """Test creating dataloaders for 4D data.""" + data_path, dims = mock_4d_data + + train, val, test = create_dataloaders( + data_path, + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + batch_size=2, + num_workers=0, + normalize=False, + ) + + inputs, targets = next(iter(train)) + assert inputs.dim() == 6 # (B, X, Y, Z, T, C) + assert targets.dim() == 5 # (B, X, Y, Z, T) + + def test_create_dataloaders_with_dimension_validation(self, mock_3d_data): + """Test dataloaders with expected_dimensions validation.""" + data_path, _ = mock_3d_data + + # Should work + train, _, _ = create_dataloaders( + data_path, + variable="pressure", + batch_size=2, + num_workers=0, + expected_dimensions="3d", + ) + assert len(train.dataset) > 0 + + # Should fail + with pytest.raises(ValueError): + create_dataloaders( + data_path, + variable="pressure", + batch_size=2, + num_workers=0, + expected_dimensions="4d", + ) + + +# ============================================================================= +# Test get_dataset_info Utility +# ============================================================================= + + +class TestDatasetInfo: + """Tests for get_dataset_info utility function.""" + + def test_get_dataset_info_3d(self, mock_3d_data): + """Test getting dataset info for 3D data.""" + data_path, dims = mock_3d_data + + info = get_dataset_info(data_path, variable="pressure") + + assert info["dimensions"] == "3d" + assert info["spatial_shape"] == (dims["H"], dims["W"]) + assert info["time_steps"] == dims["T"] + assert info["num_channels"] == dims["C"] + assert info["num_samples"]["train"] == dims["N"] + + def test_get_dataset_info_4d(self, mock_4d_data): + """Test getting dataset info for 4D data.""" + data_path, dims = mock_4d_data + + info = get_dataset_info( + data_path, + input_file="norne_{mode}_input.pt", + output_file="norne_{mode}_output.pt", + ) + + assert info["dimensions"] == "4d" + assert info["spatial_shape"] == (dims["X"], dims["Y"], dims["Z"]) + assert info["time_steps"] == dims["T"] + assert info["num_channels"] == dims["C"] + + +# ============================================================================= +# Test Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_invalid_mode(self, mock_3d_data): + """Test that invalid mode raises error.""" + data_path, _ = mock_3d_data + + with pytest.raises(ValueError) as excinfo: + ReservoirDataset(data_path, mode="invalid", variable="pressure") + + assert ( + "train" in str(excinfo.value).lower() or "val" in str(excinfo.value).lower() + ) + + def test_missing_files(self, temp_data_dir): + """Test error when data files don't exist.""" + with pytest.raises(FileNotFoundError): + ReservoirDataset( + temp_data_dir, + mode="train", + input_file="nonexistent_input.pt", + output_file="nonexistent_output.pt", + ) + + def test_invalid_data_dimensions(self, temp_data_dir): + """Test error for unsupported data dimensions.""" + # Create 3D input (wrong) + torch.save(torch.randn(10, 32, 64, 12), temp_data_dir / "train_input.pt") + torch.save(torch.randn(10, 32, 64), temp_data_dir / "train_output.pt") + + with pytest.raises(ValueError) as excinfo: + ReservoirDataset(temp_data_dir, mode="train") + + assert "Unsupported data dimensions" in str(excinfo.value) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_losses.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_losses.py new file mode 100644 index 0000000000..ac2bd6b6cf --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_losses.py @@ -0,0 +1,738 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Comprehensive unit tests for loss functions.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from training.losses import SimpleRelativeL2Loss, UnifiedLoss, get_loss_function + +# --------------------------------------------------------------------------- +# Helpers: build inputs with grid-width channels (NOF convention) +# --------------------------------------------------------------------------- + + +def _make_2d_inputs(B, H, W, T, C, dx=None, dy=None): + """Create (B, H, W, T, C) inputs with grid widths in last 3 channels.""" + inputs = torch.randn(B, H, W, T, C) + if dx is None: + dx = torch.ones(W) + if dy is None: + dy = torch.ones(H) + dt = torch.linspace(0, 30, T) + inputs[..., -3] = dx.view(1, 1, W, 1).expand(B, H, W, T) + inputs[..., -2] = dy.view(1, H, 1, 1).expand(B, H, W, T) + inputs[..., -1] = dt.view(1, 1, 1, T).expand(B, H, W, T) + return inputs + + +def _make_3d_inputs(B, X, Y, Z, T, C, dx=None, dy=None, dz=None): + """Create (B, X, Y, Z, T, C) inputs with grid widths in last 4 channels.""" + inputs = torch.randn(B, X, Y, Z, T, C) + if dx is None: + dx = torch.ones(X) + if dy is None: + dy = torch.ones(Y) + if dz is None: + dz = torch.ones(Z) + dt = torch.linspace(0, 30, T) + inputs[..., -4] = dx.view(1, X, 1, 1, 1).expand(B, X, Y, Z, T) + inputs[..., -3] = dy.view(1, 1, Y, 1, 1).expand(B, X, Y, Z, T) + inputs[..., -2] = dz.view(1, 1, 1, Z, 1).expand(B, X, Y, Z, T) + inputs[..., -1] = dt.view(1, 1, 1, 1, T).expand(B, X, Y, Z, T) + return inputs + + +# =================================================================== +# SimpleRelativeL2Loss +# =================================================================== + + +class TestSimpleRelativeL2Loss: + """Tests for SimpleRelativeL2Loss.""" + + def test_zero_for_identical(self): + """Verify relative L2 loss is zero for identical pred and target.""" + target = torch.randn(2, 8, 16, 4) + loss = SimpleRelativeL2Loss()(target.clone(), target) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + def test_positive_for_different(self): + """Verify relative L2 loss is positive for different pred and target.""" + target = torch.randn(2, 8, 16, 4) + pred = target + 0.1 * torch.randn_like(target) + assert SimpleRelativeL2Loss()(pred, target) > 0 + + def test_scale_invariance(self): + """Verify relative L2 loss is scale-invariant.""" + target = torch.randn(2, 8, 16, 4) + 2 + pred = target + 0.1 * torch.randn_like(target) + fn = SimpleRelativeL2Loss() + assert torch.isclose(fn(pred, target), fn(pred * 5, target * 5), rtol=1e-4) + + def test_epsilon_prevents_nan(self): + """Verify epsilon prevents NaN when target is all zeros.""" + target = torch.zeros(2, 4, 4, 2) + pred = torch.ones(2, 4, 4, 2) + loss = SimpleRelativeL2Loss(eps=1e-8)(pred, target) + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + +# =================================================================== +# UnifiedLoss — Data Losses +# =================================================================== + + +class TestDataLosses: + """Tests for DataLosses.""" + + @pytest.mark.parametrize("loss_type", ["mse", "l1", "relative_l2", "huber"]) + def test_all_types_run(self, loss_type): + """Verify each loss type runs and produces a positive, finite loss.""" + pred = torch.randn(2, 8, 16, 4) + target = torch.randn(2, 8, 16, 4) + fn = UnifiedLoss(types=[loss_type], weights=[1.0]) + loss = fn(pred, target) + assert loss > 0 + assert not torch.isnan(loss) + + def test_mse_value(self): + """Verify MSE loss computes correct value.""" + target = torch.zeros(1, 4, 4, 2) + pred = torch.ones(1, 4, 4, 2) + loss = UnifiedLoss(types=["mse"])(pred, target) + assert torch.isclose(loss, torch.tensor(1.0)) + + def test_l1_value(self): + """Verify L1 loss computes correct value.""" + target = torch.zeros(1, 4, 4, 2) + pred = torch.full_like(target, 2.0) + loss = UnifiedLoss(types=["l1"])(pred, target) + assert torch.isclose(loss, torch.tensor(2.0)) + + def test_huber_equals_mse_for_small_errors(self): + """Verify Huber loss approximates MSE for small errors.""" + target = torch.randn(2, 8, 8, 4) + pred = target + 0.01 * torch.randn_like(target) + mse_loss = UnifiedLoss(types=["mse"])(pred, target) + huber_loss = UnifiedLoss(types=["huber"], huber_delta=1.0)(pred, target) + assert torch.isclose(mse_loss, huber_loss * 2, rtol=0.1) + + def test_relative_l2_epsilon_zero_target(self): + """Verify relative L2 with epsilon handles zero target without NaN.""" + target = torch.zeros(2, 4, 4, 2) + pred = torch.ones(2, 4, 4, 2) + loss = UnifiedLoss(types=["relative_l2"], eps=1e-6)(pred, target) + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_multiple_losses_combined(self): + """Verify combined losses with zero weight reduce to single loss.""" + pred = torch.randn(2, 8, 8, 4) + target = torch.randn(2, 8, 8, 4) + fn_single = UnifiedLoss(types=["mse"], weights=[1.0]) + fn_multi = UnifiedLoss(types=["mse", "l1"], weights=[1.0, 0.0]) + assert torch.isclose(fn_single(pred, target), fn_multi(pred, target)) + + def test_invalid_type_raises(self): + """Verify invalid loss type raises ValueError.""" + with pytest.raises(ValueError, match="Loss type must be"): + UnifiedLoss(types=["invalid"]) + + def test_invalid_reduction_raises(self): + """Verify invalid reduction mode raises ValueError.""" + with pytest.raises(ValueError, match="reduction"): + UnifiedLoss(reduction="invalid") + + def test_mismatched_lengths_raises(self): + """Verify mismatched types/weights lengths raises ValueError.""" + with pytest.raises(ValueError, match="same length"): + UnifiedLoss(types=["mse", "l1"], weights=[1.0]) + + def test_gradient_flow(self): + """Verify gradients flow through MSE loss to prediction tensor.""" + target = torch.randn(2, 8, 8, 4) + pred = torch.randn(2, 8, 8, 4, requires_grad=True) + loss = UnifiedLoss(types=["mse"])(pred, target) + loss.backward() + assert pred.grad is not None + + +# =================================================================== +# UnifiedLoss — Masking +# =================================================================== + + +class TestMasking: + """Tests for Masking.""" + + def test_mse_mask_only_active(self): + """MSE should average only over active cells.""" + B, H, W, T = 1, 4, 4, 2 + target = torch.zeros(B, H, W, T) + pred = torch.ones(B, H, W, T) + mask = torch.zeros(H, W, dtype=torch.bool) + mask[0, 0] = True # only one cell active + + fn = UnifiedLoss(types=["mse"]) + loss = fn(pred, target, spatial_mask=mask) + assert torch.isclose(loss, torch.tensor(1.0)) + + def test_l1_mask_only_active(self): + """Verify L1 loss averages only over active masked cells.""" + B, H, W, T = 1, 4, 4, 2 + target = torch.zeros(B, H, W, T) + pred = torch.full((B, H, W, T), 3.0) + mask = torch.zeros(H, W, dtype=torch.bool) + mask[0, 0] = True + + loss = UnifiedLoss(types=["l1"])(pred, target, spatial_mask=mask) + assert torch.isclose(loss, torch.tensor(3.0)) + + def test_mask_zeros_dont_leak_2d(self): + """Error only in masked-out region => loss = 0.""" + B, H, W, T = 1, 8, 8, 4 + target = torch.randn(B, H, W, T) + pred = target.clone() + pred[:, :4, :, :] += 10.0 + + mask = torch.zeros(H, W, dtype=torch.bool) + mask[4:, :] = True + + loss = UnifiedLoss(types=["mse"])(pred, target, spatial_mask=mask) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + def test_mask_3d(self): + """Verify 3D spatial masking excludes inactive cells from loss.""" + B, X, Y, Z, T = 1, 4, 4, 2, 3 + target = torch.randn(B, X, Y, Z, T) + pred = target.clone() + pred[:, :2, :, :, :] += 10.0 + + mask = torch.zeros(X, Y, Z, dtype=torch.bool) + mask[2:, :, :] = True + + loss = UnifiedLoss(types=["mse"])(pred, target, spatial_mask=mask) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + +# =================================================================== +# UnifiedLoss — Derivative Loss (2D) +# =================================================================== + + +class TestDerivativeLoss2D: + """Tests for DerivativeLoss2D.""" + + def test_derivative_dx_runs(self): + """Verify dx derivative loss runs and produces a positive value.""" + B, H, W, T, C = 2, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + pred = torch.randn(B, H, W, T) + target = torch.randn(B, H, W, T) + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + assert not torch.isnan(loss) + + def test_derivative_dy_runs(self): + """Verify dy derivative loss runs and produces a positive value.""" + B, H, W, T, C = 2, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + pred = torch.randn(B, H, W, T) + target = torch.randn(B, H, W, T) + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dy"]}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + + def test_derivative_both_dims(self): + """Verify derivative loss with both dx and dy dimensions.""" + B, H, W, T, C = 2, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + pred = torch.randn(B, H, W, T) + target = torch.randn(B, H, W, T) + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx", "dy"]}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + + def test_derivative_zero_for_identical(self): + """Verify derivative loss is zero when pred equals target.""" + B, H, W, T, C = 1, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + pred = target.clone() + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 1.0, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + def test_derivative_gradient_flow(self): + """Verify gradients flow through derivative loss to prediction.""" + B, H, W, T, C = 2, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + pred = torch.randn(B, H, W, T, requires_grad=True) + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs) + loss.backward() + assert pred.grad is not None + + def test_derivative_non_uniform_grid(self): + """Verify derivative uses non-uniform spacing correctly.""" + B, H, W, T, C = 1, 4, 6, 2, 12 + dx = torch.tensor([10.0, 20.0, 30.0, 15.0, 25.0, 10.0]) + dy = torch.tensor([5.0, 10.0, 5.0, 10.0]) + inputs = _make_2d_inputs(B, H, W, T, C, dx=dx, dy=dy) + + target = torch.zeros(B, H, W, T) + pred = torch.zeros(B, H, W, T) + pred[0, :, :, 0] = torch.arange(W).float().unsqueeze(0).expand(H, W) + + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 1.0, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + assert not torch.isnan(loss) + + def test_requires_inputs(self): + """Verify derivative loss raises ValueError when inputs are missing.""" + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + with pytest.raises(ValueError, match="inputs required"): + fn(torch.randn(2, 8, 8, 4), torch.randn(2, 8, 8, 4)) + + +# =================================================================== +# UnifiedLoss — Derivative Loss (3D) +# =================================================================== + + +class TestDerivativeLoss3D: + """Tests for DerivativeLoss3D.""" + + def test_derivative_all_3d_dims(self): + """Verify derivative loss with all three 3D spatial dimensions.""" + B, X, Y, Z, T, C = 2, 6, 8, 4, 3, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + pred = torch.randn(B, X, Y, Z, T) + target = torch.randn(B, X, Y, Z, T) + fn = UnifiedLoss( + types=["mse"], + derivative_config={ + "enabled": True, + "weight": 0.5, + "dims": ["dx", "dy", "dz"], + }, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + assert not torch.isnan(loss) + + def test_derivative_single_dim_3d(self): + """Verify each individual 3D derivative dimension works.""" + B, X, Y, Z, T, C = 1, 6, 8, 4, 2, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + pred = torch.randn(B, X, Y, Z, T) + target = torch.randn(B, X, Y, Z, T) + for dim in ["dx", "dy", "dz"]: + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": [dim]}, + ) + loss = fn(pred, target, inputs) + assert loss > 0, f"Failed for {dim}" + + def test_invalid_dim_for_2d_raises(self): + """Verify dz derivative on 2D data raises ValueError.""" + B, H, W, T, C = 1, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dz"]}, + ) + with pytest.raises(ValueError, match="not valid for 2D"): + fn(torch.randn(B, H, W, T), torch.randn(B, H, W, T), inputs) + + +class TestDerivativeWithMask: + """Tests for derivative loss with spatial masking (Norne-like sparse grids).""" + + def test_derivative_with_sparse_mask_no_nan(self): + """Derivative loss must not produce NaN on grids with many inactive cells.""" + B, X, Y, Z, T, C = 1, 10, 12, 4, 3, 11 + dx = torch.tensor([10.0, 20.0, 15.0, 10.0, 25.0, 30.0, 10.0, 20.0, 15.0, 10.0]) + dy = torch.ones(Y) * 5.0 + dz = torch.ones(Z) * 8.0 + inputs = _make_3d_inputs(B, X, Y, Z, T, C, dx=dx, dy=dy, dz=dz) + target = torch.randn(B, X, Y, Z, T) + pred = target + 0.1 * torch.randn_like(target) + + # Sparse mask: only 30% active (similar to Norne's 39%) + mask = torch.zeros(X, Y, Z, dtype=torch.bool) + mask[2:5, 3:8, :] = True + + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs, spatial_mask=mask) + assert not torch.isnan(loss), "Loss is NaN with sparse mask" + assert not torch.isinf(loss), "Loss is Inf with sparse mask" + + def test_derivative_all_dims_with_mask(self): + """All 3D derivative directions work with masking.""" + B, X, Y, Z, T, C = 1, 8, 10, 4, 2, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + target = torch.randn(B, X, Y, Z, T) + pred = target + 0.05 * torch.randn_like(target) + + mask = torch.zeros(X, Y, Z, dtype=torch.bool) + mask[1:6, 2:8, 1:3] = True + + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={ + "enabled": True, + "weight": 0.5, + "dims": ["dx", "dy", "dz"], + }, + ) + loss = fn(pred, target, inputs, spatial_mask=mask) + assert not torch.isnan(loss) + + def test_derivative_auto_detects_inactive_no_mask(self): + """Derivative loss auto-detects inactive cells even without spatial_mask.""" + B, X, Y, Z, T, C = 1, 8, 10, 4, 3, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + target = torch.randn(B, X, Y, Z, T) + # Make some cells inactive (zero across all timesteps) + target[:, :3, :, :, :] = 0.0 + pred = target + 0.1 * torch.randn_like(target) + pred[:, :3, :, :, :] = 0.0 # pred also zero there + + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + # No spatial_mask passed — auto-detection should handle it + loss = fn(pred, target, inputs, spatial_mask=None) + assert not torch.isnan(loss), "NaN with auto-detected inactive cells" + assert not torch.isinf(loss) + + def test_derivative_all_active_no_mask(self): + """When all cells are active and no mask, derivative works normally.""" + B, H, W, T, C = 1, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + 1.0 # no zeros + pred = target + 0.05 * torch.randn_like(target) + + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx", "dy"]}, + ) + loss = fn(pred, target, inputs, spatial_mask=None) + assert not torch.isnan(loss) + assert loss > 0 + + def test_derivative_mask_excludes_boundary_artifacts(self): + """Error at inactive cells should not affect derivative loss.""" + B, H, W, T, C = 1, 8, 12, 2, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + pred = target.clone() + # Large error only in masked-out region + pred[:, :3, :, :] += 100.0 + + mask = torch.zeros(H, W, dtype=torch.bool) + mask[4:, :] = True # only bottom half active + + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 1.0, "dims": ["dx", "dy"]}, + ) + loss = fn(pred, target, inputs, spatial_mask=mask) + # Data loss should be ~0, derivative loss should be ~0 + # because pred=target in the active region + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-4), ( + f"Expected ~0, got {loss.item()}" + ) + + +# =================================================================== +# get_loss_function factory +# =================================================================== + + +class TestFactory: + """Tests for Factory.""" + + def test_defaults(self): + """Verify default factory produces UnifiedLoss with relative_l2.""" + fn = get_loss_function({}) + assert isinstance(fn, UnifiedLoss) + assert fn.loss_types == ["relative_l2"] + + def test_with_physics(self): + """Verify factory registers physics losses from config.""" + cfg = { + "types": ["mse"], + "weights": [1.0], + "physics": {"mass_conservation": {"enabled": True, "weight": 0.5}}, + } + fn = get_loss_function(cfg, variable="saturation") + assert "mass_conservation" in fn._physics_losses + + def test_new_derivative_config(self): + """Verify factory applies derivative config from loss config.""" + cfg = { + "types": ["mse"], + "weights": [1.0], + "derivative": {"enabled": True, "weight": 0.3, "dims": ["dx", "dy"]}, + } + fn = get_loss_function(cfg) + assert fn._deriv_enabled is True + assert fn._deriv_weight == 0.3 + assert fn._deriv_dims == ["dx", "dy"] + + def test_pressure_warning(self): + """Verify mass conservation on pressure variable emits a warning.""" + cfg = { + "types": ["mse"], + "weights": [1.0], + "physics": {"mass_conservation": {"enabled": True, "weight": 1.0}}, + } + with pytest.warns(UserWarning, match="pressure"): + get_loss_function(cfg, variable="pressure") + + +# =================================================================== +# AR window compatibility +# =================================================================== + + +class TestARCompatibility: + """Tests for ARCompatibility.""" + + def test_single_timestep_2d(self): + """Verify loss with single-timestep 2D AR window.""" + B, H, W, C = 2, 8, 16, 12 + pred = torch.randn(B, H, W, 1) + target = torch.randn(B, H, W, 1) + inputs = _make_2d_inputs(B, H, W, 1, C) + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + + def test_small_window_3d(self): + """Verify loss with small 3D AR window (K=3).""" + B, X, Y, Z, C = 2, 6, 8, 4, 11 + K = 3 + pred = torch.randn(B, X, Y, Z, K) + target = torch.randn(B, X, Y, Z, K) + inputs = _make_3d_inputs(B, X, Y, Z, K, C) + fn = UnifiedLoss(types=["mse"]) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +# =================================================================== +# Derivative loss with all metric types +# =================================================================== + + +class TestDerivativeAllMetrics: + """Verify derivative loss works with every loss metric + masking.""" + + @pytest.mark.parametrize("metric", ["mse", "l1", "relative_l2", "huber"]) + def test_derivative_metric_2d(self, metric): + """All metrics produce finite loss for 2D derivative.""" + B, H, W, T, C = 1, 8, 16, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + 1.0 + pred = target + 0.05 * torch.randn_like(target) + + fn = UnifiedLoss( + types=["mse"], + derivative_config={ + "enabled": True, + "weight": 0.5, + "dims": ["dx"], + "metric": metric, + }, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss), f"NaN with metric={metric}" + assert loss > 0 + + @pytest.mark.parametrize("metric", ["mse", "l1", "relative_l2", "huber"]) + def test_derivative_metric_with_mask(self, metric): + """All metrics work with derivative + sparse mask.""" + B, X, Y, Z, T, C = 1, 8, 10, 4, 3, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + target = torch.randn(B, X, Y, Z, T) + 1.0 + pred = target + 0.05 * torch.randn_like(target) + + mask = torch.zeros(X, Y, Z, dtype=torch.bool) + mask[2:6, 2:8, 1:3] = True + + fn = UnifiedLoss( + types=["mse"], + derivative_config={ + "enabled": True, + "weight": 0.5, + "dims": ["dx"], + "metric": metric, + }, + ) + loss = fn(pred, target, inputs, spatial_mask=mask) + assert not torch.isnan(loss), f"NaN with metric={metric} + mask" + + +# =================================================================== +# Edge cases +# =================================================================== + + +class TestLossEdgeCases: + """Edge cases: minimum grids, single batch, small AR windows.""" + + def test_minimum_grid_for_derivative_2d(self): + """3 cells along derivative axis = minimum for central difference.""" + B, H, W, T, C = 1, 3, 3, 2, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + 1.0 + pred = target + 0.1 + + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx", "dy"]}, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + assert loss > 0 + + def test_minimum_grid_for_derivative_3d(self): + """3x3x3 grid with all 3 derivative dims.""" + B, X, Y, Z, T, C = 1, 3, 3, 3, 2, 11 + inputs = _make_3d_inputs(B, X, Y, Z, T, C) + target = torch.randn(B, X, Y, Z, T) + 1.0 + pred = target + 0.1 + + fn = UnifiedLoss( + types=["mse"], + derivative_config={ + "enabled": True, + "weight": 0.5, + "dims": ["dx", "dy", "dz"], + }, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + + def test_single_batch(self): + """B=1 with all loss components.""" + B, H, W, T, C = 1, 8, 8, 4, 12 + inputs = _make_2d_inputs(B, H, W, T, C) + target = torch.randn(B, H, W, T) + 1.0 + pred = target + 0.05 * torch.randn_like(target) + + from training.physics_losses import MassConservationLoss + + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.3, "dims": ["dx"]}, + physics_losses={"mc": (MassConservationLoss(), 0.5)}, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + + def test_k_equals_1_ar_window(self): + """K=1 (single output timestep) with derivative and physics.""" + B, X, Y, Z, C = 1, 6, 8, 4, 11 + inputs = _make_3d_inputs(B, X, Y, Z, 1, C) + target = torch.randn(B, X, Y, Z, 1) + 1.0 + pred = target + 0.1 + + from training.physics_losses import MassConservationLoss + + fn = UnifiedLoss( + types=["mse"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + physics_losses={"mc": (MassConservationLoss(), 1.0)}, + ) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + + def test_norne_regression(self): + """Regression test: Norne-like grid (39% active, normalized widths) + with relative_l2 derivative should NOT produce NaN.""" + B, X, Y, Z, K, C = 2, 46, 112, 22, 3, 11 + pred = torch.randn(B, X, Y, Z, K) + target = torch.randn(B, X, Y, Z, K) + inputs = torch.zeros(B, X, Y, Z, 1, C) + + mask = torch.zeros(X, Y, Z, dtype=torch.bool) + mask[5:35, 10:100, 2:18] = True + + gx = torch.linspace(0, 1, X).view(X, 1, 1).expand(X, Y, Z) * mask.float() + gy = torch.linspace(0, 1, Y).view(1, Y, 1).expand(X, Y, Z) * mask.float() + gz = torch.linspace(0, 1, Z).view(1, 1, Z).expand(X, Y, Z) * mask.float() + inputs[..., 0, -4] = gx.unsqueeze(0).expand(B, -1, -1, -1) + inputs[..., 0, -3] = gy.unsqueeze(0).expand(B, -1, -1, -1) + inputs[..., 0, -2] = gz.unsqueeze(0).expand(B, -1, -1, -1) + + pred[:, ~mask, :] = 0.0 + target[:, ~mask, :] = 0.0 + + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx"]}, + ) + loss = fn(pred, target, inputs, spatial_mask=mask) + assert not torch.isnan(loss), "Norne regression: NaN detected" + assert not torch.isinf(loss), "Norne regression: Inf detected" diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_metrics.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_metrics.py new file mode 100644 index 0000000000..2b61083213 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_metrics.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for evaluation metrics (numpy and torch).""" + +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from training.metrics import ( + compute_r2_score, + compute_relative_l1_error, + compute_relative_l2_error, + mae_torch, + max_absolute_error, + max_error_torch, + mean_absolute_error, + mean_plume_error, + mean_relative_error, + mse_torch, + normalized_mse, + peak_signal_to_noise_ratio, + psnr_torch, + r2_score_torch, + relative_l1_torch, + relative_l2_torch, + rmse_torch, +) + + +class TestNumpyMetrics: + """Tests for numpy-based metrics.""" + + def test_mae_known_value(self): + """Verify MAE computes expected value for known inputs.""" + y_true = np.array([1.0, 2.0, 3.0]) + y_pred = np.array([1.5, 2.5, 3.5]) + assert abs(mean_absolute_error(y_pred, y_true) - 0.5) < 1e-8 + + def test_mae_identical(self): + """Verify MAE is zero when prediction equals target.""" + y = np.random.randn(100) + assert mean_absolute_error(y, y) == 0.0 + + def test_max_absolute_error(self): + """Verify max absolute error returns the largest element-wise difference.""" + y_true = np.array([0.0, 0.0, 0.0]) + y_pred = np.array([1.0, 3.0, 2.0]) + assert abs(max_absolute_error(y_pred, y_true) - 3.0) < 1e-8 + + def test_mre_known_value(self): + """Verify mean relative error computes expected value for known inputs.""" + y_true = np.array([10.0, 20.0, 30.0]) + y_pred = np.array([11.0, 21.0, 31.0]) + data_range = 30.0 - 10.0 # 20 + expected = np.mean(np.abs(y_pred - y_true)) / data_range # 1/20 = 0.05 + assert abs(mean_relative_error(y_pred, y_true) - expected) < 1e-8 + + def test_mpe_only_plume_region(self): + """Verify mean plume error only considers cells above plume threshold.""" + y_true = np.array([0.0, 0.0, 0.5, 0.8]) + y_pred = np.array([0.0, 0.0, 0.6, 0.9]) + mpe = mean_plume_error(y_pred, y_true) + assert abs(mpe - 0.1) < 1e-8 + + def test_mpe_no_plume(self): + """Verify mean plume error is zero when no plume region exists.""" + y_true = np.zeros(10) + y_pred = np.zeros(10) + assert mean_plume_error(y_pred, y_true) == 0.0 + + def test_r2_perfect(self): + """Verify R2 score is 1.0 for a perfect prediction.""" + y = np.random.randn(50) + assert abs(compute_r2_score(y, y) - 1.0) < 1e-8 + + def test_r2_constant_prediction(self): + """Verify R2 score is zero when prediction equals the target mean.""" + y_true = np.array([1.0, 2.0, 3.0, 4.0]) + y_pred = np.full_like(y_true, y_true.mean()) + assert abs(compute_r2_score(y_pred, y_true)) < 1e-8 + + def test_r2_negative_for_bad_prediction(self): + """Verify R2 score is negative for a prediction worse than the mean.""" + y_true = np.array([1.0, 2.0, 3.0]) + y_pred = np.array([10.0, 20.0, 30.0]) + assert compute_r2_score(y_pred, y_true) < 0 + + def test_relative_l2_known(self): + """Verify relative L2 error equals 1.0 when prediction is all zeros.""" + y_true = np.array([3.0, 4.0]) # norm = 5 + y_pred = np.array([0.0, 0.0]) # diff norm = 5 + assert abs(compute_relative_l2_error(y_pred, y_true) - 1.0) < 1e-6 + + def test_relative_l1_known(self): + """Verify relative L1 error equals 1.0 when prediction is all zeros.""" + y_true = np.array([1.0, 2.0, 3.0]) # L1 norm = 6 + y_pred = np.array([0.0, 0.0, 0.0]) # diff L1 = 6 + assert abs(compute_relative_l1_error(y_pred, y_true) - 1.0) < 1e-6 + + def test_nmse_variance(self): + """Verify normalized MSE with variance normalization for known inputs.""" + y_true = np.array([1.0, 3.0]) # var = 1 + y_pred = np.array([2.0, 2.0]) # mse = 1 + assert abs(normalized_mse(y_pred, y_true, "variance") - 1.0) < 1e-6 + + def test_psnr_perfect(self): + """Verify PSNR is infinite for a perfect prediction.""" + y = np.random.randn(100) + assert peak_signal_to_noise_ratio(y, y) == float("inf") + + def test_psnr_known(self): + """Verify PSNR computes expected dB value for known MSE and data range.""" + y_true = np.array([0.0, 1.0]) + y_pred = np.array([0.0, 0.9]) # mse = 0.005, range = 1 + psnr = peak_signal_to_noise_ratio(y_pred, y_true) + expected = 20 * np.log10(1.0 / np.sqrt(0.005)) + assert abs(psnr - expected) < 1e-4 + + +class TestTorchMetrics: + """Tests for torch-based metrics.""" + + def test_mse_torch_value(self): + """Verify torch MSE computes expected value for known inputs.""" + pred = torch.tensor([1.0, 2.0, 3.0]) + target = torch.tensor([1.5, 2.5, 3.5]) + assert torch.isclose(mse_torch(pred, target), torch.tensor(0.25)) + + def test_rmse_torch_value(self): + """Verify torch RMSE computes expected value for known inputs.""" + pred = torch.tensor([1.0, 2.0]) + target = torch.tensor([2.0, 3.0]) + assert torch.isclose(rmse_torch(pred, target), torch.tensor(1.0)) + + def test_mae_torch_value(self): + """Verify torch MAE computes expected value for known inputs.""" + pred = torch.tensor([0.0, 0.0]) + target = torch.tensor([1.0, 3.0]) + assert torch.isclose(mae_torch(pred, target), torch.tensor(2.0)) + + def test_relative_l2_torch_known(self): + """Verify torch relative L2 error equals 1.0 for zero prediction.""" + target = torch.tensor([3.0, 4.0]) # norm = 5 + pred = torch.zeros(2) + assert torch.isclose( + relative_l2_torch(pred, target), torch.tensor(1.0), atol=1e-6 + ) + + def test_relative_l1_torch_known(self): + """Verify torch relative L1 error equals 1.0 for zero prediction.""" + target = torch.tensor([1.0, 2.0, 3.0]) + pred = torch.zeros(3) + assert torch.isclose( + relative_l1_torch(pred, target), torch.tensor(1.0), atol=1e-6 + ) + + def test_r2_torch_perfect(self): + """Verify torch R2 score is 1.0 for a perfect prediction.""" + y = torch.randn(50) + assert torch.isclose(r2_score_torch(y, y), torch.tensor(1.0)) + + def test_r2_torch_negative(self): + """Verify torch R2 score is negative for a prediction worse than the mean.""" + target = torch.tensor([1.0, 2.0, 3.0]) + pred = torch.tensor([10.0, 20.0, 30.0]) + assert r2_score_torch(pred, target) < 0 + + def test_max_error_torch(self): + """Verify torch max error returns the largest element-wise difference.""" + pred = torch.tensor([0.0, 0.0]) + target = torch.tensor([1.0, 5.0]) + assert torch.isclose(max_error_torch(pred, target), torch.tensor(5.0)) + + def test_psnr_torch_perfect(self): + """Verify torch PSNR is infinite for a perfect prediction.""" + y = torch.randn(50) + assert psnr_torch(y, y) == float("inf") + + +class TestTorchNumpyConsistency: + """Verify torch and numpy metrics agree on the same data.""" + + def test_mae_consistency(self): + """Verify numpy and torch MAE agree on the same random data.""" + pred_np = np.random.randn(100) + target_np = np.random.randn(100) + np_val = mean_absolute_error(pred_np, target_np) + torch_val = mae_torch(torch.tensor(pred_np), torch.tensor(target_np)).item() + assert abs(np_val - torch_val) < 1e-6 + + def test_relative_l2_consistency(self): + """Verify numpy and torch relative L2 error agree on the same data.""" + pred_np = np.random.randn(100) + target_np = np.random.randn(100) + 2 + np_val = compute_relative_l2_error(pred_np, target_np) + torch_val = relative_l2_torch( + torch.tensor(pred_np), torch.tensor(target_np) + ).item() + assert abs(np_val - torch_val) < 1e-5 + + def test_r2_consistency(self): + """Verify numpy and torch R2 score agree on the same data.""" + pred_np = np.random.randn(100) + target_np = np.random.randn(100) + np_val = compute_r2_score(pred_np, target_np) + torch_val = r2_score_torch( + torch.tensor(pred_np), torch.tensor(target_np) + ).item() + assert abs(np_val - torch_val) < 1e-5 + + +class TestEdgeCases: + """Edge cases for metrics.""" + + def test_single_element(self): + """Verify metrics handle single-element arrays correctly.""" + y = np.array([5.0]) + assert mean_absolute_error(y, y) == 0.0 + assert compute_r2_score(y, y) == 1.0 + + def test_zero_target_relative_l2(self): + """Verify relative L2 error returns a finite value when target is zero.""" + pred = np.array([1.0]) + target = np.array([0.0]) + val = compute_relative_l2_error(pred, target) + assert np.isfinite(val) + + def test_constant_target_r2(self): + """Verify R2 score is zero when target has zero variance.""" + target = np.ones(10) + pred = np.ones(10) * 1.1 + r2 = compute_r2_score(pred, target) + assert r2 == 0.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_padding.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_padding.py new file mode 100644 index 0000000000..b99301d218 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_padding.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for padding utilities (dimension-agnostic).""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from utils.padding import ( + compute_right_pad_to_multiple, + compute_right_pad_to_multiple_per_dim, + pad_right_nd, + pad_spatial_right, +) + + +class TestComputePads: + """Tests for ComputePads.""" + + def test_compute_right_pad_to_multiple(self): + """Verify right-pad computation aligns dimensions to a given multiple.""" + assert compute_right_pad_to_multiple((16, 24), multiple=8, min_right_pad=0) == ( + 0, + 0, + ) + assert compute_right_pad_to_multiple((17, 24), multiple=8, min_right_pad=0) == ( + 7, + 0, + ) + assert compute_right_pad_to_multiple((16, 24), multiple=8, min_right_pad=8) == ( + 8, + 8, + ) + # Ensure min_right_pad does not break alignment (d + pad must stay multiple-of-8) + assert compute_right_pad_to_multiple((46,), multiple=8, min_right_pad=8) == ( + 10, + ) + + def test_compute_right_pad_to_multiple_per_dim(self): + """Verify per-dimension padding respects independent min-pad constraints.""" + assert compute_right_pad_to_multiple_per_dim( + (16, 17), multiple=8, min_right_pad=(0, 0) + ) == (0, 7) + assert compute_right_pad_to_multiple_per_dim( + (16, 17), multiple=8, min_right_pad=(8, 0) + ) == (8, 7) + + +class TestPadRightNd: + """Tests for PadRightNd.""" + + def test_replicate_right_pad_6d(self): + """Verify replicate padding extends a 6D tensor along the T dimension.""" + # Shape: (B, X, Y, Z, T, C) + x = torch.zeros(1, 1, 1, 1, 2, 1) + x[..., 0, 0] = 10.0 + x[..., 1, 0] = 20.0 + + y = pad_right_nd(x, dims=(4,), right_pad=(3,), mode="replicate") + assert y.shape == (1, 1, 1, 1, 5, 1) + # Last value should replicate the original last along T (20) + assert y[0, 0, 0, 0, -1, 0].item() == 20.0 + + +class TestPadSpatialRight: + """Tests for PadSpatialRight.""" + + def test_2d_spatial_keeps_rest(self): + """Verify 2D spatial padding grows H and W while preserving T and C.""" + x = torch.randn(2, 5, 7, 3, 4) # (B,H,W,T,C) + y = pad_spatial_right(x, spatial_ndim=2, right_pad=(1, 2), mode="replicate") + assert y.shape == (2, 6, 9, 3, 4) + + def test_3d_spatial_includes_time_when_requested(self): + """Verify 3D spatial padding pads H, W, and T dimensions.""" + x = torch.randn(2, 5, 7, 3, 4) # (B,H,W,T,C) + y = pad_spatial_right(x, spatial_ndim=3, right_pad=(1, 2, 3), mode="replicate") + assert y.shape == (2, 6, 9, 6, 4) + + def test_4d_spatial_works_for_6d_inputs(self): + """Verify 4D spatial padding works on 6D inputs with replicate mode.""" + # (B,X,Y,Z,T,C) + x = torch.tensor([[[[[[10.0], [20.0]]]]]]) # (1,1,1,1,2,1) + y = pad_spatial_right( + x, spatial_ndim=4, right_pad=(1, 1, 1, 2), mode="replicate" + ) + assert y.shape == (1, 2, 2, 2, 4, 1) + assert y[0, -1, -1, -1, -1, 0].item() == 20.0 + + +class TestPaddingAdditional: + """Additional padding tests for edge cases.""" + + def test_zero_padding_returns_unchanged(self): + """Verify zero padding returns an identical tensor.""" + from utils.padding import pad_spatial_right + + x = torch.randn(2, 8, 16, 4, 3) + out = pad_spatial_right(x, spatial_ndim=2, right_pad=(0, 0)) + assert torch.equal(x, out) + + def test_non_uniform_per_dim_padding(self): + """Verify per-dim padding satisfies both alignment and minimum-pad constraints.""" + from utils.padding import compute_right_pad_to_multiple_per_dim + + pads = compute_right_pad_to_multiple_per_dim( + (10, 13, 7), multiple=8, min_right_pad=[2, 4, 1] + ) + for i, (orig, pad) in enumerate(zip([10, 13, 7], pads)): + assert (orig + pad) % 8 == 0, f"Dim {i}: {orig}+{pad} not multiple of 8" + assert pad >= [2, 4, 1][i], f"Dim {i}: pad {pad} < min {[2, 4, 1][i]}" + + def test_constant_mode(self): + """Verify constant-mode padding fills new cells with the specified value.""" + from utils.padding import pad_right_nd + + x = torch.zeros(2, 4, 6) + out = pad_right_nd( + x, dims=[1, 2], right_pad=[2, 3], mode="constant", constant_value=99.0 + ) + assert out.shape == (2, 6, 9) + assert out[0, 5, 0].item() == 99.0 + assert out[0, 0, 7].item() == 99.0 + + def test_replicate_mode(self): + """Verify replicate-mode padding copies the last value along padded dims.""" + from utils.padding import pad_right_nd + + x = torch.arange(4).float().unsqueeze(0).unsqueeze(0) # (1, 1, 4) + out = pad_right_nd(x, dims=[2], right_pad=[2], mode="replicate") + assert out.shape == (1, 1, 6) + assert out[0, 0, 4] == out[0, 0, 3] + assert out[0, 0, 5] == out[0, 0, 3] + + def test_invalid_spatial_ndim_raises(self): + """Verify ValueError is raised for an unsupported spatial_ndim.""" + from utils.padding import pad_spatial_right + + with pytest.raises(ValueError, match="spatial_ndim must be"): + pad_spatial_right(torch.randn(2, 4, 4), spatial_ndim=1, right_pad=(2,)) + + def test_wrong_right_pad_length_raises(self): + """Verify ValueError is raised when right_pad length mismatches spatial_ndim.""" + from utils.padding import pad_spatial_right + + with pytest.raises(ValueError, match="right_pad must have length"): + pad_spatial_right( + torch.randn(2, 4, 4, 3), spatial_ndim=2, right_pad=(2, 3, 4) + ) + + def test_4d_spatial_padding(self): + """Verify 4D spatial padding produces correct output shape on 6D input.""" + from utils.padding import pad_spatial_right + + x = torch.randn(2, 4, 6, 3, 5, 8) # (B, X, Y, Z, T, C) + out = pad_spatial_right(x, spatial_ndim=4, right_pad=(1, 2, 1, 3)) + assert out.shape == (2, 5, 8, 4, 8, 8) + + def test_multiple_of_8_already_aligned(self): + """Verify zero padding when dimensions are already aligned to the multiple.""" + from utils.padding import compute_right_pad_to_multiple + + pads = compute_right_pad_to_multiple((16, 24), multiple=8, min_right_pad=0) + assert pads == (0, 0) + + def test_multiple_with_min_pad(self): + """Verify padding meets the minimum pad constraint while staying aligned.""" + from utils.padding import compute_right_pad_to_multiple + + pads = compute_right_pad_to_multiple((16,), multiple=8, min_right_pad=4) + assert pads[0] >= 4 + assert (16 + pads[0]) % 8 == 0 diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_physics_losses.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_physics_losses.py new file mode 100644 index 0000000000..acb29e22c3 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_physics_losses.py @@ -0,0 +1,698 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Comprehensive tests for physics losses with analytically-verifiable dummy reservoirs.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from training.losses import UnifiedLoss, get_loss_function +from training.physics_losses import ( + MassConservationLoss, + _extract_grid_widths, + build_physics_losses, + cell_centre_distance, + central_difference, + compute_cell_volumes_from_widths, + extract_grid_widths_for_axis, +) + +# =================================================================== +# Dummy Reservoir Fixtures +# =================================================================== +# 2D reservoir: H=4, W=6, T=3, C=12 +# dx = [10, 20, 30, 15, 25, 10] metres (varies along W, channel -3) +# dy = [5, 10, 5, 10] metres (varies along H, channel -2) +# Total area = sum_h sum_w (dy[h] * dx[w]) +# = (5+10+5+10) * (10+20+30+15+25+10) = 30 * 110 = 3300 m^2 +# +# 3D reservoir: X=3, Y=4, Z=2, T=2, C=11 +# dx = [10, 20, 15] metres (varies along X, channel -4) +# dy = [5, 10, 5, 10] metres (varies along Y, channel -3) +# dz = [8, 12] metres (varies along Z, channel -2) +# Total volume = 45 * 30 * 20 = 27000 m^3 + +DX_2D = torch.tensor([10.0, 20.0, 30.0, 15.0, 25.0, 10.0]) +DY_2D = torch.tensor([5.0, 10.0, 5.0, 10.0]) +TOTAL_AREA_2D = DY_2D.sum().item() * DX_2D.sum().item() # 3300 + +DX_3D = torch.tensor([10.0, 20.0, 15.0]) +DY_3D = torch.tensor([5.0, 10.0, 5.0, 10.0]) +DZ_3D = torch.tensor([8.0, 12.0]) +TOTAL_VOL_3D = DX_3D.sum().item() * DY_3D.sum().item() * DZ_3D.sum().item() # 27000 + + +def _build_2d_reservoir(B=1, T=3, saturation_fn=None): + """Build a 2D reservoir with known grid widths and saturation field. + + Returns (inputs, target, pred) where: + - inputs: (B, 4, 6, T, 12) + - target: (B, 4, 6, T) saturation field + - pred: (B, 4, 6, T) prediction (= target by default) + """ + H, W, C = 4, 6, 12 + inputs = torch.zeros(B, H, W, T, C) + inputs[..., -3] = DX_2D.view(1, 1, W, 1).expand(B, H, W, T) + inputs[..., -2] = DY_2D.view(1, H, 1, 1).expand(B, H, W, T) + inputs[..., -1] = torch.linspace(0, 10, T).view(1, 1, 1, T).expand(B, H, W, T) + + if saturation_fn is not None: + target = saturation_fn(B, H, W, T) + else: + target = torch.ones(B, H, W, T) * 0.5 + + return inputs, target, target.clone() + + +def _build_3d_reservoir(B=1, T=2, saturation_fn=None): + """Build a 3D reservoir with known grid widths. + + Returns (inputs, target, pred). + """ + X, Y, Z, C = 3, 4, 2, 11 + inputs = torch.zeros(B, X, Y, Z, T, C) + inputs[..., -4] = DX_3D.view(1, X, 1, 1, 1).expand(B, X, Y, Z, T) + inputs[..., -3] = DY_3D.view(1, 1, Y, 1, 1).expand(B, X, Y, Z, T) + inputs[..., -2] = DZ_3D.view(1, 1, 1, Z, 1).expand(B, X, Y, Z, T) + inputs[..., -1] = torch.linspace(0, 10, T).view(1, 1, 1, 1, T).expand(B, X, Y, Z, T) + + if saturation_fn is not None: + target = saturation_fn(B, X, Y, Z, T) + else: + target = torch.ones(B, X, Y, Z, T) * 0.3 + + return inputs, target, target.clone() + + +# =================================================================== +# Grid width extraction +# =================================================================== + + +class TestExtractGridWidths: + """Tests for ExtractGridWidths.""" + + def test_2d_extraction(self): + """Verify 2D grid width extraction returns correct dx and dy vectors.""" + inputs, _, _ = _build_2d_reservoir() + widths = _extract_grid_widths(inputs, spatial_ndim=2) + assert len(widths) == 2 + assert torch.allclose(widths[0], DY_2D) # axis-1 = H = dy + assert torch.allclose(widths[1], DX_2D) # axis-2 = W = dx + + def test_3d_extraction(self): + """Verify 3D grid width extraction returns correct dx, dy, and dz vectors.""" + inputs, _, _ = _build_3d_reservoir() + widths = _extract_grid_widths(inputs, spatial_ndim=3) + assert len(widths) == 3 + assert torch.allclose(widths[0], DX_3D) # axis-1 = X = dx + assert torch.allclose(widths[1], DY_3D) # axis-2 = Y = dy + assert torch.allclose(widths[2], DZ_3D) # axis-3 = Z = dz + + +# =================================================================== +# Cell volume computation +# =================================================================== + + +class TestCellVolumes: + """Tests for CellVolumes.""" + + def test_2d_total_area(self): + """Total area of 2D reservoir should be 3300 m^2.""" + inputs, _, _ = _build_2d_reservoir() + vol = compute_cell_volumes_from_widths(inputs, spatial_ndim=2) + assert vol.shape == (4, 6) + total = vol.sum().item() + assert abs(total - TOTAL_AREA_2D) < 1e-3, ( + f"Expected {TOTAL_AREA_2D}, got {total}" + ) + + def test_2d_individual_cells(self): + """Spot-check: cell [0,0] should have area dy[0]*dx[0] = 5*10 = 50.""" + inputs, _, _ = _build_2d_reservoir() + vol = compute_cell_volumes_from_widths(inputs, spatial_ndim=2) + assert torch.isclose(vol[0, 0], torch.tensor(50.0)) + # cell [1,2] = dy[1]*dx[2] = 10*30 = 300 + assert torch.isclose(vol[1, 2], torch.tensor(300.0)) + # cell [3,5] = dy[3]*dx[5] = 10*10 = 100 + assert torch.isclose(vol[3, 5], torch.tensor(100.0)) + + def test_3d_total_volume(self): + """Total volume of 3D reservoir should be 27000 m^3.""" + inputs, _, _ = _build_3d_reservoir() + vol = compute_cell_volumes_from_widths(inputs, spatial_ndim=3) + assert vol.shape == (3, 4, 2) + total = vol.sum().item() + assert abs(total - TOTAL_VOL_3D) < 1e-2, f"Expected {TOTAL_VOL_3D}, got {total}" + + def test_3d_individual_cells(self): + """Spot-check: cell [0,0,0] = dx[0]*dy[0]*dz[0] = 10*5*8 = 400.""" + inputs, _, _ = _build_3d_reservoir() + vol = compute_cell_volumes_from_widths(inputs, spatial_ndim=3) + assert torch.isclose(vol[0, 0, 0], torch.tensor(400.0)) + # cell [1,2,1] = 20*5*12 = 1200 + assert torch.isclose(vol[1, 2, 1], torch.tensor(1200.0)) + + +# =================================================================== +# Cell centre distance +# =================================================================== + + +class TestCellCentreDistance: + """Tests for CellCentreDistance.""" + + def test_uniform_grid(self): + """Verify cell-centre distances are constant on a uniform grid.""" + widths = torch.tensor([2.0, 2.0, 2.0, 2.0]) + d = cell_centre_distance(widths) + assert d.shape == (2,) + # d[i] = w[i]/2 + w[i+1] + w[i+2]/2 = 1+2+1 = 4 + assert torch.allclose(d, torch.tensor([4.0, 4.0])) + + def test_non_uniform(self): + """Verify cell-centre distances match hand-computed values on a non-uniform grid.""" + widths = DX_2D # [10, 20, 30, 15, 25, 10] + d = cell_centre_distance(widths) + assert d.shape == (4,) + # d[0] = 10/2 + 20 + 30/2 = 5 + 20 + 15 = 40 + assert torch.isclose(d[0], torch.tensor(40.0)) + # d[1] = 20/2 + 30 + 15/2 = 10 + 30 + 7.5 = 47.5 + assert torch.isclose(d[1], torch.tensor(47.5)) + + +# =================================================================== +# Central difference +# =================================================================== + + +class TestCentralDifference: + """Tests for CentralDifference.""" + + def test_linear_field_2d(self): + """Derivative of linear field f(w) = w should be 1/spacing.""" + B, H, W, T = 1, 4, 6, 2 + field = torch.arange(W, dtype=torch.float).view(1, 1, W, 1).expand(B, H, W, T) + widths = torch.ones(W) + spacing = cell_centre_distance(widths) # all = 2.0 + + deriv = central_difference(field, axis=2, spacing=spacing) + # (f[i+2] - f[i]) / 2 = 2/2 = 1 for a linear field + assert deriv.shape == (B, H, W - 2, T) + assert torch.allclose(deriv, torch.ones_like(deriv)) + + def test_non_uniform_spacing(self): + """Known values with non-uniform grid.""" + _B, _H, _W, _T = 1, 1, 6, 1 + # f = [0, 1, 3, 6, 10, 15] + field = torch.tensor([0.0, 1.0, 3.0, 6.0, 10.0, 15.0]).view(1, 1, 6, 1) + widths = DX_2D # [10, 20, 30, 15, 25, 10] + spacing = cell_centre_distance(widths) # [40, 47.5, 32.5, 27.5] + + deriv = central_difference(field, axis=2, spacing=spacing) + # d[0] = (f[2]-f[0])/40 = 3/40 = 0.075 + assert torch.isclose(deriv[0, 0, 0, 0], torch.tensor(3.0 / 40.0)) + # d[1] = (f[3]-f[1])/47.5 = 5/47.5 + assert torch.isclose(deriv[0, 0, 1, 0], torch.tensor(5.0 / 47.5)) + + def test_3d_axis(self): + """Verify central difference output shape for a 3D field along axis 1.""" + B, X, Y, Z, T = 1, 6, 4, 3, 2 + field = torch.randn(B, X, Y, Z, T) + spacing = cell_centre_distance(torch.ones(X)) + deriv = central_difference(field, axis=1, spacing=spacing) + assert deriv.shape == (B, X - 2, Y, Z, T) + + +# =================================================================== +# Mass Conservation Loss — Analytic 2D Reservoir +# =================================================================== + + +class TestMassConservation2DReservoir: + """Tests for MassConservation2DReservoir.""" + + def test_zero_loss_identical_fields(self): + """Verify mass conservation loss is zero when prediction equals target.""" + inputs, target, pred = _build_2d_reservoir() + loss = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + def test_known_mass_imbalance(self): + """Add extra saturation to one cell and verify the loss reflects it.""" + inputs, target, pred = _build_2d_reservoir(T=1) + # target: all 0.5, so total mass at t=0 = 0.5 * 3300 = 1650 + # Add 1.0 to cell [0,0] (volume=50): pred mass = 1650 + 50 = 1700 + pred[0, 0, 0, 0] += 1.0 + + loss_fn = MassConservationLoss(use_cell_volumes=True) + loss = loss_fn(pred, target, inputs) + + # Analytic: M_true = [1650], M_pred = [1700] + # L = |1650-1700| / (|1650| + eps) = 50/1650 ~ 0.03030 + expected = 50.0 / 1650.0 + assert abs(loss.item() - expected) < 1e-4, ( + f"Expected ~{expected:.4f}, got {loss.item():.4f}" + ) + + def test_uniform_vs_volume_weighted(self): + """Volume weighting should change the loss when grid is non-uniform.""" + inputs, target, _ = _build_2d_reservoir(T=1) + pred = target.clone() + pred[0, 0, 0, 0] += 1.0 # cell [0,0] has small volume (50) + pred_large = target.clone() + pred_large[0, 1, 2, 0] += 1.0 # cell [1,2] has large volume (300) + + loss_fn_vol = MassConservationLoss(use_cell_volumes=True) + loss_fn_uni = MassConservationLoss(use_cell_volumes=False) + + # With volume weighting, perturbation in larger cell has more impact + loss_small_vol = loss_fn_vol(pred, target, inputs).item() + loss_large_vol = loss_fn_vol(pred_large, target, inputs).item() + assert loss_large_vol > loss_small_vol + + # Without volume weighting, both perturbations add same amount (+1.0 to sum) + loss_small_uni = loss_fn_uni(pred, target, inputs).item() + loss_large_uni = loss_fn_uni(pred_large, target, inputs).item() + assert abs(loss_small_uni - loss_large_uni) < 1e-6 + + def test_mass_conservation_multi_timestep(self): + """Mass error at different timesteps contributes via L2 norm.""" + inputs, target, pred = _build_2d_reservoir(T=3) + # Perturb only at t=0 + pred[0, 0, 0, 0] += 1.0 + loss_t0_only = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + + # Now also perturb at t=1 + pred[0, 0, 0, 1] += 1.0 + loss_t01 = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + + # L2 norm over time: more errors => higher loss + assert loss_t01 > loss_t0_only + + def test_with_spatial_mask(self): + """Mask out the perturbed cell => loss should be zero.""" + inputs, target, pred = _build_2d_reservoir(T=1) + pred[0, 0, 0, 0] += 10.0 + + mask = torch.ones(4, 6, dtype=torch.bool) + mask[0, 0] = False # mask out the perturbed cell + + loss = MassConservationLoss(use_cell_volumes=True)( + pred, target, inputs, spatial_mask=mask + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) + + +# =================================================================== +# Mass Conservation Loss — Analytic 3D Reservoir +# =================================================================== + + +class TestMassConservation3DReservoir: + """Tests for MassConservation3DReservoir.""" + + def test_zero_loss_identical(self): + """Verify 3D mass conservation loss is zero when prediction equals target.""" + inputs, target, pred = _build_3d_reservoir() + loss = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + def test_known_mass_imbalance_3d(self): + """Perturb one cell and check analytic loss.""" + inputs, target, pred = _build_3d_reservoir(T=1) + # target: all 0.3, total mass = 0.3 * 27000 = 8100 + # Perturb cell [0,0,0] (volume=400): pred mass = 8100 + 400*1.0 = 8500 + pred[0, 0, 0, 0, 0] += 1.0 + + loss = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + expected = 400.0 / 8100.0 + assert abs(loss.item() - expected) < 1e-3, ( + f"Expected ~{expected:.4f}, got {loss.item():.4f}" + ) + + def test_gradient_flow(self): + """Verify gradients propagate through the 3D mass conservation loss.""" + inputs, target, _ = _build_3d_reservoir() + pred = torch.randn_like(target, requires_grad=True) + loss = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + loss.backward() + assert pred.grad is not None + + def test_scale_invariance(self): + """Verify relative mass conservation loss is invariant to uniform scaling.""" + inputs, target, _ = _build_3d_reservoir() + pred = target + 0.1 * torch.randn_like(target) + fn = MassConservationLoss(use_cell_volumes=True) + l1 = fn(pred, target, inputs) + # Reset cache for fresh volumes + fn2 = MassConservationLoss(use_cell_volumes=True) + l2 = fn2(pred * 5, target * 5, inputs) + assert torch.isclose(l1, l2, rtol=1e-4) + + +# =================================================================== +# Derivative utilities — extract_grid_widths_for_axis +# =================================================================== + + +class TestExtractGridWidthsForAxis: + """Tests for ExtractGridWidthsForAxis.""" + + def test_2d_dx(self): + """Verify dx extraction from 2D reservoir inputs matches expected widths.""" + inputs, _, _ = _build_2d_reservoir() + w = extract_grid_widths_for_axis(inputs, spatial_ndim=2, dim_name="dx") + assert torch.allclose(w, DX_2D) + + def test_2d_dy(self): + """Verify dy extraction from 2D reservoir inputs matches expected widths.""" + inputs, _, _ = _build_2d_reservoir() + w = extract_grid_widths_for_axis(inputs, spatial_ndim=2, dim_name="dy") + assert torch.allclose(w, DY_2D) + + def test_3d_all(self): + """Verify dx, dy, and dz extraction from 3D reservoir inputs.""" + inputs, _, _ = _build_3d_reservoir() + assert torch.allclose(extract_grid_widths_for_axis(inputs, 3, "dx"), DX_3D) + assert torch.allclose(extract_grid_widths_for_axis(inputs, 3, "dy"), DY_3D) + assert torch.allclose(extract_grid_widths_for_axis(inputs, 3, "dz"), DZ_3D) + + def test_invalid_dim_raises(self): + """Verify ValueError is raised for an unknown derivative dimension name.""" + inputs, _, _ = _build_2d_reservoir() + with pytest.raises(ValueError, match="Unknown derivative dim"): + extract_grid_widths_for_axis(inputs, 2, "dz") + + +# =================================================================== +# build_physics_losses +# =================================================================== + + +class TestBuildPhysicsLosses: + """Tests for BuildPhysicsLosses.""" + + def test_none_config(self): + """Verify build_physics_losses returns empty dict for None config.""" + assert build_physics_losses(None) == {} + + def test_disabled(self): + """Verify build_physics_losses returns empty dict when loss is disabled.""" + assert build_physics_losses({"mass_conservation": {"enabled": False}}) == {} + + def test_enabled(self): + """Verify build_physics_losses creates a MassConservationLoss with correct weight.""" + result = build_physics_losses( + {"mass_conservation": {"enabled": True, "weight": 0.5}} + ) + assert "mass_conservation" in result + mod, w = result["mass_conservation"] + assert isinstance(mod, MassConservationLoss) + assert w == 0.5 + + def test_use_cell_volumes_flag(self): + """Verify use_cell_volumes flag is propagated to MassConservationLoss.""" + cfg = { + "mass_conservation": { + "enabled": True, + "weight": 1.0, + "use_cell_volumes": True, + } + } + mod, _ = build_physics_losses(cfg)["mass_conservation"] + assert mod.use_cell_volumes is True + + def test_pressure_warning(self): + """Verify a UserWarning is emitted when applying mass conservation to pressure.""" + cfg = {"mass_conservation": {"enabled": True, "weight": 1.0}} + with pytest.warns(UserWarning, match="pressure"): + build_physics_losses(cfg, variable="pressure") + + def test_zero_weight_skipped(self): + """Verify a physics loss with zero weight is omitted from the result.""" + assert ( + build_physics_losses( + {"mass_conservation": {"enabled": True, "weight": 0.0}} + ) + == {} + ) + + +# =================================================================== +# Integration: UnifiedLoss + physics on dummy reservoir +# =================================================================== + + +class TestIntegrationReservoir: + """Tests for IntegrationReservoir.""" + + def test_full_pipeline_2d(self): + """End-to-end: data loss + derivative + mass conservation on 2D reservoir.""" + B, T = 1, 3 + inputs, target, pred = _build_2d_reservoir(B=B, T=T) + pred = pred + 0.01 * torch.randn_like(pred) + + fn = UnifiedLoss( + types=["mse"], + weights=[1.0], + derivative_config={"enabled": True, "weight": 0.1, "dims": ["dx"]}, + physics_losses={"mc": (MassConservationLoss(use_cell_volumes=True), 0.5)}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + assert not torch.isnan(loss) + + def test_full_pipeline_3d(self): + """End-to-end on 3D reservoir with all derivative dims.""" + B, T = 1, 2 + inputs, target, pred = _build_3d_reservoir(B=B, T=T) + pred = pred + 0.01 * torch.randn_like(pred) + + fn = UnifiedLoss( + types=["relative_l2"], + weights=[1.0], + derivative_config={ + "enabled": True, + "weight": 0.1, + "dims": ["dx", "dy", "dz"], + }, + physics_losses={"mc": (MassConservationLoss(use_cell_volumes=True), 0.5)}, + ) + loss = fn(pred, target, inputs) + assert loss > 0 + assert not torch.isnan(loss) + + def test_gradient_full_pipeline(self): + """Verify gradients propagate through the full unified loss pipeline.""" + inputs, target, _ = _build_2d_reservoir() + pred = torch.randn_like(target, requires_grad=True) + fn = UnifiedLoss( + types=["relative_l2"], + derivative_config={"enabled": True, "weight": 0.5, "dims": ["dx", "dy"]}, + physics_losses={"mc": (MassConservationLoss(use_cell_volumes=True), 1.0)}, + ) + loss = fn(pred, target, inputs) + loss.backward() + assert pred.grad is not None + + def test_factory_end_to_end(self): + """Build from config dict like Hydra would produce.""" + cfg = { + "types": ["relative_l2"], + "weights": [1.0], + "derivative": {"enabled": True, "weight": 0.3, "dims": ["dx"]}, + "physics": { + "mass_conservation": { + "enabled": True, + "weight": 0.5, + "use_cell_volumes": True, + }, + }, + } + fn = get_loss_function(cfg, variable="saturation") + inputs, target, pred = _build_2d_reservoir() + pred = pred + 0.01 * torch.randn_like(pred) + loss = fn(pred, target, inputs) + assert loss > 0 + + +# =================================================================== +# AR window compatibility +# =================================================================== + + +class TestARWindow: + """Tests for ARWindow.""" + + def test_single_timestep_mass(self): + """Verify mass conservation loss works for a single-timestep AR window.""" + B, H, W, C = 1, 4, 6, 12 + inputs = torch.zeros(B, H, W, 1, C) + inputs[..., -3] = DX_2D.view(1, 1, 6, 1) + inputs[..., -2] = DY_2D.view(1, 4, 1, 1) + inputs[..., -1] = 0.0 + + target = torch.ones(B, H, W, 1) * 0.5 + pred = target.clone() + pred[0, 0, 0, 0] += 1.0 + + loss = MassConservationLoss(use_cell_volumes=True)(pred, target, inputs) + assert not torch.isnan(loss) + assert loss > 0 + + +# =================================================================== +# Mass conservation metric options +# =================================================================== + + +class TestMassConservationMetrics: + """Tests for configurable metric in MassConservationLoss.""" + + @pytest.fixture + def reservoir_data(self): + """Simple 2D reservoir data for metric tests.""" + inputs, target, pred = _build_2d_reservoir(T=3) + pred = pred + 0.05 * torch.randn_like(pred) + return inputs, target, pred + + @pytest.mark.parametrize("metric", ["relative_l2", "mse", "l1", "huber"]) + def test_all_metrics_run(self, metric, reservoir_data): + """Verify each supported metric produces a finite positive loss.""" + inputs, target, pred = reservoir_data + fn = MassConservationLoss(metric=metric) + loss = fn(pred, target, inputs) + assert not torch.isnan(loss) + assert loss > 0 + + def test_invalid_metric_raises(self): + """Verify ValueError is raised for an unsupported metric name.""" + with pytest.raises(ValueError, match="metric must be"): + MassConservationLoss(metric="invalid") + + def test_default_metric_is_relative_l2(self): + """Verify the default metric for MassConservationLoss is relative_l2.""" + fn = MassConservationLoss() + assert fn.metric == "relative_l2" + + def test_mse_metric_value(self): + """MSE metric on integrated quantities should match hand calculation.""" + inputs, target, _ = _build_2d_reservoir(T=1) + pred = target.clone() + pred[0, 0, 0, 0] += 1.0 + fn = MassConservationLoss(use_cell_volumes=True, metric="mse") + loss = fn(pred, target, inputs) + # M_true = 0.5 * 3300 = 1650, M_pred = 1650 + 50 = 1700 + # MSE over time (T=1): (1650 - 1700)^2 = 2500 + expected = 2500.0 + assert abs(loss.item() - expected) < 1.0 + + def test_l1_metric_value(self): + """Verify L1 metric returns the expected absolute mass difference.""" + inputs, target, _ = _build_2d_reservoir(T=1) + pred = target.clone() + pred[0, 0, 0, 0] += 1.0 + fn = MassConservationLoss(use_cell_volumes=True, metric="l1") + loss = fn(pred, target, inputs) + # L1: |1650 - 1700| = 50 + expected = 50.0 + assert abs(loss.item() - expected) < 0.1 + + def test_different_metrics_give_different_values(self, reservoir_data): + """Verify different metrics produce distinct loss values on the same data.""" + inputs, target, pred = reservoir_data + losses = {} + for metric in ["relative_l2", "mse", "l1"]: + fn = MassConservationLoss(metric=metric) + losses[metric] = fn(pred, target, inputs).item() + assert losses["mse"] != losses["l1"] + assert losses["relative_l2"] != losses["mse"] + + +class TestMetricInheritance: + """Tests for metric inheritance from data loss via factory.""" + + def test_null_inherits_first_data_loss(self): + """Verify null metric inherits the provided default_metric.""" + cfg = { + "mass_conservation": {"enabled": True, "weight": 1.0, "metric": None}, + } + result = build_physics_losses(cfg, default_metric="mse") + mod, _ = result["mass_conservation"] + assert mod.metric == "mse" + + def test_null_inherits_relative_l2_by_default(self): + """Verify omitted metric defaults to relative_l2 when no default given.""" + cfg = { + "mass_conservation": {"enabled": True, "weight": 1.0}, + } + result = build_physics_losses(cfg) + mod, _ = result["mass_conservation"] + assert mod.metric == "relative_l2" + + def test_explicit_metric_overrides_default(self): + """Verify an explicit metric in config overrides the default_metric.""" + cfg = { + "mass_conservation": {"enabled": True, "weight": 1.0, "metric": "l1"}, + } + result = build_physics_losses(cfg, default_metric="mse") + mod, _ = result["mass_conservation"] + assert mod.metric == "l1" + + def test_factory_end_to_end_inheritance(self): + """Verify factory inherits data loss type as physics metric end-to-end.""" + from training.losses import get_loss_function + + cfg = { + "types": ["mse"], + "weights": [1.0], + "physics": { + "mass_conservation": {"enabled": True, "weight": 0.5}, + }, + } + fn = get_loss_function(cfg, variable="saturation") + mod, _ = fn._physics_losses["mass_conservation"] + assert mod.metric == "mse" + + def test_factory_explicit_override(self): + """Verify factory respects an explicit physics metric override in config.""" + from training.losses import get_loss_function + + cfg = { + "types": ["mse"], + "weights": [1.0], + "physics": { + "mass_conservation": { + "enabled": True, + "weight": 0.5, + "metric": "relative_l2", + }, + }, + } + fn = get_loss_function(cfg, variable="saturation") + mod, _ = fn._physics_losses["mass_conservation"] + assert mod.metric == "relative_l2" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_physicsnemo_unet.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_physicsnemo_unet.py new file mode 100644 index 0000000000..eb14fa33fd --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_physicsnemo_unet.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for PhysicsNeMo UNet wrappers.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from models.physicsnemo_unet import PhysicsNemoUNet2D, PhysicsNemoUNet3D, StandaloneUNet + + +class TestPhysicsNemoUNet2D: + """Tests for 2D UNet wrapper (adds/removes dummy T dim).""" + + def test_output_shape(self): + """Verify 2D UNet output shape matches input spatial dimensions.""" + unet = PhysicsNemoUNet2D( + in_channels=32, out_channels=32, kernel_size=3, model_depth=1 + ) + x = torch.randn(2, 32, 16, 24) + out = unet(x) + assert out.shape == (2, 32, 16, 24) + + def test_preserves_spatial_dims(self): + """Verify 2D UNet preserves spatial H and W dimensions.""" + unet = PhysicsNemoUNet2D(in_channels=16, out_channels=16, model_depth=1) + x = torch.randn(1, 16, 32, 48) + assert unet(x).shape[2:] == x.shape[2:] + + def test_different_in_out_channels(self): + """Verify 2D UNet handles different input and output channel counts.""" + unet = PhysicsNemoUNet2D(in_channels=8, out_channels=16, model_depth=1) + x = torch.randn(1, 8, 16, 16) + assert unet(x).shape == (1, 16, 16, 16) + + def test_gradient_flow(self): + """Verify gradients propagate through the 2D UNet.""" + unet = PhysicsNemoUNet2D(in_channels=16, out_channels=16, model_depth=1) + x = torch.randn(1, 16, 16, 24, requires_grad=True) + unet(x).sum().backward() + assert x.grad is not None + + +class TestPhysicsNemoUNet3D: + """Tests for 3D UNet wrapper (passthrough).""" + + def test_output_shape(self): + """Verify 3D UNet output shape matches input spatial dimensions.""" + unet = PhysicsNemoUNet3D(in_channels=32, out_channels=32, kernel_size=3) + x = torch.randn(2, 32, 8, 16, 8) + out = unet(x) + assert out.shape == (2, 32, 8, 16, 8) + + def test_different_channels(self): + """Verify 3D UNet handles different input and output channel counts.""" + unet = PhysicsNemoUNet3D(in_channels=8, out_channels=4) + x = torch.randn(1, 8, 8, 16, 8) + assert unet(x).shape == (1, 4, 8, 16, 8) + + +class TestStandaloneUNet: + """Tests for StandaloneUNet (channel-last convention).""" + + def test_output_shape(self): + """Verify StandaloneUNet output shape with channel-last convention.""" + unet = StandaloneUNet(in_channels=12, out_channels=1, unet_type="physicsnemo") + x = torch.randn(2, 16, 24, 8, 12) # (B, H, W, T, C) + out = unet(x) + assert out.shape == (2, 16, 24, 8) # (B, H, W, T) + + def test_gradient_flow(self): + """Verify gradients propagate through the StandaloneUNet.""" + unet = StandaloneUNet(in_channels=5, out_channels=1, unet_type="physicsnemo") + x = torch.randn(1, 16, 16, 8, 5, requires_grad=True) + unet(x).sum().backward() + assert x.grad is not None + + def test_count_params(self): + """Verify count_params returns a positive parameter count.""" + unet = StandaloneUNet(in_channels=12, out_channels=1, unet_type="physicsnemo") + assert unet.count_params() > 0 + + def test_custom_unet_type_raises(self): + """Verify ValueError is raised for the unsupported 'custom' unet_type.""" + with pytest.raises(ValueError, match="Custom UNet3D"): + StandaloneUNet(in_channels=12, out_channels=1, unet_type="custom") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_scalar_utils.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_scalar_utils.py new file mode 100644 index 0000000000..17c05a516d --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_scalar_utils.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MIONet scalar channel detection utilities.""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from data.scalar_utils import ( + create_mionet_collate_fn, + detect_scalar_channels, + verify_scalar_consistency, +) + + +def _make_sample(H=8, W=16, T=4, n_spatial=9, n_scalar=3): + """Create a sample with known scalar and spatial channels. + + Returns (sample, scalar_indices, spatial_indices) where: + - Channels 0..n_spatial-1 vary spatially + - Channels n_spatial..n_spatial+n_scalar-1 are constant + """ + C = n_spatial + n_scalar + sample = torch.randn(H, W, T, C) + scalar_indices = [] + for i in range(n_spatial, C): + val = float(i) * 10.0 + sample[..., i] = val + scalar_indices.append(i) + spatial_indices = list(range(n_spatial)) + return sample, scalar_indices, spatial_indices + + +class TestDetectScalarChannels: + """Tests for detect_scalar_channels.""" + + def test_identifies_scalar_channels(self): + """Verify scalar and spatial channel indices are correctly identified.""" + sample, expected_scalar, expected_spatial = _make_sample() + result = detect_scalar_channels(sample) + assert set(result["scalar_indices"]) == set(expected_scalar) + assert set(result["spatial_indices"]) == set(expected_spatial) + + def test_num_channels(self): + """Verify reported scalar and spatial channel counts match construction.""" + sample, _, _ = _make_sample(n_spatial=9, n_scalar=3) + result = detect_scalar_channels(sample) + assert result["num_scalar_channels"] == 3 + assert result["num_spatial_channels"] == 9 + + def test_all_spatial(self): + """Verify detection returns zero scalar channels for a fully varying tensor.""" + sample = torch.randn(8, 16, 4, 5) + result = detect_scalar_channels(sample) + assert result["num_scalar_channels"] == 0 + assert result["num_spatial_channels"] == 5 + + def test_all_scalar(self): + """Verify detection identifies all channels as scalar when all are constant.""" + sample = torch.ones(8, 16, 4, 3) + for i in range(3): + sample[..., i] = float(i) + result = detect_scalar_channels(sample) + assert result["num_scalar_channels"] == 3 + + def test_batched_input(self): + """Verify scalar detection works correctly on batched (5D) inputs.""" + sample = torch.randn(2, 8, 16, 4, 12) + sample[..., 10] = 5.0 + sample[..., 11] = 7.0 + result = detect_scalar_channels(sample) + assert 10 in result["scalar_indices"] + assert 11 in result["scalar_indices"] + + def test_scalar_values_tensor(self): + """Verify scalar_values tensor has the correct shape for detected scalars.""" + sample, _, _ = _make_sample(n_spatial=2, n_scalar=2) + result = detect_scalar_channels(sample) + assert result["scalar_values"].shape == (2,) + + +class TestVerifyScalarConsistency: + """Tests for verify_scalar_consistency.""" + + def test_consistent_dataset(self): + """Verify consistency check passes when scalar channels are uniform.""" + + class FakeDataset: + """Fake dataset with uniform scalar channels for testing.""" + + def __len__(self): + return 5 + + def __getitem__(self, i): + s, _, _ = _make_sample(n_scalar=2) + return s, torch.zeros(8, 16, 4) + + is_ok, msg = verify_scalar_consistency(FakeDataset(), [9, 10], num_samples=3) + assert is_ok + assert msg is None + + def test_inconsistent_dataset(self): + """Verify consistency check fails when scalar channels vary across samples.""" + + class FakeDataset: + """Fake dataset with inconsistent scalar channels for testing.""" + + def __len__(self): + return 5 + + def __getitem__(self, i): + s = torch.randn(8, 16, 4, 12) + if i == 0: + s[..., 10] = 5.0 + return s, torch.zeros(8, 16, 4) + + is_ok, msg = verify_scalar_consistency(FakeDataset(), [10], num_samples=3) + assert not is_ok + + +class TestMIONetCollateFn: + """Tests for create_mionet_collate_fn.""" + + def test_separates_channels(self): + """Verify collate function splits spatial and scalar channels correctly.""" + scalar_idx = [3, 4] + spatial_idx = [0, 1, 2] + collate = create_mionet_collate_fn(scalar_idx, spatial_idx) + + batch = [] + for _ in range(2): + inp = torch.randn(8, 16, 4, 5) + inp[..., 3] = 10.0 + inp[..., 4] = 20.0 + tgt = torch.randn(8, 16, 4) + batch.append((inp, tgt)) + + spatial, scalar, targets = collate(batch) + assert spatial.shape == (2, 8, 16, 4, 3) + assert scalar.shape == (2, 2) + assert targets.shape == (2, 8, 16, 4) + + def test_scalar_values_correct(self): + """Verify collate function extracts the correct scalar constant values.""" + collate = create_mionet_collate_fn([2], [0, 1]) + inp = torch.randn(4, 6, 3, 3) + inp[..., 2] = 42.0 + _, scalar, _ = collate([(inp, torch.zeros(4, 6, 3))]) + assert torch.isclose(scalar[0, 0], torch.tensor(42.0)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_unet.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_unet.py new file mode 100644 index 0000000000..0b6b7a9134 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_unet.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for custom UNet implementations (UNet2D, UNet3D).""" + +import sys +from pathlib import Path + +import pytest +import torch + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from models.unet import UNet2D, UNet3D + + +class TestUNet2D: + """Tests for UNet2D model.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def test_forward_pass(self, device): + """Test UNet2D forward pass with valid input dimensions.""" + torch.manual_seed(42) + model = UNet2D( + input_channels=32, + output_channels=32, + kernel_size=3, + dropout_rate=0.0, + ).to(device) + + # Input must be divisible by 8: (B, C, H, W) + batch_size = 2 + x = torch.randn(batch_size, 32, 64, 64).to(device) + output = model(x) + + assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" + + def test_different_channel_sizes(self, device): + """Test UNet2D with same input/output channel sizes. + + Note: The current UNet implementation is designed for U-FNO where + input_channels == output_channels (constant channel dimension in latent space). + """ + torch.manual_seed(42) + + # UNet is designed for same input/output channels (for U-FNO latent space) + for channels in [16, 32, 64]: + model = UNet2D( + input_channels=channels, + output_channels=channels, + kernel_size=3, + ).to(device) + + x = torch.randn(2, channels, 32, 32).to(device) + output = model(x) + + expected_shape = (2, channels, 32, 32) + assert output.shape == expected_shape, ( + f"For channels={channels}: " + f"expected {expected_shape}, got {output.shape}" + ) + + def test_invalid_dimensions(self, device): + """Test that UNet2D raises error for invalid input dimensions.""" + model = UNet2D(input_channels=32, output_channels=32).to(device) + + # Dimensions not divisible by 8 should raise ValueError + x = torch.randn(2, 32, 65, 65).to(device) + + with pytest.raises(ValueError, match="must be divisible by 8"): + model(x) + + def test_dropout(self, device): + """Test UNet2D with dropout enabled.""" + model = UNet2D( + input_channels=32, + output_channels=32, + dropout_rate=0.5, + ).to(device) + + x = torch.randn(2, 32, 32, 32).to(device) + + # Training mode - dropout active + model.train() + output1 = model(x) + _output2 = model(x) + + # With dropout, outputs should differ (with high probability) + # Note: This is a probabilistic test + assert output1.shape == x.shape + + def test_count_params(self, device): + """Test parameter counting.""" + model = UNet2D(input_channels=32, output_channels=32).to(device) + param_count = model.count_params() + + assert param_count > 0 + assert isinstance(param_count, int) + + +class TestUNet3D: + """Tests for UNet3D model.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def test_forward_pass(self, device): + """Test UNet3D forward pass with valid input dimensions.""" + torch.manual_seed(42) + model = UNet3D( + input_channels=32, + output_channels=32, + kernel_size=3, + dropout_rate=0.0, + ).to(device) + + # Input must be divisible by 8: (B, C, H, W, T) + batch_size = 2 + x = torch.randn(batch_size, 32, 32, 32, 16).to(device) + output = model(x) + + assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" + + def test_spatiotemporal_dimensions(self, device): + """Test UNet3D with CO2 sequestration-like dimensions.""" + torch.manual_seed(42) + model = UNet3D( + input_channels=36, # Typical latent width + output_channels=36, + kernel_size=3, + ).to(device) + + # Simulated padded dimensions (H=96+8=104, W=200+8=208 -> use 104, 104, 32 for test) + # Must be divisible by 8 + x = torch.randn(2, 36, 104, 104, 32).to(device) + output = model(x) + + assert output.shape == x.shape + + def test_different_channel_sizes(self, device): + """Test UNet3D with same input/output channel sizes. + + Note: The current UNet implementation is designed for U-FNO where + input_channels == output_channels (constant channel dimension in latent space). + """ + torch.manual_seed(42) + + # UNet is designed for same input/output channels (for U-FNO latent space) + for channels in [16, 32, 64]: + model = UNet3D( + input_channels=channels, + output_channels=channels, + kernel_size=3, + ).to(device) + + x = torch.randn(1, channels, 16, 16, 16).to(device) + output = model(x) + + expected_shape = (1, channels, 16, 16, 16) + assert output.shape == expected_shape, ( + f"For channels={channels}: " + f"expected {expected_shape}, got {output.shape}" + ) + + def test_invalid_dimensions(self, device): + """Test that UNet3D raises error for invalid input dimensions.""" + model = UNet3D(input_channels=32, output_channels=32).to(device) + + # Dimensions not divisible by 8 should raise ValueError + x = torch.randn(2, 32, 33, 33, 17).to(device) + + with pytest.raises(ValueError, match="must be divisible by 8"): + model(x) + + def test_gradient_flow(self, device): + """Test that gradients flow through UNet3D.""" + torch.manual_seed(42) + model = UNet3D(input_channels=16, output_channels=16).to(device) + + # Create tensor on device directly to ensure it's a leaf tensor + x = torch.randn(1, 16, 16, 16, 16, requires_grad=True, device=device) + output = model(x) + loss = output.sum() + loss.backward() + + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_count_params(self, device): + """Test parameter counting.""" + model = UNet3D(input_channels=32, output_channels=32).to(device) + param_count = model.count_params() + + assert param_count > 0 + assert isinstance(param_count, int) + + # UNet3D should have more parameters than UNet2D with same config + model_2d = UNet2D(input_channels=32, output_channels=32).to(device) + assert model.count_params() > model_2d.count_params() + + +class TestUNetIntegration: + """Integration tests for UNet models.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def test_unet3d_as_skip_connection(self, device): + """Test UNet3D used as skip connection (input + output addition).""" + torch.manual_seed(42) + model = UNet3D(input_channels=32, output_channels=32).to(device) + + x = torch.randn(2, 32, 32, 32, 16).to(device) + output = model(x) + + # Simulating skip connection: x + UNet(x) + combined = x + output + + assert combined.shape == x.shape + # Combined should be different from x + assert not torch.allclose(combined, x) + + def test_memory_efficiency(self, device): + """Test that model doesn't leak memory.""" + if device == "cpu": + pytest.skip("Memory test only relevant for GPU") + + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + model = UNet3D(input_channels=16, output_channels=16).to(device) + x = torch.randn(1, 16, 16, 16, 16).to(device) + + for _ in range(5): + output = model(x) + del output + + torch.cuda.empty_cache() + final_memory = torch.cuda.memory_allocated() + + # Memory should be roughly the same (within model parameters) + # Allow for some variance + memory_diff = abs(final_memory - initial_memory) + model_memory = sum(p.numel() * p.element_size() for p in model.parameters()) + + assert memory_diff < model_memory * 2, "Potential memory leak detected" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_xdeeponet.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_xdeeponet.py new file mode 100644 index 0000000000..920e4b02ef --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_xdeeponet.py @@ -0,0 +1,578 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for xDeepONet model variants (2D and 3D).""" + +import sys +from pathlib import Path + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from models.xdeeponet import ( + DeepONet, + DeepONet3D, + DeepONet3DWrapper, + DeepONetWrapper, + MLPBranch, + SpatialBranch, + SpatialBranch3D, + TrunkNet, +) + +BRANCH1_SPATIAL = { + "encoder": {"type": "linear", "activation_fn": "relu"}, + "layers": { + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, +} +BRANCH1_MLP = { + "encoder": { + "type": "mlp", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "relu", + }, + "layers": {"num_fourier_layers": 0, "num_unet_layers": 0, "num_conv_layers": 0}, +} +BRANCH2_SPATIAL = { + "encoder": {"type": "linear", "activation_fn": "relu"}, + "layers": { + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, +} +BRANCH2_MLP = { + "encoder": { + "type": "mlp", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "relu", + }, + "layers": {"num_fourier_layers": 0, "num_unet_layers": 0, "num_conv_layers": 0}, +} +TRUNK = { + "input_type": "time", + "hidden_width": 32, + "num_layers": 2, + "activation_fn": "tanh", +} + + +def _init_lazy(model, x, **kwargs): + """Run one forward pass to initialise LazyLinear modules.""" + with torch.no_grad(): + model(x, **kwargs) + + +class TestTrunkNet: + """Tests for TrunkNet.""" + + def test_output_shape(self): + """Verify TrunkNet output shape matches expected features.""" + trunk = TrunkNet(in_features=1, out_features=32, hidden_width=16, num_layers=3) + x = torch.randn(10, 1) + assert trunk(x).shape == (10, 32) + + def test_grid_input(self): + """Verify TrunkNet handles multi-dimensional grid input correctly.""" + trunk = TrunkNet(in_features=4, out_features=64, hidden_width=32, num_layers=2) + x = torch.randn(5, 4) + assert trunk(x).shape == (5, 64) + + +class TestMLPBranch: + """Tests for MLPBranch.""" + + def test_output_shape(self): + """Verify MLPBranch output shape matches expected features.""" + branch = MLPBranch(out_features=32, hidden_width=16, num_layers=3) + x = torch.randn(2, 50) + out = branch(x) + assert out.shape == (2, 32) + + +class TestSpatialBranch2D: + """Tests for 2D SpatialBranch.""" + + def test_output_shape(self): + """Verify 2D SpatialBranch output shape matches expected width.""" + branch = SpatialBranch( + in_channels=5, + width=16, + num_unet_layers=1, + kernel_size=3, + unet_impl="custom", + activation_fn="relu", + ) + x = torch.randn(2, 16, 24, 5) + _init_lazy(branch, x) + out = branch(x) + assert out.shape == (2, 16, 24, 16) + + +class TestSpatialBranch3D: + """Tests for 3D SpatialBranch.""" + + def test_output_shape(self): + """Verify 3D SpatialBranch output shape matches expected width.""" + branch = SpatialBranch3D( + in_channels=5, + width=16, + num_unet_layers=1, + kernel_size=3, + unet_impl="custom", + activation_fn="relu", + ) + x = torch.randn(2, 8, 16, 8, 5) + _init_lazy(branch, x) + out = branch(x) + assert out.shape == (2, 8, 16, 8, 16) + + +SINGLE_BRANCH_VARIANTS = ["deeponet", "u_deeponet", "conv_deeponet"] +DUAL_BRANCH_VARIANTS = ["mionet", "tno"] + + +class TestDeepONetWrapper2D: + """Tests for 2D DeepONet wrapper.""" + + @pytest.mark.parametrize("variant", SINGLE_BRANCH_VARIANTS) + def test_forward_shape_single_branch(self, variant): + """Verify 2D single-branch forward pass produces correct output shape.""" + B, H, W, T, C = 2, 16, 24, 4, 5 + model = DeepONetWrapper( + padding=8, + variant=variant, + width=32, + branch1_config=BRANCH1_SPATIAL, + trunk_config=TRUNK, + ) + x = torch.randn(B, H, W, T, C) + _init_lazy(model, x) + out = model(x) + assert out.shape == (B, H, W, T) + + @pytest.mark.parametrize("variant", DUAL_BRANCH_VARIANTS) + def test_forward_shape_dual_branch(self, variant): + """Verify 2D dual-branch forward pass produces correct output shape.""" + B, H, W, T, C = 2, 16, 24, 4, 5 + model = DeepONetWrapper( + padding=8, + variant=variant, + width=32, + branch1_config=BRANCH1_SPATIAL, + branch2_config=BRANCH2_SPATIAL, + trunk_config=TRUNK, + ) + x = torch.randn(B, H, W, T, C) + b2 = torch.randn(B, H, W, T) + _init_lazy(model, x, x_branch2=b2) + out = model(x, x_branch2=b2) + assert out.shape == (B, H, W, T) + + def test_target_times_changes_output_T(self): + """Verify target_times overrides the temporal output dimension size.""" + B, H, W, T_in, C = 2, 16, 24, 2, 5 + K = 5 + model = DeepONetWrapper( + padding=8, + variant="u_deeponet", + width=32, + branch1_config=BRANCH1_SPATIAL, + trunk_config=TRUNK, + ) + x = torch.randn(B, H, W, T_in, C) + tt = torch.linspace(0, 1, K) + _init_lazy(model, x) + out = model(x, target_times=tt) + assert out.shape == (B, H, W, K) + + def test_invalid_variant_raises(self): + """Verify ValueError is raised for an unknown DeepONet variant.""" + with pytest.raises(ValueError, match="Unknown variant"): + DeepONetWrapper( + variant="invalid", + width=32, + branch1_config=BRANCH1_SPATIAL, + trunk_config=TRUNK, + ) + + def test_count_params(self): + """Verify count_params returns a positive parameter count for 2D wrapper.""" + model = DeepONetWrapper( + padding=8, + variant="deeponet", + width=32, + branch1_config=BRANCH1_SPATIAL, + trunk_config=TRUNK, + ) + x = torch.randn(1, 16, 24, 2, 5) + _init_lazy(model, x) + assert model.count_params() > 0 + + def test_gradient_flow(self): + """Verify gradients propagate through the 2D DeepONet wrapper.""" + model = DeepONetWrapper( + padding=8, + variant="u_deeponet", + width=32, + branch1_config=BRANCH1_SPATIAL, + trunk_config=TRUNK, + ) + x = torch.randn(1, 16, 24, 2, 5) + _init_lazy(model, x) + x = torch.randn(1, 16, 24, 2, 5, requires_grad=True) + out = model(x) + out.sum().backward() + assert x.grad is not None + + +BRANCH1_3D = { + "encoder": {"type": "linear", "activation_fn": "relu"}, + "layers": { + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, +} +BRANCH2_3D = { + "encoder": {"type": "linear", "activation_fn": "relu"}, + "layers": { + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "kernel_size": 3, + "dropout": 0.0, + "unet_impl": "custom", + "activation_fn": "relu", + }, +} + + +class TestDeepONet3DWrapper: + """Tests for 3D DeepONet wrapper.""" + + @pytest.mark.parametrize("variant", SINGLE_BRANCH_VARIANTS) + def test_forward_shape_single_branch(self, variant): + """Verify 3D single-branch forward pass produces correct output shape.""" + B, X, Y, Z, T, C = 1, 8, 16, 8, 3, 5 + model = DeepONet3DWrapper( + padding=8, + variant=variant, + width=32, + branch1_config=BRANCH1_3D, + trunk_config=TRUNK, + ) + x = torch.randn(B, X, Y, Z, T, C) + _init_lazy(model, x) + out = model(x) + assert out.shape == (B, X, Y, Z, T) + + def test_tno_requires_branch2(self): + """Verify TNO variant produces correct output with a second branch.""" + B, X, Y, Z, T, C = 1, 8, 16, 8, 3, 5 + model = DeepONet3DWrapper( + padding=8, + variant="tno", + width=32, + branch1_config=BRANCH1_3D, + branch2_config=BRANCH2_3D, + trunk_config=TRUNK, + ) + x = torch.randn(B, X, Y, Z, T, C) + b2 = torch.randn(B, X, Y, Z, 1) + _init_lazy(model, x, x_branch2=b2) + out = model(x, x_branch2=b2) + assert out.shape == (B, X, Y, Z, T) + + def test_target_times_3d(self): + """Verify target_times overrides the temporal output dimension in 3D.""" + B, X, Y, Z, T_in, C = 1, 8, 16, 8, 1, 5 + K = 4 + model = DeepONet3DWrapper( + padding=8, + variant="u_deeponet", + width=32, + branch1_config=BRANCH1_3D, + trunk_config=TRUNK, + ) + x = torch.randn(B, X, Y, Z, T_in, C) + tt = torch.linspace(0, 1, K) + _init_lazy(model, x) + out = model(x, target_times=tt) + assert out.shape == (B, X, Y, Z, K) + + def test_count_params_3d(self): + """Verify count_params returns a positive parameter count for 3D wrapper.""" + model = DeepONet3DWrapper( + padding=8, + variant="deeponet", + width=32, + branch1_config=BRANCH1_3D, + trunk_config=TRUNK, + ) + x = torch.randn(1, 8, 16, 8, 2, 5) + _init_lazy(model, x) + assert model.count_params() > 0 + + +class TestHadamardProduct: + """Verify 3-way Hadamard product for multi-branch variants.""" + + def test_mionet_uses_multiplication(self): + """Verify MIONet variant computes a 3-way Hadamard product correctly.""" + model = DeepONetWrapper( + variant="mionet", + width=16, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + branch2_config={"encoder": "mlp", "hidden_width": 16, "num_layers": 2}, + trunk_config={"hidden_width": 16, "num_layers": 2}, + decoder_layers=0, + ) + x = torch.randn(2, 16, 24, 4, 6) + b2 = torch.randn(2, 6) + with torch.no_grad(): + out = model(x, x_branch2=b2) + assert out.shape == (2, 16, 24, 4) + + +class TestTemporalProjection: + """Test temporal_projection decoder mode.""" + + def test_2d_temporal_projection_output_shape(self): + """Verify 2D temporal-projection decoder produces correct output T dimension.""" + K = 3 + model = DeepONet( + variant="u_deeponet", + width=16, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + trunk_config={"hidden_width": 16, "num_layers": 2}, + decoder_type="temporal_projection", + decoder_layers=1, + decoder_width=16, + ) + model.set_output_window(K) + x_branch = torch.randn(2, 16, 24, 4) + x_time = torch.randn(1, 1) + with torch.no_grad(): + out = model(x_branch, x_time) + assert out.shape == (2, 16, 24, K) + + def test_2d_temporal_projection_with_branch2(self): + """Verify 2D temporal-projection works with a second branch input.""" + K = 5 + model = DeepONet( + variant="tno", + width=16, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + branch2_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + trunk_config={"hidden_width": 16, "num_layers": 2}, + decoder_type="temporal_projection", + decoder_layers=1, + decoder_width=16, + ) + model.set_output_window(K) + x_branch = torch.randn(2, 16, 24, 4) + x_branch2 = torch.randn(2, 16, 24, 4) + x_time = torch.randn(1, 1) + with torch.no_grad(): + out = model(x_branch, x_time, x_branch2=x_branch2) + assert out.shape == (2, 16, 24, K) + + def test_3d_temporal_projection(self): + """Verify 3D temporal-projection decoder produces correct output shape.""" + K = 4 + model = DeepONet3D( + variant="u_deeponet", + width=8, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + trunk_config={"hidden_width": 8, "num_layers": 2}, + decoder_type="temporal_projection", + decoder_layers=1, + decoder_width=8, + ) + model.set_output_window(K) + x_branch = torch.randn(2, 8, 8, 8, 4) + x_time = torch.randn(1, 1) + with torch.no_grad(): + out = model(x_branch, x_time) + assert out.shape == (2, 8, 8, 8, K) + + def test_mlp_decoder_still_works(self): + """Existing mlp decoder path is preserved.""" + model = DeepONet( + variant="u_deeponet", + width=16, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + trunk_config={"hidden_width": 16, "num_layers": 2}, + decoder_type="mlp", + decoder_layers=1, + decoder_width=16, + ) + x_branch = torch.randn(2, 16, 24, 4) + x_time = torch.randn(6, 1) + with torch.no_grad(): + out = model(x_branch, x_time) + assert out.shape == (2, 16, 24, 6) + + def test_gradient_flow_temporal_projection(self): + """Verify gradients propagate through the temporal-projection decoder.""" + K = 3 + model = DeepONet( + variant="tno", + width=16, + branch1_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + branch2_config={ + "encoder": "spatial", + "num_unet_layers": 0, + "num_conv_layers": 1, + "kernel_size": 3, + }, + trunk_config={"hidden_width": 16, "num_layers": 2}, + decoder_type="temporal_projection", + decoder_layers=1, + decoder_width=16, + ) + model.set_output_window(K) + x = torch.randn(2, 16, 24, 4, requires_grad=False) + b2 = torch.randn(2, 16, 24, 4, requires_grad=False) + t = torch.randn(1, 1) + out = model(x, t, x_branch2=b2) + loss = out.sum() + loss.backward() + assert model.temporal_head.weight.grad is not None + + +class TestInternalResolution: + """Test adaptive pooling in SpatialBranch.""" + + def test_2d_internal_resolution(self): + """Verify 2D SpatialBranch with internal_resolution preserves output shape.""" + from models.xdeeponet import SpatialBranch + + branch = SpatialBranch( + in_channels=4, + width=8, + num_fourier_layers=0, + num_unet_layers=0, + num_conv_layers=1, + kernel_size=3, + internal_resolution=[16, 24], + ) + x = torch.randn(2, 32, 48, 4) + out = branch(x) + assert out.shape == (2, 32, 48, 8) + + def test_2d_no_internal_resolution(self): + """Verify 2D SpatialBranch without internal_resolution preserves output shape.""" + from models.xdeeponet import SpatialBranch + + branch = SpatialBranch( + in_channels=4, + width=8, + num_fourier_layers=0, + num_unet_layers=0, + num_conv_layers=1, + kernel_size=3, + internal_resolution=None, + ) + x = torch.randn(2, 32, 48, 4) + out = branch(x) + assert out.shape == (2, 32, 48, 8) + + def test_3d_internal_resolution(self): + """Verify 3D SpatialBranch with internal_resolution preserves output shape.""" + from models.xdeeponet import SpatialBranch3D + + branch = SpatialBranch3D( + in_channels=4, + width=8, + num_fourier_layers=0, + num_unet_layers=0, + num_conv_layers=1, + kernel_size=3, + internal_resolution=[8, 8, 8], + ) + x = torch.randn(2, 16, 16, 16, 4) + out = branch(x) + assert out.shape == (2, 16, 16, 16, 8) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/tests/test_xfno.py b/examples/reservoir_simulation/neural_operator_factory/tests/test_xfno.py new file mode 100644 index 0000000000..05bffd08d1 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/tests/test_xfno.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for U-FNO model. + +Note: These tests require physicsnemo to be installed. If physicsnemo is not +available, all tests in this module will be skipped. +""" + +import sys +from pathlib import Path + +import pytest +import torch + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Check if physicsnemo is available +try: + import physicsnemo # noqa: F401 + + PHYSICSNEMO_AVAILABLE = True +except ImportError: + PHYSICSNEMO_AVAILABLE = False + +# Skip entire module if physicsnemo is not installed +if not PHYSICSNEMO_AVAILABLE: + pytest.skip( + "physicsnemo not installed - skipping U-FNO tests", allow_module_level=True + ) + +from models.xfno import UFNO, UFNONet + + +class TestUFNO: + """Tests for UFNO model.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + @pytest.fixture + def small_model_params(self): + """Return parameters for a small test model.""" + return { + "in_channels": 12, + "out_channels": 1, + "width": 16, + "modes1": 4, + "modes2": 4, + "modes3": 4, + "num_fno_layers": 1, + "num_unet_layers": 1, + "num_conv_layers": 0, + "unet_type": "custom", + } + + def test_forward_pass_ufno(self, device, small_model_params): + """Test U-FNO (FNO + UNet) forward pass.""" + torch.manual_seed(42) + model = UFNO(**small_model_params).to(device) + + # Input shape: (B, H, W, T, C) - must be divisible by 8 for UNet + batch_size = 2 + x = torch.randn(batch_size, 32, 32, 16, 12).to(device) + output = model(x) + + expected_shape = (batch_size, 32, 32, 16, 1) + assert output.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {output.shape}" + ) + + def test_forward_pass_convfno(self, device, small_model_params): + """Test Conv-FNO (FNO + Conv) forward pass.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["num_unet_layers"] = 0 + params["num_conv_layers"] = 2 + + model = UFNO(**params).to(device) + + batch_size = 2 + x = torch.randn(batch_size, 32, 32, 16, 12).to(device) + output = model(x) + + expected_shape = (batch_size, 32, 32, 16, 1) + assert output.shape == expected_shape + + def test_forward_pass_pure_fno(self, device, small_model_params): + """Test pure FNO (no UNet or Conv) forward pass.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["num_unet_layers"] = 0 + params["num_conv_layers"] = 0 + params["num_fno_layers"] = 3 + + model = UFNO(**params).to(device) + + batch_size = 2 + x = torch.randn(batch_size, 32, 32, 16, 12).to(device) + output = model(x) + + expected_shape = (batch_size, 32, 32, 16, 1) + assert output.shape == expected_shape + + @pytest.mark.parametrize("unet_type", ["custom", "physicsnemo"]) + def test_unet_types(self, device, small_model_params, unet_type): + """Test different UNet types.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["unet_type"] = unet_type + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + def test_invalid_unet_type(self, small_model_params): + """Test that invalid UNet type raises error.""" + params = small_model_params.copy() + params["unet_type"] = "invalid" + + with pytest.raises(ValueError, match="Unknown unet_type"): + UFNO(**params) + + @pytest.mark.parametrize("lifting_type", ["mlp", "conv"]) + def test_lifting_types(self, device, small_model_params, lifting_type): + """Test different lifting network types.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["lifting_type"] = lifting_type + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + @pytest.mark.parametrize("decoder_type", ["mlp", "conv"]) + def test_decoder_types(self, device, small_model_params, decoder_type): + """Test different decoder network types.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["decoder_type"] = decoder_type + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + def test_multi_layer_lifting(self, device, small_model_params): + """Test multi-layer lifting network.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["lifting_layers"] = 3 + params["lifting_width"] = 2 + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + def test_multi_layer_decoder(self, device, small_model_params): + """Test multi-layer decoder network.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["decoder_layers"] = 2 + params["decoder_width"] = 64 + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + @pytest.mark.parametrize("activation_fn", ["relu", "gelu", "silu"]) + def test_activation_functions(self, device, small_model_params, activation_fn): + """Test different activation functions.""" + torch.manual_seed(42) + params = small_model_params.copy() + params["activation_fn"] = activation_fn + + model = UFNO(**params).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + output = model(x) + + assert output.shape == (2, 32, 32, 16, 1) + + def test_gradient_flow(self, device, small_model_params): + """Test that gradients flow through the model.""" + torch.manual_seed(42) + model = UFNO(**small_model_params).to(device) + + x = torch.randn(1, 32, 32, 16, 12, requires_grad=True).to(device) + output = model(x) + loss = output.sum() + loss.backward() + + assert x.grad is not None + assert x.grad.shape == x.shape + + # Check that all parameters have gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + def test_count_params(self, device, small_model_params): + """Test parameter counting.""" + model = UFNO(**small_model_params).to(device) + param_count = model.count_params() + + assert param_count > 0 + assert isinstance(param_count, int) + + # Manual count should match + manual_count = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert param_count == manual_count + + def test_different_channel_sizes(self, device): + """Test with different input/output channel configurations.""" + torch.manual_seed(42) + + configs = [ + (12, 1), # Standard CO2 config + (8, 1), # Fewer inputs + (12, 3), # Multi-output + ] + + for in_ch, out_ch in configs: + model = UFNO( + in_channels=in_ch, + out_channels=out_ch, + width=16, + modes1=4, + modes2=4, + modes3=4, + num_fno_layers=1, + num_unet_layers=0, + num_conv_layers=0, + ).to(device) + + x = torch.randn(1, 32, 32, 16, in_ch).to(device) + output = model(x) + + expected_shape = (1, 32, 32, 16, out_ch) + assert output.shape == expected_shape, ( + f"For in_ch={in_ch}, out_ch={out_ch}: " + f"expected {expected_shape}, got {output.shape}" + ) + + +class TestUFNONet: + """Tests for UFNONet wrapper.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def test_forward_with_padding(self, device): + """Test UFNONet handles padding correctly.""" + torch.manual_seed(42) + model = UFNONet( + modes1=4, + modes2=4, + modes3=4, + width=16, + in_channels=12, + out_channels=1, + num_fno_layers=1, + num_unet_layers=1, + padding=8, + unet_type="custom", + ).to(device) + + # Input shape: (B, H, W, T, C) + x = torch.randn(2, 32, 40, 16, 12).to(device) + output = model(x) + + # Output should match input spatial dims without channel dim + expected_shape = (2, 32, 40, 16) + assert output.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {output.shape}" + ) + + def test_padding_adjustment(self, device): + """Test that padding is adjusted to be divisible by 8.""" + # Padding not divisible by 8 + model = UFNONet( + modes1=4, + modes2=4, + modes3=4, + width=16, + padding=10, # Will be adjusted to 16 + unet_type="custom", + num_unet_layers=0, + ).to(device) + + assert model.padding == 16 # Adjusted to next multiple of 8 + + def test_count_params(self, device): + """Test parameter counting via wrapper.""" + model = UFNONet( + modes1=4, + modes2=4, + modes3=4, + width=16, + num_fno_layers=1, + num_unet_layers=0, + unet_type="custom", + ).to(device) + + param_count = model.count_params() + assert param_count > 0 + + +class TestUFNOIntegration: + """Integration tests for UFNO model.""" + + @pytest.fixture + def device(self): + """Return available device.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def test_training_step(self, device): + """Test a simulated training step.""" + torch.manual_seed(42) + model = UFNO( + in_channels=12, + out_channels=1, + width=16, + modes1=4, + modes2=4, + modes3=4, + num_fno_layers=1, + num_unet_layers=1, + unet_type="custom", + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = torch.nn.MSELoss() + + # Simulated batch + x = torch.randn(2, 32, 32, 16, 12).to(device) + target = torch.randn(2, 32, 32, 16, 1).to(device) + + # Training step + model.train() + optimizer.zero_grad() + output = model(x) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + assert not torch.isnan(loss) + assert loss > 0 + + def test_eval_mode(self, device): + """Test model evaluation mode.""" + torch.manual_seed(42) + model = UFNO( + in_channels=12, + out_channels=1, + width=16, + modes1=4, + modes2=4, + modes3=4, + num_fno_layers=1, + num_unet_layers=1, + unet_type="custom", + unet_dropout=0.5, # Enable dropout + ).to(device) + + x = torch.randn(2, 32, 32, 16, 12).to(device) + + # In eval mode, outputs should be deterministic + model.eval() + with torch.no_grad(): + output1 = model(x) + output2 = model(x) + + assert torch.allclose(output1, output2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/reservoir_simulation/neural_operator_factory/training/__init__.py b/examples/reservoir_simulation/neural_operator_factory/training/__init__.py new file mode 100644 index 0000000000..b2bf11d59c --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/__init__.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Training utilities: loss functions, evaluation metrics, and autoregressive helpers. +""" + +from training.ar_utils import ( + ar_validate_full_rollout, + live_rollout_step, + rollout_step, + slice_input_window, + slice_target_window, + teacher_forcing_step, +) +from training.losses import ( + SimpleRelativeL2Loss, + UnifiedLoss, + get_loss_function, +) +from training.metrics import ( + compute_r2_score, + compute_relative_l1_error, + compute_relative_l2_error, + mae_torch, + max_absolute_error, + max_error_torch, + mean_absolute_error, + mean_plume_error, + mean_relative_error, + mse_torch, + normalized_mse, + peak_signal_to_noise_ratio, + psnr_torch, + r2_score_torch, + relative_l1_torch, + relative_l2_torch, + rmse_torch, +) +from training.physics_losses import ( + MassConservationLoss, + build_physics_losses, +) + +__all__ = [ + # Autoregressive + "teacher_forcing_step", + "rollout_step", + "live_rollout_step", + "ar_validate_full_rollout", + "slice_input_window", + "slice_target_window", + # Losses + "SimpleRelativeL2Loss", + "UnifiedLoss", + "get_loss_function", + # Physics losses + "MassConservationLoss", + "build_physics_losses", + # Metrics + "mean_relative_error", + "mean_plume_error", + "mean_absolute_error", + "max_absolute_error", + "compute_r2_score", + "compute_relative_l2_error", + "compute_relative_l1_error", + "normalized_mse", + "peak_signal_to_noise_ratio", + "mse_torch", + "rmse_torch", + "mae_torch", + "relative_l2_torch", + "relative_l1_torch", + "r2_score_torch", + "max_error_torch", + "psnr_torch", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/training/ar_utils.py b/examples/reservoir_simulation/neural_operator_factory/training/ar_utils.py new file mode 100644 index 0000000000..abe9cfb224 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/ar_utils.py @@ -0,0 +1,747 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Autoregressive training utilities with temporal bundling. + +Provides dimension-agnostic helpers for slicing time windows, constructing +AR model inputs, and running multi-step rollouts. Works with both 3D +(B, H, W, T, C) and 4D (B, X, Y, Z, T, C) datasets. + +Key concepts +------------ +- **L** (input_window): Number of context timesteps fed to the model. +- **K** (output_window): Number of timesteps the model predicts per step. +- The time axis is always the **second-to-last** dimension of the input + tensor and the **last** dimension of the target tensor. +- For DeepONet models, explicit **target_times** (trunk query coordinates) + are extracted from the full input tensor and passed to the model so that + K can differ from L (temporal bundling). + +Three-stage training +-------------------- +**Stage 1 -- Teacher Forcing** (``teacher_forcing_step``): + Sweeps sequentially through the full trajectory starting at t=0. + Each window [t, t+L) predicts [t+L, t+L+K). The model always + receives ground-truth input. For TNO, Branch2 also receives + GT solution. Loss is averaged over all windows. + +**Stage 2 -- Pushforward** (``pushforward_step``): + Chains multiple forward passes with *live* gradients (no detach). + The number of chained steps ramps via a linear curriculum from 1 + to ``max_unroll``. This bridges the gap between teacher forcing + and free-running rollout. + +**Stage 3 -- Rollout** (``rollout_step``): + Sweeps sequentially through the full trajectory starting at t=0. + For TNO, Branch2 receives the model's own (detached) prediction + from the previous step instead of ground truth. This trains the + model to handle its own approximation errors. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint as grad_checkpoint + +# --------------------------------------------------------------------------- +# Time-axis helpers +# --------------------------------------------------------------------------- + + +def _time_axis_input(x: Tensor) -> int: + """Time axis index for input ``(..., T, C)``.""" + return x.dim() - 2 + + +def _time_axis_target(y: Tensor) -> int: + """Time axis index for target ``(..., T)``.""" + return y.dim() - 1 + + +# --------------------------------------------------------------------------- +# Time-window slicing +# --------------------------------------------------------------------------- + + +def slice_input_window(inputs: Tensor, t0: int, width: int) -> Tensor: + """Extract ``(B, *spatial, width, C)`` from full-trajectory input.""" + return inputs.narrow(_time_axis_input(inputs), t0, width) + + +def slice_target_window(targets: Tensor, t0: int, width: int) -> Tensor: + """Extract ``(B, *spatial, width)`` from full-trajectory target.""" + return targets.narrow(_time_axis_target(targets), t0, width) + + +# --------------------------------------------------------------------------- +# Feedback channel injection & noise +# --------------------------------------------------------------------------- + + +def inject_feedback_channel( + x_window: Tensor, + feedback: Optional[Tensor], +) -> Tensor: + """Append feedback prediction(s) as extra channel(s) to the input window. + + Parameters + ---------- + x_window : Tensor + Model input ``(B, *spatial, T, C)``. + feedback : Tensor or None + Previous prediction. Shape ``(B, *spatial, T)`` for single-output + or ``(B, *spatial, T, C_out)`` for multi-output. If *None*, + *x_window* is returned unchanged. + """ + if feedback is None: + return x_window + if feedback.dim() < x_window.dim(): + feedback = feedback.unsqueeze(-1) + return torch.cat([x_window, feedback], dim=-1) + + +def add_noise(tensor: Tensor, noise_std: float) -> Tensor: + """Add Gaussian noise. No-op when *noise_std* <= 0.""" + if noise_std <= 0: + return tensor + return tensor + torch.randn_like(tensor) * noise_std + + +# --------------------------------------------------------------------------- +# Target-time coordinate extraction +# --------------------------------------------------------------------------- + + +def extract_target_times(inputs: Tensor, t_start: int, K: int) -> Tensor: + """Extract K target time coordinates from the full input tensor. + + The time coordinate is assumed to be the **last channel** (index -1) + of the input tensor at a fixed spatial location ``[0, 0, ..., 0]``. + + Parameters + ---------- + inputs : Tensor + Full-trajectory input ``(B, *spatial, T, C)``. + t_start : int + First target timestep index. + K : int + Number of target timesteps. + + Returns + ------- + Tensor + Shape ``(K,)`` -- the time coordinate values for the K target steps. + """ + ndim = inputs.dim() + spatial_zeros = (0,) * (ndim - 3) + idx = (0,) + spatial_zeros + return inputs[idx][t_start : t_start + K, -1] + + +# --------------------------------------------------------------------------- +# Model call helpers +# --------------------------------------------------------------------------- + + +def _call_model( + model, + x_window: Tensor, + target_times: Optional[Tensor], + use_checkpointing: bool = False, + x_branch2: Optional[Tensor] = None, +) -> Tensor: + """Call the model, optionally passing target_times and x_branch2.""" + kwargs = {} + if target_times is not None and _model_accepts_target_times(model): + kwargs["target_times"] = target_times + if x_branch2 is not None and _model_accepts_x_branch2(model): + kwargs["x_branch2"] = x_branch2 + + if use_checkpointing and model.training: + return grad_checkpoint( + _forward_with_kwargs, + model, + x_window, + kwargs, + use_reentrant=False, + ) + return model(x_window, **kwargs) + + +def _forward_with_kwargs(model, x_window, kwargs): + """Thin wrapper so ``grad_checkpoint`` can pass kwargs.""" + return model(x_window, **kwargs) + + +def _model_accepts_target_times(model) -> bool: + """Check if the model's forward() accepts a ``target_times`` kwarg.""" + import inspect + + m = model.module if hasattr(model, "module") else model + sig = inspect.signature(m.forward) + return "target_times" in sig.parameters + + +def _model_accepts_x_branch2(model) -> bool: + """Check if the model's forward() accepts an ``x_branch2`` kwarg.""" + import inspect + + m = model.module if hasattr(model, "module") else model + sig = inspect.signature(m.forward) + return "x_branch2" in sig.parameters + + +# --------------------------------------------------------------------------- +# Window iteration +# --------------------------------------------------------------------------- + + +def _iter_windows(total_T: int, L: int, K: int, stride: Optional[int] = None): + """Yield ``(t0, target_start, actual_K)`` for each window. + + Parameters + ---------- + total_T : int + Total number of timesteps in the trajectory. + L : int + Input context length (number of input timesteps). + K : int + Target prediction length (number of output timesteps). + stride : int or None + Advance between consecutive windows. Defaults to *K* + (non-overlapping). Yields truncated windows when the + remaining timesteps are fewer than *K*. + """ + if stride is None: + stride = K + t = 0 + while t + L < total_T: + target_start = t + L + remaining = total_T - target_start + actual_K = min(K, remaining) + if actual_K <= 0: + break + yield (t, target_start, actual_K) + t += stride + + +# --------------------------------------------------------------------------- +# Branch-2 builder (TNO previous-solution input) +# --------------------------------------------------------------------------- + + +def _build_branch2( + targets: Tensor, + prev_pred: Optional[Tensor], + current_t: int, + L: int, + t_ax: int, + is_tno: bool, + noise_std: float = 0.0, +) -> Optional[Tensor]: + """Build the branch-2 tensor for TNO models. + + Returns *None* for non-TNO models. For TNO: + - First window (``prev_pred is None``): ground-truth at ``[current_t, current_t+L)``. + - Subsequent windows: last *L* timesteps of ``prev_pred``, padded with GT if needed. + """ + if not is_tno: + return None + if prev_pred is None: + b2 = slice_target_window(targets, current_t, L) + elif prev_pred.shape[t_ax] >= L: + b2 = prev_pred.narrow(t_ax, prev_pred.shape[t_ax] - L, L) + else: + need = L - prev_pred.shape[t_ax] + gt_part = slice_target_window(targets, current_t, need) + b2 = torch.cat([gt_part, prev_pred], dim=t_ax) + if noise_std > 0: + b2 = add_noise(b2, noise_std) + return b2 + + +# --------------------------------------------------------------------------- +# Curriculum scheduling +# --------------------------------------------------------------------------- + + +def compute_unroll_steps( + epoch: int, + start_epoch: int, + total_epochs: int, + max_unroll: int, +) -> int: + """Linear curriculum: ramp unroll steps from 1 to *max_unroll*. + + Returns 1 at *start_epoch*, *max_unroll* at ``start_epoch + total_epochs``, + and clamps outside that range. + """ + if total_epochs <= 0: + return max_unroll + progress = (epoch - start_epoch) / total_epochs + progress = max(0.0, min(1.0, progress)) + return max(1, round(1 + (max_unroll - 1) * progress)) + + +def get_training_stage( + epoch: int, + tf_epochs: int, + pf_epochs: int = 0, + ro_epochs: int = 0, +) -> str: + """Return the training stage name for a given epoch. + + Stages (in order): ``"teacher_forcing"`` -> ``"pushforward"`` -> ``"rollout"``. + Stages with zero epochs are skipped. After all stages are exhausted the + last active stage is returned. + """ + if epoch < tf_epochs: + return "teacher_forcing" + if pf_epochs > 0 and epoch < tf_epochs + pf_epochs: + return "pushforward" + if ro_epochs > 0: + return "rollout" + if pf_epochs > 0: + return "pushforward" + return "teacher_forcing" + + +# --------------------------------------------------------------------------- +# Feedback helper (shared by training steps) +# --------------------------------------------------------------------------- + + +def _get_feedback( + targets: Tensor, + prev_pred: Optional[Tensor], + current_t: int, + L: int, + t_ax: int, +) -> Tensor: + """Return feedback tensor (*L* timesteps) for feedback-channel injection.""" + if prev_pred is None: + return slice_target_window(targets, current_t, L) + if prev_pred.shape[t_ax] >= L: + return prev_pred.narrow(t_ax, prev_pred.shape[t_ax] - L, L) + need = L - prev_pred.shape[t_ax] + gt_part = slice_target_window(targets, current_t, need) + return torch.cat([gt_part, prev_pred], dim=t_ax) + + +# --------------------------------------------------------------------------- +# Teacher-forcing training step (one batch) -- sequential sweep +# --------------------------------------------------------------------------- + + +def teacher_forcing_step( + model, + inputs: Tensor, + targets: Tensor, + loss_fn, + L: int, + K: int, + spatial_mask: Optional[Tensor] = None, + is_tno: bool = False, + noise_std: float = 0.0, + feedback_channel: Optional[int] = None, + stride: Optional[int] = None, +) -> Tensor: + """One teacher-forcing training iteration over a batch. + + Sweeps sequentially from t=0 through the full trajectory, processing + every window. Uses gradient accumulation: each window is forwarded and + backwarded independently (one graph at a time). Returns a detached + scalar loss for logging. The caller should NOT call ``loss.backward()`` + -- only ``optimizer.step()``. + + For TNO, Branch2 receives the ground-truth solution at ``[t, t+L)``. + """ + total_T = targets.shape[_time_axis_target(targets)] + t_ax = _time_axis_target(targets) + + effective_stride = stride if stride is not None else K + if total_T <= L: + return torch.tensor(0.0, device=inputs.device) + num_windows = (total_T - L - K) // effective_stride + 1 + if num_windows <= 0: + return torch.tensor(0.0, device=inputs.device) + + accumulated_loss = 0.0 + current_t = 0 + + for _ in range(num_windows): + target_start = current_t + L + remaining = total_T - target_start + actual_K = min(K, remaining) + if actual_K <= 0: + break + + x_window = slice_input_window(inputs, current_t, L) + y_target = slice_target_window(targets, target_start, actual_K) + target_times = extract_target_times(inputs, target_start, actual_K) + + y_branch2 = _build_branch2( + targets, + None, + current_t, + L, + t_ax, + is_tno, + noise_std, + ) + + if feedback_channel is not None: + fb = slice_target_window(targets, current_t, L) + fb = add_noise(fb, noise_std) + x_window = inject_feedback_channel(x_window, fb) + + pred = _call_model(model, x_window, target_times, x_branch2=y_branch2) + + if pred.shape[t_ax] > actual_K: + pred = pred.narrow(t_ax, 0, actual_K) + + window_loss = loss_fn(pred, y_target, x_window, spatial_mask=spatial_mask) + if window_loss.requires_grad: + (window_loss / num_windows).backward() + accumulated_loss += window_loss.detach().item() + + current_t += effective_stride + + return torch.tensor(accumulated_loss / num_windows, device=inputs.device) + + +# --------------------------------------------------------------------------- +# Pushforward training step -- live gradients through unrolled chain +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Rollout training step (one batch) -- sequential chain from t=0 +# --------------------------------------------------------------------------- + + +def rollout_step( + model, + inputs: Tensor, + targets: Tensor, + loss_fn, + L: int, + K: int, + use_checkpointing: bool = True, + spatial_mask: Optional[Tensor] = None, + is_tno: bool = False, + noise_std: float = 0.0, + feedback_channel: Optional[int] = None, + stride: Optional[int] = None, +) -> Tensor: + """One rollout (free-running) training iteration. + + Sweeps sequentially from t=0 through the full trajectory. Uses + gradient accumulation: each window is forwarded and backwarded + independently. Returns a detached scalar loss for logging. + The caller should NOT call ``loss.backward()`` -- only ``optimizer.step()``. + + For TNO, Branch2 receives the model's own (detached) prediction + from the previous step, creating true autoregressive feedback. + """ + total_T = targets.shape[_time_axis_target(targets)] + t_ax = _time_axis_target(targets) + + effective_stride = stride if stride is not None else K + if total_T <= L: + return torch.tensor(0.0, device=inputs.device) + num_windows = (total_T - L - K) // effective_stride + 1 + if num_windows <= 0: + return torch.tensor(0.0, device=inputs.device) + + accumulated_loss = 0.0 + prev_pred = None + current_t = 0 + + for _ in range(num_windows): + target_start = current_t + L + remaining = total_T - target_start + actual_K = min(K, remaining) + if actual_K <= 0: + break + + x_window = slice_input_window(inputs, current_t, L) + y_target = slice_target_window(targets, target_start, actual_K) + target_times = extract_target_times(inputs, target_start, actual_K) + + y_branch2 = _build_branch2( + targets, + prev_pred, + current_t, + L, + t_ax, + is_tno, + noise_std, + ) + + if feedback_channel is not None: + fb = _get_feedback(targets, prev_pred, current_t, L, t_ax) + fb = add_noise(fb, noise_std) + x_window = inject_feedback_channel(x_window, fb) + + pred = _call_model( + model, + x_window, + target_times, + use_checkpointing, + x_branch2=y_branch2, + ) + + if pred.shape[t_ax] > actual_K: + pred = pred.narrow(t_ax, 0, actual_K) + + window_loss = loss_fn(pred, y_target, x_window, spatial_mask=spatial_mask) + if window_loss.requires_grad: + (window_loss / num_windows).backward() + accumulated_loss += window_loss.detach().item() + + prev_pred = pred.detach() + current_t += effective_stride + + return torch.tensor(accumulated_loss / num_windows, device=inputs.device) + + +# --------------------------------------------------------------------------- +# Full-trajectory rollout with live gradients (matches original TNO training) +# --------------------------------------------------------------------------- + + +def _freeze_batchnorm(model): + """Set all BatchNorm layers to eval mode (freeze running stats). + + Returns a list of the modules that were switched so they can be + restored afterwards. The learned gamma/beta parameters still + receive gradients; only the inplace running-stat updates are + suppressed. + """ + switched = [] + m = model.module if hasattr(model, "module") else model + for mod in m.modules(): + if isinstance( + mod, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) + ): + if mod.training: + mod.eval() + switched.append(mod) + return switched + + +def _unfreeze_batchnorm(switched): + """Restore previously-frozen BatchNorm layers to training mode.""" + for mod in switched: + mod.train() + + +def live_rollout_step( + model, + inputs: Tensor, + targets: Tensor, + loss_fn, + L: int, + K: int, + max_steps: Optional[int] = None, + spatial_mask: Optional[Tensor] = None, + is_tno: bool = False, + noise_std: float = 0.0, + feedback_channel: Optional[int] = None, + stride: Optional[int] = None, +) -> Tensor: + """Rollout with live gradients through an unrolled prediction chain. + + Collects predictions into a single tensor, computes loss once on the + concatenated trajectory, then calls ``.backward()``. Returns a + detached scalar loss for logging. The caller should NOT call + ``loss.backward()`` -- only ``optimizer.step()``. + + Gradients flow from the final loss through all intermediate + predictions, providing strong gradient signal for learning to handle + error accumulation. + + BatchNorm layers are temporarily set to eval mode during the forward + chain to prevent inplace running-stat updates from invalidating the + autograd graph. The learned affine parameters (gamma, beta) still + receive gradients normally. + + Parameters + ---------- + max_steps : int or None + Maximum number of autoregressive windows to chain. ``None`` + (default) chains all windows in the trajectory. Set to a small + value (e.g. 1-5) for pushforward-style curriculum training. + """ + total_T = targets.shape[_time_axis_target(targets)] + t_ax = _time_axis_target(targets) + + if total_T <= L: + return torch.tensor(0.0, device=inputs.device) + + # Freeze BatchNorm running stats to avoid inplace buffer updates + # that would invalidate the live autograd graph across chained forwards. + frozen_bn = _freeze_batchnorm(model) + + pred_slices = [] + prev_pred = None + current_t = 0 + step_count = 0 + + while current_t + L < total_T: + target_start = current_t + L + remaining = total_T - target_start + actual_K = min(K, remaining) + if actual_K <= 0: + break + + x_window = slice_input_window(inputs, current_t, L) + target_times = extract_target_times(inputs, target_start, actual_K) + + y_branch2 = _build_branch2( + targets, + prev_pred, + current_t, + L, + t_ax, + is_tno, + noise_std, + ) + + if feedback_channel is not None: + fb = _get_feedback(targets, prev_pred, current_t, L, t_ax) + fb = add_noise(fb, noise_std) + x_window = inject_feedback_channel(x_window, fb) + + pred = _call_model(model, x_window, target_times, x_branch2=y_branch2) + + if pred.shape[t_ax] > actual_K: + pred = pred.narrow(t_ax, 0, actual_K) + + pred_slices.append(pred) + prev_pred = pred + current_t += stride if stride is not None else K + step_count += 1 + if max_steps is not None and step_count >= max_steps: + break + + if not pred_slices: + _unfreeze_batchnorm(frozen_bn) + return torch.tensor(0.0, device=inputs.device) + + pred_full = torch.cat(pred_slices, dim=t_ax) + target_full = slice_target_window(targets, L, pred_full.shape[t_ax]) + + loss = loss_fn(pred_full, target_full, inputs, spatial_mask=spatial_mask) + if loss.requires_grad: + loss.backward() + + _unfreeze_batchnorm(frozen_bn) + return loss.detach() + + +# --------------------------------------------------------------------------- +# Full autoregressive validation (all timesteps) +# --------------------------------------------------------------------------- + + +@torch.no_grad() +def ar_validate_full_rollout( + model, + inputs: Tensor, + targets: Tensor, + L: int, + K: int, + is_tno: bool = False, + feedback_channel: Optional[int] = None, +) -> Tensor: + """Run a complete AR rollout over the full trajectory for validation. + + Always starts at t=0 and rolls out until all timesteps are covered. + Returns the full predicted trajectory (same shape as ``targets``). + """ + total_T = targets.shape[_time_axis_target(targets)] + t_ax = _time_axis_target(targets) + + pred_slices = [] + prev_pred = None + current_t = 0 + + while current_t + L < total_T: + target_start = current_t + L + remaining = total_T - target_start + actual_K = min(K, remaining) + if actual_K <= 0: + break + + x_window = slice_input_window(inputs, current_t, L) + target_times = extract_target_times(inputs, target_start, actual_K) + + if is_tno: + if prev_pred is None: + y_branch2 = slice_target_window(targets, current_t, L) + elif prev_pred.shape[t_ax] >= L: + y_branch2 = prev_pred.narrow( + t_ax, + prev_pred.shape[t_ax] - L, + L, + ) + else: + need = L - prev_pred.shape[t_ax] + gt_part = slice_target_window(targets, current_t, need) + y_branch2 = torch.cat([gt_part, prev_pred], dim=t_ax) + else: + y_branch2 = None + + if feedback_channel is not None: + fb = _get_feedback(targets, prev_pred, current_t, L, t_ax) + x_window = inject_feedback_channel(x_window, fb) + + pred = _call_model(model, x_window, target_times, x_branch2=y_branch2) + + pred_t_ax = _time_axis_target(pred) + if pred.shape[pred_t_ax] > actual_K: + pred = pred.narrow(pred_t_ax, 0, actual_K) + + pred_slices.append(pred) + prev_pred = pred + current_t += K + + if not pred_slices: + return torch.zeros_like(targets) + + pred_full = torch.cat(pred_slices, dim=t_ax) + + gt_prefix = slice_target_window(targets, 0, L) + pred_full = torch.cat([gt_prefix, pred_full], dim=t_ax) + + if pred_full.shape[t_ax] > total_T: + pred_full = pred_full.narrow(t_ax, 0, total_T) + elif pred_full.shape[t_ax] < total_T: + deficit = total_T - pred_full.shape[t_ax] + pad_slice = slice_target_window( + targets, + pred_full.shape[t_ax], + deficit, + ) + pred_full = torch.cat([pred_full, pad_slice], dim=t_ax) + + return pred_full diff --git a/examples/reservoir_simulation/neural_operator_factory/training/losses.py b/examples/reservoir_simulation/neural_operator_factory/training/losses.py new file mode 100644 index 0000000000..50ecc39893 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/losses.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified loss functions for reservoir simulation modeling. + +Data-fitting losses: mse, l1, relative_l2, huber +Regularisation: spatial derivative constraints (dimension-agnostic) +Physics losses: mass conservation (and future additions) + +All losses work with 2D spatial (B, H, W, T) and 3D spatial (B, X, Y, Z, T) +predictions, both full-mapping and autoregressive regimes. +""" + +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from training.physics_losses import ( + build_physics_losses, + cell_centre_distance, + central_difference, + extract_grid_widths_for_axis, + get_deriv_map, +) + +# --------------------------------------------------------------------------- +# Standalone convenience loss +# --------------------------------------------------------------------------- + + +class SimpleRelativeL2Loss(nn.Module): + """Relative L2 loss without bells and whistles. + + loss = mean_b( ||pred_b - target_b||_2 / (||target_b||_2 + eps) ) + """ + + def __init__(self, eps: float = 1e-8): + super().__init__() + self.eps = eps + + def forward(self, predictions, targets, inputs=None, **kwargs): + batch_size = predictions.shape[0] + pred_flat = predictions.reshape(batch_size, -1) + target_flat = targets.reshape(batch_size, -1) + diff_norm = torch.norm(pred_flat - target_flat, p=2, dim=1) + target_norm = torch.norm(target_flat, p=2, dim=1) + return (diff_norm / (target_norm + self.eps)).mean() + + +# --------------------------------------------------------------------------- +# Unified loss +# --------------------------------------------------------------------------- + + +class UnifiedLoss(nn.Module): + """Configurable loss with data-fitting, derivative, and physics terms. + + total = sum(w_i * data_loss_i) + + derivative.weight * derivative_loss + + sum(alpha_j * physics_loss_j) + + Data-fitting losses + ------------------- + mse : Mean Squared Error, mean((pred - target)^2). + Standard regression loss. Penalises large errors quadratically. + l1 : Mean Absolute Error, mean(|pred - target|). + More robust to outliers than MSE. Linear penalty. + relative_l2 : ||pred - target||_2 / (||target||_2 + eps), per sample. + Scale-invariant; recommended when output magnitude varies across + samples (e.g. pressure with large dynamic range). + huber : Smooth L1 / Huber loss (controlled by ``huber_delta``). + Behaves like MSE for errors < delta and L1 for errors > delta. + Combines MSE precision near zero with L1 robustness for outliers. + + Spatial derivative regularization + --------------------------------- + Penalises differences in spatial gradients between pred and target + using central finite differences on the NOF grid-width channels. + Configured via ``derivative_config`` with keys: + + - ``enabled`` (bool): toggle on/off. + - ``weight`` (float): multiplier for the derivative term. + - ``dims`` (list of str): which spatial directions to differentiate. + 2D data: ``dx`` (W/horizontal), ``dy`` (H/vertical). + 3D data: ``dx`` (X), ``dy`` (Y), ``dz`` (Z). + - ``metric`` (str or None): loss metric for comparing derivatives. + ``None`` inherits the first entry in ``types``. + + Masking + ------- + When ``spatial_mask`` is provided (boolean tensor over spatial dims), + MSE, L1, and Huber average only over active cells. relative_l2 + zeros out inactive cells before computing norms. + + Parameters + ---------- + types : list of str + Data loss types. + weights : list of float + Weights for each data loss type. + huber_delta : float + Transition threshold for Huber loss. + derivative_config : dict or None + Derivative regularization settings (see above). + eps : float + Epsilon for numerical stability (relative_l2 denominator). + reduction : str + 'mean', 'sum', or 'none'. + physics_losses : dict or None + Pre-built physics losses: ``{name: (module, weight)}``. + Built by :func:`training.physics_losses.build_physics_losses`. + """ + + VALID_TYPES = {"mse", "l1", "relative_l2", "huber"} + + def __init__( + self, + types=None, + weights=None, + huber_delta: float = 1.0, + derivative_config=None, + eps: float = 1e-6, + reduction: str = "mean", + physics_losses=None, + ): + super().__init__() + + if types is None: + types = ["relative_l2"] + if isinstance(types, str): + types = [types] + types = [t.lower() for t in types] + + if weights is None: + weights = [1.0] * len(types) + if isinstance(weights, (int, float)): + weights = [float(weights)] + weights = [float(w) for w in weights] + + if len(types) != len(weights): + raise ValueError( + f"types and weights must have same length, got {len(types)} vs {len(weights)}" + ) + for t in types: + if t not in self.VALID_TYPES: + raise ValueError( + f"Loss type must be one of {self.VALID_TYPES}, got '{t}'" + ) + if reduction not in ("mean", "sum", "none"): + raise ValueError( + f"reduction must be 'mean', 'sum', or 'none', got '{reduction}'" + ) + + self.loss_types = types + self.loss_weights = weights + self.huber_delta = huber_delta + self.eps = eps + self.reduction = reduction + + # Derivative config + self._deriv_cfg = derivative_config or {} + self._deriv_enabled = self._deriv_cfg.get("enabled", False) + self._deriv_weight = float(self._deriv_cfg.get("weight", 0.5)) + self._deriv_dims: List[str] = list(self._deriv_cfg.get("dims", ["dx"])) + self._deriv_metric: Optional[str] = self._deriv_cfg.get("metric", None) + + # Physics losses + self._physics_losses: Dict[str, tuple] = physics_losses or {} + for name, (mod, _w) in self._physics_losses.items(): + self.add_module(f"physics_{name}", mod) + + # ----------------------------------------------------------------------- + # Data losses + # ----------------------------------------------------------------------- + + @staticmethod + def _is_per_sample_mask(spatial_mask, pred): + """True when mask has a leading batch dimension matching pred.""" + return ( + spatial_mask.dim() == pred.dim() - 1 + and spatial_mask.shape[0] == pred.shape[0] + ) + + @staticmethod + def _expand_mask(spatial_mask, pred): + """Expand spatial mask to match pred shape, per-sample aware. + + Accepts ``(*spatial)`` (one mask for all samples) or + ``(B, *spatial)`` (per-sample masks). Returns a boolean tensor + broadcastable to ``pred`` shape ``(B, *spatial, T)``. + """ + if UnifiedLoss._is_per_sample_mask(spatial_mask, pred): + return spatial_mask.unsqueeze(-1).expand_as(pred) + else: + return spatial_mask.unsqueeze(0).unsqueeze(-1).expand_as(pred) + + def _compute_single_loss(self, pred, target, loss_type, spatial_mask=None): + """Compute a single loss term with proper per-sample masking. + + When ``spatial_mask`` is provided, only active cells contribute + to the loss. For ``relative_l2``, active cells are selected + per-sample so that norms are computed exclusively on active + values (not diluted by zeros). Supports both ``(*spatial)`` + and ``(B, *spatial)`` masks. + """ + if loss_type in ("mse", "l1", "huber"): + if loss_type == "mse": + diff = (pred - target) ** 2 + elif loss_type == "l1": + diff = torch.abs(pred - target) + else: + diff = F.smooth_l1_loss( + pred, target, reduction="none", beta=self.huber_delta + ) + if spatial_mask is not None: + mask_exp = self._expand_mask(spatial_mask, diff) + diff = diff[mask_exp] + return diff.mean() if self.reduction == "mean" else diff.sum() + + elif loss_type == "relative_l2": + batch_size = pred.shape[0] + losses = [] + for i in range(batch_size): + if spatial_mask is not None: + if self._is_per_sample_mask(spatial_mask, pred): + m = spatial_mask[i] + else: + m = spatial_mask + m_exp = m.unsqueeze(-1).expand_as(pred[i]) + p = pred[i][m_exp] + t = target[i][m_exp] + else: + p = pred[i].reshape(-1) + t = target[i].reshape(-1) + diff_norm = torch.norm(p - t, p=2) + target_norm = torch.norm(t, p=2) + losses.append(diff_norm / (target_norm + self.eps)) + loss = torch.stack(losses) + return loss.mean() if self.reduction == "mean" else loss.sum() + + raise ValueError(f"Unknown loss type: {loss_type}") + + def _compute_base_loss(self, pred, target, spatial_mask=None): + total = torch.tensor(0.0, device=pred.device) + for loss_type, weight in zip(self.loss_types, self.loss_weights): + total = total + weight * self._compute_single_loss( + pred, target, loss_type, spatial_mask + ) + return total + + # ----------------------------------------------------------------------- + # Derivative regularisation (dimension-agnostic) + # ----------------------------------------------------------------------- + + @staticmethod + def _to_union_mask(spatial_mask, pred): + """Reduce a per-sample mask ``(B, *spatial)`` to ``(*spatial)``. + + Takes the union (logical OR) across the batch so every cell + active in any sample is included. Static masks ``(*spatial)`` + are returned unchanged. + """ + if spatial_mask is None: + return None + if UnifiedLoss._is_per_sample_mask(spatial_mask, pred): + return spatial_mask.any(dim=0) + return spatial_mask + + def _compute_derivative_loss(self, pred, target, inputs, spatial_mask=None): + """Compute derivative loss along each configured direction. + + Uses the **union** mask for zeroing out inactive cells before + differentiation (safe for the stencil) and for the stencil-safe + derivative mask. The derivative metric then receives this + ``(*spatial)``-shaped mask for comparison. + """ + ndim = pred.dim() + spatial_ndim = ndim - 2 + if spatial_ndim not in (2, 3): + raise ValueError( + f"Derivative loss requires 2 or 3 spatial dims, got {spatial_ndim}" + ) + + union_mask = self._to_union_mask(spatial_mask, pred) + + deriv_metric = self._deriv_metric or self.loss_types[0] + total = torch.tensor(0.0, device=pred.device) + + if union_mask is not None: + mask_exp = union_mask.unsqueeze(0).unsqueeze(-1).expand_as(pred) + pred = pred * mask_exp + target = target * mask_exp + + for dim_name in self._deriv_dims: + dmap = get_deriv_map(spatial_ndim) + if dim_name not in dmap: + valid = list(dmap.keys()) + raise ValueError( + f"Derivative dim '{dim_name}' not valid for {spatial_ndim}D data. Valid: {valid}" + ) + tensor_axis, _ch_offset = dmap[dim_name] + + widths = extract_grid_widths_for_axis(inputs, spatial_ndim, dim_name) + spacing = cell_centre_distance(widths) + + dy_pred = central_difference(pred, tensor_axis, spacing) + dy_target = central_difference(target, tensor_axis, spacing) + + deriv_mask = None + if union_mask is not None: + mask_axis = tensor_axis - 1 + n = union_mask.shape[mask_axis] + m_left = union_mask.narrow(mask_axis, 0, n - 2) + m_centre = union_mask.narrow(mask_axis, 1, n - 2) + m_right = union_mask.narrow(mask_axis, 2, n - 2) + deriv_mask = m_left & m_centre & m_right + + total = total + self._compute_single_loss( + dy_pred, dy_target, deriv_metric, spatial_mask=deriv_mask + ) + + return total / max(len(self._deriv_dims), 1) + + # ----------------------------------------------------------------------- + # Forward + # ----------------------------------------------------------------------- + + def forward(self, pred, target, inputs=None, spatial_mask=None): + """Compute unified loss. + + Parameters + ---------- + pred : Tensor + (B, H, W, T) or (B, X, Y, Z, T) + target : Tensor + Same shape as pred. + inputs : Tensor or None + (B, *spatial, T, C). Required for derivative and physics losses. + spatial_mask : Tensor or None + (*spatial) boolean mask; active cells = True/1. + """ + if self._deriv_enabled and inputs is None: + raise ValueError("inputs required when derivative loss is enabled") + + # 1. Data loss + data_loss = self._compute_base_loss(pred, target, spatial_mask) + + # 2. Derivative loss + if self._deriv_enabled: + deriv_loss = self._compute_derivative_loss( + pred, target, inputs, spatial_mask + ) + data_loss = data_loss + self._deriv_weight * deriv_loss + + # 3. Physics losses + for mod, weight in self._physics_losses.values(): + phys = mod(pred, target, inputs, spatial_mask=spatial_mask) + data_loss = data_loss + weight * phys + + return data_loss + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def get_loss_function(loss_config, variable=None): + """Create a UnifiedLoss from a Hydra config.""" + types = loss_config.get("types", ["relative_l2"]) + weights = loss_config.get("weights", None) + + if hasattr(types, "__iter__") and not isinstance(types, str): + types = list(types) + if weights is not None and hasattr(weights, "__iter__"): + weights = list(weights) + + # Derivative config + deriv_cfg = loss_config.get("derivative", None) + derivative_config = dict(deriv_cfg) if deriv_cfg is not None else {"enabled": False} + + # Physics losses + physics_cfg = loss_config.get("physics", None) + default_metric = types[0] if types else "relative_l2" + physics_losses = build_physics_losses( + physics_cfg, variable=variable, default_metric=default_metric + ) + + return UnifiedLoss( + types=types, + weights=weights, + huber_delta=float(loss_config.get("huber_delta", 1.0)), + derivative_config=derivative_config, + eps=float(loss_config.get("eps", 1e-6)), + reduction=loss_config.get("reduction", "mean"), + physics_losses=physics_losses, + ) + + +__all__ = [ + "SimpleRelativeL2Loss", + "UnifiedLoss", + "get_loss_function", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/training/metrics.py b/examples/reservoir_simulation/neural_operator_factory/training/metrics.py new file mode 100644 index 0000000000..edb6473370 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/metrics.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluation metrics for reservoir simulation models. + +This module provides standard metrics for evaluating pressure and saturation predictions. +Imports available metrics from PhysicsNemo and provides additional domain-specific metrics. +""" + +from typing import Optional + +import numpy as np +import torch +from torch import Tensor + +from physicsnemo.metrics.general.ensemble_metrics import Mean, Variance + +# ============================================================================ +# PhysicsNemo Imports (official implementations) +# ============================================================================ +from physicsnemo.metrics.general.mse import mse, rmse +from physicsnemo.metrics.general.reduction import WeightedMean, WeightedVariance + +# Re-export PhysicsNemo metrics for convenience +__all__ = [ + # PhysicsNemo imports + "mse", + "rmse", + "Mean", + "Variance", + "WeightedMean", + "WeightedVariance", + # NumPy-based metrics + "mean_relative_error", + "mean_plume_error", + "mean_absolute_error", + "max_absolute_error", + "compute_r2_score", + "compute_relative_l2_error", + "compute_relative_l1_error", + "normalized_mse", + "peak_signal_to_noise_ratio", + # PyTorch-based metrics + "mse_torch", + "rmse_torch", + "mae_torch", + "relative_l2_torch", + "relative_l1_torch", + "r2_score_torch", + "max_error_torch", + "psnr_torch", +] + + +# ============================================================================ +# NumPy-based Metrics (for evaluation/post-processing) +# ============================================================================ + + +def mean_relative_error( + y_pred: np.ndarray, + y_true: np.ndarray, + eps: float = 1e-8, +) -> float: + """ + Mean Relative Error (MRE) - normalized by range. + Appropriate for pressure predictions. + + Args: + y_pred: Predicted values + y_true: Ground truth values + eps: Small value to avoid division by zero + + Returns: + float: MRE value between 0 and 1 + """ + data_range = y_true.max() - y_true.min() + return float(np.mean(np.abs(y_pred - y_true)) / (data_range + eps)) + + +def mean_plume_error( + y_pred: np.ndarray, + y_true: np.ndarray, + threshold: float = 0.0, +) -> float: + """ + Mean Plume Error (MPE) - error only where both are non-zero. + Appropriate for saturation predictions where we care about + accuracy within the CO2 plume region. + + Args: + y_pred: Predicted saturation values + y_true: Ground truth saturation values + threshold: Threshold for considering values as "non-zero" + + Returns: + float: MPE value (absolute error in plume region) + """ + mask = (np.abs(y_pred) > threshold) & (np.abs(y_true) > threshold) + y_pred_masked = y_pred[mask] + y_true_masked = y_true[mask] + + if len(y_pred_masked) == 0: + return 0.0 + + return float(np.mean(np.abs(y_pred_masked - y_true_masked))) + + +def mean_absolute_error(y_pred: np.ndarray, y_true: np.ndarray) -> float: + """ + Mean Absolute Error (MAE). + + Args: + y_pred: Predicted values + y_true: Ground truth values + + Returns: + float: MAE value + """ + return float(np.mean(np.abs(y_pred - y_true))) + + +def max_absolute_error(y_pred: np.ndarray, y_true: np.ndarray) -> float: + """ + Maximum Absolute Error. + + Args: + y_pred: Predicted values + y_true: Ground truth values + + Returns: + float: Maximum absolute error value + """ + return float(np.max(np.abs(y_pred - y_true))) + + +def compute_r2_score(y_pred: np.ndarray, y_true: np.ndarray) -> float: + """ + Compute R² (coefficient of determination) score. + + Args: + y_pred: Predicted values + y_true: Ground truth values + + Returns: + float: R² score (1.0 is perfect, negative means worse than mean) + """ + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + if ss_tot == 0: + return 1.0 if ss_res == 0 else 0.0 + return float(1 - (ss_res / ss_tot)) + + +def compute_relative_l2_error( + pred: np.ndarray, + target: np.ndarray, + eps: float = 1e-8, +) -> float: + """ + Compute relative L2 error. + + Args: + pred: Predicted values + target: Ground truth values + eps: Small value to avoid division by zero + + Returns: + float: Relative L2 error + """ + pred_flat = pred.reshape(-1) + target_flat = target.reshape(-1) + diff_norm = np.linalg.norm(pred_flat - target_flat, ord=2) + target_norm = np.linalg.norm(target_flat, ord=2) + return float(diff_norm / (target_norm + eps)) + + +def compute_relative_l1_error( + pred: np.ndarray, + target: np.ndarray, + eps: float = 1e-8, +) -> float: + """ + Compute relative L1 error. + + Args: + pred: Predicted values + target: Ground truth values + eps: Small value to avoid division by zero + + Returns: + float: Relative L1 error + """ + pred_flat = pred.reshape(-1) + target_flat = target.reshape(-1) + diff_norm = np.linalg.norm(pred_flat - target_flat, ord=1) + target_norm = np.linalg.norm(target_flat, ord=1) + return float(diff_norm / (target_norm + eps)) + + +def normalized_mse( + y_pred: np.ndarray, + y_true: np.ndarray, + normalize_by: str = "variance", + eps: float = 1e-8, +) -> float: + """ + Normalized Mean Squared Error. + + Args: + y_pred: Predicted values + y_true: Ground truth values + normalize_by: 'variance' or 'range' + eps: Small value to avoid division by zero + + Returns: + float: Normalized MSE value + """ + mse_val = np.mean((y_pred - y_true) ** 2) + + if normalize_by == "variance": + normalizer = np.var(y_true) + eps + elif normalize_by == "range": + normalizer = (y_true.max() - y_true.min()) ** 2 + eps + else: + raise ValueError( + f"normalize_by must be 'variance' or 'range', got {normalize_by}" + ) + + return float(mse_val / normalizer) + + +def peak_signal_to_noise_ratio( + y_pred: np.ndarray, + y_true: np.ndarray, + data_range: Optional[float] = None, + eps: float = 1e-8, +) -> float: + """ + Peak Signal-to-Noise Ratio (PSNR). + + Args: + y_pred: Predicted values + y_true: Ground truth values + data_range: Dynamic range of the data (max - min). If None, computed from y_true. + eps: Small value to avoid log(0) + + Returns: + float: PSNR in dB + """ + if data_range is None: + data_range = y_true.max() - y_true.min() + + mse_val = np.mean((y_pred - y_true) ** 2) + if mse_val < eps: + return float("inf") + + return float(20 * np.log10(data_range / (np.sqrt(mse_val) + eps))) + + +# ============================================================================ +# PyTorch-based Metrics (for use during training) +# ============================================================================ + + +def mse_torch(pred: Tensor, target: Tensor, dim: Optional[int] = None) -> Tensor: + """ + Mean Squared Error (PyTorch). + Wrapper around PhysicsNemo's mse for consistent API. + + Args: + pred: Predicted tensor + target: Target tensor + dim: Reduction dimension (None for full reduction) + + Returns: + MSE value(s) + """ + return mse(pred, target, dim=dim) + + +def rmse_torch(pred: Tensor, target: Tensor, dim: Optional[int] = None) -> Tensor: + """ + Root Mean Squared Error (PyTorch). + Wrapper around PhysicsNemo's rmse for consistent API. + + Args: + pred: Predicted tensor + target: Target tensor + dim: Reduction dimension (None for full reduction) + + Returns: + RMSE value(s) + """ + return rmse(pred, target, dim=dim) + + +def mae_torch(pred: Tensor, target: Tensor, dim: Optional[int] = None) -> Tensor: + """ + Mean Absolute Error (PyTorch). + + Args: + pred: Predicted tensor + target: Target tensor + dim: Reduction dimension (None for full reduction) + + Returns: + MAE value(s) + """ + return torch.mean(torch.abs(pred - target), dim=dim) + + +def relative_l2_torch( + pred: Tensor, + target: Tensor, + dim: Optional[int] = None, + eps: float = 1e-8, +) -> Tensor: + """ + Relative L2 Error (PyTorch). + + Args: + pred: Predicted tensor + target: Target tensor + dim: Dimension(s) for computing norms. If None, computes over flattened tensors. + eps: Small value to avoid division by zero + + Returns: + Relative L2 error + """ + if dim is None: + pred_flat = pred.reshape(-1) + target_flat = target.reshape(-1) + diff_norm = torch.norm(pred_flat - target_flat, p=2) + target_norm = torch.norm(target_flat, p=2) + else: + diff_norm = torch.norm(pred - target, p=2, dim=dim) + target_norm = torch.norm(target, p=2, dim=dim) + + return diff_norm / (target_norm + eps) + + +def relative_l1_torch( + pred: Tensor, + target: Tensor, + dim: Optional[int] = None, + eps: float = 1e-8, +) -> Tensor: + """ + Relative L1 Error (PyTorch). + + Args: + pred: Predicted tensor + target: Target tensor + dim: Dimension(s) for computing norms. If None, computes over flattened tensors. + eps: Small value to avoid division by zero + + Returns: + Relative L1 error + """ + if dim is None: + pred_flat = pred.reshape(-1) + target_flat = target.reshape(-1) + diff_norm = torch.norm(pred_flat - target_flat, p=1) + target_norm = torch.norm(target_flat, p=1) + else: + diff_norm = torch.norm(pred - target, p=1, dim=dim) + target_norm = torch.norm(target, p=1, dim=dim) + + return diff_norm / (target_norm + eps) + + +def r2_score_torch(pred: Tensor, target: Tensor) -> Tensor: + """ + R² Score (coefficient of determination) in PyTorch. + + Args: + pred: Predicted tensor + target: Target tensor + + Returns: + R² score + """ + ss_res = torch.sum((target - pred) ** 2) + ss_tot = torch.sum((target - target.mean()) ** 2) + + if ss_tot == 0: + return torch.tensor(1.0 if ss_res == 0 else 0.0, device=pred.device) + + return 1 - (ss_res / ss_tot) + + +def max_error_torch(pred: Tensor, target: Tensor) -> Tensor: + """ + Maximum Absolute Error (PyTorch). + + Args: + pred: Predicted tensor + target: Target tensor + + Returns: + Maximum absolute error + """ + return torch.max(torch.abs(pred - target)) + + +def psnr_torch( + pred: Tensor, + target: Tensor, + data_range: Optional[float] = None, + eps: float = 1e-8, +) -> Tensor: + """ + Peak Signal-to-Noise Ratio (PyTorch). + + Args: + pred: Predicted tensor + target: Target tensor + data_range: Dynamic range of data. If None, computed from target. + eps: Small value to avoid log(0) + + Returns: + PSNR in dB + """ + if data_range is None: + data_range = target.max() - target.min() + + mse_val = torch.mean((pred - target) ** 2) + + if mse_val < eps: + return torch.tensor(float("inf"), device=pred.device) + + return 20 * torch.log10(data_range / (torch.sqrt(mse_val) + eps)) diff --git a/examples/reservoir_simulation/neural_operator_factory/training/physics_losses.py b/examples/reservoir_simulation/neural_operator_factory/training/physics_losses.py new file mode 100644 index 0000000000..a7f6a7640c --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/physics_losses.py @@ -0,0 +1,339 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Physics-informed loss functions for reservoir simulation modeling. + +All losses are dimension-agnostic and work with both +3D (B, H, W, T) and 4D (B, X, Y, Z, T) predictions. + +Available physics losses: +- mass_conservation: penalises discrepancies in spatially-integrated + quantities between prediction and ground truth at each timestep + (Chandra et al. 2025, arXiv:2503.11031, Eq. 4). + +Grid convention +--------------- +NOF datasets store cell widths (block sizes) as the last input channels: +- 2D: [..., grid_x, grid_y, grid_t] at channels [-3, -2, -1] +- 3D: [..., grid_x, grid_y, grid_z, grid_t] at channels [-4, -3, -2, -1] + +Cell volumes are computed as the outer product of per-axis widths. +""" + +from __future__ import annotations + +import logging +import warnings +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Grid / volume utilities +# --------------------------------------------------------------------------- + + +def _extract_grid_widths(inputs: Tensor, spatial_ndim: int) -> List[Tensor]: + """Extract per-axis cell widths from the NOF last-channel convention. + + Returns one 1-D tensor of cell widths per spatial axis, in axis order: + - 2D: [widths_H (grid_y), widths_W (grid_x)] + - 3D: [widths_X (grid_x), widths_Y (grid_y), widths_Z (grid_z)] + """ + if spatial_ndim == 2: + widths_w = inputs[0, 0, :, 0, -3] # (W,) x-direction cell widths + widths_h = inputs[0, :, 0, 0, -2] # (H,) y-direction cell widths + return [widths_h, widths_w] # axis-1, axis-2 order + else: + widths_x = inputs[0, :, 0, 0, 0, -4] # (X,) + widths_y = inputs[0, 0, :, 0, 0, -3] # (Y,) + widths_z = inputs[0, 0, 0, :, 0, -2] # (Z,) + return [widths_x, widths_y, widths_z] # axis-1, axis-2, axis-3 order + + +def compute_cell_volumes_from_widths(inputs: Tensor, spatial_ndim: int) -> Tensor: + """Compute per-cell volumes from the NOF grid-width channels. + + Returns (H, W) for 2D or (X, Y, Z) for 3D. + """ + widths = _extract_grid_widths(inputs, spatial_ndim) + if spatial_ndim == 2: + return widths[0].unsqueeze(1) * widths[1].unsqueeze(0) + else: + return ( + widths[0].unsqueeze(1).unsqueeze(2) + * widths[1].unsqueeze(0).unsqueeze(2) + * widths[2].unsqueeze(0).unsqueeze(1) + ) + + +# --------------------------------------------------------------------------- +# Derivative utilities (used by losses.py for derivative regularization) +# --------------------------------------------------------------------------- + +# {dim_name: (tensor_axis, grid_channel_offset_from_end)} +_DERIV_MAP_2D: Dict[str, Tuple[int, int]] = { + "dx": (2, -3), # W axis, grid_x channel + "dy": (1, -2), # H axis, grid_y channel +} +_DERIV_MAP_3D: Dict[str, Tuple[int, int]] = { + "dx": (1, -4), # X axis, grid_x channel + "dy": (2, -3), # Y axis, grid_y channel + "dz": (3, -2), # Z axis, grid_z channel +} + + +def get_deriv_map(spatial_ndim: int) -> Dict[str, Tuple[int, int]]: + """Return {dim_name: (tensor_axis, channel_offset)} for the given dimensionality.""" + return _DERIV_MAP_2D if spatial_ndim == 2 else _DERIV_MAP_3D + + +def cell_centre_distance(cell_widths: Tensor, min_spacing: float = 1e-6) -> Tensor: + """Distance between centres of cell i and cell i+2. + + d[i] = cell_widths[i]/2 + cell_widths[i+1] + cell_widths[i+2]/2 + + A minimum floor is applied to prevent division-by-zero when grid + widths are zero (e.g. inactive cells) or very small (normalized data). + + Parameters + ---------- + cell_widths : Tensor shape (N,) + min_spacing : float floor value for the output + + Returns + ------- + Tensor shape (N-2,) + """ + d = cell_widths[:-2] / 2.0 + cell_widths[1:-1] + cell_widths[2:] / 2.0 + return d.clamp(min=min_spacing) + + +def central_difference(field: Tensor, axis: int, spacing: Tensor) -> Tensor: + """Central-difference derivative along *axis*. + + Computes (f[i+2] - f[i]) / spacing[i] for each interior point. + + Parameters + ---------- + field : Tensor arbitrary-rank tensor + axis : int dimension along which to differentiate + spacing : Tensor (N-2,) cell-centre distances + + Returns + ------- + Tensor with field.shape[axis] reduced by 2 + """ + n = field.shape[axis] + f_right = field.narrow(axis, 2, n - 2) + f_left = field.narrow(axis, 0, n - 2) + + shape = [1] * field.dim() + shape[axis] = -1 + sp = spacing.reshape(shape) + + return (f_right - f_left) / sp + + +def extract_grid_widths_for_axis( + inputs: Tensor, + spatial_ndim: int, + dim_name: str, +) -> Tensor: + """Extract 1-D cell widths for a single derivative direction. + + Parameters + ---------- + inputs : Tensor full input tensor + spatial_ndim : int 2 or 3 + dim_name : str 'dx', 'dy', or 'dz' + + Returns + ------- + Tensor 1-D cell widths along the requested axis + """ + dmap = get_deriv_map(spatial_ndim) + if dim_name not in dmap: + valid = list(dmap.keys()) + raise ValueError( + f"Unknown derivative dim '{dim_name}' for {spatial_ndim}D. Valid: {valid}" + ) + _axis, ch_offset = dmap[dim_name] + # Extract widths along the correct spatial axis + if spatial_ndim == 2: + if dim_name == "dx": + return inputs[0, 0, :, 0, ch_offset] + else: + return inputs[0, :, 0, 0, ch_offset] + else: + if dim_name == "dx": + return inputs[0, :, 0, 0, 0, ch_offset] + elif dim_name == "dy": + return inputs[0, 0, :, 0, 0, ch_offset] + else: + return inputs[0, 0, 0, :, 0, ch_offset] + + +# --------------------------------------------------------------------------- +# Mass conservation loss +# --------------------------------------------------------------------------- + + +class MassConservationLoss(nn.Module): + """Weak mass-conservation constraint via spatial integration. + + Parameters + ---------- + use_cell_volumes : bool + If True, compute cell volumes from the grid-width input channels + (NOF convention). If False, uniform weighting (volume = 1). + eps : float + Numerical stability constant. + """ + + VALID_METRICS = {"relative_l2", "mse", "l1", "huber"} + + def __init__( + self, + use_cell_volumes: bool = False, + metric: str = "relative_l2", + eps: float = 1e-8, + ): + super().__init__() + self.use_cell_volumes = use_cell_volumes + if metric not in self.VALID_METRICS: + raise ValueError( + f"metric must be one of {self.VALID_METRICS}, got '{metric}'" + ) + self.metric = metric + self.eps = eps + self._cell_volumes: Optional[Tensor] = None + self._volumes_device: Optional[torch.device] = None + + def _get_cell_volumes(self, inputs: Tensor, spatial_ndim: int) -> Tensor: + if self._cell_volumes is not None and self._volumes_device == inputs.device: + return self._cell_volumes + if self.use_cell_volumes: + vol = compute_cell_volumes_from_widths(inputs, spatial_ndim) + else: + spatial_shape = inputs.shape[1 : 1 + spatial_ndim] + vol = torch.ones(spatial_shape, device=inputs.device, dtype=inputs.dtype) + self._cell_volumes = vol.detach() + self._volumes_device = inputs.device + return self._cell_volumes + + def forward(self, pred, target, inputs, spatial_mask=None): + ndim = pred.dim() + if ndim == 4: + spatial_ndim, spatial_dims = 2, (1, 2) + elif ndim == 5: + spatial_ndim, spatial_dims = 3, (1, 2, 3) + else: + raise ValueError( + f"Expected 4D (B,H,W,T) or 5D (B,X,Y,Z,T), got {ndim}D shape {tuple(pred.shape)}" + ) + + vol = self._get_cell_volumes(inputs, spatial_ndim) + if spatial_mask is not None: + # Reduce per-sample (B, *spatial) to union (*spatial) + m = spatial_mask + if m.dim() == pred.dim() - 1: + m = m.any(dim=0) + w = (vol * m.float()).unsqueeze(0).unsqueeze(-1) + else: + w = vol.unsqueeze(0).unsqueeze(-1) + + m_pred = (pred * w).sum(dim=spatial_dims) # (B, T) + m_true = (target * w).sum(dim=spatial_dims) # (B, T) + + if self.metric == "relative_l2": + diff_norm = torch.norm(m_true - m_pred, p=2, dim=-1) + true_norm = torch.norm(m_true, p=2, dim=-1) + per_sample = diff_norm / (true_norm + self.eps) + elif self.metric == "mse": + per_sample = ((m_true - m_pred) ** 2).mean(dim=-1) + elif self.metric == "l1": + per_sample = (m_true - m_pred).abs().mean(dim=-1) + elif self.metric == "huber": + per_sample = torch.nn.functional.smooth_l1_loss( + m_pred, m_true, reduction="none" + ).mean(dim=-1) + + return per_sample.mean() + + +# --------------------------------------------------------------------------- +# Registry / factory +# --------------------------------------------------------------------------- + +_PHYSICS_LOSS_REGISTRY: Dict[str, type] = { + "mass_conservation": MassConservationLoss, +} + + +def build_physics_losses(physics_config, variable=None, default_metric="relative_l2"): + """Instantiate physics losses from the loss.physics config block.""" + if physics_config is None: + return {} + + active: Dict[str, tuple] = {} + for name in _PHYSICS_LOSS_REGISTRY: + sub = physics_config.get(name, None) + if sub is None or not sub.get("enabled", False): + continue + weight = float(sub.get("weight", 1.0)) + if weight <= 0: + continue + + cls = _PHYSICS_LOSS_REGISTRY[name] + kwargs: dict = {} + + if name == "mass_conservation": + kwargs["use_cell_volumes"] = bool(sub.get("use_cell_volumes", False)) + kwargs["eps"] = float(sub.get("eps", 1e-8)) + metric = sub.get("metric", None) + kwargs["metric"] = metric if metric is not None else default_metric + if variable is not None and variable.lower() == "pressure": + warnings.warn( + "mass_conservation loss is enabled for variable='pressure'. " + "Spatially-integrated pressure is not a conserved quantity; " + "this loss is physically meaningful only for mass-like " + "variables (e.g. saturation, CO2 mass).", + stacklevel=2, + ) + + loss_module = cls(**kwargs) + active[name] = (loss_module, weight) + logger.info("Physics loss '%s' enabled (weight=%.4g)", name, weight) + + return active + + +__all__ = [ + "MassConservationLoss", + "compute_cell_volumes_from_widths", + "cell_centre_distance", + "central_difference", + "get_deriv_map", + "extract_grid_widths_for_axis", + "build_physics_losses", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/training/train.py b/examples/reservoir_simulation/neural_operator_factory/training/train.py new file mode 100644 index 0000000000..495986dc27 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/training/train.py @@ -0,0 +1,1265 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training script for neural operator reservoir simulation models.""" + +import sys +from pathlib import Path + +# Add parent directory (neural_operator_factory/) to path for package imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import hydra +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from models.xdeeponet import DeepONet3DWrapper, DeepONetWrapper +from models.xfno import FNO4DNet, UFNONet +from omegaconf import DictConfig +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Adam +from torch.optim.lr_scheduler import ExponentialLR, StepLR +from utils.checkpoint import load_checkpoint, save_checkpoint + + +def print_model_architecture(model, model_type: str, dimensions: str, cfg, logger): + """Print detailed model architecture for any model type.""" + logger.info("=" * 80) + logger.info("MODEL ARCHITECTURE") + logger.info("=" * 80) + + # Get the actual model (unwrap DDP if needed) + if hasattr(model, "module"): + actual_model = model.module + else: + actual_model = model + + # Print model type and dimensions + logger.info(f"Dimensions: {dimensions.upper()}") + logger.info(f"Model Type: {model_type.upper()}") + + if model_type == "xdeeponet": + variant = cfg.arch.xdeeponet.get("variant", "u_deeponet") + logger.info(f"Variant: {variant}") + logger.info("") + + # Branch configuration + branch1_cfg = cfg.arch.xdeeponet.get("branch1", {}) + b1_enc = branch1_cfg.get("encoder", {}) + b1_layers = branch1_cfg.get("layers", {}) + b1_enc_type = ( + b1_enc.get("type", "linear") if isinstance(b1_enc, dict) else b1_enc + ) + logger.info("Branch 1:") + logger.info(f" Encoder: {b1_enc_type}") + logger.info(" In Channels: auto (inferred from input tensor)") + logger.info(f" Fourier Layers: {b1_layers.get('num_fourier_layers', 0)}") + logger.info(f" UNet Layers: {b1_layers.get('num_unet_layers', 0)}") + logger.info(f" Conv Layers: {b1_layers.get('num_conv_layers', 0)}") + logger.info(f" Layer Activation: {b1_layers.get('activation_fn', 'sin')}") + + if variant in ["mionet", "fourier_mionet", "tno"]: + branch2_cfg = cfg.arch.xdeeponet.get("branch2", {}) + b2_enc = branch2_cfg.get("encoder", {}) + b2_layers = branch2_cfg.get("layers", {}) + b2_enc_type = ( + b2_enc.get("type", "linear") if isinstance(b2_enc, dict) else b2_enc + ) + logger.info("Branch 2:") + logger.info(f" Encoder: {b2_enc_type}") + logger.info(f" Fourier Layers: {b2_layers.get('num_fourier_layers', 0)}") + logger.info(f" UNet Layers: {b2_layers.get('num_unet_layers', 0)}") + logger.info(f" Layer Activation: {b2_layers.get('activation_fn', 'sin')}") + + # Trunk configuration + trunk_cfg = cfg.arch.xdeeponet.get("trunk", {}) + trunk_input = trunk_cfg.get("input_type", "time") + in_features = (4 if dimensions == "4d" else 3) if trunk_input == "grid" else 1 + coord_desc = "x,y,z,t" if dimensions == "4d" else "x,y,t" + logger.info("Trunk:") + logger.info( + f" Input Type: {trunk_input} ({coord_desc if trunk_input == 'grid' else 'just t'})" + ) + logger.info(f" In Features: {in_features}") + logger.info(f" Hidden Width: {trunk_cfg.get('hidden_width', 128)}") + logger.info(f" Num Layers: {trunk_cfg.get('num_layers', 6)}") + logger.info(f" Activation: {trunk_cfg.get('activation_fn', 'sin')}") + + # Decoder configuration + logger.info("Decoder:") + logger.info(f" Type: {cfg.arch.xdeeponet.get('decoder_type', 'mlp')}") + logger.info(f" Width: {cfg.arch.xdeeponet.get('decoder_width', 128)}") + logger.info(f" Layers: {cfg.arch.xdeeponet.get('decoder_layers', 2)}") + logger.info( + f" Activation: {cfg.arch.xdeeponet.get('decoder_activation_fn', 'relu')}" + ) + + logger.info(f"Latent Width: {cfg.arch.xdeeponet.get('width', 64)}") + logger.info(f"Padding: {cfg.arch.xdeeponet.get('padding', 8)}") + + elif model_type == "xfno": + xfno_cfg = cfg.arch.xfno + logger.info(f"Out Channels: {xfno_cfg.out_channels}") + logger.info(f"Width: {xfno_cfg.width}") + if dimensions == "4d": + logger.info( + f"Modes: ({xfno_cfg.modes1}, {xfno_cfg.modes2}, {xfno_cfg.modes3}, {xfno_cfg.modes4})" + ) + else: + logger.info( + f"Modes: ({xfno_cfg.modes1}, {xfno_cfg.modes2}, {xfno_cfg.modes3})" + ) + logger.info(f"FNO Layers: {xfno_cfg.num_fno_layers}") + if dimensions == "3d": + logger.info(f"U-Net Layers: {xfno_cfg.num_unet_layers}") + logger.info(f"Conv Layers: {xfno_cfg.num_conv_layers}") + logger.info( + f"Lifting: type={xfno_cfg.lifting_type}, layers={xfno_cfg.lifting_layers}" + ) + else: + logger.info(f"Coord Features: {xfno_cfg.coord_features}") + logger.info(f"Activation: {xfno_cfg.activation_fn}") + logger.info( + f"Decoder: layers={xfno_cfg.decoder_layers}, width={xfno_cfg.decoder_width}" + ) + + # Print full model structure + logger.info("") + logger.info("Full Model Structure:") + logger.info("-" * 80) + for line in str(actual_model).split("\n"): + logger.info(line) + logger.info("-" * 80) + + # Count parameters per component + logger.info("") + logger.info("Parameter Counts:") + total_params = 0 + for name, module in actual_model.named_children(): + params = sum(p.numel() for p in module.parameters() if p.requires_grad) + total_params += params + logger.info(f" {name}: {params:,} parameters") + logger.info(f" TOTAL: {total_params:,} parameters") + logger.info("=" * 80) + + +from data.dataloader import create_dataloaders # noqa: E402 +from data.validation import ( # noqa: E402 + print_validation_summary, + validate_batch_dimensions, +) +from utils.co2_normalization import dnorm_dP # noqa: E402 + +from physicsnemo.distributed import DistributedManager # noqa: E402 +from physicsnemo.launch.logging import LaunchLogger, PythonLogger # noqa: E402 +from training.ar_utils import ( # noqa: E402 + ar_validate_full_rollout, + compute_unroll_steps, + get_training_stage, + live_rollout_step, + rollout_step, + teacher_forcing_step, +) +from training.losses import get_loss_function # noqa: E402 +from training.metrics import ( # noqa: E402 + compute_relative_l2_error, + mean_absolute_error, + mean_plume_error, + mean_relative_error, +) + +# Registry of denormalization functions that can be selected via config. +_DENORM_REGISTRY = { + "dnorm_dP": dnorm_dP, +} + + +def _get_batch_mask(inputs, mask_channel, mask_per_sample, static_mask): + """Construct the spatial mask for the current batch. + + Works for any structured-grid dataset (3D or 4D). When the mask + is static (identical across all samples), returns ``(*spatial)``. + When it varies per sample, returns ``(B, *spatial)`` so each + sample's loss is computed only on its own active cells. + Returns *None* when no mask channel is available. + """ + if mask_channel is None: + return None + if not mask_per_sample: + return static_mask + # Per-sample: inputs shape is (B, *spatial, T, C). + # Returns (B, *spatial) boolean mask. + return inputs[..., 0, mask_channel] != 0 + + +# Registry of validation metric functions (numpy-based, operate on flat arrays). +def _rmse_np(y_pred, y_true): + return float(np.sqrt(np.mean((y_pred - y_true) ** 2))) + + +_METRIC_REGISTRY = { + "rmse": ("RMSE", "val_rmse", _rmse_np), + "mae": ("MAE", "val_mae", mean_absolute_error), + "mre": ("MRE", "val_mre", mean_relative_error), + "mpe": ("MPE", "val_mpe", mean_plume_error), + "relative_l2": ("RelL2", "val_relative_l2", compute_relative_l2_error), +} + + +def _live_rollout_ddp_safe( + model, dist, inputs, targets, loss_fn, ar_common, max_steps=None +): + """Run live_rollout_step with correct DDP gradient synchronization. + + live_rollout_step performs multiple forward passes with a single backward, + which conflicts with DDP's per-forward gradient hooks. This wrapper + disables DDP sync during the step, then manually AllReduces gradients + so all GPUs apply identical weight updates. + """ + kwargs = ( + dict(ar_common, max_steps=max_steps) + if max_steps is not None + else dict(ar_common) + ) + + if isinstance(model, DDP): + with model.no_sync(): + loss = live_rollout_step(model, inputs, targets, loss_fn, **kwargs) + for param in model.parameters(): + if param.grad is not None: + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.SUM + ) + param.grad /= dist.world_size + else: + loss = live_rollout_step(model, inputs, targets, loss_fn, **kwargs) + + return loss + + +@hydra.main(version_base="1.3", config_path="../conf", config_name="training_config") +def main(cfg: DictConfig) -> None: + """Main training function for neural operator reservoir simulation models.""" + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + # Helper variable for MLFlow logging (only on rank 0) + use_mlflow = cfg.logging.use_mlflow and dist.rank == 0 + + # Set random seeds for reproducibility + if hasattr(cfg, "seed"): + import random + + import numpy as np + + seed = cfg.seed + dist.rank # Different seed per rank for data augmentation + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + # Set deterministic behavior if requested + if cfg.compute.deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + elif cfg.compute.benchmark: + torch.backends.cudnn.benchmark = True + + # Initialize logger + logger = PythonLogger(name="nof_train") + logger.file_logging() + LaunchLogger.initialize() + + # Print header (only on rank 0) + if dist.rank == 0: + dimensions = cfg.arch.dimensions.lower() + model_type = cfg.arch.model.lower() + if model_type == "xdeeponet": + model_name = cfg.arch.xdeeponet.variant.replace("_", "-").upper() + else: + model_name = model_type.upper() + logger.info("=" * 80) + logger.info( + f"{model_name} ({dimensions.upper()}) Training | Variable: {cfg.data.variable} | GPUs: {dist.world_size}" + ) + logger.info("=" * 80) + + # Auto-set num_workers based on num_gpus if not specified + num_workers = cfg.data.num_workers + if num_workers is None and hasattr(cfg.compute, "num_gpus"): + num_workers = cfg.compute.num_gpus * 2 # 2 workers per GPU + elif num_workers is None: + num_workers = 4 # Default fallback + + # Get dimensions from config (used for model selection and data validation) + expected_dimensions = cfg.arch.dimensions.lower() + + train_loader, val_loader, test_loader = create_dataloaders( + data_path=cfg.data.data_path, + batch_size=cfg.training.batch_size, + normalize=cfg.data.normalize, + num_workers=num_workers, + device=dist.device, + input_file=cfg.data.get("input_file", None), + output_file=cfg.data.get("output_file", None), + variable=cfg.data.get("variable", None), + expected_dimensions=expected_dimensions, + use_mask=cfg.data.get("mask_enabled", False), + mask_channel=cfg.data.get("mask_channel", None), + num_timesteps=cfg.data.get("num_timesteps", None), + ) + + # Masking metadata from dataset + ds = train_loader.dataset + mask_channel = getattr(ds, "mask_channel", None) + mask_per_sample = getattr(ds, "mask_per_sample", False) + static_mask = ds.get_static_mask() + if static_mask is not None: + static_mask = static_mask.to(dist.device) + + # Detect variants with branch2 + regime = cfg.training.get("regime", "full_mapping").lower() + _variant = ( + cfg.arch.xdeeponet.get("variant", "") + if cfg.arch.model.lower() == "xdeeponet" + else "" + ) + is_tno = _variant == "tno" + has_branch2 = _variant in ("mionet", "fourier_mionet", "tno") + if is_tno: + if regime != "autoregressive": + raise ValueError("TNO variant requires regime: autoregressive") + if dist.rank == 0: + logger.info("TNO mode: branch2 receives previous solution state") + elif has_branch2 and dist.rank == 0: + logger.info(f"MIONet mode ({_variant}): branch2 processes scalar inputs") + + # Print data info (only on rank 0) + if dist.rank == 0: + effective_batch_size = cfg.training.batch_size * dist.world_size + logger.info( + f"Data: Train={len(train_loader.dataset)}, Val={len(val_loader.dataset)}, Test={len(test_loader.dataset)} | Batch size={cfg.training.batch_size} per GPU (Effective: {effective_batch_size})" + ) + + # Validate data dimensions against config + if dist.rank == 0: + logger.info("Validating data dimensions...") + + # Get a sample batch to check dimensions + sample_inputs, sample_targets = next(iter(train_loader)) + + # Validate using centralized validation function + validation_info = validate_batch_dimensions( + sample_inputs, sample_targets, cfg.data.get("variable", "unknown") + ) + detected_dimensions = validation_info["dimensions"] + + # Check that detected dimensions match config + if detected_dimensions != expected_dimensions: + raise ValueError( + f"❌ Dimension mismatch! Config specifies '{expected_dimensions}' but data is '{detected_dimensions}'.\n" + f" Config: arch.dimensions = {expected_dimensions}\n" + f" Data: Input shape {tuple(sample_inputs.shape)} → {detected_dimensions}\n" + f" Please update arch.dimensions in config to match your dataset." + ) + + # Print validation summary + print_validation_summary( + input_shape=tuple(sample_inputs.shape), + target_shape=tuple(sample_targets.shape), + variable=cfg.data.get("variable", "unknown"), + is_batch=True, + logger=logger, + ) + + # Create model based on dimensions and model type + dimensions = cfg.arch.dimensions.lower() + model_type = cfg.arch.model.lower() + + # Get in_channels from first batch (for auto-discovery) + sample_inputs, _ = next(iter(train_loader)) + in_channels = sample_inputs.shape[-1] # Last dimension is channels + + # Account for feedback channel (appended during AR training) + _ar_feedback_init = cfg.training.get( + "regime", "full_mapping" + ).lower() == "autoregressive" and cfg.training.autoregressive.get( + "use_feedback_channel", False + ) + if _ar_feedback_init: + in_channels += 1 + + if model_type == "xfno": + xfno_cfg = cfg.arch.xfno + + if dimensions == "4d": + # 4D FNO (3D spatial + time) - Pure FNO only + logger.info( + f"Creating FNO4D model (FNO layers: {xfno_cfg.num_fno_layers}, " + f"modes: [{xfno_cfg.modes1}, {xfno_cfg.modes2}, {xfno_cfg.modes3}, {xfno_cfg.modes4}])" + ) + model = FNO4DNet( + in_channels=in_channels, + out_channels=xfno_cfg.out_channels, + width=xfno_cfg.width, + modes1=xfno_cfg.modes1, + modes2=xfno_cfg.modes2, + modes3=xfno_cfg.modes3, + modes4=xfno_cfg.modes4, + num_fno_layers=xfno_cfg.num_fno_layers, + padding=xfno_cfg.padding, + activation_fn=xfno_cfg.activation_fn, + lifting_layers=xfno_cfg.lifting_layers, + decoder_layers=xfno_cfg.decoder_layers, + decoder_width=xfno_cfg.decoder_width, + coord_features=xfno_cfg.coord_features, + ).to(dist.device) + model_arch_name = "fno4d" + else: + # 3D FNO (2D spatial + time) - With optional U-Net/Conv + num_unet = xfno_cfg.num_unet_layers + num_conv = xfno_cfg.num_conv_layers + + if num_unet > 0 and num_conv > 0: + logger.warning("⚠️ Using both U-Net and Conv layers (Conv-U-FNO).") + model_arch_name = f"convufno_{xfno_cfg.unet_type}" + elif num_unet > 0: + model_arch_name = f"ufno_{xfno_cfg.unet_type}" + elif num_conv > 0: + model_arch_name = "convfno" + else: + model_arch_name = "fno" + + logger.info( + f"Creating {model_arch_name.upper()} model (FNO: {xfno_cfg.num_fno_layers}, " + f"U-Net: {num_unet}, Conv: {num_conv})" + ) + + model = UFNONet( + in_channels=in_channels, + out_channels=xfno_cfg.out_channels, + width=xfno_cfg.width, + modes1=xfno_cfg.modes1, + modes2=xfno_cfg.modes2, + modes3=xfno_cfg.modes3, + num_fno_layers=xfno_cfg.num_fno_layers, + num_unet_layers=num_unet, + num_conv_layers=num_conv, + padding=xfno_cfg.padding, + conv_kernel_size=xfno_cfg.conv_kernel_size, + unet_kernel_size=xfno_cfg.unet_kernel_size, + unet_dropout=xfno_cfg.unet_dropout, + unet_type=xfno_cfg.unet_type, + activation_fn=xfno_cfg.activation_fn, + lifting_type=xfno_cfg.lifting_type, + lifting_layers=xfno_cfg.lifting_layers, + lifting_width=xfno_cfg.lifting_width, + decoder_type=xfno_cfg.decoder_type, + decoder_layers=xfno_cfg.decoder_layers, + decoder_width=xfno_cfg.decoder_width, + decoder_activation_fn=xfno_cfg.get("decoder_activation_fn", None), + ).to(dist.device) + + elif model_type == "xdeeponet": + xdeeponet_cfg = cfg.arch.xdeeponet + variant = xdeeponet_cfg.variant + + # Build branch configs from yaml + branch1_config = dict(xdeeponet_cfg.branch1) + branch2_config = ( + dict(xdeeponet_cfg.branch2) + if variant in ["mionet", "fourier_mionet", "tno"] + else None + ) + trunk_config = dict(xdeeponet_cfg.trunk) + + if dimensions == "4d": + # 4D DeepONet (3D spatial + time) + logger.info( + f"Creating DeepONet3D model (variant: {variant}, " + f"branch1: {branch1_config.get('encoder', 'spatial')}, width: {xdeeponet_cfg.width})" + ) + model = DeepONet3DWrapper( + padding=xdeeponet_cfg.padding, + variant=variant, + width=xdeeponet_cfg.width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=xdeeponet_cfg.get("decoder_type", "mlp"), + decoder_width=xdeeponet_cfg.decoder_width, + decoder_layers=xdeeponet_cfg.decoder_layers, + decoder_activation_fn=xdeeponet_cfg.get( + "decoder_activation_fn", "relu" + ), + ).to(dist.device) + b1_enc = branch1_config.get("encoder", "spatial") + b1_enc_name = ( + b1_enc.get("type", "linear") if not isinstance(b1_enc, str) else b1_enc + ) + model_arch_name = f"deeponet3d_{variant}_{b1_enc_name}" + else: + # 3D DeepONet (2D spatial + time) + logger.info( + f"Creating DeepONet model (variant: {variant}, " + f"branch1: {branch1_config.get('encoder', 'spatial')}, width: {xdeeponet_cfg.width})" + ) + model = DeepONetWrapper( + padding=xdeeponet_cfg.padding, + variant=variant, + width=xdeeponet_cfg.width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=xdeeponet_cfg.get("decoder_type", "mlp"), + decoder_width=xdeeponet_cfg.decoder_width, + decoder_layers=xdeeponet_cfg.decoder_layers, + decoder_activation_fn=xdeeponet_cfg.get( + "decoder_activation_fn", "relu" + ), + ).to(dist.device) + b1_enc = branch1_config.get("encoder", "spatial") + b1_enc_name = ( + b1_enc.get("type", "linear") if not isinstance(b1_enc, str) else b1_enc + ) + model_arch_name = f"deeponet_{variant}_{b1_enc_name}" + + else: + raise ValueError(f"Unknown model: {model_type}. Use 'xfno' or 'xdeeponet'.") + + # Set temporal projection output window before any forward pass + if ( + regime == "autoregressive" + and hasattr(model, "_temporal_projection") + and model._temporal_projection + ): + ar_K_init = cfg.training.autoregressive.output_window + model.set_output_window(ar_K_init) + if dist.rank == 0: + logger.info(f"Temporal projection decoder: output window K={ar_K_init}") + + # Initialize lazy modules with a dummy forward pass (required for DDP) + # This is needed because nn.LazyLinear doesn't know its input size until first forward + _ar_feedback_init = regime == "autoregressive" and cfg.training.autoregressive.get( + "use_feedback_channel", False + ) + if dist.rank == 0: + logger.info("Initializing model with dummy forward pass...") + with torch.no_grad(): + dummy_batch = next(iter(train_loader)) + dummy_input = dummy_batch[0].to(dist.device) + if _ar_feedback_init: + dummy_fb = torch.zeros_like(dummy_input[..., :1]) + dummy_input = torch.cat([dummy_input, dummy_fb], dim=-1) + if is_tno: + dummy_target = dummy_batch[1].to(dist.device) + _L = cfg.training.autoregressive.input_window + dummy_b2 = dummy_target[..., :_L] + _ = model(dummy_input, x_branch2=dummy_b2) + elif has_branch2: + dummy_b2 = dummy_input[:, 0, 0, 0, :] + _ = model(dummy_input, x_branch2=dummy_b2) + else: + _ = model(dummy_input) + if dist.rank == 0: + logger.info("Model initialization complete.") + + # Wrap model with DistributedDataParallel for multi-GPU training + if dist.world_size > 1: + model = DDP( + model, + device_ids=[dist.local_rank], + output_device=dist.local_rank, + find_unused_parameters=False, + ) + + # Count trainable parameters + model_for_counting = model.module if isinstance(model, DDP) else model + if hasattr(model_for_counting, "count_params"): + trainable_params = model_for_counting.count_params() + else: + trainable_params = sum( + p.numel() for p in model_for_counting.parameters() if p.requires_grad + ) + + # Print model info (only on rank 0) + if dist.rank == 0: + logger.info( + f"Model: {model.__class__.__name__} | Parameters: {trainable_params:,}" + ) + # Print detailed model architecture + print_model_architecture(model, model_type, dimensions, cfg, logger) + + # Create training loss function + loss_fn = get_loss_function(cfg.loss, variable=cfg.data.get("variable", None)) + + # Create validation loss function (same as training loss for fair comparison) + from omegaconf import DictConfig + + # Validation loss: same base losses, no derivatives, no physics losses + val_loss_cfg = DictConfig( + { + "types": list(cfg.loss.types), + "weights": list(cfg.loss.weights), + "reduction": cfg.loss.get("reduction", "mean"), + } + ) + val_loss_fn = get_loss_function(val_loss_cfg) + + # Print loss info (only on rank 0) + if dist.rank == 0: + types_str = "+".join( + f"{w}*{t}" for t, w in zip(cfg.loss.types, cfg.loss.weights) + ) + loss_info = f"Train Loss: {types_str}" + if cfg.loss.get("derivative", {}).get("enabled", False): + loss_info += f" (+Derivative w={cfg.loss.derivative.weight}, dims={list(cfg.loss.derivative.dims)})" + logger.info(loss_info) + + # Create optimizer and scheduler + optimizer = Adam( + model.parameters(), + lr=cfg.training.initial_lr, + weight_decay=cfg.optimizer.weight_decay, + ) + + # Create scheduler based on config type + scheduler_type = cfg.scheduler.type.lower() + if scheduler_type == "step": + scheduler = StepLR( + optimizer, step_size=cfg.scheduler.step_size, gamma=cfg.scheduler.gamma + ) + elif scheduler_type == "exponential": + scheduler = ExponentialLR(optimizer, gamma=cfg.scheduler.gamma) + else: + raise ValueError( + f"Unknown scheduler type: {scheduler_type}. Must be 'step' or 'exponential'" + ) + + # Initialize AMP GradScaler if enabled + scaler = GradScaler() if cfg.training.use_amp else None + + # Print optimizer info (only on rank 0) + if dist.rank == 0: + logger.info( + f"Optimizer: Adam (lr={cfg.training.initial_lr}) | AMP: {cfg.training.use_amp}" + ) + + # Setup MLFlow tracking if enabled + if use_mlflow: + mlflow.set_experiment(cfg.logging.experiment_name) + + # Enable automatic system metrics logging (CPU, GPU, memory, disk, network) + mlflow.enable_system_metrics_logging() + + mlflow.start_run() + + # Log hyperparameters + mlflow_params = { + "dimensions": dimensions, + "model_type": model_type, + "model_arch_name": model_arch_name, + "batch_size": cfg.training.batch_size, + "epochs": cfg.training.epochs, + "learning_rate": cfg.training.initial_lr, + "optimizer": "Adam", + "train_loss": "+".join(list(cfg.loss.types)), + "loss_masking": cfg.data.get("mask_enabled", False), + "loss_derivative": cfg.loss.get("derivative", {}).get("enabled", False), + "use_amp": cfg.training.use_amp, + "use_graphs": cfg.training.use_graphs, + "variable": cfg.data.variable, + "trainable_parameters": trainable_params, + "in_channels": in_channels, + } + + if model_type == "xfno": + xfno_cfg = cfg.arch.xfno + mlflow_params.update( + { + "width": xfno_cfg.width, + "modes1": xfno_cfg.modes1, + "modes2": xfno_cfg.modes2, + "modes3": xfno_cfg.modes3, + "num_fno_layers": xfno_cfg.num_fno_layers, + "padding": xfno_cfg.padding, + "activation_fn": xfno_cfg.activation_fn, + } + ) + if dimensions == "4d": + mlflow_params["modes4"] = xfno_cfg.modes4 + else: + mlflow_params["num_unet_layers"] = xfno_cfg.num_unet_layers + mlflow_params["num_conv_layers"] = xfno_cfg.num_conv_layers + elif model_type == "xdeeponet": + xdeeponet_cfg = cfg.arch.xdeeponet + mlflow_params.update( + { + "variant": xdeeponet_cfg.variant, + "width": xdeeponet_cfg.width, + "padding": xdeeponet_cfg.padding, + "branch1_encoder": xdeeponet_cfg.branch1.get("encoder", {}).get( + "type", "linear" + ) + if isinstance(xdeeponet_cfg.branch1.get("encoder"), dict) + else xdeeponet_cfg.branch1.get("encoder", "spatial"), + } + ) + + mlflow.log_params(mlflow_params) + + # Setup checkpointing (make absolute path since chdir=False) + checkpoint_dir = Path(cfg.training.checkpoint_dir) + if not checkpoint_dir.is_absolute(): + # If relative, make it relative to the working directory (U-FNO folder) + checkpoint_dir = Path.cwd() / checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Load checkpoint if resuming + start_epoch = 1 + best_val_loss = float("inf") + best_val_mre = float("inf") + _resume_ckpt = None + + if ( + hasattr(cfg.training, "resume_from_checkpoint") + and cfg.training.resume_from_checkpoint + ): + checkpoint_path = Path(cfg.training.resume_from_checkpoint) + if checkpoint_path.exists(): + if dist.rank == 0: + logger.info(f"Loading checkpoint from: {checkpoint_path}") + _resume_ckpt = load_checkpoint(checkpoint_path, device=dist.device) + + # Validate that the checkpoint architecture matches the current model + _resume_ckpt.get("model_config", {}) + + model_to_load = model.module if isinstance(model, DDP) else model + model_to_load.load_state_dict(_resume_ckpt["model_state_dict"]) + + start_epoch = _resume_ckpt["epoch"] + 1 + best_val_loss = _resume_ckpt.get("val_loss", float("inf")) + best_val_mre = _resume_ckpt.get("val_mre", float("inf")) + + if dist.rank == 0: + logger.success( + f"Resumed from epoch {_resume_ckpt['epoch']}, " + f"best val loss: {best_val_loss:.6f}" + ) + else: + if dist.rank == 0: + logger.warning( + f"Checkpoint not found at {checkpoint_path}, starting from scratch" + ) + else: + if dist.rank == 0: + logger.success("Starting training from scratch...") + + # Restore optimizer and scheduler state if resuming + if _resume_ckpt is not None: + if "optimizer_state_dict" in _resume_ckpt: + optimizer.load_state_dict(_resume_ckpt["optimizer_state_dict"]) + if dist.rank == 0: + logger.info( + f"Restored optimizer state (LR={optimizer.param_groups[0]['lr']:.2e})" + ) + if "scheduler_state_dict" in _resume_ckpt: + scheduler.load_state_dict(_resume_ckpt["scheduler_state_dict"]) + if dist.rank == 0: + logger.info("Restored scheduler state") + del _resume_ckpt # free memory + + # --------------------------------------------------------------------------- + # Determine training regime + # --------------------------------------------------------------------------- + regime = cfg.training.get("regime", "full_mapping").lower() + if regime == "autoregressive": + ar_cfg = cfg.training.autoregressive + ar_L = ar_cfg.input_window + ar_K = ar_cfg.output_window + ar_stride = ar_cfg.get("stride", None) + tf_epochs = ar_cfg.teacher_forcing_epochs + pf_epochs = ar_cfg.get("pushforward_epochs", 0) + ro_epochs = ar_cfg.get("rollout_epochs", 0) + total_epochs = tf_epochs + pf_epochs + ro_epochs + ar_noise_std = ar_cfg.get("noise_std", 0.0) + ar_feedback = ar_cfg.get("use_feedback_channel", False) + ar_max_unroll = ar_cfg.get( + "max_unroll", ar_cfg.get("pushforward_max_unroll", 5) + ) + ar_lr_reset = ar_cfg.get("lr_reset_factor", 1.0) + ar_rollout_mode = ar_cfg.get("rollout_mode", "detached").lower() + + if dist.rank == 0: + logger.info("=" * 80) + logger.info(f"AUTOREGRESSIVE TRAINING | L={ar_L}, K={ar_K}") + logger.info(f" Stage 1 — Teacher Forcing: {tf_epochs} epochs") + if pf_epochs > 0: + logger.info( + f" Stage 2 — Pushforward: {pf_epochs} epochs (unroll 1 -> {ar_max_unroll})" + ) + logger.info( + f" Stage 3 — Rollout: {ro_epochs} epochs ({ar_rollout_mode})" + ) + logger.info(f" Total: {total_epochs} epochs") + if ar_noise_std > 0: + logger.info(f" Noise: std={ar_noise_std}") + if ar_feedback: + logger.info(" Feedback channel: enabled") + if not is_tno and not ar_feedback: + logger.warning( + "Autoregressive training without TNO or feedback channel: " + "the model will not receive its own predictions as input. " + "Set autoregressive.use_feedback_channel: true for real AR feedback." + ) + logger.info("=" * 80) + else: + total_epochs = cfg.training.epochs + if dist.rank == 0: + logger.info("=" * 80) + logger.info("FULL-MAPPING TRAINING") + logger.info(f" Epochs: {total_epochs}") + logger.info("=" * 80) + + # Training loop + for epoch in range(start_epoch, total_epochs + 1): + # Set epoch for distributed sampler (ensures proper shuffling across epochs) + if hasattr(train_loader.sampler, "set_epoch"): + train_loader.sampler.set_epoch(epoch) + + # Training step + with LaunchLogger( + "train", epoch=epoch, num_mini_batch=len(train_loader) + ) as log: + model.train() + total_loss = 0.0 + + for batch_idx, (inputs, targets) in enumerate(train_loader): + inputs = inputs.to(dist.device) + targets = targets.to(dist.device) + optimizer.zero_grad() + + batch_mask = _get_batch_mask( + inputs, mask_channel, mask_per_sample, static_mask + ) + + if regime == "autoregressive": + stage = get_training_stage(epoch, tf_epochs, pf_epochs, ro_epochs) + ar_common = dict( + L=ar_L, + K=ar_K, + spatial_mask=batch_mask, + is_tno=is_tno, + noise_std=ar_noise_std, + feedback_channel=1 if ar_feedback else None, + stride=ar_stride, + ) + if stage == "teacher_forcing": + loss = teacher_forcing_step( + model, + inputs, + targets, + loss_fn, + **ar_common, + ) + elif stage == "pushforward": + unroll = compute_unroll_steps( + epoch, + tf_epochs + 1, + pf_epochs, + ar_max_unroll, + ) + loss = _live_rollout_ddp_safe( + model, + dist, + inputs, + targets, + loss_fn, + ar_common, + max_steps=unroll, + ) + else: + if ar_rollout_mode == "live_gradients": + loss = _live_rollout_ddp_safe( + model, + dist, + inputs, + targets, + loss_fn, + ar_common, + ) + else: + loss = rollout_step( + model, + inputs, + targets, + loss_fn, + **ar_common, + ) + else: + # Full-mapping: single forward pass over entire trajectory + fwd_kwargs = {} + if has_branch2 and not is_tno: + fwd_kwargs["x_branch2"] = inputs[:, 0, 0, 0, :] + + if cfg.training.use_amp: + with autocast(): + pred = model(inputs, **fwd_kwargs) + loss = loss_fn( + pred, targets, inputs, spatial_mask=batch_mask + ) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + if dist.world_size > 1: + loss_tensor = loss.detach().clone() + torch.distributed.all_reduce( + loss_tensor, op=torch.distributed.ReduceOp.SUM + ) + total_loss += loss_tensor / dist.world_size + else: + total_loss += loss.detach() + continue + else: + pred = model(inputs, **fwd_kwargs) + loss = loss_fn(pred, targets, inputs, spatial_mask=batch_mask) + + # Backward + step + if regime == "autoregressive": + # All AR step functions call backward() internally + # and return detached scalars for logging. + optimizer.step() + else: + loss.backward() + optimizer.step() + + if dist.world_size > 1: + loss_tensor = loss.detach().clone() + torch.distributed.all_reduce( + loss_tensor, op=torch.distributed.ReduceOp.SUM + ) + total_loss += loss_tensor / dist.world_size + else: + total_loss += loss.detach() + + avg_train_loss = total_loss / len(train_loader) + + # Log stage transitions and LR reset for AR + if regime == "autoregressive": + stage = get_training_stage(epoch, tf_epochs, pf_epochs, ro_epochs) + prev_stage = ( + get_training_stage(epoch - 1, tf_epochs, pf_epochs, ro_epochs) + if epoch > 1 + else None + ) + if prev_stage is not None and stage != prev_stage: + if ar_lr_reset != 1.0: + for pg in optimizer.param_groups: + pg["lr"] *= ar_lr_reset + if dist.rank == 0: + new_lr = optimizer.param_groups[0]["lr"] + logger.info("=" * 60) + logger.info( + f"STAGE TRANSITION: {prev_stage.upper().replace('_', ' ')} -> {stage.upper().replace('_', ' ')} (LR={new_lr:.2e})" + ) + if stage == "pushforward": + logger.info( + f" Pushforward curriculum: unroll 1 -> {ar_max_unroll} over {pf_epochs} epochs" + ) + logger.info("=" * 60) + + log.log_epoch({"loss": avg_train_loss}) + + if cfg.logging.use_mlflow and dist.rank == 0: + mlflow.log_metric("train_loss", float(avg_train_loss), step=epoch) + + # Validation step + if epoch % cfg.training.validate_freq == 0: + with LaunchLogger("valid", epoch=epoch) as log: + model.eval() + total_val_loss = 0.0 + mre_list = [] + + with torch.no_grad(): + for inputs, targets in val_loader: + inputs = inputs.to(dist.device) + targets = targets.to(dist.device) + + val_batch_mask = _get_batch_mask( + inputs, mask_channel, mask_per_sample, static_mask + ) + + # Forward pass — same regime as training + if regime == "autoregressive": + pred = ar_validate_full_rollout( + model, + inputs, + targets, + L=ar_L, + K=ar_K, + is_tno=is_tno, + feedback_channel=1 if ar_feedback else None, + ) + else: + val_fwd = {} + if has_branch2 and not is_tno: + val_fwd["x_branch2"] = inputs[:, 0, 0, 0, :] + pred = model(inputs, **val_fwd) + + if cfg.training.use_amp: + with autocast(): + val_loss = val_loss_fn( + pred, targets, inputs, spatial_mask=val_batch_mask + ) + else: + val_loss = val_loss_fn( + pred, targets, inputs, spatial_mask=val_batch_mask + ) + + # Aggregate validation loss across GPUs + if dist.world_size > 1: + val_loss_tensor = val_loss.detach().clone() + torch.distributed.all_reduce( + val_loss_tensor, op=torch.distributed.ReduceOp.SUM + ) + val_loss_tensor = val_loss_tensor / dist.world_size + total_val_loss += val_loss_tensor + else: + total_val_loss += val_loss.detach() + + # Calculate validation metric on rank 0 only (for logging) + if dist.rank == 0: + pred_cpu = pred.cpu().numpy() + targets_cpu = targets.cpu().numpy() + inputs_cpu = inputs.cpu().numpy() + + # Optional denormalization (config-driven) + denorm_name = cfg.data.get("denormalize_fn", None) + if denorm_name and denorm_name in _DENORM_REGISTRY: + denorm_fn = _DENORM_REGISTRY[denorm_name] + pred_denorm = denorm_fn(pred_cpu) + targets_denorm = denorm_fn(targets_cpu) + else: + pred_denorm = pred_cpu + targets_denorm = targets_cpu + + # Resolve metric function from config + val_metric_choice = cfg.data.get("val_metric", "rmse") + if val_metric_choice not in _METRIC_REGISTRY: + raise ValueError( + f"Unknown val_metric '{val_metric_choice}'. " + f"Choices: {list(_METRIC_REGISTRY.keys())}" + ) + _, _, metric_fn = _METRIC_REGISTRY[val_metric_choice] + + for i in range(pred_denorm.shape[0]): + if mask_channel is not None: + mask_i = inputs_cpu[i, ..., 0, mask_channel] != 0 + y_pred = pred_denorm[i][mask_i] + y_true = targets_denorm[i][mask_i] + else: + y_pred = pred_denorm[i].ravel() + y_true = targets_denorm[i].ravel() + + mre_list.append(metric_fn(y_pred, y_true)) + + avg_val_loss = total_val_loss / len(val_loader) + avg_metric = np.mean(mre_list) if len(mre_list) > 0 else 0.0 + + # Metric display name and logging key from config + val_metric_choice = cfg.data.get("val_metric", "rmse") + metric_name, metric_key, _ = _METRIC_REGISTRY[val_metric_choice] + + is_ratio_metric = val_metric_choice in ("mre", "mpe", "relative_l2") + + # Print validation metrics (only on rank 0) + if dist.rank == 0: + if is_ratio_metric: + logger.info( + f"Epoch {epoch}: Val Loss = {avg_val_loss:.6f} | Val {metric_name} = {avg_metric:.6f} ({avg_metric * 100:.2f}%)" + ) + else: + logger.info( + f"Epoch {epoch}: Val Loss = {avg_val_loss:.6f} | Val {metric_name} = {avg_metric:.6f}" + ) + + # Log to MLFlow (only on rank 0) + if cfg.logging.use_mlflow and dist.rank == 0: + mlflow.log_metric("val_loss", float(avg_val_loss), step=epoch) + mlflow.log_metric(metric_key, float(avg_metric), step=epoch) + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + best_val_mre = avg_metric + + # Print and log best validation loss (only on rank 0) + if dist.rank == 0: + if is_ratio_metric: + logger.success( + f"New best validation: Loss = {best_val_loss:.6f} | {metric_name} = {best_val_mre:.6f} ({best_val_mre * 100:.2f}%)" + ) + else: + logger.success( + f"New best validation: Loss = {best_val_loss:.6f} | {metric_name} = {best_val_mre:.6f}" + ) + + # Log best loss to MLFlow (only on rank 0) + if cfg.logging.use_mlflow and dist.rank == 0: + mlflow.log_metric( + "best_val_loss", float(best_val_loss), step=epoch + ) + mlflow.log_metric( + f"best_{metric_key}", float(best_val_mre), step=epoch + ) + + # Save best model (only on rank 0 to avoid race condition) + if dist.rank == 0: + best_model_path = ( + checkpoint_dir + / f"best_model_{cfg.data.variable}_{model_arch_name}.pth" + ) + # Prepare model config to save with checkpoint + model_config = { + "dimensions": dimensions, + "model_type": model_type, + "model_arch_name": model_arch_name, + "variable": cfg.data.variable, + "in_channels": in_channels, + "feedback_channel": 1 if _ar_feedback_init else None, + } + + if model_type == "xfno": + xfno_cfg = cfg.arch.xfno + model_config.update( + { + "out_channels": xfno_cfg.out_channels, + "width": xfno_cfg.width, + "modes1": xfno_cfg.modes1, + "modes2": xfno_cfg.modes2, + "modes3": xfno_cfg.modes3, + "num_fno_layers": xfno_cfg.num_fno_layers, + "padding": xfno_cfg.padding, + "activation_fn": xfno_cfg.activation_fn, + "decoder_layers": xfno_cfg.decoder_layers, + "decoder_width": xfno_cfg.decoder_width, + } + ) + if dimensions == "4d": + model_config.update( + { + "modes4": xfno_cfg.modes4, + "coord_features": xfno_cfg.coord_features, + "lifting_layers": xfno_cfg.lifting_layers, + } + ) + else: + model_config.update( + { + "num_unet_layers": xfno_cfg.num_unet_layers, + "num_conv_layers": xfno_cfg.num_conv_layers, + "unet_type": xfno_cfg.unet_type, + "lifting_type": xfno_cfg.lifting_type, + "lifting_layers": xfno_cfg.lifting_layers, + "lifting_width": xfno_cfg.lifting_width, + "decoder_type": xfno_cfg.decoder_type, + } + ) + + elif model_type == "xdeeponet": + xdeeponet_cfg = cfg.arch.xdeeponet + model_config.update( + { + "variant": xdeeponet_cfg.variant, + "width": xdeeponet_cfg.width, + "padding": xdeeponet_cfg.padding, + "branch1_config": dict(xdeeponet_cfg.branch1), + "trunk_config": dict(xdeeponet_cfg.trunk), + "decoder_type": xdeeponet_cfg.get( + "decoder_type", "mlp" + ), + "decoder_width": xdeeponet_cfg.decoder_width, + "decoder_layers": xdeeponet_cfg.decoder_layers, + "decoder_activation_fn": xdeeponet_cfg.get( + "decoder_activation_fn", "relu" + ), + } + ) + if xdeeponet_cfg.variant in [ + "mionet", + "fourier_mionet", + "tno", + ]: + model_config["branch2_config"] = dict( + xdeeponet_cfg.branch2 + ) + if ( + xdeeponet_cfg.get("decoder_type", "mlp") + == "temporal_projection" + ): + model_config["output_window"] = ( + cfg.training.autoregressive.output_window + ) + + save_checkpoint( + path=best_model_path, + model=model, + epoch=epoch, + val_loss=best_val_loss, + metric_key=metric_key, + metric_value=best_val_mre, + model_config=model_config, + ) + + # Log model to MLFlow + if cfg.logging.use_mlflow: + mlflow.log_artifact(str(best_model_path)) + + # Learning rate scheduling (StepLR steps automatically every step_size epochs) + scheduler.step() + + # Resolve metric metadata once for final summary / MLflow + val_metric_choice = cfg.data.get("val_metric", "rmse") + metric_name, metric_key, _ = _METRIC_REGISTRY[val_metric_choice] + is_ratio_metric = val_metric_choice in ("mre", "mpe", "relative_l2") + + # Print training completion (only on rank 0) + if dist.rank == 0: + logger.success("Training completed!") + if is_ratio_metric: + logger.info( + f"Best validation: Loss = {best_val_loss:.6f} | {metric_name} = {best_val_mre:.6f} ({best_val_mre * 100:.2f}%)" + ) + else: + logger.info( + f"Best validation: Loss = {best_val_loss:.6f} | {metric_name} = {best_val_mre:.6f}" + ) + + # End MLFlow run (only on rank 0) + if cfg.logging.use_mlflow and dist.rank == 0: + mlflow.log_metric("final_best_val_loss", float(best_val_loss)) + mlflow.log_metric(f"final_best_{metric_key}", float(best_val_mre)) + mlflow.end_run() + logger.info("MLFlow run completed") + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/neural_operator_factory/utils/__init__.py b/examples/reservoir_simulation/neural_operator_factory/utils/__init__.py new file mode 100644 index 0000000000..2cf5701e39 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/utils/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions: padding (generic) and CO2-specific normalization/visualization. +""" + +from utils.co2_normalization import ( + denormalize_inputs, + dnorm_dP, + dnorm_inj, + dnorm_lam, + dnorm_P, + dnorm_Swi, + dnorm_temp, + extract_reservoir_mask, +) +from utils.co2_visualization import ( + create_pcolor_func, + get_time_labels, + plot_4x3_comparison, + setup_plotting_grid, +) +from utils.padding import ( + compute_right_pad_to_multiple, + compute_right_pad_to_multiple_per_dim, + pad_right_nd, + pad_spatial_right, +) + +__all__ = [ + # CO2-specific normalization + "dnorm_dP", + "dnorm_inj", + "dnorm_temp", + "dnorm_P", + "dnorm_lam", + "dnorm_Swi", + "extract_reservoir_mask", + "denormalize_inputs", + # CO2-specific visualization + "setup_plotting_grid", + "get_time_labels", + "create_pcolor_func", + "plot_4x3_comparison", + # Padding (generic) + "compute_right_pad_to_multiple", + "compute_right_pad_to_multiple_per_dim", + "pad_right_nd", + "pad_spatial_right", +] diff --git a/examples/reservoir_simulation/neural_operator_factory/utils/checkpoint.py b/examples/reservoir_simulation/neural_operator_factory/utils/checkpoint.py new file mode 100644 index 0000000000..478ef802a2 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/utils/checkpoint.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities: model reconstruction and save/load helpers.""" + +import torch +from models.xdeeponet import DeepONet3DWrapper, DeepONetWrapper +from models.xfno import FNO4DNet, UFNONet + + +def build_model_from_config(model_config: dict, device="cpu"): + """Reconstruct a model from a saved model_config dict. + + This is the single source of truth for model reconstruction, used by + both checkpoint resume and evaluation scripts. + + Parameters + ---------- + model_config : dict + The ``model_config`` dict saved inside a checkpoint. Must contain + at least ``model_type`` and ``dimensions``. + device : str or torch.device + Device to place the model on. + + Returns + ------- + tuple of (model, model_arch_name) + """ + model_type = model_config["model_type"] + dimensions = model_config.get("dimensions", "4d") + + if model_type == "xdeeponet": + variant = model_config.get("variant", "u_deeponet") + cls = DeepONet3DWrapper if dimensions == "4d" else DeepONetWrapper + model = cls( + padding=model_config.get("padding", 8), + variant=variant, + width=model_config.get("width", 128), + branch1_config=model_config.get("branch1_config", {}), + branch2_config=model_config.get("branch2_config"), + trunk_config=model_config.get("trunk_config", {}), + decoder_type=model_config.get("decoder_type", "mlp"), + decoder_width=model_config.get("decoder_width", 128), + decoder_layers=model_config.get("decoder_layers", 2), + decoder_activation_fn=model_config.get("decoder_activation_fn", "relu"), + ) + if model_config.get("decoder_type") == "temporal_projection": + K = model_config.get("output_window", 3) + model.set_output_window(K) + + b1_enc = model_config.get("branch1_config", {}).get("encoder", "spatial") + encoder = b1_enc.get("type", "linear") if isinstance(b1_enc, dict) else b1_enc + model_arch_name = model_config.get( + "model_arch_name", + f"deeponet{'3d' if dimensions == '4d' else ''}_{variant}_{encoder}", + ) + + elif model_type == "xfno": + if dimensions == "4d": + model = FNO4DNet( + modes1=model_config["modes1"], + modes2=model_config["modes2"], + modes3=model_config["modes3"], + modes4=model_config.get("modes4", 6), + width=model_config["width"], + in_channels=model_config["in_channels"], + out_channels=model_config.get("out_channels", 1), + num_fno_layers=model_config["num_fno_layers"], + padding=model_config.get("padding", 8), + activation_fn=model_config.get("activation_fn", "gelu"), + lifting_layers=model_config.get("lifting_layers", 1), + decoder_layers=model_config.get("decoder_layers", 1), + decoder_width=model_config.get("decoder_width", 128), + coord_features=model_config.get("coord_features", True), + ) + model_arch_name = model_config.get("model_arch_name", "fno4d") + else: + model = UFNONet( + modes1=model_config["modes1"], + modes2=model_config["modes2"], + modes3=model_config["modes3"], + width=model_config["width"], + in_channels=model_config["in_channels"], + out_channels=model_config.get("out_channels", 1), + num_fno_layers=model_config["num_fno_layers"], + num_unet_layers=model_config.get("num_unet_layers", 0), + num_conv_layers=model_config.get("num_conv_layers", 0), + padding=model_config.get("padding", 8), + unet_type=model_config.get("unet_type", "custom"), + activation_fn=model_config.get("activation_fn", "relu"), + lifting_type=model_config.get("lifting_type", "mlp"), + lifting_layers=model_config.get("lifting_layers", 1), + lifting_width=model_config.get("lifting_width", 36), + decoder_type=model_config.get("decoder_type", "mlp"), + decoder_layers=model_config.get("decoder_layers", 1), + decoder_width=model_config.get("decoder_width", 128), + ) + model_arch_name = model_config.get("model_arch_name", "ufno") + else: + raise ValueError(f"Unknown model_type in checkpoint: {model_type}") + + return model.to(device), model_arch_name + + +def save_checkpoint( + path, + model, + epoch: int, + val_loss: float, + metric_key: str, + metric_value: float, + model_config: dict, + optimizer=None, + scheduler=None, +): + """Save a training checkpoint with all state needed for resume.""" + from torch.nn.parallel import DistributedDataParallel as DDP + + model_to_save = model.module if isinstance(model, DDP) else model + ckpt = { + "epoch": epoch, + "model_state_dict": model_to_save.state_dict(), + "val_loss": val_loss, + metric_key: metric_value, + "model_config": model_config, + } + if optimizer is not None: + ckpt["optimizer_state_dict"] = optimizer.state_dict() + if scheduler is not None: + ckpt["scheduler_state_dict"] = scheduler.state_dict() + torch.save(ckpt, path) + + +def load_checkpoint(path, device="cpu"): + """Load a checkpoint and return the dict.""" + return torch.load(path, map_location=device, weights_only=False) diff --git a/examples/reservoir_simulation/neural_operator_factory/utils/co2_normalization.py b/examples/reservoir_simulation/neural_operator_factory/utils/co2_normalization.py new file mode 100644 index 0000000000..4acc386470 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/utils/co2_normalization.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2-specific denormalization and utility functions. + +This module provides denormalization functions specific to the CO2 +sequestration dataset. For other datasets, these are not needed. +""" + +import numpy as np + +# ============================================================================== +# Denormalization Functions +# ============================================================================== + + +def dnorm_dP(dP): + """ + Denormalize pressure change (dP) from normalized to physical units (bar). + + Args: + dP: Normalized pressure change values + + Returns: + Denormalized pressure change in bar + """ + dP = dP * 18.772821433027488 + dP = dP + 4.172939172019009 + return dP + + +def dnorm_inj(a): + """ + Denormalize injection rate to physical units (MT/yr). + + Args: + a: Normalized injection rate + + Returns: + Injection rate in MT/yr (megatons per year) + """ + return (a * (3e6 - 3e5) + 3e5) / (1e6 / 365 * 1000 / 1.862) + + +def dnorm_temp(a): + """ + Denormalize temperature to physical units (°C). + + Args: + a: Normalized temperature (0-1) + + Returns: + Temperature in °C (30-180°C range) + """ + return a * (180 - 30) + 30 + + +def dnorm_P(a): + """ + Denormalize initial pressure to physical units (bar). + + Args: + a: Normalized pressure (0-1) + + Returns: + Initial pressure in bar (100-300 bar range) + """ + return a * (300 - 100) + 100 + + +def dnorm_lam(a): + """ + Denormalize lambda parameter to physical units (-). + + Args: + a: Normalized lambda (0-1) + + Returns: + Lambda parameter (0.3-0.7 range) + """ + return a * 0.4 + 0.3 + + +def dnorm_Swi(a): + """ + Denormalize initial water saturation (Swi) to physical units (-). + + Args: + a: Normalized Swi (0-1) + + Returns: + Initial water saturation (0.1-0.3 range) + """ + return a * 0.2 + 0.1 + + +# ============================================================================== +# Helper Functions +# ============================================================================== + + +def extract_reservoir_mask(x_plot): + """ + Extract the reservoir mask from input data. + + The permeability map (channel 0) indicates the active reservoir region. + + Args: + x_plot: Input array with shape (..., H, W, T, C) + + Returns: + tuple: (mask, thickness) where mask is boolean array and thickness is int + """ + mask = x_plot[0, :, :, 0, 0] != 0 + thickness = int(np.sum(mask[:, 0])) + return mask, thickness + + +def denormalize_inputs(x_plot): + """ + Denormalize all input parameters from a sample. + + Args: + x_plot: Input array with shape (1, H, W, T, C) + + Returns: + dict: Dictionary with denormalized parameters + """ + params = { + "injection_rate": dnorm_inj(x_plot[0, 0, 0, 0, 4]), + "temperature": dnorm_temp(x_plot[0, 0, 0, 0, 6]), + "initial_pressure": dnorm_P(x_plot[0, 0, 0, 0, 5]), + "Swi": dnorm_Swi(x_plot[0, 0, 0, 0, 7]), + "lambda": dnorm_lam(x_plot[0, 0, 0, 0, 8]), + } + return params diff --git a/examples/reservoir_simulation/neural_operator_factory/utils/co2_visualization.py b/examples/reservoir_simulation/neural_operator_factory/utils/co2_visualization.py new file mode 100644 index 0000000000..41cd7d0f4e --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/utils/co2_visualization.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CO2-specific visualization utilities. + +This module provides plotting functions and grid setup specific to the +CO2 sequestration dataset (2D cross-section visualizations with +CO2-specific grid spacing and labels). +""" + +import matplotlib.pyplot as plt +import numpy as np + + +def setup_plotting_grid(): + """ + Setup the spatial grid for plotting reservoir data. + + Returns: + tuple: (X, Y, dx) meshgrid arrays for plotting + """ + dx = np.cumsum(3.5938 * np.power(1.035012, range(200))) + 0.1 + X, Y = np.meshgrid(dx, np.linspace(0, 200, num=96)) + return X, Y, dx + + +def get_time_labels(): + """ + Generate time labels for the 24 timesteps. + + Returns: + list: Time labels as strings (e.g., "10 d", "2.5 y") + """ + times = np.cumsum(np.power(1.421245, range(24))) + time_print = [] + + for t in range(times.shape[0]): + if times[t] < 365: + title = str(int(times[t])) + " d" + else: + title = f"{round(int(times[t]) / 365, 1)} y" + time_print.append(title) + + return time_print + + +def create_pcolor_func(X, Y, thickness): + """ + Create a pcolor plotting function for reservoir data. + + Args: + X: X meshgrid + Y: Y meshgrid + thickness: Number of vertical cells in reservoir + + Returns: + function: Plotting function that takes 2D array + """ + + def pcolor(x): + """Plot a 2D pseudocolor map on the reservoir grid.""" + plt.jet() + return plt.pcolor( + X[:thickness, :], Y[:thickness, :], np.flipud(x), shading="auto" + ) + + return pcolor + + +def plot_4x3_comparison( + x_plot, + y_plot, + pred_plot, + mask, + thickness, + variable="pressure", + timesteps=[14, 20, 23], +): + """ + Create a 4x3 comparison plot showing inputs, ground truth, predictions, and errors. + + Args: + x_plot: Input array + y_plot: Ground truth array + pred_plot: Prediction array + mask: Boolean mask for reservoir region + thickness: Number of vertical cells + variable: 'pressure' or 'saturation' + timesteps: List of 3 timesteps to visualize + + Returns: + matplotlib.figure.Figure: The created figure + """ + # Setup grid + X, Y, dx = setup_plotting_grid() + time_print = get_time_labels() + pcolor = create_pcolor_func(X, Y, thickness) + + # Extract input maps + poro_map = x_plot[0, :, :, 0, 2][mask].reshape((thickness, -1)) + kr_map = np.exp(x_plot[0, :, :, 0, 0][mask].reshape((thickness, -1)) * 15) + kz_map = np.exp(x_plot[0, :, :, 0, 1][mask].reshape((thickness, -1)) * 15) + + # Create figure + fig = plt.figure(figsize=(15, 16)) + + # Set labels based on variable type + if variable == "pressure": + pred_label = "$\hat{dP}$ (bar)" + true_label = "$dP$ (bar)" + error_label = "|$dP-\hat{dP}$|" + else: # saturation + pred_label = "$\hat{S}_g$ (-)" + true_label = "$S_g$ (-)" + error_label = "|$S_g-\hat{S}_g$|" + + for j, t in enumerate(timesteps): + # Row 1: Input parameters + plt.subplot(4, 3, j + 1) + if j == 2: + pcolor(poro_map) + plt.title("$\phi$ (-)") + elif j == 1: + pcolor(kz_map) + plt.title("$k_z$ (mD)") + else: + pcolor(kr_map) + plt.title("$k_r$ (mD)") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + # Row 2: Ground truth + plt.subplot(4, 3, j + 4) + pcolor(y_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"{true_label}, t={time_print[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + # Row 3: Prediction + plt.subplot(4, 3, j + 7) + pcolor(pred_plot[:, :, t][mask].reshape((thickness, -1))) + plt.title(f"{pred_label}, t={time_print[t]}") + plt.colorbar(fraction=0.02) + plt.xlim([0, 3500]) + + # Row 4: Error + plt.subplot(4, 3, j + 10) + error = pred_plot[:, :, t][mask].reshape((thickness, -1)) - y_plot[:, :, t][ + mask + ].reshape((thickness, -1)) + pcolor(error) + plt.colorbar(fraction=0.02) + plt.title(f"{error_label}, t={time_print[t]}") + plt.xlim([0, 3500]) + + plt.tight_layout() + + return fig diff --git a/examples/reservoir_simulation/neural_operator_factory/utils/padding.py b/examples/reservoir_simulation/neural_operator_factory/utils/padding.py new file mode 100644 index 0000000000..10ae808901 --- /dev/null +++ b/examples/reservoir_simulation/neural_operator_factory/utils/padding.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Padding utilities for reservoir simulation tensors. + +Design goals for this repo: +- Provide **right-side padding** helpers that are: + - **shape-safe** across 3D datasets (2D space + time) and 4D datasets (3D space + time) + - compatible with PyTorch limitations around non-constant padding modes + - explicit about what dimensions are being padded + +Tensor layouts used in this codebase: +- 3D dataset samples (CO2-style): (B, H, W, T, C) +- 4D dataset samples (Norne-style): (B, X, Y, Z, T, C) + +Important subtlety (xFNO vs xDeepONet): +- For xDeepONet, padding must be **spatial-only** because time is handled by the trunk/query. +- For xFNO, the operator is learned over the full domain, so the convolved dimensions + include time (e.g., SpectralConv3d over H,W,T; SpectralConv4d over X,Y,Z,T). Therefore, + "spatial_ndim" in xFNO wrappers may include the time dimension by design. +""" + +from __future__ import annotations + +from typing import Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def compute_right_pad_to_multiple( + spatial_shape: Sequence[int], + *, + multiple: int = 8, + min_right_pad: int = 0, +) -> Tuple[int, ...]: + """Compute right-side padding to reach a multiple (optionally with a minimum).""" + if multiple <= 0: + raise ValueError(f"multiple must be > 0, got {multiple}") + if min_right_pad < 0: + raise ValueError(f"min_right_pad must be >= 0, got {min_right_pad}") + + pads = [] + for d in spatial_shape: + if d <= 0: + raise ValueError( + f"spatial dimensions must be positive, got {spatial_shape}" + ) + to_mult = (multiple - (d % multiple)) % multiple + # Guarantee: + # - (d + pad) is divisible by `multiple` + # - pad >= min_right_pad + if to_mult >= min_right_pad: + pad = to_mult + else: + # Increase by whole multiples so the final size stays aligned. + deficit = min_right_pad - to_mult + k = (deficit + multiple - 1) // multiple + pad = to_mult + k * multiple + pads.append(int(pad)) + return tuple(pads) + + +def compute_right_pad_to_multiple_per_dim( + spatial_shape: Sequence[int], + *, + multiple: int = 8, + min_right_pad: Union[int, Sequence[int]] = 0, +) -> Tuple[int, ...]: + """Like `compute_right_pad_to_multiple`, but supports per-dimension minimum padding.""" + if isinstance(min_right_pad, int): + mins = [min_right_pad] * len(spatial_shape) + else: + mins = list(min_right_pad) + if len(mins) != len(spatial_shape): + raise ValueError( + f"min_right_pad length must match spatial_shape length " + f"({len(mins)} vs {len(spatial_shape)})" + ) + return tuple( + compute_right_pad_to_multiple((d,), multiple=multiple, min_right_pad=m)[0] + for d, m in zip(spatial_shape, mins) + ) + + +def pad_right_nd( + x: Tensor, + *, + dims: Sequence[int], + right_pad: Sequence[int], + mode: str = "replicate", + constant_value: float = 0.0, +) -> Tensor: + """Right-pad arbitrary dims for tensors of any rank. + + This is implemented manually so it works for `mode="replicate"` even when + PyTorch's `F.pad` doesn't support the tensor rank (e.g. 6D+ tensors). + """ + if len(dims) != len(right_pad): + raise ValueError("dims and right_pad must have the same length") + if not dims: + return x + + for dim, pad in zip(dims, right_pad): + pad = int(pad) + if pad <= 0: + continue + if dim < 0: + dim = x.dim() + dim + if dim < 0 or dim >= x.dim(): + raise ValueError(f"invalid dim {dim} for x.dim()={x.dim()}") + + if mode == "constant": + pad_shape = list(x.shape) + pad_shape[dim] = pad + pad_tensor = torch.full( + pad_shape, float(constant_value), dtype=x.dtype, device=x.device + ) + x = torch.cat([x, pad_tensor], dim=dim) + continue + + if mode != "replicate": + raise ValueError( + f"pad_right_nd currently supports mode='replicate' or 'constant', got {mode}" + ) + + last = x.select(dim, x.size(dim) - 1).unsqueeze(dim) # singleton at dim + expand_shape = list(x.shape) + expand_shape[dim] = pad + pad_tensor = last.expand(*expand_shape) + x = torch.cat([x, pad_tensor], dim=dim) + + return x + + +def pad_spatial_right( + x: Tensor, + *, + spatial_ndim: int, + right_pad: Sequence[int], + mode: str = "replicate", + constant_value: float = 0.0, +) -> Tensor: + """Pad only the first `spatial_ndim` dims after batch on the right. + + Assumes `x` is shaped: + (B, *spatial, *rest) + """ + if spatial_ndim not in (2, 3, 4): + raise ValueError(f"spatial_ndim must be 2, 3, or 4, got {spatial_ndim}") + if len(right_pad) != spatial_ndim: + raise ValueError( + f"right_pad must have length {spatial_ndim}, got {len(right_pad)}" + ) + if x.dim() < 1 + spatial_ndim: + raise ValueError( + f"expected x.dim() >= {1 + spatial_ndim}, got x.dim()={x.dim()}" + ) + if all(int(p) == 0 for p in right_pad): + return x + + # For 4 spatial dims, fall back to generic implementation (works for 6D tensors). + if spatial_ndim == 4: + dims = [1, 2, 3, 4] + return pad_right_nd( + x, dims=dims, right_pad=right_pad, mode=mode, constant_value=constant_value + ) + + # For 2D/3D spatial, use a reshape trick so we can call F.pad with replicate. + b = x.shape[0] + spatial_shape = x.shape[1 : 1 + spatial_ndim] + rest_shape = x.shape[1 + spatial_ndim :] + rest_prod = ( + 1 if len(rest_shape) == 0 else int(torch.tensor(rest_shape).prod().item()) + ) + + # (B, *spatial, *rest) -> (B, rest_prod, *spatial) + x_reshaped = x.reshape(b, *spatial_shape, rest_prod).permute( + 0, spatial_ndim + 1, *range(1, 1 + spatial_ndim) + ) + + if spatial_ndim == 2: + pad_h, pad_w = (int(p) for p in right_pad) + pad = (0, pad_w, 0, pad_h) + else: + pad_x, pad_y, pad_z = (int(p) for p in right_pad) + pad = (0, pad_z, 0, pad_y, 0, pad_x) + + if mode == "constant": + x_padded = F.pad(x_reshaped, pad, mode="constant", value=float(constant_value)) + else: + x_padded = F.pad(x_reshaped, pad, mode=mode) + + padded_spatial = x_padded.shape[2 : 2 + spatial_ndim] + return x_padded.permute(0, *range(2, 2 + spatial_ndim), 1).reshape( + b, *padded_spatial, *rest_shape + ) diff --git a/pyproject.toml b/pyproject.toml index 5c78e4f283..c8c86215a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -306,6 +306,7 @@ exclude = ["docs", "physicsnemo/experimental"] # Ignore `S101` (assertions) in all `test` files. "test/*.py" = ["S101"] +"examples/**/tests/*.py" = ["S101"]