Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ include_local_features: true # use local features
radii: [0.05, 0.25, 1.0, 2.5] # radius for local features
neighbors_in_radius: [8, 32, 64, 128] # neighbors in radius for local features
n_hidden_local: 32 # hidden dimension for local features
guard_buffer_size: 500 # OOD guard buffer size (set to dataset size), null to disable
guard_knn_k: 10 # k for geometry kNN OOD detection

Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ defaults:
# └───────────────────────────────────────────┘

training:
raw_data_dir: ??? # set in config or via CLI: training.raw_data_dir=/path/to/train
raw_data_dir_validation: ??? # set in config or via CLI: training.raw_data_dir_validation=/path/to/validation
global_features_filepath: ??? # set in config or via CLI: training.global_features_filepath=/path/to/global_features.json
raw_data_dir: /code/datasets/bumper_beam/train/ # set in config or via CLI: training.raw_data_dir=/path/to/train
raw_data_dir_validation: /code/datasets/bumper_beam/validation/ # set in config or via CLI: training.raw_data_dir_validation=/path/to/validation
global_features_filepath: /code/datasets/bumper_beam/global_features.json # set in config or via CLI: training.global_features_filepath=/path/to/global_features.json
Comment thread
mnabian marked this conversation as resolved.
Outdated
optimizer: muon

# ┌───────────────────────────────────────────┐
Expand Down Expand Up @@ -75,4 +75,6 @@ datapipe:
model:
functional_dim: 3 # coords (3)
out_dim: 250 # (num_time_steps - 1) * 5 = 50 * 5
global_dim: 3 # must match len(datapipe.global_features)
global_dim: 3 # must match len(datapipe.global_features)
guard_buffer_size: null # OOD guard buffer size (= num_training_samples), null to disable
guard_knn_k: 0 # k for geometry kNN OOD detection
10 changes: 5 additions & 5 deletions examples/structural_mechanics/crash/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,11 @@ def main(cfg: DictConfig):
logger0.error(f"Parent directory not found: {parent_dir}")
return

