Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6344bfc
Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D
peterdsharpe Mar 10, 2026
2318514
Merge branch 'main' into psharpe/add-mesh-improvements-for-GLOBE-3D
peterdsharpe Mar 11, 2026
1f471ff
Add dual-tree traversal algorithm to GLOBE model for O(N) kernel eval…
peterdsharpe Mar 11, 2026
1cc74e0
Adds DTT-related changes to AirFRANS train.py
peterdsharpe Mar 11, 2026
a984c12
Squashed commit of the following:
peterdsharpe Mar 12, 2026
137905f
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 12, 2026
fd1bcaf
Grammar fix
peterdsharpe Mar 12, 2026
553cf08
Docstring fix
peterdsharpe Mar 12, 2026
b45ca1b
Update BarnesHutKernel to use appropriate dtype for zero-valued tenso…
peterdsharpe Mar 12, 2026
9f5a252
Update MetaData class in model.py to disable JIT and CUDA graphs for …
peterdsharpe Mar 12, 2026
2722e64
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 12, 2026
ad859f3
Enhance GLOBE model to support cross-boundary condition (BC) interact…
peterdsharpe Mar 12, 2026
963df9b
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 12, 2026
6f5982b
Update run.sh script for improved configuration and compatibility
peterdsharpe Mar 13, 2026
b06dfab
Enhance ClusterTree and BarnesHutKernel for improved internal node ha…
peterdsharpe Mar 15, 2026
ca41fad
Adds theta and leaf_size forwarding
peterdsharpe Mar 15, 2026
7201fc6
Docs fixes, and properly abstracts _ragged.py to deduplicate code.
peterdsharpe Mar 16, 2026
a26e7d4
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 16, 2026
6466b54
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 19, 2026
fd84ec0
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 20, 2026
55122f0
Enhance global data handling in MultiscaleKernel by adding a copy ope…
peterdsharpe Mar 20, 2026
bf6c139
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 20, 2026
5571f2a
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 20, 2026
976decd
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Mar 23, 2026
dff836c
Adds ragged arange tests
peterdsharpe Mar 23, 2026
f850ef7
Adds traceable ragged_arange variant
peterdsharpe Mar 24, 2026
251400a
Refactor Kernel class to simplify network evaluation and enhance perf…
peterdsharpe Mar 24, 2026
629b137
Always use tensorclass, not dataclass
peterdsharpe Mar 24, 2026
9fe6fc9
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Apr 13, 2026
e72467c
formatting
peterdsharpe Apr 20, 2026
76fef05
Merge branch 'main' into psharpe/stacked-add-GLOBE-3D-DTT-model-changes
peterdsharpe Apr 20, 2026
9754f99
Adds minor type hint annotation
peterdsharpe Apr 20, 2026
10130cf
Adds in option to do far-field 1st-order expansion (default off, togg…
peterdsharpe Apr 20, 2026
2ec4edb
changelog wording
peterdsharpe Apr 20, 2026
f4b3b97
formatting
peterdsharpe Apr 20, 2026
b2ac138
Refactor inference and training scripts to remove chunk_size paramete…
peterdsharpe Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`)
- Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`),
Comment thread
peterdsharpe marked this conversation as resolved.
including new variant that uses a dual tree traversal algorithm to
fundamentally reduce the complexity of the kernel evaluations from O(N^2) to
O(N).
- Adds GLOBE AirFRANS example case (`examples/cfd/external_aerodynamics/globe/airfrans`)
- Adds automatic support for `FSDP` and/or `ShardTensor` models in checkpoint save/load
functionality
Expand Down
20 changes: 14 additions & 6 deletions examples/cfd/external_aerodynamics/globe/airfrans/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
set -euo pipefail

### [User Configuration]
OUTPUT_NAME="${SLURM_JOB_NAME:-globe_airfrans_local}"
SCRIPT_DIR="${SLURM_SUBMIT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)}"
OUTPUT_DIR="${SCRIPT_DIR}/output/${OUTPUT_NAME}"

