Skip to content

GeoTransolver Guard 🛡️ #1544

Draft
mnabian wants to merge 5 commits intomainfrom
GeoT_Guard
Draft

GeoTransolver Guard 🛡️ #1544
mnabian wants to merge 5 commits intomainfrom
GeoT_Guard

Conversation

@mnabian
Copy link
Copy Markdown
Collaborator

@mnabian mnabian commented Apr 1, 2026

PhysicsNeMo Pull Request

Description

Problem

When users apply a pretrained GeoTransolver checkpoint to inputs outside the training distribution (e.g., running a DrivAerML-trained model on motorcycles or aircraft), the model silently produces unreliable predictions. There is no mechanism to detect or warn about out-of-distribution (OOD) inputs at inference time.

Solution

We add two lightweight OOD guards that integrate seamlessly into the existing training and inference workflow. Both guards are controlled by a single knob (guard_buffer_size) and require no additional scripts, calibration steps, or changes to the training loop.

  • During training: guards passively collect calibration data (zero impact on model predictions or gradients).
  • During inference: guards check inputs against the training distribution and emit warnings.warn() if OOD is detected.
  • At checkpoint save: the kNN threshold is automatically computed from collected data.

Guard 1: Global Parameters — Bounding Box

What it monitors: The raw global_embedding input tensor (e.g., air density, stream velocity for DrivAerML).

How it works:

  • During training, tracks per-dimension running min/max across all batches.
  • At inference, compares each dimension of the input against the stored bounds.
  • Emits a warning per dimension that falls outside the training range.

Why input space: Global parameters are low-dimensional scalars (2-3 dims). A bounding box is simpler, more interpretable, and more reliable than latent-space methods at this dimensionality.

Example warning:

OOD Guard: global_embedding dim 0 value 0.7500 below training min 0.9000

Guard 2: Geometry Context — kNN Distance

The implementation is based on this paper: https://arxiv.org/pdf/2204.06507

What it monitors: The geometry context vector produced by GlobalContextBuilder.geometry_tokenizer -- a learned 32-dimensional representation of the input geometry, mean-pooled over attention heads and slices.

How it works:

  • During training, accumulates pooled geometry embeddings into a fixed-size FIFO rolling buffer.
  • At checkpoint save, computes a kNN distance threshold from the buffer (99th percentile of leave-one-out k-th nearest neighbor distances).
  • At inference, L2-normalizes the input embedding, computes its k-th nearest neighbor distance to the stored training embeddings, and warns if above threshold.

Why latent space: Geometry is a variable-size point cloud -- there is no fixed-dimensional input representation to bound. The post-ContextProjector embedding compresses geometry into a fixed 32-dim vector suitable for distance-based methods.

Why kNN:

  • Training set is small (~400 samples for DrivAerML). Covariance estimation (Mahalanobis) is unreliable.
  • No distributional assumption -- geometries may cluster multimodally.
  • No additional trainable components needed.

Why not monitor multi-scale local features:

  • Local features are derived from geometry via ball queries -- not an independent signal.
  • They dominate the context dimension (768/832 = 92%), drowning out other signals in a combined detector.
  • 768 dims with ~400 samples is a poor regime for kNN (curse of dimensionality).

Example warning:

image

Usage

Enabling guards

Add to model config (or pass to constructor):

guard_buffer_size: 500  # set to dataset size; null to disable
guard_knn_k: 10         # k for geometry kNN (optional, default 10)

Both guards are enabled when guard_buffer_size is set, and disabled when it is None.

Training

No changes to the training script. Guards collect data automatically during model.train() forward passes. The kNN threshold is computed automatically when the checkpoint is saved (via state_dict() override).

Inference

No changes to the inference script. Guards run automatically during model.eval() forward passes and emit Python warnings for OOD inputs.

Configuration

Parameter Default Description
guard_buffer_size None (disabled) FIFO buffer size for geometry embeddings. Set to dataset size.
guard_knn_k 10 k for k-th nearest neighbor distance. Range 5-15 recommended.

Threshold

The kNN threshold is set at the 99th percentile of training-set leave-one-out kNN distances. This means ~1% false alarm rate on in-distribution data -- near-zero warnings on validation/test sets that are in-distribution.