run_items = [
f.path
for f in os.scandir(parent_dir)
if f.is_file() and f.name.lower().endswith(".vtp")
]
run_items = []
for root, _dirs, files in os.walk(parent_dir):
for fname in files:
if fname.lower().endswith(".vtp"):
run_items.append(os.path.join(root, fname))
run_items.sort()
run_names = [os.path.splitext(os.path.basename(p))[0] for p in run_items]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ def build_context(

# Tokenize geometry features
if self.geometry_tokenizer is not None and geometry is not None:
context_parts.append(self.geometry_tokenizer(geometry))
geometry_context = self.geometry_tokenizer(geometry)
self._last_geometry_context = geometry_context
Comment thread
mnabian marked this conversation as resolved.
context_parts.append(geometry_context)

# Tokenize global embedding
if self.global_tokenizer is not None and global_embedding is not None:
Expand Down
170 changes: 169 additions & 1 deletion physicsnemo/experimental/models/geotransolver/geotransolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from __future__ import annotations

import logging
import warnings
Comment thread
mnabian marked this conversation as resolved.
Outdated
from collections.abc import Sequence
from dataclasses import dataclass

Expand Down Expand Up @@ -315,6 +317,8 @@ def __init__(
radii: list[float] | None = None,
neighbors_in_radius: list[int] | None = None,
n_hidden_local: int = 32,
guard_buffer_size: int | None = None,
guard_knn_k: int = 10,
) -> None:
super().__init__(meta=GeoTransolverMetaData())
self.__name__ = "GeoTransolver"
Expand Down Expand Up @@ -437,6 +441,56 @@ def __init__(
nn.Linear(n_hidden, n_hidden),
)

# OOD guard buffers
dim_head = n_hidden // n_head
self._guard_buffer_size = guard_buffer_size

if guard_buffer_size is not None:
# Global parameters: per-dimension bounding box
if global_dim is not None:
self.register_buffer(
"guard_global_min",
torch.full((global_dim,), float("inf")),
)
self.register_buffer(
"guard_global_max",
torch.full((global_dim,), float("-inf")),
)
else:
self.register_buffer("guard_global_min", None)
self.register_buffer("guard_global_max", None)

# Geometry context: kNN rolling buffer
if geometry_dim is not None:
self.register_buffer(
"guard_geo_embeddings",
torch.zeros(guard_buffer_size, dim_head),
)
self.register_buffer(
"guard_geo_ptr",
torch.zeros(1, dtype=torch.long),
)
self.register_buffer(
"guard_knn_threshold",
torch.tensor(float("inf")),
)
else:
self.register_buffer("guard_geo_embeddings", None)
self.register_buffer("guard_geo_ptr", None)
self.register_buffer("guard_knn_threshold", None)

self.register_buffer(
"guard_knn_k",
torch.tensor(guard_knn_k, dtype=torch.long),
)
else:
self.register_buffer("guard_global_min", None)
self.register_buffer("guard_global_max", None)
self.register_buffer("guard_geo_embeddings", None)
self.register_buffer("guard_geo_ptr", None)
self.register_buffer("guard_knn_threshold", None)
self.register_buffer("guard_knn_k", None)

def forward(
self,
local_embedding: (
Expand Down Expand Up @@ -533,6 +587,13 @@ def forward(
local_embedding, local_positions, geometry, global_embedding
)

# --- OOD Guard ---
if self._guard_buffer_size is not None:
if self.training:
self._guard_collect(global_embedding, geometry)
else:
self._guard_check(global_embedding, geometry)

# Project inputs to hidden dimension: (B, N, C) -> (B, N, n_hidden)
x = [self.preprocess[i](le) for i, le in enumerate(local_embedding)]

Expand All @@ -556,4 +617,111 @@ def forward(
else:
x = tuple(x)

return x
return x

# --- OOD Guard methods ---

@torch.no_grad()
def _guard_collect(
self,
global_embedding: torch.Tensor | None,
geometry: torch.Tensor | None,
) -> None:
"""Collect OOD guard calibration data during training."""
# Update global parameter bounds
if global_embedding is not None and self.guard_global_min is not None:
batch_min = global_embedding.detach().min(dim=0).values.min(dim=0).values
batch_max = global_embedding.detach().max(dim=0).values.max(dim=0).values
self.guard_global_min.copy_(
torch.minimum(self.guard_global_min, batch_min)
)
self.guard_global_max.copy_(
torch.maximum(self.guard_global_max, batch_max)
)

# Append geometry context embeddings to FIFO buffer
if geometry is not None and self.guard_geo_embeddings is not None:
geo_ctx = self.context_builder._last_geometry_context # (B, H, S, D)
pooled = geo_ctx.detach().mean(dim=(1, 2)) # (B, D)
buf_size = self.guard_geo_embeddings.shape[0]
ptr = self.guard_geo_ptr.item()
b = pooled.shape[0]
for i in range(b):
self.guard_geo_embeddings[ptr % buf_size] = pooled[i]
ptr += 1
self.guard_geo_ptr.fill_(ptr)

@torch.no_grad()
def _guard_check(
self,
global_embedding: torch.Tensor | None,
geometry: torch.Tensor | None,
) -> None:
"""Run OOD checks during inference and emit warnings."""
_RED = "\033[91m"
_RESET = "\033[0m"

# Check global parameter bounds
if global_embedding is not None and self.guard_global_min is not None:
if not torch.isinf(self.guard_global_min).any():
vals = global_embedding.detach()
batch_min = vals.min(dim=0).values.min(dim=0).values
batch_max = vals.max(dim=0).values.max(dim=0).values
for d in range(batch_min.shape[0]):
lo = self.guard_global_min[d].item()
hi = self.guard_global_max[d].item()
if batch_min[d].item() < lo:
logging.warning(
f"{_RED}OOD Guard: global_embedding dim {d} value "
f"{batch_min[d].item():.4f} below training min {lo:.4f}{_RESET}"
)
if batch_max[d].item() > hi:
logging.warning(
f"{_RED}OOD Guard: global_embedding dim {d} value "
f"{batch_max[d].item():.4f} above training max {hi:.4f}{_RESET}"
)

# Check geometry kNN distance
if geometry is not None and self.guard_geo_embeddings is not None:
if not torch.isinf(self.guard_knn_threshold):
geo_ctx = self.context_builder._last_geometry_context
pooled = geo_ctx.detach().mean(dim=(1, 2)) # (B, D)
z = pooled / (pooled.norm(dim=-1, keepdim=True) + 1e-8)
store = self.guard_geo_embeddings
store_norm = store / (store.norm(dim=-1, keepdim=True) + 1e-8)
dists = torch.cdist(z, store_norm) # (B, buf_size)
k = self.guard_knn_k.item()
kth_dists = dists.topk(k, largest=False).values[:, -1] # (B,)
threshold = self.guard_knn_threshold.item()
for i in range(kth_dists.shape[0]):
dist_val = kth_dists[i].item()
if dist_val > threshold:
logging.warning(
f"{_RED}OOD Guard: geometry sample {i} kNN distance "
f"{dist_val:.4f} above threshold {threshold:.4f}{_RESET}"
)
Comment thread
mnabian marked this conversation as resolved.
Outdated

@torch.no_grad()
def compute_guard_threshold(self) -> None:
"""Compute kNN threshold from accumulated geometry embeddings."""
if self.guard_geo_embeddings is None:
return
ptr = self.guard_geo_ptr.item()
if ptr == 0:
return
buf_size = self.guard_geo_embeddings.shape[0]
n_valid = min(ptr, buf_size)
store = self.guard_geo_embeddings[:n_valid]
store_norm = store / (store.norm(dim=-1, keepdim=True) + 1e-8)
dists = torch.cdist(store_norm, store_norm) # (N, N)
dists.fill_diagonal_(float("inf"))
k = self.guard_knn_k.item()
kth_dists = dists.topk(k, largest=False).values[:, -1] # (N,)
# 99th percentile — near-zero false alarms on in-distribution data
threshold = torch.quantile(kth_dists, 0.99)
self.guard_knn_threshold.copy_(threshold)
Comment thread
mnabian marked this conversation as resolved.
Outdated

def state_dict(self, *args, **kwargs):
"""Override to compute guard threshold before saving."""
self.compute_guard_threshold()
return super().state_dict(*args, **kwargs)
Loading