TRAIN_ARGS=(
--output-name ${SLURM_JOB_NAME:-globe_airfrans_local}
--output-name "${OUTPUT_NAME}"
--airfrans-task "scarce"
--no-use-compile
--amp
)

export AIRFRANS_DATA_DIR="${HOME}/datasets/airfrans/Dataset" # Set this to your AirFRANS dataset
Expand All @@ -37,10 +43,12 @@ CUDA_MAJOR=$(sed -n 's/.*CUDA Version: \([0-9]*\).*/\1/p' <<< "$NVIDIA_SMI_OUTPU
echo "Number of GPUs per node detected: $NUM_GPUS_PER_NODE"

### [Thread Configuration]
# OMP_NUM_THREADS=1: DataLoader workers use process-level parallelism
# (num_workers auto-computed as n_cpus/n_gpus), so per-process threading
# is unnecessary and causes thread oversubscription.
CPUS_PER_NODE=${SLURM_CPUS_ON_NODE:-$(nproc)}
export OMP_NUM_THREADS=$((CPUS_PER_NODE / NUM_GPUS_PER_NODE))
OMP_NUM_THREADS=$((OMP_NUM_THREADS > 0 ? OMP_NUM_THREADS : 1))
echo "OMP_NUM_THREADS=$OMP_NUM_THREADS (${CPUS_PER_NODE} CPUs / ${NUM_GPUS_PER_NODE} GPUs)"
export OMP_NUM_THREADS=1
echo "OMP_NUM_THREADS=$OMP_NUM_THREADS (process-level parallelism via DataLoader workers; ${CPUS_PER_NODE} CPUs / ${NUM_GPUS_PER_NODE} GPUs)"

### [Sync Dependencies]
if [ -z "$CUDA_MAJOR" ]; then
Expand All @@ -66,8 +74,8 @@ rm -f "$OUTPUT_DIR/SHUTDOWN"

if [ "${SLURM_NNODES:-1}" -gt 1 ]; then
echo "Running multi-node training..."
head_node=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
head_node=$(hostname -s)
head_node_ip=$(hostname --ip-address)
echo "Head node: $head_node"
echo "Head node IP: $head_node_ip"
srun uv run --no-sync torchrun \
Expand Down
62 changes: 19 additions & 43 deletions examples/cfd/external_aerodynamics/globe/airfrans/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import torch
import torch.nn.functional as F
import torchinfo
from dataset import AirFRANSDataSet, AirFRANSSample, compute_max_mesh_sizes
from jaxtyping import Float, Int
from dataset import AirFRANSDataSet, AirFRANSSample
from jaxtyping import Float
from mlflow.tracking.fluent import (
active_run,
log_artifact,
Expand Down Expand Up @@ -71,8 +71,8 @@ def main(
amp: bool = False,
use_compile: bool = True,
compile_mode: Literal[
"default", "max-autotune-no-cudagraphs", "reduce-overhead", "max-autotune"
] = "max-autotune",
"default", "max-autotune-no-cudagraphs"
] = "max-autotune-no-cudagraphs",
points_per_iter: int = 2048,
learning_rate: float = 1e-3,
weight_decay: float = 1e-4,
Expand All @@ -87,6 +87,8 @@ def main(
n_latent_scalars: int = 12,
n_latent_vectors: int = 6,
n_spherical_harmonics: int = 1,
theta: float = 1.0,
leaf_size: int = 1,
airfrans_task: Literal["full", "scarce", "reynolds", "aoa"] = "full",
use_profiler: bool = True,
make_images: bool = True,
Expand Down Expand Up @@ -115,6 +117,8 @@ def main(
n_latent_scalars: Number of scalar latent channels propagated between hyperlayers.
n_latent_vectors: Number of vector latent channels propagated between hyperlayers.
n_spherical_harmonics: Number of Legendre polynomial terms for angle features.
theta: Barnes-Hut opening angle. Larger = more aggressive approximation.
leaf_size: Maximum sources per leaf node in the Barnes-Hut tree.
airfrans_task: Which AirFRANS dataset task to train on.
use_profiler: Enable PyTorch profiler for performance analysis.
make_images: Whether to make images for visualization.
Expand Down Expand Up @@ -235,6 +239,8 @@ def main(
n_latent_scalars=n_latent_scalars,
n_latent_vectors=n_latent_vectors,
n_spherical_harmonics=n_spherical_harmonics,
theta=theta,
leaf_size=leaf_size,
).to(device)

if dist.rank == 0:
Expand Down Expand Up @@ -269,24 +275,6 @@ def main(
static_graph=True,
)

### [Compute Maximum Mesh Sizes Per BC Type and Split]
max_sizes: dict[
Split,
TensorDict[
str, TensorDict[Literal["n_points", "n_cells"], Int[torch.Tensor, ""]]
],
] = {
split: compute_max_mesh_sizes(
dataloaders[split],
device,
face_downsampling_ratio=(
train_face_downsampling_ratio if split == "train" else 1.0
),
rank=dist.rank,
)
for split in splits
}

### [Optimizer and Scheduler Setup]
# Square-root batch-size scaling: when the effective batch size grows
# (more GPUs or more points), gradient variance decreases proportionally,
Expand Down Expand Up @@ -401,7 +389,7 @@ def main(

### [Training and Testing]
@torch.compile(
dynamic=False,
dynamic=True,
mode=compile_mode,
disable=not use_compile,
)
Expand Down Expand Up @@ -462,29 +450,17 @@ def run_epoch(split: Split) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
)
sample.boundary_meshes[bc_type] = mesh

### Pad boundary meshes to fixed size for static compilation
split_max_sizes = max_sizes[split]
for bc_type, mesh in sample.boundary_meshes.items():
padded = mesh.pad(
target_n_points=int(split_max_sizes[bc_type, "n_points"]),
target_n_cells=int(split_max_sizes[bc_type, "n_cells"]),
data_padding_value=0.0,
)
### Pre-cache all geometry on the *padded* mesh so that
# the cache structure is fully populated before torch.compile
# ever sees it. Mesh.pad() creates a new Mesh with an empty
# cache, so caching must happen *after* padding. Without
# this, lazy computation during the compiled forward pass
# grows the cache dict, triggering Dynamo guard failures.
### Pre-cache geometry so lazy computation doesn't trigger
# Dynamo guard failures during compiled forward passes.
for mesh in sample.boundary_meshes.values():
if training and train_randomize_face_centers:
padded._cache["cell", "centroids"] = (
padded.sample_random_points_on_cells()
mesh._cache["cell", "centroids"] = (
mesh.sample_random_points_on_cells()
)
else:
_ = padded.cell_centroids
_ = padded.cell_areas
_ = padded.cell_normals
sample.boundary_meshes[bc_type] = padded
_ = mesh.cell_centroids
_ = mesh.cell_areas
_ = mesh.cell_normals

with record_function("data_transfer"):
sample = sample.to(device)
Expand Down
12 changes: 10 additions & 2 deletions physicsnemo/experimental/models/globe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from physicsnemo.experimental.models.globe.cluster_tree import (
ClusterTree,
DualInteractionPlan,
SourceAggregates,
)
from physicsnemo.experimental.models.globe.field_kernel import (
ChunkedKernel,
BarnesHutKernel,
Kernel,
MultiscaleKernel,
)
Expand All @@ -24,6 +29,9 @@
__all__ = [
"GLOBE",
"Kernel",
"ChunkedKernel",
"BarnesHutKernel",
"MultiscaleKernel",
"ClusterTree",
"DualInteractionPlan",
"SourceAggregates",
]
Loading
Loading