Multi-GPU / DDP

  • Each rank maintains its own buffer independently -- no cross-rank communication.
  • The distributed sampler shuffles data each epoch, so after a few epochs each rank's FIFO buffer covers most of the training set.
  • Rank 0 saves the checkpoint; its buffer and threshold are what persist.
  • Recommendation: set guard_buffer_size >= dataset_size to ensure good coverage per rank.

Checkpoint Compatibility

  • Pre-guard checkpoint loaded into guard-enabled model: Guard buffers retain their initial values (inf/-inf). Guards remain inactive until training populates the buffers.
  • Guard checkpoint loaded into non-guard model: Extra guard_* keys in the state dict. Requires strict=False when loading.

Tests

This is currently tested for the Crash recipe on the bumper beam dataset. No OOD warnings for inference on test samples. For new OOD samples which have either OOD global parameters or a scaled geometry by a small factor (1.05), the OOD warning is raised.

Implementation

Registered buffers

All guard state is stored as registered buffers (persistent, non-trainable):

Buffer Shape Description
guard_global_min (global_dim,) Per-dimension training min
guard_global_max (global_dim,) Per-dimension training max
guard_geo_embeddings (buffer_size, dim_head) FIFO geometry embedding store
guard_geo_ptr (1,) FIFO write pointer
guard_knn_threshold scalar 99th percentile kNN distance
guard_knn_k scalar k for kNN

Key methods

Method When called Purpose
_guard_collect() Every training forward pass Update global bounds, append geometry embeddings to FIFO
_guard_check() Every inference forward pass Check bounds and kNN distance, emit warnings
compute_guard_threshold() At state_dict() call (before save) Compute 99th percentile kNN threshold from buffer

Checklist

Dependencies

None

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 1, 2026

Greptile Summary

This PR adds an OOD (out-of-distribution) guard to GeoTransolver that tracks global-parameter bounding boxes and geometry kNN distances during training, then emits warnings during inference. It also widens VTP file discovery in the crash inference script from flat os.scandir to recursive os.walk.

  • P1 – training crash: state_dict() unconditionally calls compute_guard_threshold(), which calls dists.topk(k, ...) on an (n_valid × n_valid) matrix. If checkpointing happens before k geometry batches have been processed (n_valid < k), PyTorch raises RuntimeError: k is too big for dimension size and crashes training.
  • P1 – hardcoded paths in example YAML: bumper_geotransolver_oneshot.yaml replaces the ??? required-field markers with /code/datasets/bumper_beam/..., breaking the example for any user without data at that exact location.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/geotransolver/geotransolver.py Adds OOD guard feature with global-param bounding-box and geometry kNN checks; contains a crash in compute_guard_threshold/state_dict when fewer than k geometry samples have been collected, unused import, and a stale zero-buffer issue in _guard_check.
physicsnemo/experimental/models/geotransolver/context_projector.py Caches the geometry context tensor in _last_geometry_context for the OOD guard to consume; attribute is not initialized in init, leaving it implicitly dynamic.
examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml Replaces required Hydra ??? markers with hardcoded developer-local paths (/code/datasets/bumper_beam/...) that will fail for all other users; also adds new guard params.
examples/cfd/external_aerodynamics/transformer_models/src/conf/model/geotransolver.yaml Adds guard_buffer_size and guard_knn_k config entries with sensible defaults.
examples/structural_mechanics/crash/inference.py Switches from os.scandir (flat) to os.walk (recursive) for VTP file discovery — clean, correct change.

Reviews (1): Last reviewed commit: "add geotransolver guard" | Re-trigger Greptile

Comment thread physicsnemo/experimental/models/geotransolver/geotransolver.py Outdated
Comment thread examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml Outdated
Comment thread physicsnemo/experimental/models/geotransolver/geotransolver.py Outdated
Comment thread physicsnemo/experimental/models/geotransolver/geotransolver.py Outdated
Comment thread physicsnemo/experimental/models/geotransolver/context_projector.py
@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 1, 2026

Note: Tests will be added after the initial review is done for the overall design and implementation of this guardrail.

@mnabian mnabian marked this pull request as draft April 7, 2026 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant