Draft
Conversation
Contributor
Greptile SummaryThis PR adds an OOD (out-of-distribution) guard to
Important Files Changed
Reviews (1): Last reviewed commit: "add geotransolver guard" | Re-trigger Greptile |
Collaborator
Author
|
Note: Tests will be added after the initial review is done for the overall design and implementation of this guardrail. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PhysicsNeMo Pull Request
Description
Problem
When users apply a pretrained GeoTransolver checkpoint to inputs outside the training distribution (e.g., running a DrivAerML-trained model on motorcycles or aircraft), the model silently produces unreliable predictions. There is no mechanism to detect or warn about out-of-distribution (OOD) inputs at inference time.
Solution
We add two lightweight OOD guards that integrate seamlessly into the existing training and inference workflow. Both guards are controlled by a single knob (
guard_buffer_size) and require no additional scripts, calibration steps, or changes to the training loop.warnings.warn()if OOD is detected.Guard 1: Global Parameters — Bounding Box
What it monitors: The raw
global_embeddinginput tensor (e.g., air density, stream velocity for DrivAerML).How it works:
Why input space: Global parameters are low-dimensional scalars (2-3 dims). A bounding box is simpler, more interpretable, and more reliable than latent-space methods at this dimensionality.
Example warning:
Guard 2: Geometry Context — kNN Distance
The implementation is based on this paper: https://arxiv.org/pdf/2204.06507
What it monitors: The geometry context vector produced by
GlobalContextBuilder.geometry_tokenizer-- a learned 32-dimensional representation of the input geometry, mean-pooled over attention heads and slices.How it works:
Why latent space: Geometry is a variable-size point cloud -- there is no fixed-dimensional input representation to bound. The post-ContextProjector embedding compresses geometry into a fixed 32-dim vector suitable for distance-based methods.
Why kNN:
Why not monitor multi-scale local features:
Example warning:
Usage
Enabling guards
Add to model config (or pass to constructor):
Both guards are enabled when
guard_buffer_sizeis set, and disabled when it isNone.Training
No changes to the training script. Guards collect data automatically during
model.train()forward passes. The kNN threshold is computed automatically when the checkpoint is saved (viastate_dict()override).Inference
No changes to the inference script. Guards run automatically during
model.eval()forward passes and emit Python warnings for OOD inputs.Configuration
guard_buffer_sizeNone(disabled)guard_knn_k10Threshold
The kNN threshold is set at the 99th percentile of training-set leave-one-out kNN distances. This means ~1% false alarm rate on in-distribution data -- near-zero warnings on validation/test sets that are in-distribution.
Multi-GPU / DDP
guard_buffer_size >= dataset_sizeto ensure good coverage per rank.Checkpoint Compatibility
guard_*keys in the state dict. Requiresstrict=Falsewhen loading.Tests
This is currently tested for the Crash recipe on the bumper beam dataset. No OOD warnings for inference on test samples. For new OOD samples which have either OOD global parameters or a scaled geometry by a small factor (1.05), the OOD warning is raised.
Implementation
Registered buffers
All guard state is stored as registered buffers (persistent, non-trainable):
guard_global_min(global_dim,)guard_global_max(global_dim,)guard_geo_embeddings(buffer_size, dim_head)guard_geo_ptr(1,)guard_knn_thresholdguard_knn_kKey methods
_guard_collect()_guard_check()compute_guard_threshold()state_dict()call (before save)Checklist
Dependencies
None