diff --git a/CHANGELOG.md b/CHANGELOG.md index c1d43cd507..d94941c717 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`) +- Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`), + including new variant that uses a dual tree traversal algorithm to reduce the + complexity of the kernel evaluations from O(N^2) to O(N). - Adds GLOBE AirFRANS example case (`examples/cfd/external_aerodynamics/globe/airfrans`) - Adds concrete dropout uncertainty quantification for GeoTransolver. Learnable per-layer dropout rates enable MC-Dropout inference for uncertainty diff --git a/examples/cfd/external_aerodynamics/globe/airfrans/inference.py b/examples/cfd/external_aerodynamics/globe/airfrans/inference.py index e3ccf400fa..d056395e4d 100644 --- a/examples/cfd/external_aerodynamics/globe/airfrans/inference.py +++ b/examples/cfd/external_aerodynamics/globe/airfrans/inference.py @@ -83,7 +83,7 @@ # %% with torch.no_grad(): model.eval() - pred_mesh = model(**sample.model_input_kwargs, chunk_size=128) + pred_mesh = model(**sample.model_input_kwargs) # %% AirFRANSDataSet.postprocess( diff --git a/examples/cfd/external_aerodynamics/globe/airfrans/run.sh b/examples/cfd/external_aerodynamics/globe/airfrans/run.sh index a6acf4e746..4446acd6eb 100755 --- a/examples/cfd/external_aerodynamics/globe/airfrans/run.sh +++ b/examples/cfd/external_aerodynamics/globe/airfrans/run.sh @@ -14,9 +14,15 @@ set -euo pipefail ### [User Configuration] +OUTPUT_NAME="${SLURM_JOB_NAME:-globe_airfrans_local}" +SCRIPT_DIR="${SLURM_SUBMIT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)}" +OUTPUT_DIR="${SCRIPT_DIR}/output/${OUTPUT_NAME}" + TRAIN_ARGS=( - --output-name ${SLURM_JOB_NAME:-globe_airfrans_local} + --output-name "${OUTPUT_NAME}" --airfrans-task "scarce" + --no-use-compile + --amp ) export AIRFRANS_DATA_DIR="${HOME}/datasets/airfrans/Dataset" # Set this to your AirFRANS dataset @@ -37,10 +43,12 @@ CUDA_MAJOR=$(sed -n 's/.*CUDA Version: \([0-9]*\).*/\1/p' <<< "$NVIDIA_SMI_OUTPU echo "Number of GPUs per node detected: $NUM_GPUS_PER_NODE" ### [Thread Configuration] +# OMP_NUM_THREADS=1: DataLoader workers use process-level parallelism +# (num_workers auto-computed as n_cpus/n_gpus), so per-process threading +# is unnecessary and causes thread oversubscription. CPUS_PER_NODE=${SLURM_CPUS_ON_NODE:-$(nproc)} -export OMP_NUM_THREADS=$((CPUS_PER_NODE / NUM_GPUS_PER_NODE)) -OMP_NUM_THREADS=$((OMP_NUM_THREADS > 0 ? OMP_NUM_THREADS : 1)) -echo "OMP_NUM_THREADS=$OMP_NUM_THREADS (${CPUS_PER_NODE} CPUs / ${NUM_GPUS_PER_NODE} GPUs)" +export OMP_NUM_THREADS=1 +echo "OMP_NUM_THREADS=$OMP_NUM_THREADS (process-level parallelism via DataLoader workers; ${CPUS_PER_NODE} CPUs / ${NUM_GPUS_PER_NODE} GPUs)" ### [Sync Dependencies] if [ -z "$CUDA_MAJOR" ]; then @@ -66,8 +74,8 @@ rm -f "$OUTPUT_DIR/SHUTDOWN" if [ "${SLURM_NNODES:-1}" -gt 1 ]; then echo "Running multi-node training..." - head_node=$(scontrol show hostnames $SLURM_NODELIST | head -n1) - head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + head_node=$(hostname -s) + head_node_ip=$(hostname --ip-address) echo "Head node: $head_node" echo "Head node IP: $head_node_ip" srun uv run --no-sync torchrun \ diff --git a/examples/cfd/external_aerodynamics/globe/airfrans/train.py b/examples/cfd/external_aerodynamics/globe/airfrans/train.py index 6c15a90784..c3b13b70f3 100644 --- a/examples/cfd/external_aerodynamics/globe/airfrans/train.py +++ b/examples/cfd/external_aerodynamics/globe/airfrans/train.py @@ -30,8 +30,8 @@ import torch import torch.nn.functional as F import torchinfo -from dataset import AirFRANSDataSet, AirFRANSSample, compute_max_mesh_sizes -from jaxtyping import Float, Int +from dataset import AirFRANSDataSet, AirFRANSSample +from jaxtyping import Float from mlflow.tracking.fluent import ( active_run, log_artifact, @@ -71,8 +71,8 @@ def main( amp: bool = False, use_compile: bool = True, compile_mode: Literal[ - "default", "max-autotune-no-cudagraphs", "reduce-overhead", "max-autotune" - ] = "max-autotune", + "default", "max-autotune-no-cudagraphs" + ] = "max-autotune-no-cudagraphs", points_per_iter: int = 2048, learning_rate: float = 1e-3, weight_decay: float = 1e-4, @@ -87,6 +87,8 @@ def main( n_latent_scalars: int = 12, n_latent_vectors: int = 6, n_spherical_harmonics: int = 1, + theta: float = 1.0, + leaf_size: int = 1, airfrans_task: Literal["full", "scarce", "reynolds", "aoa"] = "full", use_profiler: bool = True, make_images: bool = True, @@ -115,6 +117,8 @@ def main( n_latent_scalars: Number of scalar latent channels propagated between hyperlayers. n_latent_vectors: Number of vector latent channels propagated between hyperlayers. n_spherical_harmonics: Number of Legendre polynomial terms for angle features. + theta: Barnes-Hut opening angle. Larger = more aggressive approximation. + leaf_size: Maximum sources per leaf node in the Barnes-Hut tree. airfrans_task: Which AirFRANS dataset task to train on. use_profiler: Enable PyTorch profiler for performance analysis. make_images: Whether to make images for visualization. @@ -235,6 +239,8 @@ def main( n_latent_scalars=n_latent_scalars, n_latent_vectors=n_latent_vectors, n_spherical_harmonics=n_spherical_harmonics, + theta=theta, + leaf_size=leaf_size, ).to(device) if dist.rank == 0: @@ -269,24 +275,6 @@ def main( static_graph=True, ) - ### [Compute Maximum Mesh Sizes Per BC Type and Split] - max_sizes: dict[ - Split, - TensorDict[ - str, TensorDict[Literal["n_points", "n_cells"], Int[torch.Tensor, ""]] - ], - ] = { - split: compute_max_mesh_sizes( - dataloaders[split], - device, - face_downsampling_ratio=( - train_face_downsampling_ratio if split == "train" else 1.0 - ), - rank=dist.rank, - ) - for split in splits - } - ### [Optimizer and Scheduler Setup] # Square-root batch-size scaling: when the effective batch size grows # (more GPUs or more points), gradient variance decreases proportionally, @@ -401,7 +389,7 @@ def main( ### [Training and Testing] @torch.compile( - dynamic=False, + dynamic=True, mode=compile_mode, disable=not use_compile, ) @@ -462,29 +450,17 @@ def run_epoch(split: Split) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: ) sample.boundary_meshes[bc_type] = mesh - ### Pad boundary meshes to fixed size for static compilation - split_max_sizes = max_sizes[split] - for bc_type, mesh in sample.boundary_meshes.items(): - padded = mesh.pad( - target_n_points=int(split_max_sizes[bc_type, "n_points"]), - target_n_cells=int(split_max_sizes[bc_type, "n_cells"]), - data_padding_value=0.0, - ) - ### Pre-cache all geometry on the *padded* mesh so that - # the cache structure is fully populated before torch.compile - # ever sees it. Mesh.pad() creates a new Mesh with an empty - # cache, so caching must happen *after* padding. Without - # this, lazy computation during the compiled forward pass - # grows the cache dict, triggering Dynamo guard failures. + ### Pre-cache geometry so lazy computation doesn't trigger + # Dynamo guard failures during compiled forward passes. + for mesh in sample.boundary_meshes.values(): if training and train_randomize_face_centers: - padded._cache["cell", "centroids"] = ( - padded.sample_random_points_on_cells() + mesh._cache["cell", "centroids"] = ( + mesh.sample_random_points_on_cells() ) else: - _ = padded.cell_centroids - _ = padded.cell_areas - _ = padded.cell_normals - sample.boundary_meshes[bc_type] = padded + _ = mesh.cell_centroids + _ = mesh.cell_areas + _ = mesh.cell_normals with record_function("data_transfer"): sample = sample.to(device) @@ -650,7 +626,6 @@ def checkpoint_metadata() -> dict[str, Any]: base_model.eval() pred_mesh = base_model( **viz_sample.model_input_kwargs, - chunk_size=points_per_iter, ) AirFRANSDataSet.postprocess( pred_mesh=pred_mesh.to(device="cpu"), diff --git a/physicsnemo/experimental/models/globe/__init__.py b/physicsnemo/experimental/models/globe/__init__.py index ab68ad5787..7067ce74e1 100644 --- a/physicsnemo/experimental/models/globe/__init__.py +++ b/physicsnemo/experimental/models/globe/__init__.py @@ -14,8 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from physicsnemo.experimental.models.globe.cluster_tree import ( + ClusterTree, + DualInteractionPlan, + SourceAggregates, +) from physicsnemo.experimental.models.globe.field_kernel import ( - ChunkedKernel, + BarnesHutKernel, Kernel, MultiscaleKernel, ) @@ -24,6 +29,9 @@ __all__ = [ "GLOBE", "Kernel", - "ChunkedKernel", + "BarnesHutKernel", "MultiscaleKernel", + "ClusterTree", + "DualInteractionPlan", + "SourceAggregates", ] diff --git a/physicsnemo/experimental/models/globe/cluster_tree.py b/physicsnemo/experimental/models/globe/cluster_tree.py new file mode 100644 index 0000000000..7b0637a7a7 --- /dev/null +++ b/physicsnemo/experimental/models/globe/cluster_tree.py @@ -0,0 +1,1388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spatial cluster tree for dual-tree Barnes-Hut acceleration of GLOBE kernels. + +This module provides a GPU-compatible hierarchical spatial decomposition over a +set of points, designed for dual-tree Barnes-Hut O(N) kernel acceleration. +Trees are built over both source and target points. The dual-tree traversal +classifies (target_node, source_node) pairs as near-field or far-field: + +- **Near-field**: both nodes are leaves and nearby - expand to individual + (target, source) pairs for exact kernel evaluation. +- **Far-field**: nodes are well-separated - evaluate the kernel ONCE at the + node centroids and broadcast the result to all targets in the target node. + +This reduces far-field kernel evaluations from O(N log N) (single-tree) to +O(N) (dual-tree), which is critical at large mesh scales (800k+ faces). + +Construction uses the same morton-code-based Linear BVH (LBVH) algorithm as +:mod:`physicsnemo.mesh.spatial.bvh` (morton sort, midpoint splits, bottom-up +AABB propagation), but the resulting data structure differs: ClusterTree stores +additional per-node fields (diameter, subtree ranges, area-weighted aggregates) +needed for the Barnes-Hut opening criterion, dual-tree traversal, and +far-field monopole approximation. The two classes share +:func:`~physicsnemo.mesh.spatial.bvh._compute_morton_codes` and +:func:`~physicsnemo.mesh.spatial._ragged._ragged_arange` but are otherwise +independent. +""" + +import logging + +import torch +from jaxtyping import Float, Int +from tensordict import TensorDict, tensorclass +from torch.profiler import record_function + +from physicsnemo.mesh.spatial._ragged import _ragged_arange +from physicsnemo.mesh.spatial.bvh import _compute_morton_codes + +logger = logging.getLogger("globe.cluster_tree") + + +# --------------------------------------------------------------------------- +# InteractionPlan: the output of tree traversal +# --------------------------------------------------------------------------- + + +@tensorclass +class DualInteractionPlan: + r"""Result of a dual-tree Barnes-Hut traversal: four categories of + interactions that together cover all source contributions for every + target point. + + **(near, near)**: ``(near_target_ids[i], near_source_ids[i])`` are + individual target-source pairs requiring exact kernel evaluation. + + **(far, far)**: ``(far_target_node_ids[i], far_source_node_ids[i])`` + are node-to-node pairs where the kernel is evaluated ONCE at the + node centroids and the result is broadcast to all individual targets + in the target node. + + **(near, far)**: ``(nf_target_ids[i], nf_source_node_ids[i])`` are + individual target points paired with source nodes. The kernel is + evaluated at ``(target_point, source_centroid)`` using the source + node's monopole approximation. No target-side broadcast. + + **(far, near)**: ``(fn_target_node_ids[i], fn_source_ids[i])`` are + target nodes paired with individual source points. The kernel is + evaluated at ``(target_centroid, source_point)`` using exact source + data, then broadcast to stage-1 survivor targets via the + ``fn_broadcast_*`` mapping. + + All index tensors are ``int64`` on the same device as the tree. + """ + + near_target_ids: Int[torch.Tensor, " n_near"] + near_source_ids: Int[torch.Tensor, " n_near"] + far_target_node_ids: Int[torch.Tensor, " n_far_nodes"] + far_source_node_ids: Int[torch.Tensor, " n_far_nodes"] + nf_target_ids: Int[torch.Tensor, " n_nf"] + nf_source_node_ids: Int[torch.Tensor, " n_nf"] + fn_target_node_ids: Int[torch.Tensor, " n_fn"] + fn_source_ids: Int[torch.Tensor, " n_fn"] + fn_broadcast_targets: Int[torch.Tensor, " n_fn_bcast"] + fn_broadcast_starts: Int[torch.Tensor, " n_fn"] + fn_broadcast_counts: Int[torch.Tensor, " n_fn"] + + @property + def n_near(self) -> int: + """Number of (near,near) exact individual interaction pairs.""" + return self.near_target_ids.shape[0] + + @property + def n_far_nodes(self) -> int: + """Number of (far,far) node-to-node pairs (each = one kernel eval).""" + return self.far_target_node_ids.shape[0] + + @property + def n_nf(self) -> int: + """Number of (near,far) target-point-to-source-node pairs.""" + return self.nf_target_ids.shape[0] + + @property + def n_fn(self) -> int: + """Number of (far,near) target-node-to-source-point pairs.""" + return self.fn_target_node_ids.shape[0] + + def validate(self) -> None: + """Check internal consistency of the interaction plan. + + Verifies shape pairing, non-negativity, and fn_broadcast bounds. + Raises ``ValueError`` on any inconsistency. Intended to be called + behind a ``not torch.compiler.is_compiling()`` guard so it is + zero-cost under ``torch.compile``. + + Raises + ------ + ValueError + If any internal consistency check fails. + """ + ### Shape pairing: matched tensor pairs must have identical lengths + pairs: list[tuple[str, torch.Tensor, str, torch.Tensor]] = [ + ("near_target_ids", self.near_target_ids, + "near_source_ids", self.near_source_ids), + ("far_target_node_ids", self.far_target_node_ids, + "far_source_node_ids", self.far_source_node_ids), + ("nf_target_ids", self.nf_target_ids, + "nf_source_node_ids", self.nf_source_node_ids), + ("fn_target_node_ids", self.fn_target_node_ids, + "fn_source_ids", self.fn_source_ids), + ] + for name_a, a, name_b, b in pairs: + if a.shape != b.shape: + raise ValueError( + f"Shape mismatch: {name_a}.shape={a.shape!r} != " + f"{name_b}.shape={b.shape!r}" + ) + + ### fn_broadcast tensors must be consistently sized + n_fn = self.fn_source_ids.shape[0] + for name, tensor in [ + ("fn_broadcast_starts", self.fn_broadcast_starts), + ("fn_broadcast_counts", self.fn_broadcast_counts), + ]: + if tensor.shape != (n_fn,): + raise ValueError( + f"{name}.shape={tensor.shape!r}, expected ({n_fn},)" + ) + + ### Non-negativity + for name, tensor in [ + ("fn_broadcast_starts", self.fn_broadcast_starts), + ("fn_broadcast_counts", self.fn_broadcast_counts), + ]: + if tensor.numel() > 0 and (tensor < 0).any(): + raise ValueError(f"{name} contains negative values") + + ### fn_broadcast bounds: every (start, count) range with count > 0 + ### must fit within fn_broadcast_targets. Zero-count entries are + ### no-ops whose starts are never dereferenced. + if n_fn > 0: + nonzero = self.fn_broadcast_counts > 0 + if nonzero.any(): + ends = self.fn_broadcast_starts[nonzero] + self.fn_broadcast_counts[nonzero] + max_end = ends.max().item() + bcast_len = self.fn_broadcast_targets.shape[0] + if max_end > bcast_len: + raise ValueError( + f"fn_broadcast out of bounds: max(starts + counts)=" + f"{max_end} > fn_broadcast_targets.shape[0]={bcast_len}" + ) + + +# --------------------------------------------------------------------------- +# Segmented reduction helpers +# --------------------------------------------------------------------------- + + +def _segmented_weighted_sum( + values: Float[torch.Tensor, "n *features"], + weights: Float[torch.Tensor, " n"], + seg_ids: Int[torch.Tensor, " n"], + n_segments: int, +) -> Float[torch.Tensor, "n_segments *features"]: + """Compute weighted sum per segment via scatter_add. + + Parameters + ---------- + values : torch.Tensor + Values to aggregate, shape ``(N,)`` or ``(N, F)``. + weights : torch.Tensor + Per-element weights, shape ``(N,)``. + seg_ids : torch.Tensor + Segment assignment for each element, shape ``(N,)``, int64. + n_segments : int + Total number of output segments. + + Returns + ------- + torch.Tensor + Weighted sums, shape ``(n_segments,)`` or ``(n_segments, F)``. + """ + weighted = values * (weights.unsqueeze(-1) if values.ndim > 1 else weights) + out = torch.zeros( + (n_segments,) + values.shape[1:], + dtype=values.dtype, + device=values.device, + ) + idx = seg_ids.unsqueeze(-1).expand_as(weighted) if weighted.ndim > 1 else seg_ids + out.scatter_add_(0, idx, weighted) + return out + + +def _expand_dual_leaf_hits( + target_leaf_ids: Int[torch.Tensor, " n_leaf_pairs"], + source_leaf_ids: Int[torch.Tensor, " n_leaf_pairs"], + target_tree: "ClusterTree", + source_tree: "ClusterTree", + theta: float, +) -> tuple[ + Int[torch.Tensor, " n_near"], Int[torch.Tensor, " n_near"], + Int[torch.Tensor, " n_nf"], Int[torch.Tensor, " n_nf"], + Int[torch.Tensor, " n_fn"], Int[torch.Tensor, " n_fn"], + Int[torch.Tensor, " n_fn_bcast"], + Int[torch.Tensor, " n_fn"], Int[torch.Tensor, " n_fn"], +]: + """Expand ``(target_leaf, source_leaf)`` pairs with two-stage filtering. + + Applies two sequential per-point tests to classify each (target, source) + interaction within a leaf pair: + + **Stage 1 (per-target)**: Test each target against the source leaf AABB. + Targets that pass become **(near, far)** - they use the source monopole. + Targets that fail are "survivors" and proceed to stage 2. + + **Stage 2 (per-source)**: Test each source against the target leaf AABB. + Sources that pass become **(far, near)** - evaluated at the target + centroid and broadcast to all survivors. Sources that fail produce + **(near, near)** Cartesian product pairs with the survivors. + + The two stages are independent (different AABBs) and sequential (stage 2 + only applies to survivors), so no (target, source) pair is double-counted. + + Returns + ------- + near_target_ids, near_source_ids : torch.Tensor + (near, near) individual target-source pairs. + nf_target_ids, nf_source_node_ids : torch.Tensor + (near, far) individual target to source-node pairs. + fn_target_node_ids, fn_source_ids : torch.Tensor + (far, near) target-node to individual source pairs. + fn_broadcast_targets : torch.Tensor + Survivor target IDs sorted by leaf pair, for (far, near) broadcast. + fn_broadcast_starts, fn_broadcast_counts : torch.Tensor + Per-fn-pair offset/count into ``fn_broadcast_targets``. + """ + device = target_leaf_ids.device + theta_sq = theta * theta + n_pairs = target_leaf_ids.shape[0] + + def _empty_result(): + e = torch.empty(0, dtype=torch.long, device=device) + return e, e.clone(), e.clone(), e.clone(), e.clone(), e.clone(), e.clone(), e.clone(), e.clone() + + if n_pairs == 0: + return _empty_result() + + t_starts = target_tree.leaf_start[target_leaf_ids] + t_counts = target_tree.leaf_count[target_leaf_ids] + s_starts = source_tree.leaf_start[source_leaf_ids] + s_counts = source_tree.leaf_count[source_leaf_ids] + + # ================================================================== + # Stage 1: per-target test against source leaf AABBs + # ================================================================== + positions_t, leaf_pair_ids_t = _ragged_arange(t_starts, t_counts) + target_point_ids = target_tree.sorted_source_order[positions_t] + target_pts = target_tree.source_points[target_point_ids] + + src_leaf_per_target = source_leaf_ids[leaf_pair_ids_t] + clamped_t = torch.clamp( + target_pts, + min=source_tree.node_aabb_min[src_leaf_per_target], + max=source_tree.node_aabb_max[src_leaf_per_target], + ) + dist_sq_t = (target_pts - clamped_t).pow(2).sum(dim=-1) + target_is_far = dist_sq_t * theta_sq > source_tree.node_diameter_sq[src_leaf_per_target] + + ### (near, far) output + nf_target_ids = target_point_ids[target_is_far] + nf_source_node_ids = src_leaf_per_target[target_is_far] + + ### Survivors: targets that failed the per-target test + surv_mask = ~target_is_far + if not surv_mask.any(): + e = torch.empty(0, dtype=torch.long, device=device) + return e, e.clone(), nf_target_ids, nf_source_node_ids, e.clone(), e.clone(), e.clone(), e.clone(), e.clone() + + surv_point_ids = target_point_ids[surv_mask] + surv_lp_ids = leaf_pair_ids_t[surv_mask] + + # ================================================================== + # Stage 2: per-source test against target leaf AABBs + # ================================================================== + positions_s, leaf_pair_ids_s = _ragged_arange(s_starts, s_counts) + src_point_ids = source_tree.sorted_source_order[positions_s] + src_pts = source_tree.source_points[src_point_ids] + + tgt_leaf_per_src = target_leaf_ids[leaf_pair_ids_s] + clamped_s = torch.clamp( + src_pts, + min=target_tree.node_aabb_min[tgt_leaf_per_src], + max=target_tree.node_aabb_max[tgt_leaf_per_src], + ) + dist_sq_s = (src_pts - clamped_s).pow(2).sum(dim=-1) + source_is_far = dist_sq_s * theta_sq > target_tree.node_diameter_sq[tgt_leaf_per_src] + + ### (far, near) output: source points far from the target leaf + fn_source_ids = src_point_ids[source_is_far] + fn_target_node_ids = tgt_leaf_per_src[source_is_far] + fn_lp_ids = leaf_pair_ids_s[source_is_far] + + # ================================================================== + # Build (far, near) broadcast mapping + # ================================================================== + # Group survivors by leaf pair so each fn source can look up its + # broadcast targets (all survivors from the same leaf pair). + # Only include survivors from leaf pairs that have fn sources; + # survivors from all-close leaf pairs are not referenced by any + # fn_broadcast_starts/counts entry. + has_fn_source = torch.zeros(n_pairs, dtype=torch.bool, device=device) + if fn_lp_ids.numel() > 0: + has_fn_source[fn_lp_ids] = True + fn_active_mask = has_fn_source[surv_lp_ids] + + active_surv_ids = surv_point_ids[fn_active_mask] + active_surv_lp_ids = surv_lp_ids[fn_active_mask] + + surv_sort = active_surv_lp_ids.argsort(stable=True) + fn_broadcast_targets = active_surv_ids[surv_sort] + + surv_counts_per_lp = torch.bincount(active_surv_lp_ids, minlength=n_pairs) + surv_starts_per_lp = surv_counts_per_lp.cumsum(0) - surv_counts_per_lp + + fn_broadcast_starts = surv_starts_per_lp[fn_lp_ids] + fn_broadcast_counts = surv_counts_per_lp[fn_lp_ids] + + # ================================================================== + # Reduced Cartesian product: survivors × close sources only + # ================================================================== + close_mask = ~source_is_far + close_src_ids = src_point_ids[close_mask] + close_lp_ids = leaf_pair_ids_s[close_mask] + + if close_src_ids.numel() == 0 or surv_point_ids.numel() == 0: + e = torch.empty(0, dtype=torch.long, device=device) + return ( + e, e.clone(), + nf_target_ids, nf_source_node_ids, + fn_target_node_ids, fn_source_ids, + fn_broadcast_targets, fn_broadcast_starts, fn_broadcast_counts, + ) + + ### Group close sources by leaf pair for contiguous access + close_sort = close_lp_ids.argsort(stable=True) + sorted_close_srcs = close_src_ids[close_sort] + close_counts_per_lp = torch.bincount(close_lp_ids, minlength=n_pairs) + close_starts_per_lp = close_counts_per_lp.cumsum(0) - close_counts_per_lp + + ### Each survivor expands against its leaf pair's close sources + per_surv_close_counts = close_counts_per_lp[surv_lp_ids] + total_nn = int(per_surv_close_counts.sum()) + + if total_nn == 0: + e = torch.empty(0, dtype=torch.long, device=device) + return ( + e, e.clone(), + nf_target_ids, nf_source_node_ids, + fn_target_node_ids, fn_source_ids, + fn_broadcast_targets, fn_broadcast_starts, fn_broadcast_counts, + ) + + expanded_near_tgts = torch.repeat_interleave(surv_point_ids, per_surv_close_counts) + per_surv_close_starts = close_starts_per_lp[surv_lp_ids] + src_positions_nn, _ = _ragged_arange(per_surv_close_starts, per_surv_close_counts) + expanded_near_srcs = sorted_close_srcs[src_positions_nn] + + return ( + expanded_near_tgts, expanded_near_srcs, + nf_target_ids, nf_source_node_ids, + fn_target_node_ids, fn_source_ids, + fn_broadcast_targets, fn_broadcast_starts, fn_broadcast_counts, + ) + + +# --------------------------------------------------------------------------- +# ClusterTree tensorclass +# --------------------------------------------------------------------------- + + +@tensorclass +class ClusterTree: + r"""Hierarchical spatial decomposition for Barnes-Hut kernel acceleration. + + Stores a binary radix tree over source points as flat GPU-compatible tensors. + The tree structure (positions, AABBs, children) is precomputable per mesh + geometry. Per-node source-data aggregates are recomputed whenever the source + features change (e.g., between communication hyperlayers). + + The tree supports both boundary face centroids and prediction point clouds + (same construction algorithm, same data structure). + + Attributes + ---------- + node_aabb_min : torch.Tensor + AABB minimum corner per node, shape ``(n_nodes, D)``. + node_aabb_max : torch.Tensor + AABB maximum corner per node, shape ``(n_nodes, D)``. + node_diameter_sq : torch.Tensor + Squared AABB diagonal per node, shape ``(n_nodes,)``. + node_left_child : torch.Tensor + Left child index per node, ``-1`` for leaves, shape ``(n_nodes,)``. + node_right_child : torch.Tensor + Right child index per node, ``-1`` for leaves, shape ``(n_nodes,)``. + leaf_start : torch.Tensor + Start offset into ``sorted_source_order`` for leaf nodes, + ``-1`` for internal nodes, shape ``(n_nodes,)``. + leaf_count : torch.Tensor + Number of sources in each leaf node, ``0`` for internal nodes, + shape ``(n_nodes,)``. + node_range_start : torch.Tensor + Start offset into ``sorted_source_order`` for ALL nodes (both + leaf and internal), shape ``(n_nodes,)``. Each node's subtree + covers a contiguous range in morton-sorted order. + node_range_count : torch.Tensor + Number of points in each node's subtree, shape ``(n_nodes,)``. + For leaves this equals ``leaf_count``; for internal nodes it + equals the sum of children's range counts. + node_total_area : torch.Tensor + Total source area in each node's subtree, shape ``(n_nodes,)``. + sorted_source_order : torch.Tensor + Morton-code-sorted permutation of source indices, + shape ``(n_sources,)``. + source_points : torch.Tensor + Original source point coordinates, shape ``(n_sources, D)``. + max_depth : torch.Tensor + Scalar tensor storing the tree depth (for fixed-iteration traversal). + leaf_node_ids : torch.Tensor + Indices of leaf nodes, shape ``(n_leaves,)``. Precomputed during + tree construction so ``compute_source_aggregates`` avoids a + data-dependent ``torch.where`` that would break ``torch.compile``. + leaf_seg_ids : torch.Tensor + Per-source compact leaf segment ID in sorted order, shape + ``(n_sources,)``. Maps each source to the index of its + containing leaf within ``leaf_node_ids``, used for segmented + reductions in ``compute_source_aggregates``. + """ + + node_aabb_min: torch.Tensor + node_aabb_max: torch.Tensor + node_diameter_sq: torch.Tensor + node_left_child: torch.Tensor + node_right_child: torch.Tensor + leaf_start: torch.Tensor + leaf_count: torch.Tensor + node_range_start: torch.Tensor + node_range_count: torch.Tensor + node_total_area: torch.Tensor + sorted_source_order: torch.Tensor + source_points: torch.Tensor + max_depth: torch.Tensor + internal_level_ids: torch.Tensor + internal_level_offsets: torch.Tensor + # internal_level_ids and internal_level_offsets store the tree's + # internal node IDs in CSR-packed level order (shallowest first). + # Computed once during from_points() and reused by all bottom-up + # propagation routines (_propagate_centroids_bottom_up, + # _compute_node_strengths) to avoid recomputing the BFS traversal + # that discovers this ordering. Stored as tensors (not a Python + # list) so they participate in tensorclass .to(device) moves. + leaf_node_ids: torch.Tensor + leaf_seg_ids: torch.Tensor + + @property + def n_nodes(self) -> int: + """Number of nodes in the tree.""" + return self.node_aabb_min.shape[0] + + @property + def n_sources(self) -> int: + """Number of source points.""" + return self.sorted_source_order.shape[0] + + @property + def n_spatial_dims(self) -> int: + """Spatial dimensionality.""" + return self.node_aabb_min.shape[1] + + @property + def n_leaves(self) -> int: + """Number of leaf nodes in the tree.""" + return self.leaf_node_ids.shape[0] + + @property + def internal_nodes_per_level(self) -> list[torch.Tensor]: + """Internal node IDs grouped by tree depth, shallowest first. + + Reconstructed from CSR-packed ``internal_level_ids`` and + ``internal_level_offsets`` tensors that are computed once during + tree construction in :meth:`from_points`. + """ + offsets = self.internal_level_offsets + return [ + self.internal_level_ids[offsets[i] : offsets[i + 1]] + for i in range(len(offsets) - 1) + ] + + @classmethod + def from_points( + cls, + points: Float[torch.Tensor, "n_points n_dims"], + *, + leaf_size: int = 1, + areas: Float[torch.Tensor, " n_points"] | None = None, + ) -> "ClusterTree": + r"""Build a cluster tree from a set of points via morton-code LBVH. + + Parameters + ---------- + points : Float[torch.Tensor, "n_points n_dims"] + Source point coordinates, shape :math:`(N, D)`. + leaf_size : int + Maximum sources per leaf node. Larger values produce shallower + trees (fewer traversal iterations) at the cost of more exact + near-field interactions per leaf hit. + areas : Float[torch.Tensor, "n_points"] or None + Per-source area weights used for aggregate computation. If + ``None``, all areas default to 1. + + Returns + ------- + ClusterTree + Constructed tree ready for traversal and aggregate computation. + """ + if leaf_size < 1: + raise ValueError(f"leaf_size must be >= 1, got {leaf_size=!r}") + + n_points = points.shape[0] + D = points.shape[1] + device = points.device + dtype = points.dtype + + if areas is None: + areas = torch.ones(n_points, device=device, dtype=dtype) + + ### Handle empty point set + if n_points == 0: + empty_long = torch.empty(0, dtype=torch.long, device=device) + return cls( + node_aabb_min=torch.empty((0, D), dtype=dtype, device=device), + node_aabb_max=torch.empty((0, D), dtype=dtype, device=device), + node_diameter_sq=torch.empty(0, dtype=dtype, device=device), + node_left_child=empty_long, + node_right_child=empty_long, + leaf_start=empty_long, + leaf_count=empty_long, + node_range_start=empty_long, + node_range_count=empty_long, + node_total_area=torch.empty(0, dtype=dtype, device=device), + sorted_source_order=empty_long, + source_points=points, + max_depth=torch.tensor(0, dtype=torch.long, device=device), + internal_level_ids=empty_long, + internal_level_offsets=torch.tensor([0], dtype=torch.long, device=device), + leaf_node_ids=empty_long, + leaf_seg_ids=empty_long, + batch_size=torch.Size([]), + ) + + ### Sort points by morton code for spatial coherence + with record_function("cluster_tree::morton_sort"): + morton_codes = _compute_morton_codes(points) + sorted_order = morton_codes.argsort(stable=True) # (n_points,) + sorted_points = points[sorted_order] # (n_points, D) + sorted_areas = areas[sorted_order] # (n_points,) + + ### Pre-allocate node storage. + # The midpoint split guarantees each child gets at least + # floor(parent_size / 2) sources, so the minimum leaf occupancy + # is ceil(leaf_size / 2). From that we bound the maximum number + # of leaves and apply the full-binary-tree identity (n_internal = + # n_leaves - 1) to get max_nodes. + min_per_leaf = max(1, (leaf_size + 1) // 2) + max_leaves = (n_points + min_per_leaf - 1) // min_per_leaf + max_nodes = max(1, 2 * max_leaves - 1) + + aabb_min_buf = torch.full( + (max_nodes, D), float("inf"), dtype=dtype, device=device + ) + aabb_max_buf = torch.full( + (max_nodes, D), float("-inf"), dtype=dtype, device=device + ) + left_child = torch.full((max_nodes,), -1, dtype=torch.long, device=device) + right_child = torch.full((max_nodes,), -1, dtype=torch.long, device=device) + leaf_start_buf = torch.full((max_nodes,), -1, dtype=torch.long, device=device) + leaf_count_buf = torch.zeros(max_nodes, dtype=torch.long, device=device) + range_start_buf = torch.zeros(max_nodes, dtype=torch.long, device=device) + range_count_buf = torch.zeros(max_nodes, dtype=torch.long, device=device) + total_area_buf = torch.zeros(max_nodes, dtype=dtype, device=device) + + # ----------------------------------------------------------- + # Phase 1: Top-down LBVH construction (O(log N) iterations) + # ----------------------------------------------------------- + with record_function("cluster_tree::top_down_build"): + seg_starts = torch.tensor([0], dtype=torch.long, device=device) + seg_ends = torch.tensor([n_points], dtype=torch.long, device=device) + seg_node_ids = torch.tensor([0], dtype=torch.long, device=device) + node_count = 1 + actual_depth = 0 + + internal_nodes_per_level: list[torch.Tensor] = [] + + while len(seg_starts) > 0: + seg_sizes = seg_ends - seg_starts + + ### Store the sorted-order range for ALL nodes at this level. + # Each node covers a contiguous range [seg_start, seg_end) + # in the morton-sorted order. Used by dual-tree traversal + # to expand node-level results to individual points. + range_start_buf[seg_node_ids] = seg_starts + range_count_buf[seg_node_ids] = seg_sizes + + ### Classify segments as leaf or internal + is_leaf_seg = seg_sizes <= leaf_size + is_internal_seg = ~is_leaf_seg + + ### Process leaf segments + leaf_indices = torch.where(is_leaf_seg)[0] + if len(leaf_indices) > 0: + leaf_nids = seg_node_ids[leaf_indices] + l_starts = seg_starts[leaf_indices] + l_sizes = seg_sizes[leaf_indices] + + leaf_start_buf[leaf_nids] = l_starts + leaf_count_buf[leaf_nids] = l_sizes + + # Compute leaf AABBs via segmented reduction + _fill_leaf_aabbs( + leaf_nids, + l_starts, + l_sizes, + sorted_points, + aabb_min_buf, + aabb_max_buf, + ) + + # Compute leaf total areas + _fill_leaf_total_areas( + leaf_nids, l_starts, l_sizes, sorted_areas, total_area_buf + ) + + ### Process internal segments: split at the midpoint of the + # morton-sorted range. Because morton codes preserve spatial + # locality, this approximates a spatial median split and produces + # a balanced binary tree in O(log N) iterations. + internal_indices = torch.where(is_internal_seg)[0] + if len(internal_indices) == 0: + break + + actual_depth += 1 + int_starts = seg_starts[internal_indices] + int_ends = seg_ends[internal_indices] + int_sizes = seg_sizes[internal_indices] + int_node_ids = seg_node_ids[internal_indices] + + midpoints = int_starts + int_sizes // 2 + + n_internal = len(internal_indices) + left_ids = ( + node_count + + torch.arange(n_internal, dtype=torch.long, device=device) * 2 + ) + right_ids = left_ids + 1 + node_count += 2 * n_internal + + left_child[int_node_ids] = left_ids + right_child[int_node_ids] = right_ids + internal_nodes_per_level.append(int_node_ids) + + seg_starts = torch.cat([int_starts, midpoints]) + seg_ends = torch.cat([midpoints, int_ends]) + seg_node_ids = torch.cat([left_ids, right_ids]) + + # ----------------------------------------------------------- + # Phase 2: Bottom-up AABB and area propagation + # ----------------------------------------------------------- + with record_function("cluster_tree::bottom_up_aabb"): + for level_node_ids in reversed(internal_nodes_per_level): + left = left_child[level_node_ids] + right = right_child[level_node_ids] + aabb_min_buf[level_node_ids] = torch.minimum( + aabb_min_buf[left], aabb_min_buf[right] + ) + aabb_max_buf[level_node_ids] = torch.maximum( + aabb_max_buf[left], aabb_max_buf[right] + ) + total_area_buf[level_node_ids] = ( + total_area_buf[left] + total_area_buf[right] + ) + + ### Compute squared AABB diagonals + aabb_min_trimmed = aabb_min_buf[:node_count] + aabb_max_trimmed = aabb_max_buf[:node_count] + diameter_sq = (aabb_max_trimmed - aabb_min_trimmed).pow(2).sum(dim=-1) + + ### Precompute leaf indices and per-source segment IDs so that + ### compute_source_aggregates() avoids a data-dependent + ### torch.where() that would break torch.compile tracing. + leaf_count_trimmed = leaf_count_buf[:node_count] + _leaf_node_ids = torch.where(leaf_count_trimmed > 0)[0] + _leaf_starts = leaf_start_buf[_leaf_node_ids] + _leaf_counts = leaf_count_trimmed[_leaf_node_ids] + _positions, _compact_ids = _ragged_arange( + _leaf_starts, _leaf_counts, total=n_points, + ) + _leaf_seg_ids = torch.zeros(n_points, dtype=torch.long, device=device) + _leaf_seg_ids[_positions] = _compact_ids + + logger.debug( + "ClusterTree: %d points -> %d nodes (%d leaves), " + "depth %d, leaf_size=%d", + n_points, node_count, _leaf_node_ids.shape[0], actual_depth, + leaf_size, + ) + + ### Pack the per-level internal node IDs into CSR tensors so they + ### survive as tensorclass attributes (device-safe, no BFS needed later). + _level_ids = ( + torch.cat(internal_nodes_per_level) + if internal_nodes_per_level + else torch.empty(0, dtype=torch.long, device=device) + ) + _level_lengths = torch.tensor( + [len(t) for t in internal_nodes_per_level], + dtype=torch.long, + device=device, + ) + _level_offsets = torch.cat([ + torch.zeros(1, dtype=torch.long, device=device), + _level_lengths.cumsum(0), + ]) + + return cls( + node_aabb_min=aabb_min_trimmed, + node_aabb_max=aabb_max_trimmed, + node_diameter_sq=diameter_sq, + node_left_child=left_child[:node_count], + node_right_child=right_child[:node_count], + leaf_start=leaf_start_buf[:node_count], + leaf_count=leaf_count_trimmed, + node_range_start=range_start_buf[:node_count], + node_range_count=range_count_buf[:node_count], + node_total_area=total_area_buf[:node_count], + sorted_source_order=sorted_order, + source_points=points, + max_depth=torch.tensor(actual_depth, dtype=torch.long, device=device), + internal_level_ids=_level_ids, + internal_level_offsets=_level_offsets, + leaf_node_ids=_leaf_node_ids, + leaf_seg_ids=_leaf_seg_ids, + batch_size=torch.Size([]), + ) + + def compute_source_aggregates( + self, + source_points: Float[torch.Tensor, "n_sources n_dims"], + areas: Float[torch.Tensor, " n_sources"], + source_data: TensorDict | None = None, + ) -> "SourceAggregates": + r"""Compute per-node aggregate source data for far-field approximation. + + Aggregates are area-weighted averages of source features within each + node's subtree. The total weight for each node is the sum of per-source + strengths (handled separately during kernel evaluation, not here). + + Parameters + ---------- + source_points : Float[torch.Tensor, "n_sources n_dims"] + Source coordinates, shape :math:`(N, D)`. + areas : Float[torch.Tensor, "n_sources"] + Per-source area weights, shape :math:`(N,)`. + source_data : TensorDict or None + Per-source features (normals, latents, etc.) with + ``batch_size=(N,)``. ``None`` if no per-source features. + + Returns + ------- + SourceAggregates + Per-node aggregated centroids and source data. + """ + if self.n_nodes == 0: + D = source_points.shape[1] + device = source_points.device + dtype = source_points.dtype + return SourceAggregates( + node_centroid=torch.empty((0, D), dtype=dtype, device=device), + node_source_data=None, + ) + + device = source_points.device + dtype = source_points.dtype + D = source_points.shape[1] + n_nodes = self.n_nodes + + ### Leaf aggregation: compute per-leaf centroids and source data. + ### leaf_node_ids and leaf_seg_ids were precomputed during tree + ### construction (from_points) to avoid data-dependent torch.where + ### and _ragged_arange calls that would break torch.compile. + with record_function("cluster_tree::leaf_aggregation"): + leaf_node_ids = self.leaf_node_ids + n_leaves = leaf_node_ids.shape[0] + seg_ids_compact = self.leaf_seg_ids + + sorted_points = source_points[self.sorted_source_order] + sorted_areas = areas[self.sorted_source_order] + + centroid_buf = torch.zeros(n_nodes, D, dtype=dtype, device=device) + + leaf_centroids = _segmented_weighted_sum( + sorted_points, sorted_areas, seg_ids_compact, n_leaves + ) + leaf_total_areas = self.node_total_area[leaf_node_ids] + safe_areas = leaf_total_areas.clamp(min=1e-30) + leaf_centroids = leaf_centroids / safe_areas.unsqueeze(-1) + centroid_buf[leaf_node_ids] = leaf_centroids + + node_source_data: TensorDict | None = None + if source_data is not None: + sorted_source_data = source_data[self.sorted_source_order] + node_source_data = _aggregate_source_data_leaves( + sorted_source_data, + sorted_areas, + seg_ids_compact, + n_leaves, + leaf_node_ids, + leaf_total_areas, + n_nodes, + device, + ) + + ### Bottom-up propagation: internal node centroids + with record_function("cluster_tree::bottom_up_propagation"): + _propagate_centroids_bottom_up( + centroid_buf, + node_source_data, + self.node_left_child, + self.node_right_child, + self.node_total_area, + self.internal_nodes_per_level, + ) + + return SourceAggregates( + node_centroid=centroid_buf, + node_source_data=node_source_data, + ) + + def find_dual_interaction_pairs( + self, + target_tree: "ClusterTree", + theta: float = 1.0, + *, + expand_far_targets: bool = False, + ) -> DualInteractionPlan: + r"""Find near-field and far-field pairs via dual-tree traversal. + + Traverses both the source tree (``self``) and ``target_tree`` + simultaneously. For well-separated node pairs, records a single + far-field (target_node, source_node) entry - the kernel is evaluated + ONCE at the node centroids and broadcast to all targets in the node. + This reduces far-field kernel evaluations from O(N log N) to O(N). + + Uses a combined AABB-distance opening criterion: + ``(D_T + D_S) / r < theta``, where D_T and D_S are the AABB + diagonals and r is the minimum distance between the two AABBs. + This accounts for approximation error on both the target and + source sides. + + Parameters + ---------- + target_tree : ClusterTree + Tree over target points. For self-interaction (communication + layers), this is the same object as ``self``. + theta : float + Barnes-Hut opening angle. Larger = more aggressive. + ``theta = 0`` forces all interactions to be exact. + expand_far_targets : bool, optional, default=False + If ``True``, far-field node pairs are expanded to individual + target points, converting ``(far, far)`` entries into + ``(near, far)`` entries. This eliminates the target-side + centroid approximation (and the blocky spatial artifacts it + produces) at the cost of more kernel evaluations while + preserving the source-side monopole speedup. + + Returns + ------- + DualInteractionPlan + Near-field individual pairs and far-field node-to-node pairs. + """ + source_tree = self + device = source_tree.node_aabb_min.device + theta_sq = theta * theta + + ### Handle empty trees + if source_tree.n_nodes == 0 or target_tree.n_nodes == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + return DualInteractionPlan( + near_target_ids=empty, + near_source_ids=empty.clone(), + far_target_node_ids=empty.clone(), + far_source_node_ids=empty.clone(), + nf_target_ids=empty.clone(), + nf_source_node_ids=empty.clone(), + fn_target_node_ids=empty.clone(), + fn_source_ids=empty.clone(), + fn_broadcast_targets=empty.clone(), + fn_broadcast_starts=empty.clone(), + fn_broadcast_counts=empty.clone(), + ) + + with record_function("cluster_tree::dual_traversal"): + ### Initialize: root-to-root pair + active_tgt_nodes = torch.zeros(1, dtype=torch.long, device=device) + active_src_nodes = torch.zeros(1, dtype=torch.long, device=device) + + near_target_list: list[torch.Tensor] = [] + near_source_list: list[torch.Tensor] = [] + far_tgt_node_list: list[torch.Tensor] = [] + far_src_node_list: list[torch.Tensor] = [] + nf_target_list: list[torch.Tensor] = [] + nf_source_node_list: list[torch.Tensor] = [] + fn_tgt_node_list: list[torch.Tensor] = [] + fn_src_list: list[torch.Tensor] = [] + fn_bcast_targets_list: list[torch.Tensor] = [] + fn_bcast_starts_list: list[torch.Tensor] = [] + fn_bcast_counts_list: list[torch.Tensor] = [] + fn_bcast_offset = 0 + + max_iters = int(target_tree.max_depth.item()) + int(source_tree.max_depth.item()) + 1 + depth = 0 + + for depth in range(max_iters): + if active_tgt_nodes.numel() == 0: + break + + ### Combined opening criterion: minimum AABB-to-AABB gap. + # For each dimension, the gap is the positive distance + # between the two boxes (zero if they overlap). + aabb_min_T = target_tree.node_aabb_min[active_tgt_nodes] + aabb_max_T = target_tree.node_aabb_max[active_tgt_nodes] + aabb_min_S = source_tree.node_aabb_min[active_src_nodes] + aabb_max_S = source_tree.node_aabb_max[active_src_nodes] + + gap = torch.clamp( + torch.maximum(aabb_min_T - aabb_max_S, aabb_min_S - aabb_max_T), + min=0, + ) + min_dist_sq = gap.pow(2).sum(dim=-1) + + diam_T = target_tree.node_diameter_sq[active_tgt_nodes].sqrt() + diam_S = source_tree.node_diameter_sq[active_src_nodes].sqrt() + combined_diam_sq = (diam_T + diam_S).pow(2) + + is_far = min_dist_sq * theta_sq > combined_diam_sq + + ### Classify active pairs + is_leaf_T = target_tree.leaf_count[active_tgt_nodes] > 0 + is_leaf_S = source_tree.leaf_count[active_src_nodes] > 0 + + ### 1. Far-field: well-separated node pairs + if is_far.any(): + if expand_far_targets: + # Expand target nodes to individual points, + # converting (far,far) → (near,far). + far_tgt_nids = active_tgt_nodes[is_far] + far_src_nids = active_src_nodes[is_far] + starts = target_tree.node_range_start[far_tgt_nids] + counts = target_tree.node_range_count[far_tgt_nids] + positions, pair_ids = _ragged_arange(starts, counts) + nf_target_list.append( + target_tree.sorted_source_order[positions] + ) + nf_source_node_list.append(far_src_nids[pair_ids]) + else: + far_tgt_node_list.append(active_tgt_nodes[is_far]) + far_src_node_list.append(active_src_nodes[is_far]) + + ### 2. Near-field, both leaves: two-stage filtered expansion. + # Stage 1 (per-target) -> (near,far). + # Stage 2 (per-source) -> (far,near). + # Remainder -> (near,near). + near_leaf_leaf = (~is_far) & is_leaf_T & is_leaf_S + if near_leaf_leaf.any(): + ( + nn_tgts, nn_srcs, + nf_tgts, nf_snids, + fn_tnids, fn_sids, + fn_btgts, fn_bstarts, fn_bcounts, + ) = _expand_dual_leaf_hits( + active_tgt_nodes[near_leaf_leaf], + active_src_nodes[near_leaf_leaf], + target_tree, + source_tree, + theta, + ) + near_target_list.append(nn_tgts) + near_source_list.append(nn_srcs) + nf_target_list.append(nf_tgts) + nf_source_node_list.append(nf_snids) + fn_tgt_node_list.append(fn_tnids) + fn_src_list.append(fn_sids) + fn_bcast_targets_list.append(fn_btgts) + fn_bcast_starts_list.append(fn_bstarts + fn_bcast_offset) + fn_bcast_counts_list.append(fn_bcounts) + fn_bcast_offset += fn_btgts.shape[0] + + ### 3. Need to split: at least one is internal, not far + need_split = (~is_far) & (~near_leaf_leaf) + if not need_split.any(): + break + + split_tgt = active_tgt_nodes[need_split] + split_src = active_src_nodes[need_split] + split_is_leaf_T = is_leaf_T[need_split] + split_is_leaf_S = is_leaf_S[need_split] + split_diam_sq_T = target_tree.node_diameter_sq[split_tgt] + split_diam_sq_S = source_tree.node_diameter_sq[split_src] + + ### Splitting decision: split the larger node. + # If equal (including self-interaction T==S), split both. + # If one side is a leaf, can only split the other. + do_split_T = (~split_is_leaf_T) & ( + split_is_leaf_S | (split_diam_sq_T >= split_diam_sq_S) + ) + do_split_S = (~split_is_leaf_S) & ( + split_is_leaf_T | (split_diam_sq_S >= split_diam_sq_T) + ) + + ### Generate child pairs for each split case + next_tgt_parts: list[torch.Tensor] = [] + next_src_parts: list[torch.Tensor] = [] + + # Case A: split T only (T internal, S leaf or T strictly larger) + case_T_only = do_split_T & (~do_split_S) + if case_T_only.any(): + t_ids = split_tgt[case_T_only] + s_ids = split_src[case_T_only] + left_T = target_tree.node_left_child[t_ids] + right_T = target_tree.node_right_child[t_ids] + for child_T in (left_T, right_T): + valid = child_T >= 0 + if valid.any(): + next_tgt_parts.append(child_T[valid]) + next_src_parts.append(s_ids[valid]) + + # Case B: split S only (S internal, T leaf or S strictly larger) + case_S_only = do_split_S & (~do_split_T) + if case_S_only.any(): + t_ids = split_tgt[case_S_only] + s_ids = split_src[case_S_only] + left_S = source_tree.node_left_child[s_ids] + right_S = source_tree.node_right_child[s_ids] + for child_S in (left_S, right_S): + valid = child_S >= 0 + if valid.any(): + next_tgt_parts.append(t_ids[valid]) + next_src_parts.append(child_S[valid]) + + # Case C: split both (both internal, equal diameter or T==S) + case_both = do_split_T & do_split_S + if case_both.any(): + t_ids = split_tgt[case_both] + s_ids = split_src[case_both] + left_T = target_tree.node_left_child[t_ids] + right_T = target_tree.node_right_child[t_ids] + left_S = source_tree.node_left_child[s_ids] + right_S = source_tree.node_right_child[s_ids] + for child_T in (left_T, right_T): + for child_S in (left_S, right_S): + valid = (child_T >= 0) & (child_S >= 0) + if valid.any(): + next_tgt_parts.append(child_T[valid]) + next_src_parts.append(child_S[valid]) + + if next_tgt_parts: + active_tgt_nodes = torch.cat(next_tgt_parts) + active_src_nodes = torch.cat(next_src_parts) + else: + break + + ### Concatenate accumulated pairs + if near_target_list: + near_tgt = torch.cat(near_target_list) + near_src = torch.cat(near_source_list) + else: + near_tgt = torch.empty(0, dtype=torch.long, device=device) + near_src = torch.empty(0, dtype=torch.long, device=device) + + if far_tgt_node_list: + far_tgt_nid = torch.cat(far_tgt_node_list) + far_src_nid = torch.cat(far_src_node_list) + else: + far_tgt_nid = torch.empty(0, dtype=torch.long, device=device) + far_src_nid = torch.empty(0, dtype=torch.long, device=device) + + if nf_target_list: + nf_tgt = torch.cat(nf_target_list) + nf_snid = torch.cat(nf_source_node_list) + else: + nf_tgt = torch.empty(0, dtype=torch.long, device=device) + nf_snid = torch.empty(0, dtype=torch.long, device=device) + + if fn_tgt_node_list: + fn_tnid = torch.cat(fn_tgt_node_list) + fn_sid = torch.cat(fn_src_list) + fn_btgts = torch.cat(fn_bcast_targets_list) + fn_bstarts = torch.cat(fn_bcast_starts_list) + fn_bcounts = torch.cat(fn_bcast_counts_list) + else: + fn_tnid = torch.empty(0, dtype=torch.long, device=device) + fn_sid = torch.empty(0, dtype=torch.long, device=device) + fn_btgts = torch.empty(0, dtype=torch.long, device=device) + fn_bstarts = torch.empty(0, dtype=torch.long, device=device) + fn_bcounts = torch.empty(0, dtype=torch.long, device=device) + + ### Sort near pairs by source index for coalesced gather + if near_src.numel() > 0: + sort_order = near_src.argsort(stable=True) + near_tgt = near_tgt[sort_order] + near_src = near_src[sort_order] + + ### Sort far pairs by source node for coalesced aggregate gather + if far_src_nid.numel() > 0: + sort_order = far_src_nid.argsort(stable=True) + far_tgt_nid = far_tgt_nid[sort_order] + far_src_nid = far_src_nid[sort_order] + + ### Sort (near,far) pairs by source node for coalesced gather + if nf_snid.numel() > 0: + sort_order = nf_snid.argsort(stable=True) + nf_tgt = nf_tgt[sort_order] + nf_snid = nf_snid[sort_order] + + ### Sort (far,near) pairs by source index for coalesced gather + if fn_sid.numel() > 0: + sort_order = fn_sid.argsort(stable=True) + fn_tnid = fn_tnid[sort_order] + fn_sid = fn_sid[sort_order] + fn_bstarts = fn_bstarts[sort_order] + fn_bcounts = fn_bcounts[sort_order] + + plan = DualInteractionPlan( + near_target_ids=near_tgt, + near_source_ids=near_src, + far_target_node_ids=far_tgt_nid, + far_source_node_ids=far_src_nid, + nf_target_ids=nf_tgt, + nf_source_node_ids=nf_snid, + fn_target_node_ids=fn_tnid, + fn_source_ids=fn_sid, + fn_broadcast_targets=fn_btgts, + fn_broadcast_starts=fn_bstarts, + fn_broadcast_counts=fn_bcounts, + ) + + if not torch.compiler.is_compiling(): + plan.validate() + + is_self = target_tree is self + logger.debug( + "dual traversal: %d near + %d nf + %d fn + %d far_node pairs, " + "theta=%.2f, self_interaction=%s, %d iterations", + plan.n_near, plan.n_nf, plan.n_fn, plan.n_far_nodes, + theta, is_self, depth, + ) + + return plan + + +# --------------------------------------------------------------------------- +# SourceAggregates: per-node aggregate data for far-field approximation +# --------------------------------------------------------------------------- + + +@tensorclass +class SourceAggregates: + """Per-node aggregated source data for far-field monopole approximation. + + Computed by :meth:`ClusterTree.compute_source_aggregates` and consumed + by :class:`BarnesHutKernel` during kernel evaluation. + """ + + node_centroid: Float[torch.Tensor, "n_nodes n_dims"] + """Area-weighted centroid per node.""" + + node_source_data: TensorDict | None + """Area-weighted average source features per node, or ``None`` if no + per-source features. Has ``batch_size=(n_nodes,)``.""" + + +# --------------------------------------------------------------------------- +# Internal helpers for tree construction +# --------------------------------------------------------------------------- + + +def _fill_leaf_aabbs( + leaf_nids: Int[torch.Tensor, " n_leaves"], + leaf_starts: Int[torch.Tensor, " n_leaves"], + leaf_sizes: Int[torch.Tensor, " n_leaves"], + sorted_points: Float[torch.Tensor, "n_sorted_sources n_dims"], + aabb_min_buf: Float[torch.Tensor, "n_nodes n_dims"], + aabb_max_buf: Float[torch.Tensor, "n_nodes n_dims"], +) -> None: + """Fill AABB buffers for leaf nodes via segmented reduction (in-place).""" + device = leaf_nids.device + D = sorted_points.shape[1] + dtype = sorted_points.dtype + n_leaves = leaf_nids.shape[0] + total = int(leaf_sizes.sum()) + + if total == 0 or n_leaves == 0: + return + + positions, seg_ids = _ragged_arange(leaf_starts, leaf_sizes) + pts = sorted_points[positions] # (total, D) + + seg_min = torch.full((n_leaves, D), float("inf"), dtype=dtype, device=device) + seg_max = torch.full((n_leaves, D), float("-inf"), dtype=dtype, device=device) + exp_ids = seg_ids.unsqueeze(1).expand_as(pts) + seg_min.scatter_reduce_(0, exp_ids, pts, reduce="amin", include_self=True) + seg_max.scatter_reduce_(0, exp_ids, pts, reduce="amax", include_self=True) + + aabb_min_buf[leaf_nids] = seg_min + aabb_max_buf[leaf_nids] = seg_max + + +def _fill_leaf_total_areas( + leaf_nids: Int[torch.Tensor, " n_leaves"], + leaf_starts: Int[torch.Tensor, " n_leaves"], + leaf_sizes: Int[torch.Tensor, " n_leaves"], + sorted_areas: Float[torch.Tensor, " n_sorted_sources"], + total_area_buf: Float[torch.Tensor, " n_nodes"], +) -> None: + """Compute total area per leaf node (in-place).""" + device = leaf_nids.device + n_leaves = leaf_nids.shape[0] + total = int(leaf_sizes.sum()) + + if total == 0 or n_leaves == 0: + return + + positions, seg_ids = _ragged_arange(leaf_starts, leaf_sizes) + areas = sorted_areas[positions] + + leaf_areas = torch.zeros(n_leaves, dtype=areas.dtype, device=device) + leaf_areas.scatter_add_(0, seg_ids, areas) + + total_area_buf[leaf_nids] = leaf_areas + + +def _aggregate_source_data_leaves( + sorted_source_data: TensorDict, + sorted_areas: Float[torch.Tensor, " n_sorted_sources"], + seg_ids: Int[torch.Tensor, " n_sorted_sources"], + n_leaves: int, + leaf_node_ids: Int[torch.Tensor, " n_leaves"], + leaf_total_areas: Float[torch.Tensor, " n_leaves"], + n_nodes: int, + device: torch.device, +) -> TensorDict: + """Compute area-weighted average source data for leaf nodes. + + Returns a TensorDict with ``batch_size=(n_nodes,)`` where only + leaf entries are populated (internal nodes are zeros, filled by + bottom-up propagation). + """ + safe_areas = leaf_total_areas.clamp(min=1e-30) + + def _aggregate_leaf(tensor: torch.Tensor) -> torch.Tensor: + trailing_shape = tensor.shape[1:] + flat = tensor.reshape(tensor.shape[0], -1) # (n_sorted_sources, F) + + weighted_sum = _segmented_weighted_sum( + flat, sorted_areas, seg_ids, n_leaves + ) + avg = weighted_sum / safe_areas.unsqueeze(-1) + + out = torch.zeros( + (n_nodes,) + trailing_shape, + dtype=tensor.dtype, + device=device, + ) + out_flat = out.reshape(n_nodes, -1) + out_flat[leaf_node_ids] = avg + return out.reshape((n_nodes,) + trailing_shape) + + return sorted_source_data.apply(_aggregate_leaf, batch_size=[n_nodes]) + + +### Disabled for torch.compile: this function iterates over a +### variable-length list (depth_levels), whose length equals the tree +### depth. Dynamo unrolls this loop and specializes on the length, +### causing recompilation every time a new tree depth is encountered +### (each airfoil mesh produces a different-depth tree). Disabling +### compilation here produces one clean graph break at the function +### boundary instead of per-depth-level recompilation storms. +@torch.compiler.disable +def _propagate_centroids_bottom_up( + centroid_buf: Float[torch.Tensor, "n_nodes n_dims"], + node_source_data: TensorDict | None, + left_child: Int[torch.Tensor, " n_nodes"], + right_child: Int[torch.Tensor, " n_nodes"], + total_area: Float[torch.Tensor, " n_nodes"], + depth_levels: list[torch.Tensor], +) -> None: + """Propagate centroids and source data from leaves to root (in-place). + + Internal node centroid = area-weighted average of its children's centroids. + Internal node source data = area-weighted average of its children's data. + + Parameters + ---------- + centroid_buf : Float[torch.Tensor, "n_nodes n_dims"] + Buffer of per-node centroids (leaf values pre-filled, internal values + written by this function). + node_source_data : TensorDict or None + Per-node source data to propagate (same structure as centroid_buf). + left_child : Int[torch.Tensor, "n_nodes"] + Left child index per node (-1 for leaves). + right_child : Int[torch.Tensor, "n_nodes"] + Right child index per node (-1 for leaves). + total_area : Float[torch.Tensor, "n_nodes"] + Total source area in each node's subtree. + depth_levels : list[torch.Tensor] + Internal node IDs grouped by tree depth (shallowest first), + from :attr:`ClusterTree.internal_nodes_per_level`. + """ + for level_ids in reversed(depth_levels): + left = left_child[level_ids] + right = right_child[level_ids] + + left_area = total_area[left] + right_area = total_area[right] + total = (left_area + right_area).clamp(min=1e-30) + + # 1D base weights; each consumer unsqueezes as needed for its rank + w_left_1d = left_area / total # (n,) + w_right_1d = right_area / total # (n,) + + centroid_buf[level_ids] = ( + centroid_buf[left] * w_left_1d.unsqueeze(-1) + + centroid_buf[right] * w_right_1d.unsqueeze(-1) + ) + + if node_source_data is not None: + for key in node_source_data.keys(include_nested=True, leaves_only=True): + val_left = node_source_data[key][left] + val_right = node_source_data[key][right] + w_l = w_left_1d + w_r = w_right_1d + while w_l.ndim < val_left.ndim: + w_l = w_l.unsqueeze(-1) + w_r = w_r.unsqueeze(-1) + node_source_data[key][level_ids] = ( + val_left * w_l + val_right * w_r + ) diff --git a/physicsnemo/experimental/models/globe/field_kernel.py b/physicsnemo/experimental/models/globe/field_kernel.py index 2c517280d2..7f7cda1d36 100644 --- a/physicsnemo/experimental/models/globe/field_kernel.py +++ b/physicsnemo/experimental/models/globe/field_kernel.py @@ -19,13 +19,13 @@ import operator from functools import cached_property, reduce from math import ceil, comb, prod -from typing import Literal, Sequence +from typing import TYPE_CHECKING, Literal, Sequence import torch import torch.nn as nn -import tqdm from jaxtyping import Float from tensordict import TensorDict +from torch.profiler import record_function from torch.utils.checkpoint import checkpoint from physicsnemo.core.module import Module @@ -46,9 +46,15 @@ smooth_log, spherical_basis, ) -from physicsnemo.utils.logging import PythonLogger -logger = PythonLogger("globe.field_kernel") +logger = logging.getLogger("globe.field_kernel") + +if TYPE_CHECKING: + from physicsnemo.experimental.models.globe.cluster_tree import ( + ClusterTree, + DualInteractionPlan, + SourceAggregates, + ) class Kernel(Module): @@ -194,6 +200,42 @@ def __init__( f"Invalid network type: {network_type=!r}; must be one of ['pade', 'mlp']" ) + @cached_property + def _floats_per_interaction(self) -> int: + """Identifiable float allocations per (target, source) interaction. + + Counts tensor elements from feature engineering, MLP evaluation, + and post-processing that coexist at peak during ``Kernel.forward``. + Used by :class:`BarnesHutKernel` to estimate chunk memory budgets. + + This is a lower bound - the actual peak is higher due to autograd + saving input tensors for backward through each element-wise + operation. The caller applies a runtime multiplier to account for + this (see ``BarnesHutKernel._auto_chunk_size``). + """ + source_rc = rank_counts(self.source_data_ranks) + global_rc = rank_counts(self.global_data_ranks) + n_vec = 1 + source_rc[1] + global_rc[1] + n_pairs = comb(n_vec, 2) + + return ( + ### Feature engineering: spatial vectors (n_targets, n_sources, 3, ...) + 3 # r = target - source + + 3 * n_vec * 2 # vectors + unit vectors + ### Feature engineering: scalars (n_targets, n_sources, ...) + + n_vec * 3 # magnitudes: squared, raw, log + + n_pairs * (1 + 2 * self.n_spherical_harmonics) # cos_theta + harmonics + products + + self.network_in_features # concatenated MLP input + ### MLP layers (sequential; peak is largest layer plus I/O) + + self.network_in_features + + sum(self.hidden_layer_sizes) + + self.network_out_features + ### Post-processing + + self.network_out_features # reshaped output + + 1 # far-field r_mag_sq + + self.n_spatial_dims * max(1, 2 * n_vec - 1) # basis vectors + ) + @cached_property def network_in_features(self) -> int: r"""Number of input features for the kernel's internal network. @@ -436,6 +478,7 @@ def forward( ) ### Assemble inputs to the neural network + interaction_dims = torch.Size([n_targets, n_sources]) scalars = TensorDict( { "source_scalars": source_scalars.expand( @@ -445,7 +488,7 @@ def forward( n_targets, n_sources, *global_scalars.batch_size ), }, - batch_size=torch.Size([n_targets, n_sources]), + batch_size=interaction_dims, device=device, ) @@ -460,18 +503,71 @@ def forward( torch.Size([n_targets, n_sources]) + global_vectors.batch_size ), }, - batch_size=torch.Size([n_targets, n_sources, self.n_spatial_dims]), + batch_size=interaction_dims + torch.Size([self.n_spatial_dims]), device=device, ) vectors["r"] = ( - target_points[:, None, :] # shape (n_targets, 1, n_dims) - - source_points[None, :, :] # shape (1, n_sources, n_dims) - ) / reference_length # shape (n_targets, n_sources, n_dims) - - # At this point, cast to the autocast dtype if possible and we're - # currently in an autocast context. This saves tons of memory, and - # really the only reason we needed to keep fp32 up to this point was to - # prevent catastrophic cancellation on the `r` vector computation. + target_points[:, None, :] # (n_targets, 1, n_dims) + - source_points[None, :, :] # (1, n_sources, n_dims) + ) / reference_length # (n_targets, n_sources, n_dims) + + ### Core feature engineering, network evaluation, and post-processing + result = self._evaluate_interactions( + scalars=scalars, + vectors=vectors, + device=device, + ) + + ### Aggregate over sources, weighted by source strengths + final_result = TensorDict( + { + k: torch.einsum( + "ts...,s->t...", + v, + source_strengths, + ) + for k, v in result.items() + }, + batch_size=torch.Size([n_targets]), + device=device, + ) + + return final_result + + def _evaluate_interactions( + self, + *, + scalars: TensorDict[str, Float[torch.Tensor, "*interaction_dims"]], + vectors: TensorDict[str, Float[torch.Tensor, "*interaction_dims n_spatial_dims"]], + device: torch.device, + ) -> TensorDict[str, Float[torch.Tensor, "*interaction_dims"]]: + r"""Core kernel computation: feature engineering, network, and post-processing. + + Operates on pre-assembled interaction feature tensors with arbitrary + leading batch dimensions. Both ``Kernel.forward()`` (with dense + ``(N_{tgt}, N_{src})`` interactions) and ``BarnesHutKernel`` (with + sparse ``(N_{pairs},)`` interactions) call this method. + + Parameters + ---------- + scalars : TensorDict + Scalar features with ``batch_size=(*interaction_dims,)``. + Must contain ``"source_scalars"`` and ``"global_scalars"`` sub-dicts. + vectors : TensorDict + Vector features with ``batch_size=(*interaction_dims, D)``. + Must contain ``"r"`` (displacement), ``"source_vectors"``, and + ``"global_vectors"`` sub-dicts. All values must be dimensionless. + device : torch.device + Device for tensor allocation. + + Returns + ------- + TensorDict[str, Float[torch.Tensor, "..."]] + Per-interaction output fields with ``batch_size=(*interaction_dims,)``. + NOT aggregated over sources. Scalar fields have shape + ``(*interaction_dims,)``, vector fields ``(*interaction_dims, D)``. + """ + # Cast to autocast dtype after the fp32-critical r computation if torch.is_autocast_enabled(device.type): dtype = torch.get_autocast_dtype(device.type) scalars = scalars.to(dtype=dtype) @@ -482,47 +578,51 @@ def forward( smoothing_radius = torch.tensor( self.smoothing_radius, device=device, dtype=dtype ) - vectors_mag_squared: TensorDict = ( # ty: ignore[invalid-assignment] - (vectors * vectors).sum(dim=-1).apply(lambda x: x + smoothing_radius**2) - ) - vectors_mag = vectors_mag_squared.sqrt() - vectors_hat = vectors / vectors_mag.unsqueeze(-1) - vectors_log_mag = smooth_log(vectors_mag) - - # Each of the vectors' magnitudes become an input feature - scalars["vectors_log_mag"] = vectors_log_mag - - # TODO in 3D, add cross products of pairs of vectors as input features - - ### Now, engineer some features from pairs of vectors - keypairs = list(itertools.combinations(range(concatenated_length(vectors)), 2)) - k1, k2 = zip(*keypairs) if keypairs else ([], []) - vectors_hat_concatenated: torch.Tensor = concatenate_leaves(vectors_hat) - # shape: (n_targets, n_sources, n_spatial_dims, n_vectors_in) - - v1_hat = vectors_hat_concatenated[:, :, :, k1] - v2_hat = vectors_hat_concatenated[:, :, :, k2] - cos_theta_pairs = torch.sum(v1_hat * v2_hat, dim=-2) - # shape: (n_targets, n_sources, len(keypairs)) - - # [1:] skips P_0(x) = 1 (constant), which carries no angular information - spherical_harmonics: list[torch.Tensor] = legendre_polynomials( - x=cos_theta_pairs, n=self.n_spherical_harmonics + 1 - )[1:] - - vectors_mag_concatenated: torch.Tensor = concatenate_leaves(vectors_mag) - v1_mag = vectors_mag_concatenated[:, :, k1] - v2_mag = vectors_mag_concatenated[:, :, k2] - - for i, harmonics in enumerate(spherical_harmonics): - scalars[f"pairwise_spherical_harmonics_{i}"] = ( - smooth_log(v1_mag * v2_mag) * harmonics + + ### Vector magnitude, direction, and log-magnitude features + with record_function("kernel::feature_engineering"): + vectors_mag_squared: TensorDict = ( + (vectors * vectors).sum(dim=-1).apply(lambda x: x + smoothing_radius**2) ) + vectors_mag = vectors_mag_squared.sqrt() + vectors_hat = vectors / vectors_mag.unsqueeze(-1) + vectors_log_mag = smooth_log(vectors_mag) + + # Each of the vectors' magnitudes become an input feature + scalars["vectors_log_mag"] = vectors_log_mag + + # TODO in 3D, add cross products of pairs of vectors as input features + + ### Pairwise spherical harmonic features from vector pairs + keypairs = list(itertools.combinations(range(concatenated_length(vectors)), 2)) + k1, k2 = zip(*keypairs) if keypairs else ([], []) + vectors_hat_concatenated: torch.Tensor = concatenate_leaves(vectors_hat) + # shape: (*interaction_dims, n_spatial_dims, n_vectors_in) + + v1_hat = vectors_hat_concatenated[..., :, k1] + v2_hat = vectors_hat_concatenated[..., :, k2] + cos_theta_pairs = torch.sum(v1_hat * v2_hat, dim=-2) + # shape: (*interaction_dims, len(keypairs)) + + # [1:] skips P_0(x) = 1 (constant), which carries no angular information + spherical_harmonics: list[torch.Tensor] = legendre_polynomials( + x=cos_theta_pairs, n=self.n_spherical_harmonics + 1 + )[1:] + + vectors_mag_concatenated: torch.Tensor = concatenate_leaves(vectors_mag) + v1_mag = vectors_mag_concatenated[..., k1] + v2_mag = vectors_mag_concatenated[..., k2] + + for i, harmonics in enumerate(spherical_harmonics): + scalars[f"pairwise_spherical_harmonics_{i}"] = ( + smooth_log(v1_mag * v2_mag) * harmonics + ) - cat_input_tensors: torch.Tensor = concatenate_leaves(scalars) - # shape (n_targets, n_sources, self.network_in_features) + cat_input_tensors: torch.Tensor = concatenate_leaves(scalars) + del scalars + # shape: (*interaction_dims, self.network_in_features) - ### Evaluate the neural-network-based field kernel function + ### Validate and evaluate the neural network if not torch.compiler.is_compiling(): if not cat_input_tensors.shape[-1] == self.network_in_features: raise RuntimeError( @@ -530,253 +630,189 @@ def forward( f"This is due to a shape inconsistency between the `network_in_features` and `forward` methods of the {self.__class__.__name__!r} class." ) - flattened_input = cat_input_tensors.view( - n_targets * n_sources, self.network_in_features - ) + interaction_dims = cat_input_tensors.shape[:-1] + flattened_input = cat_input_tensors.reshape(prod(interaction_dims), self.network_in_features) - if self.training and self.use_gradient_checkpointing: - flattened_output = checkpoint( - self.network, flattened_input, use_reentrant=False - ) # shape (n_targets * n_sources, last_layer_size) - else: + with record_function("kernel::network"): flattened_output = self.network(flattened_input) - output = flattened_output.view(n_targets, n_sources, self.network_out_features) - - ### Enforces correct far-field decay rate - r_mag_sq: torch.Tensor = vectors_mag_squared["r"] # ty: ignore[invalid-assignment] - output = output * ( - -torch.expm1(-r_mag_sq[..., None]) - ) # Lamb-Oseen vortex kernel, numerically stable using expm1 - if self.n_spatial_dims == 2: - output = output / (r_mag_sq[..., None] + 1).sqrt() - elif self.n_spatial_dims == 3: - output = output / (r_mag_sq[..., None] + 1) - else: - output = output / (r_mag_sq[..., None] + 1) ** ( - (self.n_spatial_dims - 1) / 2 - ) - - ### Add semantics to the output - n_vectors_in = len(vectors.keys(include_nested=True, leaves_only=True)) - result: TensorDict[str, Float[torch.Tensor, "..."]] = self.add_semantics( - output, - shape_for_scalars=torch.Size([]), - shape_for_vectors=torch.Size( - [ - 1 # r_hat - + 2 * (n_vectors_in - 1), # All non-r vectors - ] - ), - ) - # Values are tensors of shape (n_targets, n_sources, field_dim), where - # field_dim is taken from the `size_for_` arguments above. - - ### Vector Reprojection - # If there are any vector fields, we want to interpret them as a vector - # field on a local basis defined by the vectors we already have - this - # preserves rotational invariance. - - ranks_dict = flatten_rank_spec(self.output_field_ranks) - vector_reprojection_needed = any( - rank == 1 for rank in ranks_dict.values() - ) - - if vector_reprojection_needed: - ### Compute the local basis vectors - # Note that each combination of source and target points yields its - # own basis. In both 2D and 3D, we take the axis of the coordinate - # system used to generate the basis vectors to be `source_vectors` - - # a convenient source of non-arbitrary direction. We then repeat - # this for each source vector, and stack them together. - - # This is effectively an expanded version of a Helmholtz - # decomposition for vector fields: each field is the sum of a - # uniform field, a source field, a solenoidal field, and a - # dipole-like field. - - basis_vector_components: list[torch.Tensor] = [] - # Eventually, this is a list of length 3 * n_source_vectors (in both - # 2D and 3D) with tensors of shape (n_targets, n_sources, - # n_spatial_dims) - - basis_vector_components.append(vectors_hat["r"]) - - for k in vectors.keys(include_nested=True, leaves_only=True): - if k == "r": - continue - - scale: torch.Tensor = vectors_log_mag[k][..., None] # ty: ignore[invalid-assignment] - - basis_vector_components.append(scale * vectors_hat[k]) - - if self.n_spatial_dims == 2: - # In 2D, we use a polar/dipole basis: e_r is radial, e_theta - # is tangential (orthogonal to e_r), and e_kappa is a - # dipole-like direction (orthogonal to e_r, parallel to - # e_theta). This basis is not a true vector basis (it has 3 - # vectors, not 2), but this third basis vector increases - # expressivity. - _, e_theta, e_kappa = polar_and_dipole_basis( - r_hat=vectors_hat["r"], - n_hat=vectors_hat[k], - normalize_basis_vectors=False, - ) # shape (n_targets, n_sources, 2) - - basis_vector_components.extend( - [ - # scale * e_theta, # Vortex-like direction - scale * e_kappa, # Dipole-like direction - ] - ) - - elif self.n_spatial_dims == 3: - # In 3D, we use a modified spherical coordinate basis: e_r - # is radial, e_theta is the polar / dipole-like / "latitude" - # direction, and e_phi is the azimuthal / vortex-like / - # "longitude" direction. - _, e_theta, e_phi = spherical_basis( - r_hat=vectors_hat["r"], - n_hat=vectors_hat[k], - normalize_basis_vectors=False, - ) # Shape of each: (n_targets, n_sources, 3) - - basis_vector_components.extend( - [ - scale * e_theta, # Polar / meridional direction - # scale * e_phi, # Vortex-like / azimuthal direction - ] - ) - - else: - raise NotImplementedError( - f"The {self.__class__.__name__!r} class does not support {self.n_spatial_dims=!r}-dimensional problems." - ) - - basis_vectors = torch.stack(basis_vector_components, dim=-1) - # shape (n_targets, n_sources, n_spatial_dims, 4 * n_vectors) - - ### Now, reproject each vector field onto the basis vectors - for field_name, rank in ranks_dict.items(): - if rank == 1: - # # ORIGINAL (SLOW) - keeping for reference, as this is the most readable version - # # Axes: t = target, s = source, d = dim, b = basis vector id - # result[field_name] = torch.einsum( - # "tsb,tsdb->tsd", - # result[field_name], - # basis_vectors, - # ) - - # OPTIMIZED VERSION: Manual broadcast matrix-vector multiplication - # This is ~16x faster than the original einsum and uses less memory - result[field_name] = torch.sum( - basis_vectors - * result[field_name].unsqueeze(-2), # Broadcasting - dim=-1, - ) - - # Incorporate the source strengths and sum over all source points - # Axes: t = target, s = source, ... = all remaining dimensions (i.e., n_spatial_dims for vectors, nothing for scalars) - final_result = TensorDict( - { - k: torch.einsum( - "ts...,s->t...", - v, - source_strengths, + output = flattened_output.reshape(*interaction_dims, self.network_out_features) + + ### Far-field decay envelope and vector reprojection + with record_function("kernel::postprocess"): + r_mag_sq: torch.Tensor = vectors_mag_squared["r"] + output = output * ( + -torch.expm1(-r_mag_sq[..., None]) + ) # Lamb-Oseen vortex kernel, numerically stable via expm1 + if self.n_spatial_dims == 2: + output = output / (r_mag_sq[..., None] + 1).sqrt() + elif self.n_spatial_dims == 3: + output = output / (r_mag_sq[..., None] + 1) + else: + output = output / (r_mag_sq[..., None] + 1) ** ( + (self.n_spatial_dims - 1) / 2 ) - for k, v in result.items() - }, - batch_size=torch.Size([n_targets]), - device=device, - ) - return final_result + ### Add field-name semantics to the flat output channels + n_vectors_in = len(vectors.keys(include_nested=True, leaves_only=True)) + result: TensorDict[str, Float[torch.Tensor, "..."]] = self.add_semantics( + output, + shape_for_scalars=torch.Size([]), + shape_for_vectors=torch.Size( + [ + 1 # r_hat + + 2 * (n_vectors_in - 1), # All non-r vectors + ] + ), + ) + ### Vector reprojection onto local rotationally-equivariant basis + ranks_dict = flatten_rank_spec(self.output_field_ranks) + vector_reprojection_needed = any( + rank == 1 for rank in ranks_dict.values() + ) -class ChunkedKernel(Kernel): - r"""Memory-efficient kernel evaluation through automatic target point chunking. + if vector_reprojection_needed: + # Helmholtz-like decomposition: each vector field is expressed in a + # local basis derived from the input vectors (r_hat, source vectors, + # and their derived dipole/polar/spherical directions). + basis_vector_components: list[torch.Tensor] = [] + + basis_vector_components.append(vectors_hat["r"]) + + for k in vectors.keys(include_nested=True, leaves_only=True): + if k == "r": + continue + + scale: torch.Tensor = vectors_log_mag[k][..., None] + + basis_vector_components.append(scale * vectors_hat[k]) + + if self.n_spatial_dims == 2: + _, e_theta, e_kappa = polar_and_dipole_basis( + r_hat=vectors_hat["r"], + n_hat=vectors_hat[k], + normalize_basis_vectors=False, + ) + basis_vector_components.append(scale * e_kappa) + + elif self.n_spatial_dims == 3: + _, e_theta, e_phi = spherical_basis( + r_hat=vectors_hat["r"], + n_hat=vectors_hat[k], + normalize_basis_vectors=False, + ) + basis_vector_components.append(scale * e_theta) + + else: + raise NotImplementedError( + f"The {self.__class__.__name__!r} class does not support {self.n_spatial_dims=!r}-dimensional problems." + ) + + basis_vectors = torch.stack(basis_vector_components, dim=-1) + + for field_name, rank in ranks_dict.items(): + if rank == 1: + result[field_name] = torch.sum( + basis_vectors + * result[field_name].unsqueeze(-2), + dim=-1, + ) - :class:`ChunkedKernel` extends the base :class:`Kernel` class with chunking - capabilities that enable memory-efficient evaluation on large target point sets. - The kernel evaluation has ``O(n_sources * n_targets)`` memory complexity due to - the all-to-all pairwise computation, which can exhaust GPU memory for large - problems. Chunking processes target points in smaller batches, trading modest - computational overhead for dramatic memory reduction. + return result - Chunking is particularly useful in three scenarios: - 1. **Training**: When using downsampled query points (e.g., 4096 points) but many - source faces, chunking can reduce memory during the backward pass. - 2. **Inference on dense grids**: When evaluating on complete high-resolution volume - meshes (e.g., 100k+ points), chunking prevents out-of-memory errors. - 3. **Limited GPU memory**: When running on GPUs with constrained memory (e.g., during - development or deployment on smaller hardware). +class BarnesHutKernel(Kernel): + r"""Tree-accelerated kernel evaluation via Barnes-Hut monopole approximation. - The chunking is implemented at the target point dimension, so each chunk independently - computes its output from all source points, then results are concatenated. This is - numerically identical to non-chunked evaluation - there are no approximations. + Reduces the :math:`O(N_{src} \cdot N_{tgt})` cost of the all-to-all kernel + evaluation to :math:`O((N_{src} + N_{tgt}) \log N_{src})` by building a + spatial cluster tree over source points and using aggregate (monopole) + representations for distant clusters. - Chunk size selection: + For each target point, sources are classified as either: - - ``chunk_size=None``: No chunking, fastest but highest memory (default for small - problems) - - ``chunk_size="auto"``: Automatically determines size targeting ~1GB per chunk - - ``chunk_size=int``: Manual specification for fine control + - **Near-field**: within the opening-angle threshold, evaluated exactly + using the underlying :class:`Kernel`'s neural network. + - **Far-field**: beyond the threshold, approximated by evaluating the + same network with the cluster's area-weighted centroid, average normal, + and average features as a "virtual source." - The ``"auto"`` mode estimates memory based on network layer sizes and interaction - count, providing a good balance for most use cases. The implementation uses recursive - calls to handle the chunking logic, and the overhead is minimal for reasonable chunk - sizes. + Both near- and far-field interactions are accumulated into a single batch + and evaluated in one call to :meth:`Kernel._evaluate_interactions`, + minimizing kernel launch overhead ("accumulate pairs, evaluate once"). - Inherits all other functionality from :class:`Kernel`, including invariant feature - engineering, Pade-approximant networks, far-field decay, and equivariant vector - reprojection. + The ``ClusterTree`` spatial structure can be precomputed per mesh geometry + and reused across kernel branches and hyperlayers. The + ``DualInteractionPlan`` can be cached when targets equal sources + (communication hyperlayers). Parameters ---------- Inherits all parameters from :class:`Kernel`. + leaf_size : int, optional, default=1 + Maximum sources per tree leaf node. Larger values produce shallower + trees (fewer traversal iterations) at the cost of more exact + interactions per leaf. + Forward ------- - Same parameters as :class:`Kernel`, with the addition of: - - chunk_size : None or int or {"auto"}, optional, default="auto" - Controls chunking behavior. ``"auto"`` determines chunk size targeting - ~1GB per chunk. An integer processes in exact chunk sizes. ``None`` - evaluates all at once. + Same parameters as :class:`Kernel`, with additions: + + theta : float, optional, default=1.0 + Barnes-Hut opening angle. A node is approximated when + ``D/r < theta``. Larger values are more aggressive (more + approximation, faster). At ``theta = 0``, all interactions + are exact. + cluster_tree : ClusterTree or None, optional, default=None + Precomputed spatial tree over source points. If ``None``, built + from ``source_points`` on each call. + dual_plan : DualInteractionPlan or None, optional, default=None + Precomputed dual traversal plan. If ``None``, computed from the + trees and target points on each call. + source_areas : Float[torch.Tensor, "n_sources"] or None, optional, default=None + Per-source areas for aggregate weighting. Defaults to ones. + source_aggregates : SourceAggregates or None, optional, default=None + Precomputed per-node aggregates. If ``None``, computed on each + call. Pass this to avoid redundant computation across branches. Outputs ------- TensorDict[str, Float[torch.Tensor, "n_targets ..."]] - TensorDict with batch_size :math:`(N_{targets},)` containing the computed - fields. Numerically identical to non-chunked :class:`Kernel` evaluation. - - Examples - -------- - >>> # For a large problem with 1M query points: - >>> kernel = ChunkedKernel( - ... n_spatial_dims=3, - ... output_fields={"pressure": "scalar"}, - ... n_source_vectors=1, - ... hidden_layer_sizes=[64, 64], - ... ) - >>> # Evaluate with automatic chunking to prevent OOM - >>> result = kernel( - ... source_points=boundary_centers, # e.g., 10k faces - ... target_points=volume_points, # e.g., 1M points - ... reference_length=torch.tensor(1.0), - ... source_vectors=TensorDict({"normal": normals}, ...), - ... chunk_size="auto", # Will process in chunks of ~10-20k points - ... ) - - Notes - ----- - During training, chunking has limited benefit because PyTorch's autograd must - store all intermediate activations regardless. Memory reduction is most effective - during inference (with ``torch.no_grad()``) where chunking can reduce peak usage - by orders of magnitude. + Approximate kernel output, converging to the exact result as + ``theta`` approaches zero. """ + def __init__( + self, + *, + n_spatial_dims: int, + output_field_ranks: RankSpecDict, + source_data_ranks: RankSpecDict | None = None, + global_data_ranks: RankSpecDict | None = None, + smoothing_radius: float = 1e-8, + hidden_layer_sizes: Sequence[int] | None = None, + n_spherical_harmonics: int = 4, + network_type: Literal["pade", "mlp"] = "pade", + spectral_norm: bool = False, + use_gradient_checkpointing: bool = True, + leaf_size: int = 1, + ): + super().__init__( + n_spatial_dims=n_spatial_dims, + output_field_ranks=output_field_ranks, + source_data_ranks=source_data_ranks, + global_data_ranks=global_data_ranks, + smoothing_radius=smoothing_radius, + hidden_layer_sizes=hidden_layer_sizes, + n_spherical_harmonics=n_spherical_harmonics, + network_type=network_type, + spectral_norm=spectral_norm, + use_gradient_checkpointing=use_gradient_checkpointing, + ) + self.leaf_size = leaf_size + def forward( self, *, @@ -784,114 +820,618 @@ def forward( source_points: Float[torch.Tensor, "n_sources n_dims"], target_points: Float[torch.Tensor, "n_targets n_dims"], source_strengths: Float[torch.Tensor, " n_sources"] | None = None, - source_data: TensorDict[str, Float[torch.Tensor, "n_sources ..."]] - | None = None, - global_data: TensorDict[str, Float[torch.Tensor, "..."]] | None = None, - chunk_size: None | int | Literal["auto"] = "auto", + source_data: TensorDict | None = None, + global_data: TensorDict | None = None, + theta: float = 1.0, + cluster_tree: "ClusterTree | None" = None, + target_tree: "ClusterTree | None" = None, + dual_plan: "DualInteractionPlan | None" = None, + source_areas: Float[torch.Tensor, " n_sources"] | None = None, + source_aggregates: "SourceAggregates | None" = None, + target_centroids: Float[torch.Tensor, "n_target_nodes n_dims"] | None = None, + near_chunk_size: int | None = None, + expand_far_targets: bool = False, ) -> TensorDict[str, Float[torch.Tensor, "n_targets ..."]]: - r"""Evaluates the kernel with optional chunking for memory efficiency. + r"""Evaluate the kernel with dual-tree Barnes-Hut acceleration. - Parameters - ---------- - chunk_size : None or int or {"auto"}, optional - Controls chunking behavior: + Uses two separate evaluation phases: - - ``"auto"``: Automatically determine chunk size based on estimated memory - usage, targeting approximately 1GB per chunk. - - ``int``: Process target points in chunks of exactly this size. - - ``None``: No chunking, evaluate all target points at once. + - **Phase A (near-field)**: individual target-source pairs from + nearby leaf nodes, evaluated exactly with chunked processing. + - **Phase B (far-field node pairs)**: the kernel is evaluated ONCE + at ``(centroid_T, centroid_S, avg_data_S)`` per well-separated + node pair, then broadcast to all individual targets in the + target node via scatter_add. - **kernel_kwargs - All arguments accepted by :meth:`Kernel.forward`, including: - ``reference_length``, ``source_points``, ``target_points``, - ``source_strengths``, ``source_data``, ``global_data``. + Parameters + ---------- + reference_length : Float[torch.Tensor, ""] + Reference length scale for nondimensionalization. + source_points : Float[torch.Tensor, "n_sources n_dims"] + Source point coordinates. + target_points : Float[torch.Tensor, "n_targets n_dims"] + Target point coordinates. + source_strengths : Float[torch.Tensor, "n_sources"] or None + Per-source strength weights. Defaults to ones. + source_data : TensorDict or None + Per-source features (normals, latents). + global_data : TensorDict or None + Problem-level conditioning features. + theta : float + Barnes-Hut opening angle (larger = more aggressive). + cluster_tree : ClusterTree or None + Precomputed source tree. Built on-the-fly if ``None``. + target_tree : ClusterTree or None + Precomputed target tree. Built on-the-fly if ``None``. + For self-interaction (comm layers), pass the same tree as + ``cluster_tree``. + dual_plan : DualInteractionPlan or None + Precomputed dual traversal plan. Computed on-the-fly if ``None``. + source_areas : Float[torch.Tensor, "n_sources"] or None + Per-source areas for aggregate weighting. Defaults to ones. + source_aggregates : SourceAggregates or None + Precomputed per-node source aggregates. + target_centroids : Float[torch.Tensor, "n_target_nodes n_dims"] or None + Per-node centroids for the target tree. If ``None`` and + ``target_tree is cluster_tree`` (self-interaction), source + aggregates' centroids are reused. Otherwise computed from + the target tree. + near_chunk_size : int or None + Fixed chunk size for near-field pair processing. When provided, + overrides :meth:`_auto_chunk_size`. Pass this from an outer scope + to ensure deterministic chunking inside ``torch.utils.checkpoint`` + replay (free GPU memory changes between forward and backward, + so ``_auto_chunk_size`` would return different values). + expand_far_targets : bool, optional, default=False + If ``True``, far-field node pairs are expanded to individual + target points during plan construction, eliminating the + target-side centroid broadcast. Passed through to + :meth:`ClusterTree.find_dual_interaction_pairs`. Returns ------- TensorDict[str, Float[torch.Tensor, "n_targets ..."]] - TensorDict mapping field names to computed tensors. - Each scalar field has shape :math:`(N_{targets},)` and each vector field - has shape :math:`(N_{targets}, D)`. + Kernel output fields at target points. """ - n_sources: int = len(source_points) - n_targets: int = len(target_points) - n_interactions: int = n_targets * n_sources + from physicsnemo.experimental.models.globe.cluster_tree import ( + ClusterTree, + DualInteractionPlan, + SourceAggregates, + ) + from physicsnemo.mesh.spatial._ragged import _ragged_arange - if chunk_size == "auto": - approx_n_floats = n_interactions * ( - self.network_in_features - + sum(self.hidden_layer_sizes) - + self.network_out_features - ) - approx_n_bytes = ( - approx_n_floats * 4 - ) # float32; conservative enough for bfloat16 too - approx_memory_gb = approx_n_bytes / (1024**3) - target_memory_gb = 1.0 + n_sources = source_points.shape[0] + n_targets = target_points.shape[0] + device = source_points.device - n_chunks_needed = max(1, ceil(approx_memory_gb / target_memory_gb)) - chunk_size: int = max(1, ceil(n_targets / n_chunks_needed)) + ### Set defaults + if source_strengths is None: + source_strengths = torch.ones(n_sources, device=device) + if source_data is None: + source_data = TensorDict({}, batch_size=[n_sources], device=device) + if global_data is None: + global_data = TensorDict({}, device=device) + if source_areas is None: + source_areas = torch.ones(n_sources, device=device) - if not torch.compiler.is_compiling(): - logger.debug(f"Auto-chunking: {chunk_size=!r}, {n_chunks_needed=!r}") + ### Build trees if not precomputed + if cluster_tree is None: + cluster_tree = ClusterTree.from_points( + source_points, leaf_size=self.leaf_size, areas=source_areas + ) + if target_tree is None: + target_tree = ClusterTree.from_points( + target_points, leaf_size=self.leaf_size, + ) - return self.forward( - reference_length=reference_length, + ### Find dual interaction pairs if not precomputed + if dual_plan is None: + dual_plan = cluster_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta, + expand_far_targets=expand_far_targets, + ) + + ### Compute source aggregates for far-field clusters. + if source_aggregates is not None: + aggregates = source_aggregates + else: + aggregates = cluster_tree.compute_source_aggregates( source_points=source_points, - target_points=target_points, - source_strengths=source_strengths, + areas=source_areas, source_data=source_data, - global_data=global_data, - chunk_size=chunk_size, ) - elif isinstance(chunk_size, int): - result_pieces: list[TensorDict[str, Float[torch.Tensor, "..."]]] = [] + ### Resolve target centroids for far-field node pairs. + # For self-interaction (target_tree is cluster_tree), reuse source + # centroids. For separate targets, compute from the target tree. + if target_centroids is None: + if target_tree is cluster_tree: + target_centroids = aggregates.node_centroid + else: + tgt_agg = target_tree.compute_source_aggregates( + source_points=target_points, + areas=torch.ones(n_targets, device=device, dtype=target_points.dtype), + source_data=None, + ) + target_centroids = tgt_agg.node_centroid + + with record_function("bh_kernel::compute_strengths"): + node_total_strength = self._compute_node_strengths( + cluster_tree, source_strengths + ) + + ### Prepare rank-split source/global data (shared setup) + with record_function("bh_kernel::prepare_data"): + source_by_rank = split_by_leaf_rank(source_data) + source_scalars = source_by_rank[0] + source_vectors = source_by_rank[1] + source_vectors.batch_size = torch.Size([n_sources, self.n_spatial_dims]) - start_indices = range(0, n_targets, chunk_size) + global_by_rank = split_by_leaf_rank(global_data) + global_scalars = global_by_rank[0] + global_vectors = global_by_rank[1] + global_vectors.batch_size = torch.Size([self.n_spatial_dims]) - if not torch.compiler.is_compiling() and logger.isEnabledFor(logging.DEBUG): - start_indices = tqdm.tqdm( - start_indices, - desc="Evaluating kernel in chunks", - unit=" chunks", + n_near = dual_plan.n_near + n_nf = dual_plan.n_nf + n_fn = dual_plan.n_fn + n_far_nodes = dual_plan.n_far_nodes + + if not torch.compiler.is_compiling(): + n_dense = n_sources * n_targets + logger.debug( + "BarnesHutKernel: %d near + %d nf + %d fn + %d far_node " + "(%d sources x %d targets = %d dense, %.2f%% near-field)", + n_near, n_nf, n_fn, n_far_nodes, + n_sources, n_targets, n_dense, + 100.0 * n_near / max(n_dense, 1), ) - for start_idx in start_indices: - end_idx = min(start_idx + chunk_size, n_targets) - target_points_chunk = target_points[start_idx:end_idx] - - chunk_result = self.forward( - reference_length=reference_length, - source_points=source_points, - target_points=target_points_chunk, - source_strengths=source_strengths, - source_data=source_data, - global_data=global_data, - chunk_size=None, + if n_near == 0 and n_nf == 0 and n_fn == 0 and n_far_nodes == 0: + return self._empty_result(n_targets, device) + + ### Prepare aggregate data for far-field and (near,far) phases + if n_far_nodes > 0 or n_nf > 0: + if aggregates.node_source_data is not None: + agg_by_rank = split_by_leaf_rank(aggregates.node_source_data) + else: + agg_by_rank = split_by_leaf_rank( + TensorDict( + {}, batch_size=[cluster_tree.n_nodes], device=device + ) + ) + agg_scalars = agg_by_rank[0] + agg_vectors = agg_by_rank[1] + agg_vectors.batch_size = torch.Size( + [cluster_tree.n_nodes, self.n_spatial_dims] ) - result_pieces.append(chunk_result) + ### Initialize output buffers + output_bufs: dict[str, torch.Tensor] = {} + + # ================================================================== + # Phase A: Near-field (individual target-source pairs, chunked) + # ================================================================== + if n_near > 0: + near_tgt_ids = dual_plan.near_target_ids + near_src_ids = dual_plan.near_source_ids + chunk_size = ( + near_chunk_size + if near_chunk_size is not None + else self._auto_chunk_size(n_near, device) + ) - result = TensorDict.cat(result_pieces, dim=0) + for start in range(0, n_near, chunk_size): + end = min(start + chunk_size, n_near) + + chunk_tgt_ids = near_tgt_ids[start:end] + chunk_src_ids = near_src_ids[start:end] + + ### Gather + evaluate inside one checkpoint boundary. + # By checkpointing a function that takes INDICES (int64, + # ~8 bytes/pair) and references to the shared source data + # (O(1)), the autograd graph saves only the indices - not + # the gathered float data (~300 bytes/pair). This is a + # ~37x reduction in checkpoint-saved memory per branch. + with record_function("bh_kernel::near_chunk"): + if self.training and self.use_gradient_checkpointing: + chunk_result = checkpoint( + self._gather_and_evaluate, + chunk_tgt_ids, chunk_src_ids, + target_points, source_points, + source_scalars, source_vectors, + global_scalars, global_vectors, + reference_length, device, + use_reentrant=False, + ) + else: + chunk_result = self._gather_and_evaluate( + chunk_tgt_ids, chunk_src_ids, + target_points, source_points, + source_scalars, source_vectors, + global_scalars, global_vectors, + reference_length, device, + ) + + with record_function("bh_kernel::near_scatter"): + chunk_strengths = source_strengths[chunk_src_ids] + for k, v in chunk_result.items(): + weighted = v * chunk_strengths.view(-1, *([1] * (v.ndim - 1))) + if k not in output_bufs: + output_bufs[k] = torch.zeros( + (n_targets,) + v.shape[1:], + dtype=weighted.dtype, + device=device, + ) + idx = chunk_tgt_ids.view( + -1, *([1] * (v.ndim - 1)) + ).expand_as(weighted) + output_bufs[k].scatter_add_(0, idx, weighted) + + # ================================================================== + # Phase B: Far-field node pairs (evaluate once, broadcast to targets) + # ================================================================== + if n_far_nodes > 0: + far_tgt_nids = dual_plan.far_target_node_ids + far_src_nids = dual_plan.far_source_node_ids + + ### Evaluate kernel at (centroid_T, centroid_S, avg_data_S). + # Same gather-inside-checkpoint pattern: the checkpoint saves + # only the node ID indices, not the gathered aggregate data. + with record_function("bh_kernel::far_node_evaluate"): + if self.training and self.use_gradient_checkpointing: + far_result = checkpoint( + self._gather_and_evaluate, + far_tgt_nids, far_src_nids, + target_centroids, aggregates.node_centroid, + agg_scalars, agg_vectors, + global_scalars, global_vectors, + reference_length, device, + use_reentrant=False, + ) + else: + far_result = self._gather_and_evaluate( + far_tgt_nids, far_src_nids, + target_centroids, aggregates.node_centroid, + agg_scalars, agg_vectors, + global_scalars, global_vectors, + reference_length, device, + ) - return result + ### Broadcast node-level results to individual targets. + with record_function("bh_kernel::far_node_broadcast"): + far_strengths = node_total_strength[far_src_nids] + + node_starts = target_tree.node_range_start[far_tgt_nids] + node_counts = target_tree.node_range_count[far_tgt_nids] + positions, pair_ids = _ragged_arange(node_starts, node_counts) + expanded_tgt_ids = target_tree.sorted_source_order[positions] + + for k, v in far_result.items(): + weighted = v * far_strengths.view(-1, *([1] * (v.ndim - 1))) + expanded = weighted[pair_ids] + if k not in output_bufs: + output_bufs[k] = torch.zeros( + (n_targets,) + v.shape[1:], + dtype=expanded.dtype, + device=device, + ) + idx = expanded_tgt_ids.view( + -1, *([1] * (v.ndim - 1)) + ).expand_as(expanded) + output_bufs[k].scatter_add_(0, idx, expanded) + + # ================================================================== + # Phase C: (near,far) - individual targets × source node centroids + # ================================================================== + if n_nf > 0: + nf_tgt_ids = dual_plan.nf_target_ids + nf_src_nids = dual_plan.nf_source_node_ids + + ### Same evaluation as Phase B (source centroids + aggregates), + # but same scatter as Phase A (per-target, no broadcast). + with record_function("bh_kernel::nf_evaluate"): + if self.training and self.use_gradient_checkpointing: + nf_result = checkpoint( + self._gather_and_evaluate, + nf_tgt_ids, nf_src_nids, + target_points, aggregates.node_centroid, + agg_scalars, agg_vectors, + global_scalars, global_vectors, + reference_length, device, + use_reentrant=False, + ) + else: + nf_result = self._gather_and_evaluate( + nf_tgt_ids, nf_src_nids, + target_points, aggregates.node_centroid, + agg_scalars, agg_vectors, + global_scalars, global_vectors, + reference_length, device, + ) - elif chunk_size is None: - return super().forward( - reference_length=reference_length, - source_points=source_points, - target_points=target_points, - source_strengths=source_strengths, - source_data=source_data, - global_data=global_data, + with record_function("bh_kernel::nf_scatter"): + nf_strengths = node_total_strength[nf_src_nids] + for k, v in nf_result.items(): + weighted = v * nf_strengths.view(-1, *([1] * (v.ndim - 1))) + if k not in output_bufs: + output_bufs[k] = torch.zeros( + (n_targets,) + v.shape[1:], + dtype=weighted.dtype, + device=device, + ) + idx = nf_tgt_ids.view( + -1, *([1] * (v.ndim - 1)) + ).expand_as(weighted) + output_bufs[k].scatter_add_(0, idx, weighted) + + # ================================================================== + # Phase D: (far,near) - target node centroid × individual sources, + # broadcast to stage-1 survivors + # ================================================================== + if n_fn > 0: + fn_tgt_nids = dual_plan.fn_target_node_ids + fn_src_ids = dual_plan.fn_source_ids + + ### Evaluate K(target_centroid, source_point, source_data). + # Uses target centroids (like Phase B) but individual source + # points and data (like Phase A). + with record_function("bh_kernel::fn_evaluate"): + if self.training and self.use_gradient_checkpointing: + fn_result = checkpoint( + self._gather_and_evaluate, + fn_tgt_nids, fn_src_ids, + target_centroids, source_points, + source_scalars, source_vectors, + global_scalars, global_vectors, + reference_length, device, + use_reentrant=False, + ) + else: + fn_result = self._gather_and_evaluate( + fn_tgt_nids, fn_src_ids, + target_centroids, source_points, + source_scalars, source_vectors, + global_scalars, global_vectors, + reference_length, device, + ) + + ### Broadcast to stage-1 survivors via the ragged mapping. + with record_function("bh_kernel::fn_broadcast"): + fn_strengths = source_strengths[fn_src_ids] + + positions, pair_ids = _ragged_arange( + dual_plan.fn_broadcast_starts, + dual_plan.fn_broadcast_counts, + ) + expanded_tgt_ids = dual_plan.fn_broadcast_targets[positions] + + for k, v in fn_result.items(): + weighted = v * fn_strengths.view(-1, *([1] * (v.ndim - 1))) + expanded = weighted[pair_ids] + if k not in output_bufs: + output_bufs[k] = torch.zeros( + (n_targets,) + v.shape[1:], + dtype=expanded.dtype, + device=device, + ) + idx = expanded_tgt_ids.view( + -1, *([1] * (v.ndim - 1)) + ).expand_as(expanded) + output_bufs[k].scatter_add_(0, idx, expanded) + + if not output_bufs: + return self._empty_result(n_targets, device) + + return TensorDict( + output_bufs, + batch_size=torch.Size([n_targets]), + device=device, + ) + + def _compute_node_strengths( + self, + tree: "ClusterTree", + source_strengths: Float[torch.Tensor, " n_sources"], + ) -> Float[torch.Tensor, " n_nodes"]: + """Compute total source strength per tree node via bottom-up summation. + + Parameters + ---------- + tree : ClusterTree + The spatial cluster tree. + source_strengths : Float[torch.Tensor, "n_sources"] + Per-source strength values. + + Returns + ------- + torch.Tensor + Total strength per node, shape ``(n_nodes,)``. + """ + device = source_strengths.device + n_nodes = tree.n_nodes + node_strengths = torch.zeros(n_nodes, dtype=source_strengths.dtype, device=device) + + is_leaf = tree.leaf_count > 0 + leaf_ids = torch.where(is_leaf)[0] + + if leaf_ids.numel() == 0: + return node_strengths + + ### Sum strengths within each leaf + leaf_starts = tree.leaf_start[leaf_ids] + leaf_counts = tree.leaf_count[leaf_ids] + n_leaves = leaf_ids.shape[0] + + if int(leaf_counts.sum()) > 0: + from physicsnemo.mesh.spatial._ragged import _ragged_arange + + positions, seg_ids = _ragged_arange( + leaf_starts, leaf_counts, total=tree.n_sources, ) + sorted_strengths = source_strengths[tree.sorted_source_order[positions]] + leaf_sums = torch.zeros(n_leaves, dtype=source_strengths.dtype, device=device) + leaf_sums.scatter_add_(0, seg_ids, sorted_strengths) + node_strengths[leaf_ids] = leaf_sums + + ### Bottom-up propagation using cached level ordering + for level_ids in reversed(tree.internal_nodes_per_level): + node_strengths[level_ids] = ( + node_strengths[tree.node_left_child[level_ids]] + + node_strengths[tree.node_right_child[level_ids]] + ) + + return node_strengths + + def _empty_result( + self, + n_targets: int, + device: torch.device, + ) -> TensorDict[str, Float[torch.Tensor, "n_targets ..."]]: + """Produce a zero-valued result TensorDict for the degenerate case.""" + # Match the dtype that AMP autocast would produce for real activations, + # so downstream ops don't hit a float32-vs-half mismatch. + dtype = ( + torch.get_autocast_dtype(device.type) + if torch.is_autocast_enabled(device.type) + else torch.float32 + ) + ranks_dict = flatten_rank_spec(self.output_field_ranks) + fields: dict[str, torch.Tensor] = {} + for name, rank in sorted(ranks_dict.items()): + if rank == 0: + fields[name] = torch.zeros(n_targets, device=device, dtype=dtype) + else: + fields[name] = torch.zeros( + n_targets, self.n_spatial_dims, device=device, dtype=dtype + ) + return TensorDict(fields, batch_size=torch.Size([n_targets]), device=device) + + def _gather_and_evaluate( + self, + tgt_ids: torch.Tensor, + src_ids: torch.Tensor, + target_positions: torch.Tensor, + source_positions: torch.Tensor, + source_scalars: TensorDict, + source_vectors: TensorDict, + global_scalars: TensorDict, + global_vectors: TensorDict, + reference_length: torch.Tensor, + device: torch.device, + ) -> TensorDict: + """Gather source/target data by index and evaluate interactions. + + This function is the checkpoint boundary for memory-efficient + training. By wrapping both the gather (indexing into shared + source data) and the evaluate (feature engineering + MLP) in one + checkpointed call, the autograd graph saves only the int64 index + tensors (~8 bytes/pair) and references to the shared source data + (O(1)), instead of the gathered float features (~300 bytes/pair). + + Source scalars and vectors are pre-flattened via + ``concatenate_leaves`` before indexing, reducing K per-leaf index + ops to 1 cat + 1 index each. Vectors are split back into + individual named leaves afterward because the feature engineering + pipeline in ``_evaluate_interactions`` processes each vector + separately (magnitudes, dot products, basis construction). + """ + n_pairs = tgt_ids.shape[0] + chunk_r = ( + target_positions[tgt_ids] - source_positions[src_ids] + ) / reference_length + + ### Flatten source scalars into one tensor, gather once. + # concatenate_leaves: 1 GPU kernel (torch.cat) + # [src_ids]: 1 GPU kernel (aten::index) + # Total: 2 kernels instead of K (one per TensorDict leaf). + gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids] + scalars = TensorDict( + { + "source_scalars": gathered_src_scalars, + "global_scalars": global_scalars.expand( + n_pairs, *global_scalars.batch_size + ), + }, + batch_size=torch.Size([n_pairs]), + device=device, + ) + + ### Flatten source vectors, gather once, split back into named leaves. + # The split-back is required because _evaluate_interactions processes + # each vector leaf separately for magnitude/direction extraction and + # rotationally-equivariant basis construction. Integer indexing + # along the last dimension creates non-contiguous views (zero copies). + src_vector_keys = list( + source_vectors.keys(include_nested=True, leaves_only=True) + ) + gathered_src_vectors = concatenate_leaves(source_vectors)[src_ids] + gathered_vector_leaves = { + k: gathered_src_vectors[..., i] + for i, k in enumerate(src_vector_keys) + } + vectors = TensorDict( + { + "source_vectors": TensorDict( + gathered_vector_leaves, + batch_size=torch.Size([n_pairs, self.n_spatial_dims]), + device=device, + ), + "global_vectors": global_vectors.expand( + torch.Size([n_pairs]) + global_vectors.batch_size + ), + }, + batch_size=torch.Size([n_pairs, self.n_spatial_dims]), + device=device, + ) + vectors["r"] = chunk_r + + return self._evaluate_interactions(scalars=scalars, vectors=vectors, device=device) + + def _auto_chunk_size(self, n_total_pairs: int, device: torch.device) -> int: + """Determine chunk size for pair-batched kernel evaluation. + + Estimates peak memory per pair from the kernel's feature engineering + pipeline and sizes chunks to fit within ~50% of GPU memory. During + inference (no grad), the autograd overhead multiplier is dropped, + allowing larger chunks. + Returns ``n_total_pairs`` (i.e., no chunking) when the estimated + peak fits comfortably, or when running on CPU. + """ + if device.type != "cuda": + return n_total_pairs + + if torch.is_autocast_enabled(device.type): + element_bytes = torch.tensor( + [], dtype=torch.get_autocast_dtype(device.type) + ).element_size() else: - raise ValueError( - f"Got {chunk_size=!r}; this must be one of ['auto', int, None]" + element_bytes = 4 # fp32 + + autograd_overhead = 5 if torch.is_grad_enabled() else 1 + approx_peak_bytes = ( + n_total_pairs + * self._floats_per_interaction + * element_bytes + * autograd_overhead + ) + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + target_bytes = free_bytes // 2 + + n_chunks = max(1, ceil(approx_peak_bytes / target_bytes)) + chunk_size = max(1, ceil(n_total_pairs / n_chunks)) + + if not torch.compiler.is_compiling(): + logger.debug( + "auto_chunk_size: %d pairs -> %d chunks of %d " + "(%.1f MB est. peak, %.1f MB free / %.1f MB total GPU)", + n_total_pairs, n_chunks, chunk_size, + approx_peak_bytes / 1e6, free_bytes / 1e6, total_bytes / 1e6, ) + return chunk_size + class MultiscaleKernel(Module): r"""Multiscale kernel composition that linearly combines kernels at different length scales. @@ -964,13 +1504,23 @@ class MultiscaleKernel(Module): ``reference_length_names``. Defaults to all ones. source_data : TensorDict or None, optional, default=None Per-source features with ``batch_size=(N_sources,)``. Mixed-rank - TensorDict passed through to each :class:`ChunkedKernel` branch. + TensorDict passed through to each :class:`BarnesHutKernel` branch. global_data : TensorDict or None, optional, default=None Problem-level features with ``batch_size=()``. Automatically augmented with log-ratios of reference lengths before being passed to each kernel branch. - chunk_size : None or int or {"auto"}, optional, default="auto" - Chunking behavior. + theta : float, optional, default=1.0 + Barnes-Hut opening angle (larger = more aggressive). + cluster_tree : ClusterTree or None, optional, default=None + Pre-built cluster tree for source points. If ``None``, one is + built from ``source_points`` using the kernel's ``leaf_size``. + target_tree : ClusterTree or None, optional, default=None + Pre-built target tree. For self-interaction, pass the same tree + as ``cluster_tree``. + dual_plan : DualInteractionPlan or None, optional, default=None + Pre-computed dual traversal plan. If ``None``, computed from trees. + source_areas : Float[torch.Tensor, " n_sources"] or None, optional, default=None + Area weight per source, used for cluster aggregation. Outputs ------- @@ -1013,6 +1563,7 @@ def __init__( network_type: Literal["pade", "mlp"] = "pade", spectral_norm: bool = False, use_gradient_checkpointing: bool = True, + leaf_size: int = 1, ): super().__init__() @@ -1032,6 +1583,7 @@ def __init__( self.network_type = network_type self.spectral_norm = spectral_norm self.use_gradient_checkpointing = use_gradient_checkpointing + self.leaf_size = leaf_size ### Augment global_data_ranks with log-ratio entries for each # pair of reference lengths. These are rank-0 (scalar) features. @@ -1045,7 +1597,7 @@ def __init__( self.kernels = nn.ModuleDict( { - name: ChunkedKernel( + name: BarnesHutKernel( n_spatial_dims=n_spatial_dims, output_field_ranks=output_field_ranks, source_data_ranks=source_data_ranks, @@ -1056,6 +1608,7 @@ def __init__( network_type=network_type, spectral_norm=spectral_norm, use_gradient_checkpointing=use_gradient_checkpointing, + leaf_size=leaf_size, ) for name in reference_length_names } @@ -1076,46 +1629,56 @@ def forward( source_data: TensorDict[str, Float[torch.Tensor, "n_sources ..."]] | None = None, global_data: TensorDict[str, Float[torch.Tensor, "..."]] | None = None, - chunk_size: None | int | Literal["auto"] = "auto", + theta: float = 1.0, + cluster_tree: "ClusterTree | None" = None, + target_tree: "ClusterTree | None" = None, + dual_plan: "DualInteractionPlan | None" = None, + source_areas: Float[torch.Tensor, " n_sources"] | None = None, + expand_far_targets: bool = False, ) -> TensorDict[str, Float[torch.Tensor, "n_targets ..."]]: r"""Evaluates the multiscale kernel by combining results from multiple scales. - Evaluates each constituent kernel at its respective reference length - (scaled by a learnable factor), automatically adds log-ratios of - reference lengths to ``global_data`` as scalar features, and sums - the results across all scales. + Builds a shared :class:`ClusterTree` and :class:`DualInteractionPlan` + once, then evaluates each :class:`BarnesHutKernel` branch at its + respective reference length. Parameters ---------- reference_lengths : dict[str, torch.Tensor] Mapping of reference length names to scalar tensors. source_points : Float[torch.Tensor, "n_sources n_dims"] - Tensor of shape :math:`(N_{sources}, D)`. Physical coordinates of - the source points. + Source point coordinates, shape :math:`(N_{sources}, D)`. target_points : Float[torch.Tensor, "n_targets n_dims"] - Tensor of shape :math:`(N_{targets}, D)`. Physical coordinates of - the target points. - source_strengths : TensorDict[str, Float[torch.Tensor, " n_sources"]] or None, optional - Per-source, per-branch strength values, keyed by - ``reference_length_names``. Defaults to all ones. + Target point coordinates, shape :math:`(N_{targets}, D)`. + source_strengths : TensorDict or None, optional + Per-source, per-branch strength values. Defaults to all ones. source_data : TensorDict or None, optional - Per-source features with ``batch_size=(N_sources,)``. Passed - through to each :class:`ChunkedKernel` branch unchanged. + Per-source features with ``batch_size=(N_sources,)``. global_data : TensorDict or None, optional - Problem-level features with ``batch_size=()``. Augmented with - log-ratios of reference lengths before being passed to each - kernel branch. - chunk_size : None or int or {"auto"}, optional - Chunking behavior passed to :meth:`ChunkedKernel.forward`. - Default is ``"auto"``. + Problem-level features with ``batch_size=()``. + theta : float + Barnes-Hut opening angle (larger = more aggressive). + cluster_tree : ClusterTree or None, optional + Precomputed source tree. Built from ``source_points`` if ``None``. + target_tree : ClusterTree or None, optional + Precomputed target tree. Built from ``target_points`` if ``None``. + dual_plan : DualInteractionPlan or None, optional + Precomputed dual traversal plan. Computed if ``None``. + source_areas : Float[torch.Tensor, "n_sources"] or None, optional + Per-source areas for aggregate weighting. Defaults to ones. + expand_far_targets : bool, optional, default=False + If ``True``, eliminates target-side centroid broadcast by + expanding far-field node pairs to individual target points. + Passed through to + :meth:`ClusterTree.find_dual_interaction_pairs`. Returns ------- TensorDict[str, Float[torch.Tensor, "n_targets ..."]] - Dictionary mapping field names to the summed results from all kernels. - Each scalar field has shape :math:`(N_{targets},)` and each vector field - has shape :math:`(N_{targets}, D)`. + Summed results from all kernel branches. """ + from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree + n_sources: int = len(source_points) device = source_points.device @@ -1133,6 +1696,8 @@ def forward( source_data = TensorDict({}, batch_size=[n_sources], device=device) if global_data is None: global_data = TensorDict({}, device=device) + if source_areas is None: + source_areas = torch.ones(n_sources, device=device) # Skip validation when running under torch.compile for performance if not torch.compiler.is_compiling(): @@ -1152,6 +1717,28 @@ def forward( f"but the forward-method input gives {actual} {name}." ) + ### Build shared trees, dual plan, and aggregates (reused across branches) + with record_function("multiscale_kernel::build_tree"): + if cluster_tree is None: + cluster_tree = ClusterTree.from_points( + source_points, leaf_size=self.leaf_size, areas=source_areas, + ) + if target_tree is None: + target_tree = ClusterTree.from_points( + target_points, leaf_size=self.leaf_size, + ) + if dual_plan is None: + dual_plan = cluster_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta, + expand_far_targets=expand_far_targets, + ) + with record_function("multiscale_kernel::compute_aggregates"): + source_aggregates = cluster_tree.compute_source_aggregates( + source_points=source_points, + areas=source_areas, + source_data=source_data, + ) + ### Augment global_data with log-ratios of reference lengths. log_ratios = TensorDict( { @@ -1164,21 +1751,102 @@ def forward( }, device=device, ) + global_data = global_data.copy() global_data["log_reference_length_ratios"] = log_ratios - results_pieces: list[TensorDict[str, Float[torch.Tensor, "n_targets ..."]]] = [ - self.kernels[name]( - reference_length=reference_lengths[name] - * torch.exp(self.log_scalefactors[name]), - source_points=source_points, - target_points=target_points, - source_strengths=source_strengths[name], - source_data=source_data, - global_data=global_data, - chunk_size=chunk_size, + ### Precompute near-field chunk sizes outside the checkpoint boundary. + # _auto_chunk_size queries free GPU memory, which differs between + # forward and checkpoint replay (backward). Computing here ensures + # each branch's chunk size is a fixed checkpoint input. + near_chunk_sizes: dict[str, int] = { + name: self.kernels[name]._auto_chunk_size( + dual_plan.n_near, source_points.device ) for name in self.reference_length_names - ] + } + + ### Decide whether branch-level checkpointing is worthwhile. + # Each branch accumulates ~34 bytes/near-pair of autograd state + # (int64 checkpoint-saved indices + multiply/scatter graph nodes). + # Branch checkpointing avoids holding all branches' graphs + # simultaneously, which is essential at large N (800k+ faces) + # but a pure compute overhead at small N (20k faces). + _AUTOGRAD_BYTES_PER_PAIR = 34 + n_branches = len(self.reference_length_names) + use_branch_ckpt = False + if self.training and self.use_gradient_checkpointing and n_branches > 1: + n_total_pairs = dual_plan.n_near + dual_plan.n_nf + dual_plan.n_fn + per_branch_bytes = n_total_pairs * _AUTOGRAD_BYTES_PER_PAIR + all_branches_bytes = per_branch_bytes * n_branches + if device.type == "cuda": + free_bytes = torch.cuda.mem_get_info(device)[0] + use_branch_ckpt = all_branches_bytes > free_bytes * 0.1 + else: + use_branch_ckpt = False + + if not torch.compiler.is_compiling(): + logger.debug( + "branch checkpoint: %s (est. %.1f MB/branch, " + "%.1f MB all branches, %.1f MB free, %d branches)", + "ENABLED" if use_branch_ckpt else "DISABLED", + per_branch_bytes / 1e6, + all_branches_bytes / 1e6, + free_bytes / 1e6 if device.type == "cuda" else 0, + n_branches, + ) + + ### Evaluate each branch with the shared tree, plan, and aggregates. + # When enabled, branch-level checkpointing ensures only ONE branch's + # autograd graph exists at a time during backward, preventing + # autograd memory from accumulating across all branches. + results_pieces: list[TensorDict[str, Float[torch.Tensor, "n_targets ..."]]] = [] + for name in self.reference_length_names: + with record_function(f"multiscale_kernel::branch/{name}"): + ref_length = ( + reference_lengths[name] + * torch.exp(self.log_scalefactors[name]) + ) + strengths = source_strengths[name] + chunk_size = near_chunk_sizes[name] + kernel = self.kernels[name] + if use_branch_ckpt: + results_pieces.append( + checkpoint( + kernel, + use_reentrant=False, + reference_length=ref_length, + source_points=source_points, + target_points=target_points, + source_strengths=strengths, + source_data=source_data, + global_data=global_data, + theta=theta, + cluster_tree=cluster_tree, + target_tree=target_tree, + dual_plan=dual_plan, + source_areas=source_areas, + source_aggregates=source_aggregates, + near_chunk_size=chunk_size, + ) + ) + else: + results_pieces.append( + kernel( + reference_length=ref_length, + source_points=source_points, + target_points=target_points, + source_strengths=strengths, + source_data=source_data, + global_data=global_data, + theta=theta, + cluster_tree=cluster_tree, + target_tree=target_tree, + dual_plan=dual_plan, + source_areas=source_areas, + source_aggregates=source_aggregates, + near_chunk_size=chunk_size, + ) + ) result: TensorDict[str, Float[torch.Tensor, "n_targets ..."]] = reduce( operator.add, results_pieces diff --git a/physicsnemo/experimental/models/globe/hierarchical_acceleration.md b/physicsnemo/experimental/models/globe/hierarchical_acceleration.md new file mode 100644 index 0000000000..eb95c33523 --- /dev/null +++ b/physicsnemo/experimental/models/globe/hierarchical_acceleration.md @@ -0,0 +1,585 @@ +# Hierarchical Acceleration for GLOBE + +This document describes the dual-tree hierarchical acceleration applied to +GLOBE's field kernel evaluation, reducing the O(N^2) all-to-all interaction +cost to O(N log N). It assumes familiarity with the base GLOBE architecture +(the whitepaper's Sections 3-4) and focuses on the acceleration strategy. + +--- + +## 1. Motivation + +GLOBE's field kernel computes, for each target point, the influence of *every* +source face on the boundary mesh. This produces an `(N_tgt, N_src, D)` +displacement tensor, followed by per-pair feature engineering, neural network +evaluation, and an aggregation sum over sources. The cost is +O(N_tgt * N_src) - quadratic in the mesh size. + +This quadratic cost appears in two places: + +- **Communication hyperlayers** (boundary-to-boundary): N_src = N_tgt = N_faces. + With N_faces = 20k, this is 400M interactions per layer. +- **Final prediction** (boundary-to-volume): N_src = N_faces, N_tgt = N_prediction. + At DrivAerML scale (100k+ faces, 180k prediction points), this is 18 billion + interactions. + +The key observation enabling acceleration is GLOBE's explicit far-field decay +envelope. The kernel output is multiplied by a Lamb-Oseen-like factor +`(1 - exp(-|r|^2)) / (|r|^2 + 1)^p` that forces contributions to decay as +`1/r^(d-1)` at large distances. This means distant sources contribute weakly, +and grouping them into clusters introduces only small approximation error. + +--- + +## 2. The Monopole Approximation + +For a target point far from a cluster C_S of source faces, the exact sum + +```text +exact = sum_{s in C_S} strength_s * K(target, source_s, data_s) +``` + +is approximated by + +```text +approx = total_strength_{C_S} * K(target, centroid_{C_S}, avg_data_{C_S}) +``` + +where: + +- `centroid_{C_S}` is the area-weighted centroid of sources in C_S +- `avg_data_{C_S}` is the area-weighted average of source features (normals, + latent scalars/vectors) +- `total_strength_{C_S} = sum_{s in C_S} strength_s` is the sum of learned + per-source strengths + +The same neural network evaluates both exact and approximate interactions - +cluster centroids are treated as "virtual sources" with averaged features. +This is a zeroth-order (monopole) Taylor expansion of the kernel about the +cluster centroid. + +### Dual-tree extension: node-to-node evaluation + +The dual-tree variant goes further. When a cluster C_T of *targets* is +well-separated from a cluster C_S of *sources*, the kernel varies slowly +across all targets in C_T (the target cluster is small relative to the +inter-cluster distance). The kernel is evaluated **once** at the pair of +centroids `(centroid_{C_T}, centroid_{C_S})` and the result is broadcast to +all individual targets in C_T via scatter-add. + +This reduces far-field evaluations from O(N_target * #source_nodes) to +O(#node_pairs), which is typically O(N) for well-separated geometries. + +### Why area-weighting, not strength-weighting? + +The spatial averages (centroid, feature means) use *area*-weighting, while the +multiplicative strength factor is summed separately. Areas are fixed +geometric properties of the mesh (always positive, always stable), making the +aggregates reusable across kernel branches (the `MultiscaleKernel` has +multiple branches sharing the same source geometry). Strengths, by contrast, +are learned per-source and per-branch values that change between communication +layers. Separating these concerns means: + +1. Aggregates are computed once per forward pass and shared across branches. +2. Only strength summation is per-branch (cheap O(N) work). +3. The aggregation is numerically stable (no division by near-zero when + learned strengths cancel within a cluster). + +--- + +## 3. Spatial Data Structure: ClusterTree + +### 3.1 Construction via LBVH + +The tree is built using a Linear Bounding Volume Hierarchy (LBVH) algorithm +(Karras 2012), the same approach used in PhysicsNeMo Mesh's existing `BVH` +class for mesh spatial decomposition: + +1. **Morton codes**: Each point is assigned a 63-bit Morton code that + interleaves the quantized coordinates. Morton codes produce a + space-filling Z-curve ordering that preserves spatial locality - nearby + points in space tend to have nearby codes. + +2. **Sort**: Points are sorted by Morton code. After sorting, spatially + nearby points are contiguous in the array. + +3. **Top-down recursive splitting**: Starting from the full sorted range as + the root, each segment with more than `leaf_size` points is split at its + midpoint. Because Morton-sorted order preserves spatial locality, midpoint + splitting approximates a spatial median split, producing a balanced binary + tree. Each iteration processes all segments at the current depth in + parallel, yielding O(log N) Python-level iterations. + +4. **Bottom-up axis-aligned bounding box (AABB) propagation**: Leaf AABBs are + computed from the actual points they contain. Internal node AABBs are the + union of their children's AABBs. Total areas are similarly propagated + (sum, not average). + +The tree is stored as flat tensor arrays (`node_aabb_min`, `node_aabb_max`, +`node_left_child`, etc.) indexed by node ID, making it fully GPU-compatible. + +### 3.2 Node Pre-allocation Bounds + +Before construction, arrays are pre-allocated at the worst-case node count. +The midpoint split guarantees each child gets at least `floor(parent_size/2)` +sources, so the minimum leaf occupancy is `ceil(leaf_size/2)`. The maximum +number of leaves is `ceil(N / min_per_leaf)`, and by the full-binary-tree +identity (`n_internal = n_leaves - 1`), the maximum total node count is +`2 * max_leaves - 1`. After construction, the arrays are trimmed to the +actual count. + +### 3.3 Source Aggregates + +Per-node aggregate data is computed bottom-up for far-field evaluation: + +- **Centroid**: area-weighted mean of source positions +- **Source features** (normals, latent scalars/vectors): area-weighted mean + via `TensorDict.apply()` with segmented scatter operations +- **Total area**: sum (not average) of children's areas + +Internal node aggregates are computed from their children's aggregates using +area-weighted averaging via a BFS level-ordering: internal nodes are +discovered by depth, then processed deepest-first so children are correct +before their parents read from them. + +Aggregates depend on the source data (which changes between communication +layers as latent features are updated) but NOT on the tree structure (which +depends only on geometry). The tree is built once per forward pass; aggregates +are recomputed each time the source data changes. + +--- + +## 4. Dual-Tree Traversal + +The classical Barnes-Hut algorithm pairs each *individual* target point with +tree nodes, yielding O(N_tgt * log N_src) far-field evaluations. The +dual-tree variant builds trees for **both** sources and targets, then +traverses pairs of nodes from the two trees simultaneously. This produces +far-field node-to-node pairs whose count can be as low as O(N). + +### 4.1 Acceptance Criterion + +The dual-tree acceptance criterion generalizes the single-tree Barnes-Hut +opening test by accounting for the spatial extent of *both* nodes: + +```text +(D_T + D_S) / r < theta +``` + +where D_T and D_S are the AABB diagonals of the target and source nodes, and +r is the minimum distance between the two AABBs (gap distance). In code: + +```python +# Per-dimension gap between AABBs (0 where they overlap) +gap = torch.clamp( + torch.maximum(aabb_min_T - aabb_max_S, aabb_min_S - aabb_max_T), + min=0, +) +min_dist_sq = gap.pow(2).sum(dim=-1) + +combined_diam_sq = (diam_T + diam_S).pow(2) + +is_far = min_dist_sq * theta_sq > combined_diam_sq +``` + +The combined-diameter criterion is more conservative than the single-tree test +(which effectively sets D_T = 0). This is appropriate because the far-field +broadcast approximation assumes the kernel is roughly constant across the +*target* node as well - an assumption that degrades when the target node is +large relative to the inter-node distance. + +When both AABBs overlap (`gap = 0` in some dimension, `min_dist_sq = 0`), +the criterion always fails, forcing refinement. This eliminates edge cases +where a node's centroid might be close to the cluster boundary. + +### 4.2 Theta Parameter Semantics + +The `theta` parameter follows the standard Barnes-Hut convention (Barnes & +Hut 1986): + +- **Larger theta** = more aggressive (more approximations, faster). +- **Smaller theta** = more conservative (more exact interactions, slower). +- **theta = 0** = all interactions are exact (no approximation). + +Typical values for GLOBE: `theta = 0.5` (conservative) to `theta = 1.5` +(aggressive). The default is `theta = 1.0`. + +### 4.3 Breadth-First Traversal + +The traversal processes all active (target_node, source_node) pairs at each +level simultaneously: + +1. **Initialize**: the single pair `(root_T, root_S)`. +2. **For each iteration** (bounded by `depth_T + depth_S + 1`), classify + active pairs into three categories: + + - **Far-field**: passes the acceptance criterion. Record the + `(target_node, source_node)` pair. + - **Near-field leaves**: fails the criterion, and BOTH nodes are leaves. + Expand into the Cartesian product of individual targets and sources + within those leaves. + - **Needs refinement**: fails the criterion, and at least one node is + internal. Split into child pairs for the next iteration. + +3. **Splitting rule**: Split the node with the larger AABB diameter (by + squared diagonal). If both nodes have equal diameter, split both. If one + side is a leaf, only the other can be split. + + - Split target only: 2 child pairs `(left_T, S)` and `(right_T, S)`. + - Split source only: 2 child pairs `(T, left_S)` and `(T, right_S)`. + - Split both: 4 child pairs `(left_T, left_S)`, `(left_T, right_S)`, + `(right_T, left_S)`, `(right_T, right_S)`. + +4. **Post-processing**: Near pairs are sorted by source index, far pairs by + source node, for cache-friendly memory access during kernel evaluation. + +The output is a `DualInteractionPlan` containing four index arrays: + +- `(near_target_ids, near_source_ids)`: individual target-source pairs + requiring exact evaluation. +- `(far_target_node_ids, far_source_node_ids)`: node-to-node pairs using the + monopole approximation with target-side broadcast. + +### 4.4 Self-Interaction and Cross-BC Interaction + +For communication layers with a single BC type (or the self-interaction +portion of a multi-BC model), the same `ClusterTree` is used for both the +source and target sides. The traversal starts with `(root, root)` and +proceeds normally. The splitting rule defaults to splitting both nodes when +diameters are equal, which is always the case for self-interaction (both sides +reference the same tree). + +When multiple BC types are present, communication layers also evaluate +cross-BC interactions: source BC "A" contributes to destination BC "B" and +vice versa. For cross-BC pairs, the source tree and target tree are different +objects (built from different point sets), and a separate +`DualInteractionPlan` is computed for each (source BC, destination BC) pair. +This produces B^2 plans for B BC types. Since B is small in practice (1-4), +the additional traversal cost is negligible. + +### 4.5 Caching Interaction Plans + +The interaction plan depends only on the geometric positions of sources and +targets, not on the source data or strengths. For communication hyperlayers, +all B^2 plans (covering both self-interaction and cross-BC pairs) are computed +once and reused across all layers. For the final prediction evaluation, +separate plans are computed from each source BC tree to the prediction-point +target tree. This eliminates redundant traversals. + +--- + +## 5. Two-Phase Kernel Evaluation + +`BarnesHutKernel.forward()` evaluates near-field and far-field interactions in +two distinct phases, each with its own gather-evaluate-scatter pipeline. The +same `_evaluate_interactions()` method handles both - it operates on generic +`(N_pairs, ...)` tensors and is agnostic to whether the pairs are individual +points or node centroids. + +### 5.1 Phase A: Near-Field (Individual Pairs) + +Near-field pairs are individual (target, source) interactions requiring exact +kernel evaluation. They are processed in chunks: + +1. **Chunk the pair arrays**: Slice `near_target_ids[start:end]` and + `near_source_ids[start:end]`. +2. **Gather**: Index into the shared source/target point arrays and feature + data to build per-chunk float tensors. +3. **Evaluate**: Run `_evaluate_interactions()` (feature engineering + MLP + + post-processing). +4. **Weight and scatter**: Multiply by per-source strengths, then + `scatter_add` into the output buffer at the target indices. + +### 5.2 Phase B: Far-Field (Node Pairs with Broadcast) + +Far-field pairs are node-to-node interactions that exploit the monopole +approximation with target-side broadcast: + +1. **Gather**: Index into node centroids and aggregate features for both + the target nodes (`far_target_node_ids`) and source nodes + (`far_source_node_ids`). +2. **Evaluate**: Run `_evaluate_interactions()` at the centroid pair, yielding + one result per node pair. +3. **Weight**: Multiply by total source-node strength. +4. **Broadcast to targets**: For each target node, use `_ragged_arange` to + expand the node-level result to all individual targets within that node, + then `scatter_add` to the output buffer. + +The broadcast step uses the target tree's `node_range_start` and +`node_range_count` arrays to find which individual targets belong to each +target node, and `sorted_source_order` to map back to original target indices. + +### 5.3 The _evaluate_interactions() Factoring + +The core feature engineering pipeline (vector magnitudes, spherical harmonics, +network evaluation, far-field decay, vector reprojection) lives in a shared +`_evaluate_interactions()` method. This method operates on generic +`(*interaction_dims, ...)` tensors - it does not know or care whether the +interactions are dense `(N_tgt, N_src)` or sparse `(N_chunk,)`. + +- `Kernel.forward()` calls it with `interaction_dims = (N_tgt, N_src)` (dense, + brute-force evaluation) +- `BarnesHutKernel.forward()` calls it with `interaction_dims = (N_chunk,)` in + both Phase A and Phase B + +This avoids duplicating the ~250-line feature engineering pipeline. + +--- + +## 6. Memory Management + +### 6.1 Gather-Inside-Checkpoint Pattern + +The key memory optimization: each chunk's gather and evaluate steps are +wrapped together in a single `torch.utils.checkpoint.checkpoint` call. The +checkpoint boundary is drawn so that autograd saves only the compact int64 +index arrays (~8 bytes/pair) and references to the shared source data (O(1)), +rather than the gathered float data (~300 bytes/pair). This is a ~37x +reduction in checkpoint-saved memory per chunk. + +### 6.2 Auto-Chunk Sizing + +`_auto_chunk_size()` estimates peak memory per interaction pair from the +kernel's feature engineering pipeline (counting intermediate floats for +spatial vectors, scalar features, MLP layers, and post-processing) and sizes +chunks to fit within ~50% of free GPU memory. During training, a 5x +multiplier accounts for autograd tensor retention; during inference, this +multiplier is dropped, allowing larger chunks. + +### 6.3 Branch-Level Checkpointing + +`MultiscaleKernel` wraps each `BarnesHutKernel` branch call in +`checkpoint(use_reentrant=False)`. This ensures only ONE branch's autograd +graph exists at a time during backward, preventing autograd memory from +accumulating across all branches. Combined with the gather-inside-checkpoint +pattern, peak autograd memory scales as O(chunk_size \* indices_only) rather +than O(n_branches \* n_pairs \* features). + +The branch-level and chunk-level checkpoints nest correctly: +`use_reentrant=False` composes via `saved_tensors_hooks`. + +### 6.4 Chunk-Size Determinism + +`_auto_chunk_size()` queries free GPU memory (`torch.cuda.mem_get_info`), +which changes between the forward pass and a checkpoint replay during +backward. If the chunk size changes, intermediate tensors have different +shapes, and the outer (branch-level) checkpoint raises `CheckpointError`. + +To prevent this, `MultiscaleKernel.forward()` precomputes each branch's chunk +size **outside** the checkpoint boundary and passes it as a fixed input via +the `near_chunk_size` kwarg. The checkpoint saves this value as an input and +replays with the identical value, regardless of current GPU memory state. + +--- + +## 7. Integration with GLOBE + +### 7.1 Tree and Plan Lifecycle + +Within a single `GLOBE.forward()` call: + +1. **Phase 1 (init)**: Build one `ClusterTree` per boundary condition type + from the cell centroids. Compute `DualInteractionPlan`s for communication + covering all (source BC, destination BC) pairs - B^2 plans for B BC types. + For self-interaction pairs (source == destination), the target tree is the + same object as the source tree. All trees and plans are cached for the + duration of the forward pass. + +2. **Phase 2 (communication)**: For each communication hyperlayer, reuse the + cached trees and plans. Only source aggregates are recomputed (the latent + features change between layers). + +3. **Phase 3 (prediction)**: Build a single target tree for prediction points + and compute one interaction plan per source BC type (B plans total). + Source trees are reused from Phase 1. + +Tree construction and plan finding are decorated with +`@torch.compiler.disable` because they involve irregular control flow (Morton +code bit operations, data-dependent loop termination) that `torch.compile` +cannot trace. The kernel evaluation inside `_evaluate_interactions` compiles +normally. + +### 7.2 Shared Aggregates Across Branches + +`MultiscaleKernel` computes source aggregates once and passes them to all +`BarnesHutKernel` branches via the `source_aggregates` parameter. Since +aggregates depend only on geometry and source data (both shared across +branches), this eliminates redundant computation. Only per-node strength +summation (which depends on per-branch strengths) is computed per-branch. + +### 7.3 Dynamic Shapes + +The hierarchical approach naturally requires dynamic tensor shapes (each mesh +produces a different tree, different interaction plan, different pair counts). +Training scripts use `torch.compile(dynamic=True)` and +`compile_mode="max-autotune-no-cudagraphs"` to accommodate this. Mesh padding +(previously used for static-shape CUDA graph compatibility) has been removed. + +--- + +## 8. Parameter Tuning + +### 8.1 Theta (opening angle) + +The `theta` parameter controls accuracy vs. speed: + +| theta | Character | Typical use case | +|-------|----------------------|----------------------------------------| +| 0 | Exact | No approximation (equivalent to dense) | +| 0.5 | Conservative | High accuracy, for validation | +| 1.0 | Moderate | Good default for production training | +| 1.5 | Aggressive | Fast approximate evaluation | +| 100+ | Extremely aggressive | Testing only | + +The approximation error per interaction scales with theta, but the total +error is bounded by the kernel's far-field decay. Distant clusters contribute +little regardless of approximation quality, providing a natural error ceiling. + +### 8.2 Leaf Size + +The `leaf_size` parameter (default 1) controls tree granularity: + +- **Smaller leaf_size** (e.g., 1-4): deeper trees, finer-grained near/far + classification, more far-field approximations at higher precision (each + node represents a smaller spatial region, so centroids are more accurate). + Near-field count drops dramatically since the opening criterion passes more + easily for small-diameter nodes. +- **Larger leaf_size** (e.g., 32-64): shallower trees, coarser + classification, fewer traversal iterations, but each near-field leaf-pair + hit expands into up to `leaf_size^2` individual interactions, and far-field + node centroids are coarser averages over larger spatial regions. + +Crucially, **smaller leaf_size does not reduce accuracy** for a fixed theta. +The far-field approximation for a single-point leaf (leaf_size=1) is exact in +the source coordinate (the "centroid" is the point itself), so all +approximation error comes from the target side, which is controlled by theta. +Smaller leaves produce strictly finer-resolution far-field evaluations. + +Benchmarks on DrivAerML (20k boundary faces, H100) show `leaf_size=1` is +3.8x faster than `leaf_size=32` with no accuracy penalty. The default is +`leaf_size=1`. + +--- + +## 9. Complexity Analysis + +| Component | Time complexity | Memory complexity | +|--------------------|---------------------|---------------------| +| Tree construction | O(N log N) | O(N) | +| Aggregate computation | O(N) | O(N) | +| Dual-tree traversal | O(N log N) | O(N log N) | +| Near-field evaluation | O(N log N) | O(chunk_size) | +| Far-field evaluation | O(N) | O(N_far_pairs) | +| Far-field broadcast | O(N log N) | O(N_targets) | +| **Total** | **O(N log N)** | **O(N log N)** | + +The far-field evaluation step is O(N) rather than O(N log N) because the +number of well-separated node pairs grows linearly for typical point +distributions. This is a concrete improvement over single-tree Barnes-Hut, +where each target individually evaluates against O(log N) source nodes. + +Compare with the all-to-all baseline: + +| Component | Time complexity | Memory complexity | +|--------------------|---------------------|---------------------| +| Dense displacement | O(N^2) | O(N^2) | +| Feature engineering| O(N^2) | O(N^2) | +| Network evaluation | O(N^2) | O(N^2) | +| Aggregation | O(N^2) | O(N) | +| **Total** | **O(N^2)** | **O(N^2)** | + +For N = 100k sources and targets, this represents a ~5000x reduction in +interaction count (from 10 billion to ~2 million at theta=1.0). + +--- + +## 10. Architecture Summary + +```text +GLOBE.forward() + | + +-- _build_trees_and_plans() [outside torch.compile] + | Build ClusterTree per BC type (B trees) + | Find DualInteractionPlan for all (src, dst) BC pairs (B^2 plans) + | + +-- Phase 2: Communication hyperlayers (repeat n_comm times) + | | + | +-- _evaluate_hyperlayer() + | | + | +-- MultiscaleKernel.forward() + | | + | +-- compute_source_aggregates() [once, shared across branches] + | +-- precompute near_chunk_sizes [outside checkpoint boundary] + | | + | +-- for each branch: + | checkpoint(BarnesHutKernel.forward(), ...) + | | + | +-- _compute_node_strengths() + | | + | +-- Phase A: Near-field + | | for each chunk: + | | checkpoint(_gather_and_evaluate()) + | | weight by strength, scatter_add to output + | | + | +-- Phase B: Far-field + | checkpoint(_gather_and_evaluate()) [node centroids] + | weight by node strength + | broadcast to individual targets via scatter_add + | + +-- _build_prediction_plans() [outside torch.compile] + | Build target tree for prediction points + | Find DualInteractionPlan (pred: different target points) + | + +-- Phase 3: Final evaluation + (same structure as communication, different target points) +``` + +--- + +## 11. Testing Strategy + +The implementation is validated through several complementary test categories: + +- **Convergence to exact**: As theta decreases toward 0, `BarnesHutKernel` + output converges monotonically to the exact `Kernel` output. At + theta = 0.01, the two agree within floating-point tolerance. Tested + across all combinations of 2D/3D, scalar/vector outputs, and + scalar/vector source features. + +- **Source coverage invariant**: For every target, the union of near-field + sources and far-field node subtrees equals the complete source set + `{0, ..., N-1}` with no duplicates and no omissions. This is the + fundamental correctness property of the dual-tree traversal. + +- **Gradient correctness**: Gradients through `BarnesHutKernel` match exact + `Kernel` gradients at high theta, verifying that the non-differentiable + traversal decisions do not corrupt gradient flow through the differentiable + kernel evaluation. + +- **Equivariance preservation**: Translation, rotation, and source-permutation + equivariance are preserved by the hierarchical approximation, verified at + both moderate and high theta. + +- **Nested key structure**: Tests with deeply nested TensorDict keys matching + GLOBE's actual production data format (physical/latent/strength namespaces). + +--- + +## 12. References + +- Barnes & Hut (1986). "A hierarchical O(N log N) force-calculation algorithm." + *Nature* 324, 446-449. +- Appel (1985). "An Efficient Program for Many-Body Simulation." *SIAM J. Sci. + Stat. Comput.* 6(1), 85-103. Early dual-tree variant of the Barnes-Hut idea. +- Gray & Moore (2001). "'N-Body' Problems in Statistical Learning." + *NIPS 2001*. Formalized dual-tree algorithms with generalized acceptance + criteria. +- Karras (2012). "Maximizing Parallelism in the Construction of BVHs, Octrees, + and k-d Trees." *HPG 2012*. The LBVH construction algorithm used here. +- Burtscher & Pingali (2011). "An Efficient CUDA Implementation of the + Tree-Based Barnes Hut n-Body Algorithm." *GPU Computing Gems Emerald Edition*. +- Lukat & Banerjee (2015). "A GPU accelerated Barnes-Hut tree code for FLASH4." + Describes AABB-distance opening criterion. +- Madan et al. (2025). "Stochastic Barnes-Hut Approximation of Kernel Matrices." + *SIGGRAPH 2025*. Uses a `beta = 1/theta` convention (inverted relative to + the original Barnes & Hut convention used in this codebase). diff --git a/physicsnemo/experimental/models/globe/model.py b/physicsnemo/experimental/models/globe/model.py index 746ddb16f5..9b84e74729 100644 --- a/physicsnemo/experimental/models/globe/model.py +++ b/physicsnemo/experimental/models/globe/model.py @@ -17,21 +17,27 @@ import operator from dataclasses import dataclass from functools import reduce -from typing import Literal, Sequence +from typing import Sequence import torch import torch.nn as nn from jaxtyping import Float from tensordict import TensorDict +from torch.profiler import record_function from physicsnemo.core.meta import ModelMetaData from physicsnemo.core.module import Module +from physicsnemo.experimental.models.globe.cluster_tree import ( + ClusterTree, + DualInteractionPlan, +) from physicsnemo.experimental.models.globe.field_kernel import MultiscaleKernel from physicsnemo.experimental.models.globe.utilities.rank_spec import ( RankSpecDict, flatten_rank_spec, ) from physicsnemo.mesh import Mesh +from physicsnemo.utils.logging import PythonLogger # allow_in_graph wraps these TensorDict methods as opaque graph nodes so that # torch.compile doesn't trace into them (their internals cause graph breaks). @@ -43,15 +49,13 @@ # these wrappers. _flatten_keys = torch.compiler.allow_in_graph(TensorDict.flatten_keys) _unflatten_keys = torch.compiler.allow_in_graph(TensorDict.unflatten_keys) -from physicsnemo.utils.logging import PythonLogger logger = PythonLogger("globe.model") - @dataclass class MetaData(ModelMetaData): - jit: bool = True - cuda_graphs: bool = True + jit: bool = False # Refers to torch.compile compatibility - this is compatible. + cuda_graphs: bool = False # Computational graph changes depending on inputs due to tree traversals amp: bool = True torch_fx: bool = False onnx: bool = False @@ -122,12 +126,28 @@ class GLOBE(Module): n_spherical_harmonics : int, optional, default=4 Number of Legendre polynomial terms used for angle-dependent features in kernel functions. + theta : float, optional, default=1.0 + Barnes-Hut opening angle controlling the near/far-field split in the + dual-tree traversal. The criterion is + :math:`(D_T + D_S) / r < \theta`, where :math:`D_T` and :math:`D_S` + are AABB diagonals and :math:`r` is the minimum inter-AABB distance. + Larger values approximate more aggressively; ``0`` forces all + interactions to be exact (no far-field approximation). + leaf_size : int, optional, default=1 + Maximum number of source points per leaf node in the cluster tree. + Larger values produce shallower trees (fewer traversal iterations) at + the cost of more exact near-field interactions per leaf hit. + expand_far_targets : bool, optional, default=False + If ``True``, far-field target nodes are expanded to individual points, + converting ``(far, far)`` pairs into ``(near, far)`` pairs. This + eliminates the target-side approximation at the cost of more kernel + evaluations. Forward ------- prediction_points : Float[torch.Tensor, "n_points n_dims"] Target points for field evaluation of shape :math:`(N_{points}, D)`. - boundary_meshes : dict[str, Mesh] + boundary_meshes : dict[str, Mesh["n-1", "n"]] Dictionary mapping boundary condition type names to :class:`~physicsnemo.mesh.Mesh` objects. Keys must be a subset of the model's boundary condition names (from ``boundary_source_data_ranks``). @@ -136,12 +156,10 @@ class GLOBE(Module): global_data : TensorDict or None, optional, default=None Nondimensional conditioning features. Leaf keys and ranks must match ``global_data_ranks``. Passed through to the output Mesh. - chunk_size : None | int | Literal["auto"], optional, default=None - Controls memory usage during kernel evaluation. Outputs ------- - Mesh + Mesh[0, "n"] A point-cloud :class:`~physicsnemo.mesh.Mesh` (0-dimensional manifold) whose ``.points`` attribute equals the input ``prediction_points``. The predicted fields are in ``.point_data``, keyed by the names from @@ -161,6 +179,11 @@ class GLOBE(Module): - Cell areas are automatically normalized by ``reference_area`` to preserve discretization-invariance. - The cell normal vector is automatically added to source data for each mesh. + - The ``Mesh["n-1", "n"]`` type annotations assume the PDE domain fills the + full ambient space (domain manifold dim = spatial dim), so boundary meshes + are codimension-1 in the ambient space. For a PDE on a ``d``-dimensional + manifold embedded in ``n``-dimensional space (``d < n``), the boundary + type would be ``Mesh[d-1, n]`` instead. Examples -------- @@ -197,6 +220,9 @@ def __init__( smoothing_radius: float = 1e-8, hidden_layer_sizes: Sequence[int] | None = None, n_spherical_harmonics: int = 4, + theta: float = 1.0, + leaf_size: int = 1, + expand_far_targets: bool = False, ): if hidden_layer_sizes is None: hidden_layer_sizes = [64, 64, 64] @@ -234,6 +260,9 @@ def __init__( self.smoothing_radius = smoothing_radius self.hidden_layer_sizes = hidden_layer_sizes self.n_spherical_harmonics = n_spherical_harmonics + self.theta = theta + self.leaf_size = leaf_size + self.expand_far_targets = expand_far_targets ### Build the intermediate output-field rank spec for communication # hyperlayers. Only the final hyperlayer emits output_field_ranks. @@ -267,6 +296,7 @@ def __init__( smoothing_radius=smoothing_radius, hidden_layer_sizes=hidden_layer_sizes, n_spherical_harmonics=n_spherical_harmonics, + leaf_size=leaf_size, ) for bc_type in boundary_condition_names } @@ -312,50 +342,157 @@ def _build_source_data_ranks( } return result + @torch.compiler.disable + def _build_trees_and_plans( + self, + boundary_meshes: dict[str, Mesh["n-1", "n"]], # ty: ignore[unresolved-reference] + ) -> tuple[ + dict[str, ClusterTree], + dict[str, torch.Tensor], + dict[str, dict[str, DualInteractionPlan]], + ]: + """Build per-BC-type cluster trees and cross-BC dual interaction plans. + + Builds one :class:`ClusterTree` per BC type (O(B) trees), then computes + a :class:`DualInteractionPlan` for every (source BC, destination BC) + pair (B^2 plans total). For self-interaction (source == destination), + the target tree is the same object as the source tree. Plans are + reused across all communication layers since the geometry is fixed. + + Returns + ------- + cluster_trees : dict[str, ClusterTree] + Per-BC-type cluster trees built from cell centroids. + bc_areas : dict[str, torch.Tensor] + Per-BC-type normalized cell area tensors. + comm_plans : dict[str, dict[str, DualInteractionPlan]] + Communication plans indexed as ``comm_plans[dst_bc][src_bc]``. + """ + from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree + + cluster_trees: dict[str, ClusterTree] = {} + bc_areas: dict[str, torch.Tensor] = {} + for bc_type, mesh in boundary_meshes.items(): + areas = mesh.cell_areas / self.reference_area + bc_areas[bc_type] = areas + cluster_trees[bc_type] = ClusterTree.from_points( + mesh.cell_centroids, leaf_size=self.leaf_size, areas=areas + ) + + ### Build interaction plans for all (source, destination) BC pairs. + comm_plans: dict[str, dict[str, DualInteractionPlan]] = {} + for dst_bc in boundary_meshes: + comm_plans[dst_bc] = { + src_bc: cluster_trees[src_bc].find_dual_interaction_pairs( + target_tree=cluster_trees[dst_bc], theta=self.theta, + expand_far_targets=self.expand_far_targets, + ) + for src_bc in boundary_meshes + } + + for dst_bc, plans_for_dst in comm_plans.items(): + n_dst = boundary_meshes[dst_bc].n_cells + for src_bc, plan in plans_for_dst.items(): + n_src = boundary_meshes[src_bc].n_cells + logger.logger.debug( + "comm plan [%s -> %s]: %d near + %d nf + %d fn + %d far_node " + "(%.2f%% near-field, %d src x %d dst faces, " + "theta=%.2f, leaf_size=%d)", + src_bc, dst_bc, + plan.n_near, plan.n_nf, plan.n_fn, plan.n_far_nodes, + 100.0 * plan.n_near / max(n_src * n_dst, 1), + n_src, n_dst, self.theta, self.leaf_size, + ) + + return cluster_trees, bc_areas, comm_plans + + @torch.compiler.disable + def _build_prediction_plans( + self, + cluster_trees: dict[str, ClusterTree], + prediction_points: torch.Tensor, + ) -> tuple[ClusterTree, dict[str, DualInteractionPlan]]: + """Build target tree and dual plans for prediction-point evaluation. + + Builds a single target tree from ``prediction_points`` and computes + one :class:`DualInteractionPlan` per source BC type against it. + + Returns + ------- + pred_target_tree : ClusterTree + Target tree built from prediction points. + pred_plans : dict[str, DualInteractionPlan] + Plans indexed by source BC type, each computed from that source + BC's tree to ``pred_target_tree``. + """ + from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree + + pred_target_tree = ClusterTree.from_points( + prediction_points, leaf_size=self.leaf_size, + ) + pred_plans = { + bc_type: tree.find_dual_interaction_pairs( + target_tree=pred_target_tree, theta=self.theta, + expand_far_targets=self.expand_far_targets, + ) + for bc_type, tree in cluster_trees.items() + } + + n_pred = prediction_points.shape[0] + for bc_type, plan in pred_plans.items(): + n_src = cluster_trees[bc_type].n_sources + logger.logger.debug( + "pred plan [%s]: %d near + %d nf + %d fn + %d far_node " + "(%d sources x %d targets, theta=%.2f)", + bc_type, plan.n_near, plan.n_nf, plan.n_fn, plan.n_far_nodes, + n_src, n_pred, self.theta, + ) + + return pred_target_tree, pred_plans + def _evaluate_hyperlayer( self, layer_idx: int, target_points: Float[torch.Tensor, "n_targets n_dims"], - source_meshes: dict[str, Mesh], + source_meshes: dict[str, Mesh["n-1", "n"]], # ty: ignore[unresolved-reference] reference_lengths: dict[str, Float[torch.Tensor, ""]], global_data: TensorDict[str, Float[torch.Tensor, "..."]] | None, - chunk_size: None | int | Literal["auto"], + cluster_trees: dict[str, ClusterTree], + target_tree: ClusterTree | None, + dual_plans: dict[str, DualInteractionPlan], + source_areas: dict[str, torch.Tensor], ) -> TensorDict[str, Float[torch.Tensor, "n_targets ..."]]: r"""Evaluate one hyperlayer by summing kernel contributions from all BC types. - For each boundary condition type, extracts source data from the mesh's - enriched ``cell_data``, evaluates the corresponding - :class:`MultiscaleKernel`, and sums the results. - - Each mesh's ``cell_data`` carries a namespaced structure: - - - ``"physical"``: original boundary condition features - - ``"strengths"``: per-reference-length scalar multipliers that modulate - each source face's kernel contribution (learned during communication - and area-normalized before use) - - ``"latent"``: (after first layer) learned scalar and vector features - - Strengths are extracted and area-normalized separately. All remaining - features (plus cell normals) are combined into a unified - ``source_data`` TensorDict and passed to the kernel, which splits - them by tensor rank internally. + Each call evaluates all source BC types against a single set of target + points. The ``target_tree`` and per-source-BC ``dual_plans`` must + correspond to those target points. Parameters ---------- layer_idx : int - Index into ``self.kernel_layers`` selecting which hyperlayer to evaluate. + Index into ``self.kernel_layers``. target_points : Float[torch.Tensor, "n_targets n_dims"] Target points of shape :math:`(N_{targets}, D)`. - source_meshes : dict[str, Mesh] - Mapping of BC type names to enriched :class:`~physicsnemo.mesh.Mesh` - objects whose ``cell_data`` carries both physical features and latent - state. + source_meshes : dict[str, Mesh["n-1", "n"]] + Enriched boundary meshes with cell_data containing physical + features, strengths, and (after layer 0) latent state. reference_lengths : dict[str, Float[torch.Tensor, ""]] - Mapping of reference length names to scalar tensors. + Reference length names to scalar tensors. global_data : TensorDict or None - Problem-level features (mixed scalar/vector ranks). - chunk_size : None or int or {"auto"} - Controls memory usage during kernel evaluation. + Problem-level features. + cluster_trees : dict[str, ClusterTree] + Per-BC-type precomputed source trees. + target_tree : ClusterTree or None + Precomputed target tree shared by all source BCs in this call. + For communication self-interaction, this is the destination BC's + own cluster tree. If ``None``, each kernel branch builds a tree + from ``target_points`` on the fly. + dual_plans : dict[str, DualInteractionPlan] + Per-source-BC-type precomputed dual interaction plans, each + computed from that source BC's tree to ``target_tree``. + source_areas : dict[str, torch.Tensor] + Per-BC-type source area tensors. Returns ------- @@ -371,9 +508,6 @@ def _evaluate_hyperlayer( ) ) - ### Combine non-strength features with cell normals into source_data. - # flatten_keys produces a flat namespace so the kernel's - # split_by_leaf_rank can separate scalars from vectors by rank. source_data = _flatten_keys(mesh.cell_data.exclude("strengths")) source_data["normals"] = mesh.cell_normals @@ -385,7 +519,11 @@ def _evaluate_hyperlayer( target_points=target_points, reference_lengths=reference_lengths, global_data=global_data, - chunk_size=chunk_size, + theta=self.theta, + cluster_tree=cluster_trees[bc_type], + target_tree=target_tree, + dual_plan=dual_plans[bc_type], + source_areas=source_areas[bc_type], ) result_pieces.append(_unflatten_keys(kernel_result)) @@ -394,40 +532,28 @@ def _evaluate_hyperlayer( def _evaluate_communication_hyperlayer( self, layer_idx: int, - boundary_meshes: dict[str, Mesh], + boundary_meshes: dict[str, Mesh["n-1", "n"]], # ty: ignore[unresolved-reference] reference_lengths: dict[str, Float[torch.Tensor, ""]], global_data: TensorDict[str, Float[torch.Tensor, "..."]] | None, - chunk_size: None | int | Literal["auto"], - ) -> dict[str, Mesh]: + cluster_trees: dict[str, ClusterTree], + comm_plans: dict[str, dict[str, DualInteractionPlan]], + source_areas: dict[str, torch.Tensor], + ) -> dict[str, Mesh["n-1", "n"]]: # ty: ignore[unresolved-reference] r"""Run one boundary-to-boundary communication step. - For each BC type, evaluates :meth:`_evaluate_hyperlayer` at the mesh's - cell centroids and wraps the result into an enriched Mesh that carries - both the original physical ``cell_data`` (under ``"physical"``) and the - new latent state (``"strengths"``, ``"latent"``). - - Geometry tensors and cached properties (centroids, areas, normals) are - shared by reference across layers - no copies are made. - - Parameters - ---------- - layer_idx : int - Index into ``self.kernel_layers`` for this communication layer. - boundary_meshes : dict[str, Mesh] - Current enriched boundary meshes (from the previous layer or init). - reference_lengths : dict[str, Float[torch.Tensor, ""]] - Mapping of reference length names to scalar tensors. - global_data : TensorDict[str, Float[torch.Tensor, "..."]] or None - Problem-level features (mixed scalar/vector ranks). - chunk_size : None or int or {"auto"} - Controls memory usage during kernel evaluation. + For each destination BC type, evaluates :meth:`_evaluate_hyperlayer` + at that BC's cell centroids, summing contributions from all source + BC types. The target tree for each destination is that BC's own + cluster tree; for self-interaction (source == destination), this is + the same object as the source tree. Returns ------- - dict[str, Mesh] - New enriched boundary meshes for the next layer. + dict[str, Mesh["n-1", "n"]] + Updated boundary meshes with evaluation results merged into + each mesh's ``cell_data``. """ - new_meshes: dict[str, Mesh] = {} + new_meshes: dict[str, Mesh["n-1", "n"]] = {} # ty: ignore[unresolved-reference] for bc_type, mesh in boundary_meshes.items(): result_td = self._evaluate_hyperlayer( layer_idx=layer_idx, @@ -435,7 +561,10 @@ def _evaluate_communication_hyperlayer( source_meshes=boundary_meshes, reference_lengths=reference_lengths, global_data=global_data, - chunk_size=chunk_size, + cluster_trees=cluster_trees, + target_tree=cluster_trees[bc_type], + dual_plans=comm_plans[bc_type], + source_areas=source_areas, ) new_cell_data = TensorDict( {"physical": mesh.cell_data["physical"]}, @@ -454,11 +583,10 @@ def _evaluate_communication_hyperlayer( def forward( self, prediction_points: Float[torch.Tensor, "n_points n_dims"], - boundary_meshes: dict[str, Mesh], + boundary_meshes: dict[str, Mesh["n-1", "n"]], # ty: ignore[unresolved-reference] reference_lengths: dict[str, torch.Tensor], global_data: TensorDict[str, Float[torch.Tensor, "..."]] | None = None, - chunk_size: None | int | Literal["auto"] = None, - ) -> Mesh: + ) -> Mesh[0, "n"]: # ty: ignore[unresolved-reference] r"""Evaluate GLOBE model to predict fields at target points. Runs the full GLOBE forward pass in three phases: @@ -475,7 +603,7 @@ def forward( ---------- prediction_points : Float[torch.Tensor, "n_points n_dims"] Target points of shape :math:`(N_{points}, D)`. - boundary_meshes : dict[str, Mesh] + boundary_meshes : dict[str, Mesh["n-1", "n"]] Dictionary mapping BC type names to pre-merged :class:`~physicsnemo.mesh.Mesh` objects. reference_lengths : dict[str, torch.Tensor] @@ -483,20 +611,11 @@ def forward( global_data : TensorDict or None, optional, default=None Nondimensional conditioning features. Leaf keys and ranks must match ``global_data_ranks``. Passed through to the output Mesh. - chunk_size : None | int | Literal["auto"], optional, default=None - Controls memory usage during kernel evaluation. Returns ------- - Mesh - A point-cloud Mesh (0-dimensional manifold) with: - - - ``points``: the input ``prediction_points`` - - ``point_data``: calibrated output fields (keys from - ``output_fields``) - - ``global_data``: the input ``global_data``, passed through - - ``cells``: empty (shape ``(0, 1)``) - - ``cell_data``: empty + Mesh[0, "n"] + A point-cloud Mesh (0-dimensional manifold) with predicted fields. """ device = prediction_points.device @@ -504,7 +623,6 @@ def forward( global_data = TensorDict({}, device=device) ### Input validation - # Skip validation when running under torch.compile for performance if not torch.compiler.is_compiling(): if prediction_points.ndim != 2: raise ValueError( @@ -540,62 +658,92 @@ def forward( ) ### Phase 1: Enrich boundary meshes with initial (all-ones) strengths. - # Wraps original cell_data under "physical" and adds "strengths". - # Geometry tensors are shared by reference - no copies. - boundary_meshes = { - bc_type: Mesh( - points=mesh.points, - cells=mesh.cells, - cell_data=TensorDict( - { - "physical": mesh.cell_data, - "strengths": TensorDict( - { - name: torch.ones(mesh.n_cells, device=device) - for name in self.reference_length_names - }, - batch_size=torch.Size([mesh.n_cells]), - device=device, - ), - }, - batch_size=torch.Size([mesh.n_cells]), - device=device, - ), - _cache=mesh._cache, + with record_function("globe::enrich_meshes"): + boundary_meshes = { + bc_type: Mesh( + points=mesh.points, + cells=mesh.cells, + cell_data=TensorDict( + { + "physical": mesh.cell_data, + "strengths": TensorDict( + { + name: torch.ones(mesh.n_cells, device=device) + for name in self.reference_length_names + }, + batch_size=torch.Size([mesh.n_cells]), + device=device, + ), + }, + batch_size=torch.Size([mesh.n_cells]), + device=device, + ), + _cache=mesh._cache, + ) + for bc_type, mesh in boundary_meshes.items() + } + + ### Build per-BC-type trees and areas (reused across all layers). + ### Tree construction and traversal involve irregular control flow + ### (morton codes, variable-depth loops) that cannot be traced by + ### torch.compile, so we skip compilation for this block. + with record_function("globe::build_trees_and_plans"): + cluster_trees, bc_areas, comm_plans = self._build_trees_and_plans( + boundary_meshes ) - for bc_type, mesh in boundary_meshes.items() - } ### Phase 2: Communication hyperlayers (boundary-to-boundary). + # Trees and comm_plans are reused across all layers because cell + # centroids (the source/target points) are fixed - only the + # cell_data (latent features, strengths) changes between layers. for i in range(self.n_communication_hyperlayers): - boundary_meshes = self._evaluate_communication_hyperlayer( - layer_idx=i, - boundary_meshes=boundary_meshes, + with record_function(f"globe::communication_layer/{i}"): + boundary_meshes = self._evaluate_communication_hyperlayer( + layer_idx=i, + boundary_meshes=boundary_meshes, + reference_lengths=reference_lengths, + global_data=global_data, + cluster_trees=cluster_trees, + comm_plans=comm_plans, + source_areas=bc_areas, + ) + + ### Free comm plans - no longer needed after communication layers. + # At 800k faces, near-pair indices can be ~3 GB of int64. + del comm_plans + + ### Phase 3: Final evaluation at prediction points. + with record_function("globe::build_prediction_plans"): + pred_target_tree, pred_plans = self._build_prediction_plans( + cluster_trees, prediction_points + ) + + with record_function("globe::final_evaluation"): + result: TensorDict[str, Float[torch.Tensor, "n_points ..."]] = self._evaluate_hyperlayer( + layer_idx=self.n_communication_hyperlayers, + target_points=prediction_points, + source_meshes=boundary_meshes, reference_lengths=reference_lengths, global_data=global_data, - chunk_size=chunk_size, + cluster_trees=cluster_trees, + target_tree=pred_target_tree, + dual_plans=pred_plans, + source_areas=bc_areas, ) - ### Phase 3: Final evaluation at prediction points. - result: TensorDict[str, Float[torch.Tensor, "n_points ..."]] = self._evaluate_hyperlayer( - layer_idx=self.n_communication_hyperlayers, - target_points=prediction_points, - source_meshes=boundary_meshes, - reference_lengths=reference_lengths, - global_data=global_data, - chunk_size=chunk_size, - ) + del pred_plans, pred_target_tree ### Wrap as point-cloud Mesh and apply per-field calibration. - output_mesh = Mesh( - points=prediction_points, - point_data=result, - global_data=global_data, - ) - for idx, name in enumerate(self._output_field_order): - key = tuple(name.split(".")) - t = output_mesh.point_data[key] - output_mesh.point_data[key] = self.final_field_transforms[idx]( - t.view(-1, 1) - ).view(t.shape) + with record_function("globe::calibration"): + output_mesh = Mesh( + points=prediction_points, + point_data=result, + global_data=global_data, + ) + for idx, name in enumerate(self._output_field_order): + key = tuple(name.split(".")) + t = output_mesh.point_data[key] + output_mesh.point_data[key] = self.final_field_transforms[idx]( + t.view(-1, 1) + ).view(t.shape) return output_mesh diff --git a/physicsnemo/mesh/spatial/_ragged.py b/physicsnemo/mesh/spatial/_ragged.py new file mode 100644 index 0000000000..d09befb116 --- /dev/null +++ b/physicsnemo/mesh/spatial/_ragged.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Segmented (ragged) tensor utilities for spatial data structures.""" + +import torch + + +def _ragged_arange( + starts: torch.Tensor, + counts: torch.Tensor, + total: int | torch.SymInt | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Expand segment descriptors ``(start, count)`` into flat index arrays. + + Given *N* segments where segment *i* spans positions + ``[starts[i], starts[i] + counts[i])``, produces two flat tensors of + length ``sum(counts)``: + + - ``positions[k]``: the absolute index for element *k* + - ``seg_ids[k]``: the segment (``0..N-1``) that element *k* belongs to + + Conceptually, this concatenates ``arange(s, s+c)`` for each ``(s, c)`` + pair, along with the corresponding segment labels. + + The implementation uses ``searchsorted`` rather than + ``repeat_interleave``, so it is fully traceable by ``torch.compile``. + + Parameters + ---------- + starts : torch.Tensor + Start offset per segment, shape ``(N,)``, int64. + counts : torch.Tensor + Element count per segment, shape ``(N,)``, int64. + Entries may be zero (those segments produce no output elements). + total : int | torch.SymInt | None, optional + Pre-computed ``counts.sum()``. When available from a tensor shape + (e.g. ``some_tensor.shape[0]``), passing it avoids an internal + ``.item()`` call and the associated ``torch.compile`` graph break. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(positions, seg_ids)`` each with shape ``(sum(counts),)``. + """ + device = starts.device + if total is None: + total = counts.sum() + + # Exclusive prefix sum: flat-space start of each segment. + seg_start_flat = counts.cumsum(0) - counts # (N,) + + # For each flat position, find the owning segment via binary search. + # searchsorted(right=True) returns the index *after* the last matching + # entry, so subtracting 1 gives the last segment whose flat start is + # <= the query position. This correctly skips zero-count segments + # (which share a seg_start_flat value with the next non-zero segment). + flat_idx = torch.arange(total, dtype=torch.long, device=device) + seg_ids = torch.searchsorted(seg_start_flat, flat_idx, right=True) - 1 + + # Within-segment offsets: [0, 1, ..., c0-1, 0, 1, ..., c1-1, ...] + intra_offset = flat_idx - seg_start_flat[seg_ids] + positions = starts[seg_ids] + intra_offset + + return positions, seg_ids diff --git a/physicsnemo/mesh/spatial/bvh.py b/physicsnemo/mesh/spatial/bvh.py index 102e933d8b..4d66371f00 100644 --- a/physicsnemo/mesh/spatial/bvh.py +++ b/physicsnemo/mesh/spatial/bvh.py @@ -31,6 +31,7 @@ from tensordict import tensorclass from physicsnemo.mesh.neighbors._adjacency import Adjacency, build_adjacency_from_pairs +from physicsnemo.mesh.spatial._ragged import _ragged_arange if TYPE_CHECKING: from physicsnemo.mesh.mesh import Mesh @@ -133,25 +134,17 @@ def _expand_leaf_hits( """ starts = leaf_start[leaf_node_indices] # (n_hits,) counts = leaf_count[leaf_node_indices] # (n_hits,) - total = int(counts.sum()) device = leaf_query_indices.device - if total == 0: + if int(counts.sum()) == 0: return ( torch.empty(0, dtype=torch.long, device=device), torch.empty(0, dtype=torch.long, device=device), ) - ### Expand query indices: repeat each by its leaf's cell count expanded_queries = torch.repeat_interleave(leaf_query_indices, counts) - ### Compute position-within-leaf offsets: [0,1,...,c0-1, 0,1,...,c1-1, ...] - cum = counts.cumsum(0) - offsets_within = torch.arange(total, dtype=torch.long, device=device) - offsets_within = offsets_within - torch.repeat_interleave(cum - counts, counts) - - ### Map to original cell indices through the sorted permutation - sorted_positions = torch.repeat_interleave(starts, counts) + offsets_within + sorted_positions, _ = _ragged_arange(starts, counts) expanded_cells = sorted_cell_order[sorted_positions] return expanded_queries, expanded_cells @@ -193,27 +186,15 @@ def _compute_leaf_aabbs( D = sorted_aabb_min.shape[1] dtype = sorted_aabb_min.dtype n_leaf_segs = len(leaf_seg_starts) - total_cells = leaf_seg_sizes.sum().item() - if total_cells == 0 or n_leaf_segs == 0: + if int(leaf_seg_sizes.sum()) == 0 or n_leaf_segs == 0: return ( torch.empty((0, D), dtype=dtype, device=device), torch.empty((0, D), dtype=dtype, device=device), ) - ### Build segment-ID for each cell across all leaf segments - seg_ids = torch.repeat_interleave( - torch.arange(n_leaf_segs, dtype=torch.long, device=device), - leaf_seg_sizes, - ) # (total_cells,) - - ### Build positions into the sorted cell array - cum = leaf_seg_sizes.cumsum(0) - offsets = torch.arange(total_cells, dtype=torch.long, device=device) - offsets = offsets - torch.repeat_interleave(cum - leaf_seg_sizes, leaf_seg_sizes) - cell_pos = torch.repeat_interleave(leaf_seg_starts, leaf_seg_sizes) + offsets + cell_pos, seg_ids = _ragged_arange(leaf_seg_starts, leaf_seg_sizes) - ### Gather cell AABBs cell_mins = sorted_aabb_min[cell_pos] # (total_cells, D) cell_maxs = sorted_aabb_max[cell_pos] # (total_cells, D) diff --git a/test/mesh/spatial/test_ragged.py b/test/mesh/spatial/test_ragged.py new file mode 100644 index 0000000000..50251dc9a4 --- /dev/null +++ b/test/mesh/spatial/test_ragged.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _ragged_arange segmented tensor utility.""" + +import pytest +import torch +from torch._dynamo.utils import counters + +from physicsnemo.mesh.spatial._ragged import _ragged_arange + + +@pytest.mark.parametrize( + "starts, counts", + [ + pytest.param([0, 5, 10], [3, 2, 4], id="basic"), + pytest.param([7], [5], id="single_segment"), + pytest.param([0, 1, 2, 3], [1, 1, 1, 1], id="all_ones"), + pytest.param([0, 10], [10, 3], id="unequal"), + pytest.param([100, 200, 300], [1, 1, 1], id="large_starts"), + pytest.param([10, 20, 30], [2, 0, 3], id="zero_middle"), + pytest.param([10, 20, 30], [0, 0, 3], id="zero_leading"), + pytest.param([10, 20, 30], [3, 0, 0], id="zero_trailing"), + pytest.param([10, 20, 30, 40, 50], [2, 0, 0, 3, 0], id="zero_interleaved"), + pytest.param([10, 20, 30], [0, 0, 0], id="zero_all"), + pytest.param([10], [0], id="zero_single"), + pytest.param([10, 20, 30], [1, 0, 1], id="zero_between_units"), + pytest.param([], [], id="empty"), + ], +) +def test_ragged_arange_correctness(starts: list[int], counts: list[int]): + """Verify positions and seg_ids match the naive per-segment arange.""" + starts_t = torch.tensor(starts) + counts_t = torch.tensor(counts) + + positions, seg_ids = _ragged_arange(starts_t, counts_t) + + # Build expected output the obvious way + pos_parts = [torch.arange(s, s + c) for s, c in zip(starts, counts)] + seg_parts = [torch.full((c,), i, dtype=torch.long) for i, c in enumerate(counts)] + expected_pos = ( + torch.cat(pos_parts) if pos_parts else torch.empty(0, dtype=torch.long) + ) + expected_seg = ( + torch.cat(seg_parts) if seg_parts else torch.empty(0, dtype=torch.long) + ) + + assert torch.equal(positions, expected_pos) + assert torch.equal(seg_ids, expected_seg) + + +def test_ragged_arange_explicit_total(): + """When total is passed, it should be used instead of counts.sum().""" + starts = torch.tensor([0, 5, 10]) + counts = torch.tensor([3, 2, 4]) + + pos1, seg1 = _ragged_arange(starts, counts) + pos2, seg2 = _ragged_arange(starts, counts, total=9) + + assert torch.equal(pos1, pos2) + assert torch.equal(seg1, seg2) + + +@pytest.mark.parametrize( + "starts, counts", + [ + pytest.param([0, 5, 10], [3, 2, 4], id="no_zeros"), + pytest.param([10, 20, 30], [2, 0, 3], id="with_zeros"), + ], +) +def test_ragged_arange_no_graph_break_with_explicit_total( + starts: list[int], + counts: list[int], +): + """searchsorted implementation + explicit total should produce zero graph breaks.""" + + def fn(starts_t, counts_t, total_holder): + pos, seg = _ragged_arange(starts_t, counts_t, total=total_holder.shape[0]) + return pos.sum() + seg.sum() + + starts_t = torch.tensor(starts) + counts_t = torch.tensor(counts) + total_holder = torch.empty(int(counts_t.sum())) + + counters.clear() + compiled = torch.compile(fn, dynamic=True, backend="eager") + compiled(starts_t, counts_t, total_holder) + + n_breaks = ( + sum(counters["graph_break"].values()) if counters.get("graph_break") else 0 + ) + assert n_breaks == 0, f"Expected 0 graph breaks, got {n_breaks}" diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py new file mode 100644 index 0000000000..7b8d3a7d9c --- /dev/null +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -0,0 +1,1549 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Barnes-Hut accelerated kernel evaluation. + +Covers: ClusterTree construction and aggregation, BarnesHutKernel convergence +to exact results, gradient correctness, equivariance preservation, and +MultiscaleKernel integration. +""" + +from typing import Any, Literal + +import pytest +import torch +import torch.nn.functional as F +from tensordict import TensorDict + +from physicsnemo.experimental.models.globe.cluster_tree import ( + ClusterTree, + DualInteractionPlan, +) +from physicsnemo.experimental.models.globe.field_kernel import ( + BarnesHutKernel, + Kernel, + MultiscaleKernel, +) +from physicsnemo.mesh.spatial._ragged import _ragged_arange + +DEFAULT_SEED = 42 +DEFAULT_LEAF_SIZE = 4 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bh_kernel_and_data( + n_spatial_dims: int = 2, + n_source_scalars: int = 0, + n_source_vectors: int = 1, + output_fields: dict[str, Literal["scalar", "vector"]] | None = None, + n_global_scalars: int = 0, + n_global_vectors: int = 0, + hidden_layer_sizes: list[int] | None = None, + n_source_points: int = 30, + n_target_points: int = 20, + leaf_size: int = DEFAULT_LEAF_SIZE, + device: str = "cpu", + seed: int = DEFAULT_SEED, +) -> tuple[BarnesHutKernel, Kernel, dict[str, Any]]: + """Create matched BH and exact kernels with shared weights and test data.""" + if output_fields is None: + output_fields = {"pressure": "scalar", "velocity": "vector"} + if hidden_layer_sizes is None: + hidden_layer_sizes = [32, 32] + + device_obj = torch.device(device) + torch.manual_seed(seed) + + output_field_ranks = { + k: (0 if v == "scalar" else 1) for k, v in output_fields.items() + } + source_data_ranks = { + **{f"source_scalar_{i}": 0 for i in range(n_source_scalars)}, + **{f"source_vector_{i}": 1 for i in range(n_source_vectors)}, + } + global_data_ranks = { + **{f"global_scalar_{i}": 0 for i in range(n_global_scalars)}, + **{f"global_vector_{i}": 1 for i in range(n_global_vectors)}, + } + + common_kwargs = dict( + n_spatial_dims=n_spatial_dims, + output_field_ranks=output_field_ranks, + source_data_ranks=source_data_ranks, + global_data_ranks=global_data_ranks, + hidden_layer_sizes=hidden_layer_sizes, + ) + + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=leaf_size).to(device_obj) + exact_kernel = Kernel(**common_kwargs).to(device_obj) + + # Share weights so outputs are comparable + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + bh_kernel.eval() + exact_kernel.eval() + + torch.manual_seed(seed + 1) + + source_data_dict: dict[str, torch.Tensor] = {} + for i in range(n_source_scalars): + source_data_dict[f"source_scalar_{i}"] = torch.randn( + n_source_points, device=device_obj + ) + for i in range(n_source_vectors): + source_data_dict[f"source_vector_{i}"] = F.normalize( + torch.randn(n_source_points, n_spatial_dims, device=device_obj), dim=-1 + ) + + global_data_dict: dict[str, torch.Tensor] = {} + for i in range(n_global_scalars): + global_data_dict[f"global_scalar_{i}"] = torch.randn( + 1, device=device_obj + ).squeeze() + for i in range(n_global_vectors): + global_data_dict[f"global_vector_{i}"] = F.normalize( + torch.randn(n_spatial_dims, device=device_obj), dim=0 + ) + + input_data = { + "source_points": torch.randn( + n_source_points, n_spatial_dims, device=device_obj + ), + "target_points": torch.randn(n_target_points, n_spatial_dims, device=device_obj) + * 5, + "source_strengths": torch.randn(n_source_points, device=device_obj).abs() + 0.1, + "reference_length": torch.ones((), device=device_obj), + "source_data": TensorDict( + source_data_dict, batch_size=[n_source_points], device=device_obj + ), + "global_data": TensorDict(global_data_dict, batch_size=[], device=device_obj), + } + + return bh_kernel, exact_kernel, input_data + + +# --------------------------------------------------------------------------- +# ClusterTree tests +# --------------------------------------------------------------------------- + + +class TestClusterTree: + """Tests for ClusterTree construction and traversal.""" + + def test_construction_basic(self): + """Tree construction produces valid node structure.""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(50, 3) + tree = ClusterTree.from_points(points, leaf_size=4) + + assert tree.n_nodes > 0 + assert tree.n_sources == 50 + assert tree.n_spatial_dims == 3 + assert tree.sorted_source_order.shape == (50,) + # Sorted order is a permutation of [0, N) + assert set(tree.sorted_source_order.tolist()) == set(range(50)) + + def test_construction_empty(self): + """Empty point set produces empty tree.""" + tree = ClusterTree.from_points(torch.empty(0, 2), leaf_size=4) + assert tree.n_nodes == 0 + assert tree.n_sources == 0 + + def test_construction_single_point(self): + """Single point produces a single-leaf tree.""" + tree = ClusterTree.from_points(torch.randn(1, 2), leaf_size=4) + assert tree.n_nodes == 1 + assert tree.leaf_count[0].item() == 1 + + def test_aabb_containment(self): + """Every source point is contained in the root's AABB.""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(100, 3) + tree = ClusterTree.from_points(points, leaf_size=8) + + root_min = tree.node_aabb_min[0] + root_max = tree.node_aabb_max[0] + + assert (points >= root_min - 1e-6).all(), "Some points below root AABB min" + assert (points <= root_max + 1e-6).all(), "Some points above root AABB max" + + def test_leaf_source_coverage(self): + """All sources are covered by exactly one leaf node.""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(60, 2) + tree = ClusterTree.from_points(points, leaf_size=8) + + is_leaf = tree.leaf_count > 0 + leaf_ids = torch.where(is_leaf)[0] + total_sources = tree.leaf_count[leaf_ids].sum().item() + assert total_sources == 60, ( + f"Expected 60 sources in leaves, got {total_sources}" + ) + + @pytest.mark.parametrize("n_dims", [2, 3]) + @pytest.mark.parametrize("theta", [0.3, 1.0, 5.0]) + def test_interaction_plan_source_coverage(self, n_dims: int, theta: float): + """For every target, near + far pairs cover all sources exactly once. + + This is the fundamental invariant of the dual-tree traversal: + every source must be accounted for (no omissions) and no source + may be double-counted (no duplicates). For far-field node pairs, + we expand both the target node (to individual targets) and the + source node (to individual sources via DFS) to verify coverage. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 40, 10 + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta + ) + + all_sources = set(range(n_src)) + + def _collect_sources(tree: ClusterTree, node_id: int) -> set[int]: + """DFS to collect all source indices under a tree node.""" + count = tree.leaf_count[node_id].item() + if count > 0: + start = tree.leaf_start[node_id].item() + return { + tree.sorted_source_order[start + j].item() for j in range(count) + } + result: set[int] = set() + left = tree.node_left_child[node_id].item() + right = tree.node_right_child[node_id].item() + if left >= 0: + result |= _collect_sources(tree, left) + if right >= 0: + result |= _collect_sources(tree, right) + return result + + near_tgt = plan.near_target_ids.tolist() + near_src = plan.near_source_ids.tolist() + far_tgt_nids = plan.far_target_node_ids.tolist() + far_src_nids = plan.far_source_node_ids.tolist() + nf_tgt = plan.nf_target_ids.tolist() + nf_src_nids = plan.nf_source_node_ids.tolist() + fn_src = plan.fn_source_ids.tolist() + fn_bcast_tgts = plan.fn_broadcast_targets.tolist() + fn_bcast_starts = plan.fn_broadcast_starts.tolist() + fn_bcast_counts = plan.fn_broadcast_counts.tolist() + + ### Expand (far,far): target node × source node + far_expanded: list[tuple[int, int]] = [] + for tgt_nid, src_nid in zip(far_tgt_nids, far_src_nids): + tgt_set = _collect_sources(target_tree, tgt_nid) + src_set = _collect_sources(source_tree, src_nid) + far_expanded.extend((t, s) for t in tgt_set for s in src_set) + + ### Expand (near,far): individual target × source node + nf_expanded: list[tuple[int, int]] = [] + for ti, src_nid in zip(nf_tgt, nf_src_nids): + src_set = _collect_sources(source_tree, src_nid) + nf_expanded.extend((ti, s) for s in src_set) + + ### Expand (far,near): broadcast to survivors × individual source + fn_expanded: list[tuple[int, int]] = [] + for src_id, start, count in zip(fn_src, fn_bcast_starts, fn_bcast_counts): + fn_expanded.extend((fn_bcast_tgts[start + j], src_id) for j in range(count)) + + for t in range(n_tgt): + near_sources = {s for ti, s in zip(near_tgt, near_src) if ti == t} + far_sources = {s for ti, s in far_expanded if ti == t} + nf_sources = {s for ti, s in nf_expanded if ti == t} + fn_sources = {s for ti, s in fn_expanded if ti == t} + + all_sets = [near_sources, far_sources, nf_sources, fn_sources] + for i, (a, name_a) in enumerate(zip(all_sets, ["near", "far", "nf", "fn"])): + for b, name_b in zip(all_sets[i + 1 :], ["far", "nf", "fn"][i:]): + overlap = a & b + assert not overlap, ( + f"Target {t}: sources {overlap} in both {name_a} and {name_b}" + ) + + covered = near_sources | far_sources | nf_sources | fn_sources + assert covered == all_sources, ( + f"Target {t}: missing sources {all_sources - covered}, " + f"extra sources {covered - all_sources}" + ) + + def test_large_theta_all_far(self): + """With very large theta, most interactions become far-field node pairs.""" + torch.manual_seed(DEFAULT_SEED) + source_pts = torch.randn(30, 2) * 0.1 + target_pts = torch.randn(10, 2) * 100 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=100.0 + ) + assert plan.n_far_nodes > 0, "Expected some far-field node pairs" + + def test_zero_theta_all_near(self): + """With theta=0 (exact), all interactions are near-field.""" + torch.manual_seed(DEFAULT_SEED) + source_pts = torch.randn(20, 2) + target_pts = torch.randn(5, 2) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=0.0 + ) + + # theta=0: all per-point criteria also fail, everything is (near,near). + assert plan.n_near > 0 + assert plan.n_near == 20 * 5, ( + f"Expected {20 * 5} near-field pairs, got {plan.n_near}" + ) + assert plan.n_far_nodes == 0 + assert plan.n_nf == 0, f"Expected 0 nf pairs at theta=0, got {plan.n_nf}" + assert plan.n_fn == 0, f"Expected 0 fn pairs at theta=0, got {plan.n_fn}" + + # Every (target, source) pair must be unique. + pairs = torch.stack([plan.near_target_ids, plan.near_source_ids], dim=1) + unique_pairs = pairs.unique(dim=0) + assert unique_pairs.shape[0] == pairs.shape[0], ( + f"Found {pairs.shape[0] - unique_pairs.shape[0]} duplicate " + f"(target, source) pairs" + ) + + def test_aggregate_centroid_accuracy(self): + """Root centroid matches brute-force area-weighted mean.""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(30, 3) + areas = torch.rand(30) + 0.1 + tree = ClusterTree.from_points(points, leaf_size=4, areas=areas) + agg = tree.compute_source_aggregates(points, areas) + + expected_centroid = (points * areas.unsqueeze(-1)).sum(0) / areas.sum() + root_centroid = agg.node_centroid[0] + + torch.testing.assert_close( + root_centroid, expected_centroid, atol=1e-5, rtol=1e-5 + ) + + def test_aggregate_source_data_scalars(self): + """Root aggregate of scalar source data matches brute-force.""" + torch.manual_seed(DEFAULT_SEED) + n = 30 + points = torch.randn(n, 3) + areas = torch.rand(n) + 0.1 + scalar_feat = torch.randn(n) + + tree = ClusterTree.from_points(points, leaf_size=4, areas=areas) + source_data = TensorDict({"my_scalar": scalar_feat}, batch_size=[n]) + agg = tree.compute_source_aggregates(points, areas, source_data=source_data) + + expected = (scalar_feat * areas).sum() / areas.sum() + actual = agg.node_source_data["my_scalar"][0] + + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + + def test_aggregate_source_data_mixed(self): + """Root aggregate of mixed scalar + vector source data matches brute-force.""" + torch.manual_seed(DEFAULT_SEED) + n = 40 + D = 3 + points = torch.randn(n, D) + areas = torch.rand(n) + 0.1 + scalar_feat = torch.randn(n) + vector_feat = torch.randn(n, D) + + tree = ClusterTree.from_points(points, leaf_size=4, areas=areas) + source_data = TensorDict({"s": scalar_feat, "v": vector_feat}, batch_size=[n]) + agg = tree.compute_source_aggregates(points, areas, source_data=source_data) + + total_area = areas.sum() + expected_s = (scalar_feat * areas).sum() / total_area + expected_v = (vector_feat * areas.unsqueeze(-1)).sum(0) / total_area + + torch.testing.assert_close( + agg.node_source_data["s"][0], expected_s, atol=1e-5, rtol=1e-5 + ) + torch.testing.assert_close( + agg.node_source_data["v"][0], expected_v, atol=1e-5, rtol=1e-5 + ) + + # -- Precomputed leaf field consistency tests ---------------------------- + + @pytest.mark.parametrize( + "n_points, leaf_size, n_dims", + [ + (50, 4, 3), + (1, 4, 2), + (10, 100, 2), + (20, 1, 3), + ], + ids=["normal", "single_point", "root_only_leaf", "one_per_leaf"], + ) + def test_precomputed_leaf_node_ids( + self, + n_points: int, + leaf_size: int, + n_dims: int, + ): + """Precomputed leaf_node_ids matches torch.where(leaf_count > 0).""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(n_points, n_dims) + tree = ClusterTree.from_points(points, leaf_size=leaf_size) + + expected = torch.where(tree.leaf_count > 0)[0] + assert torch.equal(tree.leaf_node_ids, expected) + assert tree.n_leaves == expected.shape[0] + + @pytest.mark.parametrize( + "n_points, leaf_size, n_dims", + [ + (50, 4, 3), + (1, 4, 2), + (10, 100, 2), + (20, 1, 3), + ], + ids=["normal", "single_point", "root_only_leaf", "one_per_leaf"], + ) + def test_precomputed_leaf_seg_ids( + self, + n_points: int, + leaf_size: int, + n_dims: int, + ): + """Precomputed leaf_seg_ids matches on-the-fly _ragged_arange computation.""" + torch.manual_seed(DEFAULT_SEED) + points = torch.randn(n_points, n_dims) + tree = ClusterTree.from_points(points, leaf_size=leaf_size) + + assert tree.leaf_seg_ids.shape == (n_points,) + assert tree.leaf_seg_ids.dtype == torch.long + if n_points > 0: + assert tree.leaf_seg_ids.max() < tree.n_leaves + + # Rebuild seg_ids from scratch and compare + leaf_starts = tree.leaf_start[tree.leaf_node_ids] + leaf_counts = tree.leaf_count[tree.leaf_node_ids] + positions, compact_ids = _ragged_arange( + leaf_starts, + leaf_counts, + total=n_points, + ) + expected = torch.zeros(n_points, dtype=torch.long) + expected[positions] = compact_ids + + assert torch.equal(tree.leaf_seg_ids, expected) + + def test_precomputed_leaf_fields_empty_tree(self): + """Empty tree has empty leaf_node_ids and leaf_seg_ids.""" + tree = ClusterTree.from_points(torch.empty(0, 2), leaf_size=4) + + assert tree.leaf_node_ids.numel() == 0 + assert tree.leaf_seg_ids.numel() == 0 + assert tree.n_leaves == 0 + + def test_compute_source_aggregates_single_point(self): + """Single-point tree centroid equals the point itself.""" + point = torch.tensor([[3.0, -1.0, 7.0]]) + area = torch.tensor([2.5]) + tree = ClusterTree.from_points(point, leaf_size=4, areas=area) + agg = tree.compute_source_aggregates(point, area) + + torch.testing.assert_close(agg.node_centroid[0], point[0]) + + def test_compute_source_aggregates_root_only_leaf(self): + """Root-is-only-leaf centroid matches brute-force area-weighted mean.""" + torch.manual_seed(DEFAULT_SEED) + n = 10 + points = torch.randn(n, 3) + areas = torch.rand(n) + 0.1 + tree = ClusterTree.from_points(points, leaf_size=100, areas=areas) + + assert tree.n_leaves == 1, "Expected single leaf (root)" + agg = tree.compute_source_aggregates(points, areas) + + expected = (points * areas.unsqueeze(-1)).sum(0) / areas.sum() + torch.testing.assert_close( + agg.node_centroid[0], + expected, + atol=1e-5, + rtol=1e-5, + ) + + +# --------------------------------------------------------------------------- +# BarnesHutKernel convergence tests +# --------------------------------------------------------------------------- + + +dims_params = pytest.mark.parametrize("n_dims", [2, 3]) +output_fields_params = pytest.mark.parametrize( + "output_fields", + [ + {"potential": "scalar"}, + {"velocity": "vector"}, + {"potential": "scalar", "velocity": "vector"}, + ], +) +source_config_params = pytest.mark.parametrize( + "n_source_scalars, n_source_vectors", + [(0, 1), (2, 0), (2, 1)], + ids=["vectors_only", "scalars_only", "mixed"], +) + + +@dims_params +@output_fields_params +@source_config_params +def test_bh_convergence_to_exact( + n_dims: int, + output_fields: dict[str, Literal["scalar", "vector"]], + n_source_scalars: int, + n_source_vectors: int, +): + """BarnesHutKernel converges to exact Kernel as theta decreases toward 0.""" + bh_kernel, exact_kernel, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields=output_fields, + n_source_scalars=n_source_scalars, + n_source_vectors=n_source_vectors, + n_source_points=30, + n_target_points=15, + ) + + exact_result = exact_kernel( + **data, + ) + + ### As theta decreases (more conservative), result converges to exact. + # The tolerance factor accounts for the four-quadrant classification: + # as theta changes, interactions shift between (near,far), (far,near), + # and (near,near) modes, each with different approximation properties. + # This can cause non-monotonic error at large theta values. + prev_max_err = float("inf") + for theta in [10.0, 2.0, 0.5, 0.01]: + bh_result = bh_kernel(**data, theta=theta) + + max_err = max( + (bh_result[k] - exact_result[k]).abs().max().item() for k in output_fields + ) + + assert max_err <= prev_max_err * 3.0 + 1e-5, ( + f"Error increased from {prev_max_err:.2e} to {max_err:.2e} at theta={theta}" + ) + prev_max_err = max_err + + # At theta=0.01, should be very close to exact + for field_name in output_fields: + torch.testing.assert_close( + bh_result[field_name], + exact_result[field_name], + atol=1e-4, + rtol=1e-3, + msg=f"Field {field_name!r} not close to exact at theta=0.01", + ) + + +@dims_params +@source_config_params +def test_bh_gradient_correctness( + n_dims: int, + n_source_scalars: int, + n_source_vectors: int, +): + """Gradients through BarnesHutKernel match exact kernel at low theta.""" + bh_kernel, exact_kernel, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields={"field": "scalar"}, + n_source_scalars=n_source_scalars, + n_source_vectors=n_source_vectors, + n_source_points=15, + n_target_points=8, + ) + bh_kernel.train() + exact_kernel.train() + + # Make source_points require grad for gradient comparison + data["source_points"] = data["source_points"].clone().requires_grad_(True) + + # Exact gradient + exact_result = exact_kernel(**data) + exact_loss = exact_result["field"].sum() + exact_loss.backward() + exact_grad = data["source_points"].grad.clone() + + data["source_points"].grad = None + + # BH gradient at low theta (near-exact, should match closely) + bh_result = bh_kernel(**data, theta=0.01) + bh_loss = bh_result["field"].sum() + bh_loss.backward() + bh_grad = data["source_points"].grad.clone() + + torch.testing.assert_close( + bh_grad, + exact_grad, + atol=1e-3, + rtol=1e-2, + msg="BH gradients don't match exact at low theta", + ) + + +# --------------------------------------------------------------------------- +# Equivariance tests +# --------------------------------------------------------------------------- + + +@dims_params +@output_fields_params +@source_config_params +def test_bh_translation_equivariance( + n_dims: int, + output_fields: dict[str, Literal["scalar", "vector"]], + n_source_scalars: int, + n_source_vectors: int, +): + """Barnes-Hut kernel preserves translation equivariance. + + Translation does not change the morton-code relative ordering, so the + tree structure and interaction plan are identical pre- and + post-translation. This test uses a moderate theta. + """ + bh_kernel, _, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields=output_fields, + n_source_scalars=n_source_scalars, + n_source_vectors=n_source_vectors, + ) + + result1 = bh_kernel(**data, theta=2.0) + + translation = torch.randn(n_dims) + translated_data = {**data} + translated_data["source_points"] = data["source_points"] + translation + translated_data["target_points"] = data["target_points"] + translation + + result2 = bh_kernel(**translated_data, theta=2.0) + + for field_name in output_fields: + torch.testing.assert_close( + result1[field_name], + result2[field_name], + atol=1e-4, + rtol=1e-4, + msg=f"Translation equivariance failed for {field_name!r}", + ) + + +@dims_params +@output_fields_params +@source_config_params +def test_bh_rotational_equivariance( + n_dims: int, + output_fields: dict[str, Literal["scalar", "vector"]], + n_source_scalars: int, + n_source_vectors: int, +): + """Barnes-Hut kernel preserves rotational equivariance. + + The underlying kernel is exactly equivariant, but the tree + decomposition is axis-aligned (morton codes). Rotation changes the tree + structure, so equivariance is only recovered in the near-exact limit. + We use a small theta so that nearly all interactions are exact. + """ + # Ensure at least one source vector for basis construction + effective_src_vectors = max(n_source_vectors, 1) + bh_kernel, _, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields=output_fields, + n_source_scalars=n_source_scalars, + n_source_vectors=effective_src_vectors, + n_global_vectors=1, + ) + + ### Build rotation matrix + if n_dims == 2: + angle = torch.tensor(torch.pi / 3) + R = torch.tensor( + [ + [torch.cos(angle), -torch.sin(angle)], + [torch.sin(angle), torch.cos(angle)], + ] + ) + else: + axis = F.normalize(torch.randn(3), dim=0) + angle = torch.tensor(torch.pi / 3) + K = torch.zeros(3, 3) + K[0, 1], K[0, 2] = -axis[2], axis[1] + K[1, 0], K[1, 2] = axis[2], -axis[0] + K[2, 0], K[2, 1] = -axis[1], axis[0] + R = torch.eye(3) + torch.sin(angle) * K + (1 - torch.cos(angle)) * (K @ K) + + def _rotate_td(td: TensorDict) -> TensorDict: + return td.apply(lambda v: v @ R.T if v.ndim > td.batch_dims else v) + + # Low theta: near-exact, so equivariance holds + result1 = bh_kernel(**data, theta=0.01) + + rotated_data = {**data} + rotated_data["source_points"] = data["source_points"] @ R.T + rotated_data["target_points"] = data["target_points"] @ R.T + rotated_data["source_data"] = _rotate_td(data["source_data"]) + rotated_data["global_data"] = _rotate_td(data["global_data"]) + + result2 = bh_kernel(**rotated_data, theta=0.01) + + for field_name, field_type in output_fields.items(): + if field_type == "scalar": + torch.testing.assert_close( + result1[field_name], + result2[field_name], + atol=1e-4, + rtol=1e-4, + msg=f"Scalar {field_name!r} not invariant under rotation", + ) + else: + rotated_field1 = result1[field_name] @ R.T + torch.testing.assert_close( + rotated_field1, + result2[field_name], + atol=1e-4, + rtol=1e-4, + msg=f"Vector {field_name!r} not equivariant under rotation", + ) + + +# --------------------------------------------------------------------------- +# MultiscaleKernel integration +# --------------------------------------------------------------------------- + + +@dims_params +def test_multiscale_bh_convergence(n_dims: int): + """MultiscaleKernel at low theta converges to exact per-branch Kernel results.""" + torch.manual_seed(DEFAULT_SEED) + + ms = MultiscaleKernel( + n_spatial_dims=n_dims, + output_field_ranks={"p": 0}, + reference_length_names=["short", "long"], + source_data_ranks={"normal": 1}, + hidden_layer_sizes=[16], + leaf_size=4, + ) + ms.eval() + + n_src = 25 + torch.manual_seed(DEFAULT_SEED + 1) + src = torch.randn(n_src, n_dims) + tgt = torch.randn(10, n_dims) * 3 + normals = F.normalize(torch.randn(n_src, n_dims), dim=-1) + ref_lengths = {"short": torch.tensor(0.1), "long": torch.tensor(1.0)} + + # Compute exact reference by evaluating each branch's underlying Kernel + # (the parent class forward = exact dense evaluation) + from physicsnemo.experimental.models.globe.field_kernel import Kernel + + exact_total = None + for name in ms.reference_length_names: + branch: BarnesHutKernel = ms.kernels[name] + branch_result = Kernel.forward( + branch, + reference_length=ref_lengths[name] * torch.exp(ms.log_scalefactors[name]), + source_points=src, + target_points=tgt, + source_strengths=torch.ones(n_src), + source_data=TensorDict({"normal": normals}, batch_size=[n_src]), + global_data=TensorDict( + { + "log_reference_length_ratios": TensorDict( + { + "short_long": ( + ref_lengths["short"] / ref_lengths["long"] + ).log() + } + ), + } + ), + ) + exact_total = ( + branch_result if exact_total is None else exact_total + branch_result + ) + + bh_result = ms( + source_points=src, + target_points=tgt, + reference_lengths=ref_lengths, + source_data=TensorDict({"normal": normals}, batch_size=[n_src]), + theta=0.01, + ) + + torch.testing.assert_close( + bh_result["p"], + exact_total["p"], + atol=1e-3, + rtol=1e-2, + msg="MultiscaleKernel BH doesn't converge to exact at low theta", + ) + + +# --------------------------------------------------------------------------- +# Source permutation equivariance +# --------------------------------------------------------------------------- + + +@dims_params +@source_config_params +def test_bh_source_permutation( + n_dims: int, + n_source_scalars: int, + n_source_vectors: int, +): + """Result is independent of source ordering.""" + bh_kernel, _, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields={"p": "scalar"}, + n_source_scalars=n_source_scalars, + n_source_vectors=n_source_vectors, + ) + + result1 = bh_kernel(**data, theta=2.0) + + perm = torch.randperm(data["source_points"].shape[0]) + perm_data = {**data} + perm_data["source_points"] = data["source_points"][perm] + perm_data["source_strengths"] = data["source_strengths"][perm] + perm_data["source_data"] = data["source_data"][perm] + + result2 = bh_kernel(**perm_data, theta=2.0) + + torch.testing.assert_close( + result1["p"], + result2["p"], + atol=1e-4, + rtol=1e-4, + msg="BH result changed under source permutation", + ) + + +# --------------------------------------------------------------------------- +# GLOBE-like configuration (mimics communication hyperlayer source data) +# --------------------------------------------------------------------------- + + +@dims_params +def test_bh_globe_like_config(n_dims: int): + """Convergence with a source data configuration matching GLOBE's + communication hyperlayers: multiple latent scalars, latent vectors, + and strength scalars - the exact mix that triggered the production bug. + """ + bh_kernel, exact_kernel, data = _make_bh_kernel_and_data( + n_spatial_dims=n_dims, + output_fields={"p": "scalar", "u": "vector"}, + n_source_scalars=8, + n_source_vectors=3, + n_global_scalars=1, + n_global_vectors=1, + n_source_points=40, + n_target_points=20, + ) + + exact_result = exact_kernel(**data) + bh_result = bh_kernel(**data, theta=0.01) + + # Wider tolerance than basic tests: 8 scalars + 3 vectors + globals + # produces more accumulated floating-point error through the aggregation + # and feature engineering pipeline, even at low theta. + for field in ("p", "u"): + torch.testing.assert_close( + bh_result[field], + exact_result[field], + atol=5e-3, + rtol=5e-2, + msg=f"GLOBE-like config: {field!r} not close to exact at theta=0.01", + ) + + +# --------------------------------------------------------------------------- +# Nested source_data keys (matches GLOBE's actual data structure) +# --------------------------------------------------------------------------- + + +@dims_params +def test_bh_nested_source_data_keys(n_dims: int): + """Convergence with nested TensorDict keys matching GLOBE's production format. + + GLOBE passes source_data structured like: + {"physical": {"velocity": ...}, "latent": {"scalars": {"0": ...}, + "vectors": {"0": ...}}, "normals": ...} + + The aggregation, split_by_leaf_rank, and TensorDict.cat operations must + handle this nesting correctly. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 30, 15 + + source_data_ranks = { + "physical": {"pressure": 0}, + "latent": {"scalars": {"0": 0, "1": 0}, "vectors": {"0": 1}}, + "normals": 1, + } + output_field_ranks = {"p": 0, "u": 1} + + common_kwargs = dict( + n_spatial_dims=n_dims, + output_field_ranks={ + k: (0 if v == "scalar" else 1) for k, v in output_field_ranks.items() + }, + source_data_ranks=source_data_ranks, + hidden_layer_sizes=[16], + ) + + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) + exact_kernel = Kernel(**common_kwargs) + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + bh_kernel.eval() + exact_kernel.eval() + + torch.manual_seed(DEFAULT_SEED + 1) + source_data = TensorDict( + { + "physical": TensorDict( + {"pressure": torch.randn(n_src)}, + batch_size=[n_src], + ), + "latent": TensorDict( + { + "scalars": TensorDict( + {"0": torch.randn(n_src), "1": torch.randn(n_src)}, + batch_size=[n_src], + ), + "vectors": TensorDict( + {"0": F.normalize(torch.randn(n_src, n_dims), dim=-1)}, + batch_size=[n_src], + ), + }, + batch_size=[n_src], + ), + "normals": F.normalize(torch.randn(n_src, n_dims), dim=-1), + }, + batch_size=[n_src], + ) + + data = { + "source_points": torch.randn(n_src, n_dims), + "target_points": torch.randn(n_tgt, n_dims) * 5, + "source_strengths": torch.rand(n_src) + 0.1, + "reference_length": torch.ones(()), + "source_data": source_data, + "global_data": TensorDict({}, batch_size=[]), + } + + exact_result = exact_kernel(**data) + bh_result = bh_kernel(**data, theta=0.01) + + for field_name in output_field_ranks: + torch.testing.assert_close( + bh_result[field_name], + exact_result[field_name], + atol=1e-3, + rtol=1e-2, + msg=f"Nested keys: {field_name!r} not close to exact at theta=0.01", + ) + + +# --------------------------------------------------------------------------- +# Four-quadrant interaction mode tests +# --------------------------------------------------------------------------- + + +@dims_params +@source_config_params +def test_all_four_categories_active_and_correct( + n_dims: int, + n_source_scalars: int, + n_source_vectors: int, +): + """At moderate theta, all four interaction categories should be active + and the combined result should still converge to exact. + + This is the critical test for the (near,far) and (far,near) code paths: + the convergence tests at theta=0.01 barely exercise them because nearly + everything is (near,near) at low theta. + """ + ### Use balanced source/target scales so the (far,near) target-centroid + # broadcast and (near,far) source monopole have comparable accuracy. + # The default helper scales targets by 5x, making target leaf diameters + # much larger and the (far,near) approximation very coarse. + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 60, 30 + common_kwargs = dict( + n_spatial_dims=n_dims, + output_field_ranks={"p": 0, "v": 1}, + source_data_ranks={ + **{f"source_scalar_{i}": 0 for i in range(n_source_scalars)}, + **{f"source_vector_{i}": 1 for i in range(max(n_source_vectors, 1))}, + }, + hidden_layer_sizes=[32, 32], + ) + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=4) + exact_kernel = Kernel(**common_kwargs) + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + bh_kernel.eval() + exact_kernel.eval() + + torch.manual_seed(DEFAULT_SEED + 1) + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) + + source_data_dict: dict[str, torch.Tensor] = {} + for i in range(n_source_scalars): + source_data_dict[f"source_scalar_{i}"] = torch.randn(n_src) + for i in range(max(n_source_vectors, 1)): + source_data_dict[f"source_vector_{i}"] = F.normalize( + torch.randn(n_src, n_dims), dim=-1 + ) + + data = { + "source_points": source_pts, + "target_points": target_pts, + "source_strengths": torch.randn(n_src).abs() + 0.1, + "reference_length": torch.ones(()), + "source_data": TensorDict(source_data_dict, batch_size=[n_src]), + "global_data": TensorDict({}, batch_size=[]), + } + + exact_result = exact_kernel(**data) + + ### Sweep theta to find one where all four categories are active. + # With balanced geometry (source and target at same scale) and + # theta=1.0, the diagnostic shows near=751, nf=200, fn=131, far=2. + for theta in [1.0, 1.5, 2.0]: + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta + ) + if plan.n_near > 0 and plan.n_nf > 0 and plan.n_fn > 0 and plan.n_far_nodes > 0: + break + else: + pytest.skip("Could not find theta with all four categories active") + + bh_result = bh_kernel(**data, theta=theta) + + ### Verify the result is close to exact + for field_name in ("p", "v"): + torch.testing.assert_close( + bh_result[field_name], + exact_result[field_name], + atol=0.1, + rtol=0.3, + msg=f"Field {field_name!r} not close to exact at theta={theta} " + f"with all four categories active " + f"(near={plan.n_near}, nf={plan.n_nf}, fn={plan.n_fn}, far={plan.n_far_nodes})", + ) + + +@dims_params +@pytest.mark.parametrize("theta", [0.3, 1.0, 5.0]) +def test_self_interaction_source_coverage(n_dims: int, theta: float): + """Source coverage invariant for self-interaction (target_tree is source_tree). + + Communication hyperlayers use self-interaction where the same tree + serves as both source and target. The traversal starts with + (root, root) and D_T == D_S at every level. + """ + torch.manual_seed(DEFAULT_SEED) + n_pts = 40 + points = torch.randn(n_pts, n_dims) + tree = ClusterTree.from_points(points, leaf_size=4) + plan = tree.find_dual_interaction_pairs(target_tree=tree, theta=theta) + + all_sources = set(range(n_pts)) + + def _collect(tree: ClusterTree, node_id: int) -> set[int]: + count = tree.leaf_count[node_id].item() + if count > 0: + start = tree.leaf_start[node_id].item() + return {tree.sorted_source_order[start + j].item() for j in range(count)} + result: set[int] = set() + left = tree.node_left_child[node_id].item() + right = tree.node_right_child[node_id].item() + if left >= 0: + result |= _collect(tree, left) + if right >= 0: + result |= _collect(tree, right) + return result + + near_tgt = plan.near_target_ids.tolist() + near_src = plan.near_source_ids.tolist() + + far_expanded: list[tuple[int, int]] = [] + for tgt_nid, src_nid in zip( + plan.far_target_node_ids.tolist(), plan.far_source_node_ids.tolist() + ): + far_expanded.extend( + (t, s) for t in _collect(tree, tgt_nid) for s in _collect(tree, src_nid) + ) + + nf_expanded: list[tuple[int, int]] = [] + for ti, src_nid in zip( + plan.nf_target_ids.tolist(), plan.nf_source_node_ids.tolist() + ): + nf_expanded.extend((ti, s) for s in _collect(tree, src_nid)) + + fn_expanded: list[tuple[int, int]] = [] + fn_bcast = plan.fn_broadcast_targets.tolist() + for src_id, start, count in zip( + plan.fn_source_ids.tolist(), + plan.fn_broadcast_starts.tolist(), + plan.fn_broadcast_counts.tolist(), + ): + fn_expanded.extend((fn_bcast[start + j], src_id) for j in range(count)) + + for t in range(n_pts): + near_s = {s for ti, s in zip(near_tgt, near_src) if ti == t} + far_s = {s for ti, s in far_expanded if ti == t} + nf_s = {s for ti, s in nf_expanded if ti == t} + fn_s = {s for ti, s in fn_expanded if ti == t} + + covered = near_s | far_s | nf_s | fn_s + assert covered == all_sources, ( + f"Self-interaction target {t} at theta={theta}: " + f"missing {all_sources - covered}" + ) + + +def test_near_field_monotonicity(): + """Near-field pair count should decrease as theta increases. + + At higher theta, more interactions move to approximate modes + (near-far, far-near, far-far), reducing the exact near-field count. + """ + torch.manual_seed(DEFAULT_SEED) + source_pts = torch.randn(50, 3) + target_pts = torch.randn(25, 3) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + + prev_n_near = float("inf") + for theta in [0.1, 0.5, 1.0, 2.0, 5.0]: + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta + ) + assert plan.n_near <= prev_n_near, ( + f"Near-field count increased from {prev_n_near} to {plan.n_near} " + f"when theta increased to {theta}" + ) + prev_n_near = plan.n_near + + +# --------------------------------------------------------------------------- +# DualInteractionPlan validation tests +# --------------------------------------------------------------------------- +class TestDualInteractionPlanValidate: + """Tests for DualInteractionPlan.validate().""" + + def _make_valid_plan(self) -> DualInteractionPlan: + """Construct a minimal valid DualInteractionPlan.""" + return DualInteractionPlan( + near_target_ids=torch.tensor([0, 1]), + near_source_ids=torch.tensor([2, 3]), + far_target_node_ids=torch.tensor([0]), + far_source_node_ids=torch.tensor([1]), + nf_target_ids=torch.tensor([0]), + nf_source_node_ids=torch.tensor([1]), + fn_target_node_ids=torch.tensor([0, 0]), + fn_source_ids=torch.tensor([1, 2]), + fn_broadcast_targets=torch.tensor([3, 4, 5]), + fn_broadcast_starts=torch.tensor([0, 1]), + fn_broadcast_counts=torch.tensor([1, 2]), + ) + + def test_valid_plan_passes(self): + """A correctly constructed plan passes validation.""" + plan = self._make_valid_plan() + plan.validate() + + def test_empty_plan_passes(self): + """An empty plan (all zero-length tensors) passes validation.""" + e = torch.empty(0, dtype=torch.long) + plan = DualInteractionPlan( + near_target_ids=e, + near_source_ids=e.clone(), + far_target_node_ids=e.clone(), + far_source_node_ids=e.clone(), + nf_target_ids=e.clone(), + nf_source_node_ids=e.clone(), + fn_target_node_ids=e.clone(), + fn_source_ids=e.clone(), + fn_broadcast_targets=e.clone(), + fn_broadcast_starts=e.clone(), + fn_broadcast_counts=e.clone(), + ) + plan.validate() + + def test_shape_mismatch_detected(self): + """Mismatched near_target_ids / near_source_ids shapes are caught.""" + plan = DualInteractionPlan( + near_target_ids=torch.tensor([0, 1, 2]), + near_source_ids=torch.tensor([0, 1]), + far_target_node_ids=torch.empty(0, dtype=torch.long), + far_source_node_ids=torch.empty(0, dtype=torch.long), + nf_target_ids=torch.empty(0, dtype=torch.long), + nf_source_node_ids=torch.empty(0, dtype=torch.long), + fn_target_node_ids=torch.empty(0, dtype=torch.long), + fn_source_ids=torch.empty(0, dtype=torch.long), + fn_broadcast_targets=torch.empty(0, dtype=torch.long), + fn_broadcast_starts=torch.empty(0, dtype=torch.long), + fn_broadcast_counts=torch.empty(0, dtype=torch.long), + ) + with pytest.raises(ValueError, match="Shape mismatch"): + plan.validate() + + def test_broadcast_out_of_bounds_detected(self): + """fn_broadcast_starts + counts exceeding targets length is caught. + + This is the exact invariant violation that caused the original + IndexError bug (starts + counts pointed beyond fn_broadcast_targets). + """ + plan = DualInteractionPlan( + near_target_ids=torch.empty(0, dtype=torch.long), + near_source_ids=torch.empty(0, dtype=torch.long), + far_target_node_ids=torch.empty(0, dtype=torch.long), + far_source_node_ids=torch.empty(0, dtype=torch.long), + nf_target_ids=torch.empty(0, dtype=torch.long), + nf_source_node_ids=torch.empty(0, dtype=torch.long), + fn_target_node_ids=torch.tensor([0]), + fn_source_ids=torch.tensor([1]), + fn_broadcast_targets=torch.tensor([0, 1]), + fn_broadcast_starts=torch.tensor([1]), + fn_broadcast_counts=torch.tensor([3]), + ) + with pytest.raises(ValueError, match="fn_broadcast out of bounds"): + plan.validate() + + def test_negative_counts_detected(self): + """Negative fn_broadcast_counts values are caught.""" + plan = DualInteractionPlan( + near_target_ids=torch.empty(0, dtype=torch.long), + near_source_ids=torch.empty(0, dtype=torch.long), + far_target_node_ids=torch.empty(0, dtype=torch.long), + far_source_node_ids=torch.empty(0, dtype=torch.long), + nf_target_ids=torch.empty(0, dtype=torch.long), + nf_source_node_ids=torch.empty(0, dtype=torch.long), + fn_target_node_ids=torch.tensor([0]), + fn_source_ids=torch.tensor([1]), + fn_broadcast_targets=torch.tensor([0, 1, 2]), + fn_broadcast_starts=torch.tensor([0]), + fn_broadcast_counts=torch.tensor([-1]), + ) + with pytest.raises(ValueError, match="negative values"): + plan.validate() + + @pytest.mark.parametrize("n_dims", [2, 3]) + @pytest.mark.parametrize("theta", [0.3, 1.0, 5.0]) + def test_validate_called_by_find_dual_interaction_pairs( + self, + n_dims: int, + theta: float, + ): + """validate() is exercised on every plan produced by the traversal. + + This also checks external validity: all index tensors reference + valid source/target/node indices within their respective trees. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 40, 15 + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, theta=theta + ) + + ### External validity: indices within tree-specific ranges + if plan.n_near > 0: + assert plan.near_target_ids.max() < n_tgt + assert plan.near_source_ids.max() < n_src + if plan.n_far_nodes > 0: + assert plan.far_target_node_ids.max() < target_tree.n_nodes + assert plan.far_source_node_ids.max() < source_tree.n_nodes + if plan.n_nf > 0: + assert plan.nf_target_ids.max() < n_tgt + assert plan.nf_source_node_ids.max() < source_tree.n_nodes + if plan.n_fn > 0: + assert plan.fn_source_ids.max() < n_src + assert plan.fn_target_node_ids.max() < target_tree.n_nodes + if plan.fn_broadcast_targets.numel() > 0: + assert plan.fn_broadcast_targets.max() < n_tgt + + +# --------------------------------------------------------------------------- +# fn_broadcast expansion round-trip test +# --------------------------------------------------------------------------- + + +@dims_params +@pytest.mark.parametrize("theta", [0.5, 1.0, 3.0]) +def test_fn_broadcast_ragged_arange_matches_python_expansion( + n_dims: int, + theta: float, +): + """The _ragged_arange expansion of fn_broadcast (BarnesHutKernel's code + path) produces the same (target, source) pairs as the pure-Python + expansion used in test_interaction_plan_source_coverage. + + This bridges the gap between "the plan is semantically correct" and + "the consumer expands it correctly via _ragged_arange." + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 40, 15 + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, + theta=theta, + ) + + if plan.n_fn == 0: + pytest.skip("No fn pairs at this theta") + + ### Reference: pure-Python expansion (same logic as source coverage test) + ref_pairs: set[tuple[int, int]] = set() + for src_id, start, count in zip( + plan.fn_source_ids.tolist(), + plan.fn_broadcast_starts.tolist(), + plan.fn_broadcast_counts.tolist(), + ): + for j in range(count): + ref_pairs.add((plan.fn_broadcast_targets[start + j].item(), src_id)) + + ### Actual: _ragged_arange expansion (same code path as BarnesHutKernel) + positions, pair_ids = _ragged_arange( + plan.fn_broadcast_starts, + plan.fn_broadcast_counts, + ) + expanded_tgt_ids = plan.fn_broadcast_targets[positions] + expanded_src_ids = plan.fn_source_ids[pair_ids] + + actual_pairs = set( + zip( + expanded_tgt_ids.tolist(), + expanded_src_ids.tolist(), + ) + ) + + assert actual_pairs == ref_pairs, ( + f"Ragged expansion mismatch: " + f"{len(actual_pairs - ref_pairs)} extra, " + f"{len(ref_pairs - actual_pairs)} missing" + ) + + +# --------------------------------------------------------------------------- +# fn_broadcast_targets dead-entry detection +# --------------------------------------------------------------------------- + + +@dims_params +@pytest.mark.parametrize("theta", [0.5, 1.0, 3.0]) +def test_fn_broadcast_targets_no_dead_entries(n_dims: int, theta: float): + """Every entry in fn_broadcast_targets is reachable from at least one + fn pair's (start, count) range. + + Dead entries (survivors from leaf pairs with no fn sources) inflate + fn_broadcast_targets.shape[0] beyond fn_broadcast_counts.sum(). This + was the root cause of the original IndexError: the consumer passed + total=fn_broadcast_targets.shape[0] to _ragged_arange, which + generated out-of-bounds positions. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 40, 15 + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, + theta=theta, + ) + + if plan.n_fn == 0: + pytest.skip("No fn pairs at this theta") + + ### Build the set of all referenced positions in fn_broadcast_targets + referenced = torch.zeros( + plan.fn_broadcast_targets.shape[0], + dtype=torch.bool, + ) + for start, count in zip( + plan.fn_broadcast_starts.tolist(), + plan.fn_broadcast_counts.tolist(), + ): + referenced[start : start + count] = True + + n_dead = int((~referenced).sum()) + assert n_dead == 0, ( + f"{n_dead} of {plan.fn_broadcast_targets.shape[0]} entries in " + f"fn_broadcast_targets are unreferenced (dead)" + ) + + +# --------------------------------------------------------------------------- +# Post-sort invariant preservation +# --------------------------------------------------------------------------- + + +@dims_params +@pytest.mark.parametrize("theta", [0.5, 1.0, 3.0]) +def test_fn_sort_preserves_broadcast_mapping(n_dims: int, theta: float): + """The source-ID sort in find_dual_interaction_pairs preserves the + fn_broadcast expansion semantics. + + After sorting, fn_broadcast_starts/counts are permuted but still + reference the same (unsorted) fn_broadcast_targets array. This test + verifies the expansion produces the same (target, source) pair set + regardless of the sort order. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 40, 15 + source_pts = torch.randn(n_src, n_dims) + target_pts = torch.randn(n_tgt, n_dims) * 3 + source_tree = ClusterTree.from_points(source_pts, leaf_size=4) + target_tree = ClusterTree.from_points(target_pts, leaf_size=4) + plan = source_tree.find_dual_interaction_pairs( + target_tree=target_tree, + theta=theta, + ) + + if plan.n_fn == 0: + pytest.skip("No fn pairs at this theta") + + ### Expand with the current (sorted) order + sorted_pairs: set[tuple[int, int]] = set() + for src_id, start, count in zip( + plan.fn_source_ids.tolist(), + plan.fn_broadcast_starts.tolist(), + plan.fn_broadcast_counts.tolist(), + ): + for j in range(count): + sorted_pairs.add((plan.fn_broadcast_targets[start + j].item(), src_id)) + + ### Expand with a random permutation of the fn entries + perm = torch.randperm(plan.n_fn) + permuted_pairs: set[tuple[int, int]] = set() + perm_src = plan.fn_source_ids[perm] + perm_starts = plan.fn_broadcast_starts[perm] + perm_counts = plan.fn_broadcast_counts[perm] + for src_id, start, count in zip( + perm_src.tolist(), + perm_starts.tolist(), + perm_counts.tolist(), + ): + for j in range(count): + permuted_pairs.add((plan.fn_broadcast_targets[start + j].item(), src_id)) + + assert sorted_pairs == permuted_pairs, ( + "fn_broadcast expansion changed under permutation of fn entries" + ) + + +# --------------------------------------------------------------------------- +# Tightened four-quadrant accuracy test +# --------------------------------------------------------------------------- + + +@dims_params +@source_config_params +def test_four_quadrant_per_category_accuracy( + n_dims: int, + n_source_scalars: int, + n_source_vectors: int, +): + """Verify that each interaction category individually produces + reasonable results, not just the combined sum. + + Compares the BH result at a moderate theta against exact, with + tighter tolerances than test_all_four_categories_active_and_correct. + Also verifies that the near-field contribution alone (theta=0, + no approximation) exactly matches the exact kernel. + """ + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 60, 30 + common_kwargs = dict( + n_spatial_dims=n_dims, + output_field_ranks={"p": 0}, + source_data_ranks={ + **{f"source_scalar_{i}": 0 for i in range(n_source_scalars)}, + **{f"source_vector_{i}": 1 for i in range(max(n_source_vectors, 1))}, + }, + hidden_layer_sizes=[32, 32], + ) + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=4) + exact_kernel = Kernel(**common_kwargs) + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + bh_kernel.eval() + exact_kernel.eval() + + torch.manual_seed(DEFAULT_SEED + 1) + source_data_dict: dict[str, torch.Tensor] = {} + for i in range(n_source_scalars): + source_data_dict[f"source_scalar_{i}"] = torch.randn(n_src) + for i in range(max(n_source_vectors, 1)): + source_data_dict[f"source_vector_{i}"] = F.normalize( + torch.randn(n_src, n_dims), dim=-1 + ) + + data = { + "source_points": torch.randn(n_src, n_dims), + "target_points": torch.randn(n_tgt, n_dims), + "source_strengths": torch.randn(n_src).abs() + 0.1, + "reference_length": torch.ones(()), + "source_data": TensorDict(source_data_dict, batch_size=[n_src]), + "global_data": TensorDict({}, batch_size=[]), + } + + exact_result = exact_kernel(**data) + + ### Near-only (theta=0): should be numerically exact + near_only = bh_kernel(**data, theta=0.0) + torch.testing.assert_close( + near_only["p"], + exact_result["p"], + atol=1e-5, + rtol=1e-5, + msg="Near-only (theta=0) doesn't match exact", + ) + + ### Low theta (0.01): near-exact, tight tolerance + low_theta = bh_kernel(**data, theta=0.01) + torch.testing.assert_close( + low_theta["p"], + exact_result["p"], + atol=1e-4, + rtol=1e-3, + msg="Low theta (0.01) not close enough to exact", + ) + + +if __name__ == "__main__": + pytest.main() diff --git a/test/models/globe/test_field_kernel.py b/test/models/globe/test_field_kernel.py index 727d9d734b..f808f32bc3 100644 --- a/test/models/globe/test_field_kernel.py +++ b/test/models/globe/test_field_kernel.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from tensordict import TensorDict -from physicsnemo.experimental.models.globe.field_kernel import ChunkedKernel +from physicsnemo.experimental.models.globe.field_kernel import Kernel DEFAULT_RTOL = 1e-5 # Default relative tolerance for comparisons DEFAULT_ATOL = 1e-5 # Default absolute tolerance for comparisons @@ -48,7 +48,7 @@ def make_kernel_and_input_data( n_target_points: int = 12, device: str | torch.device = "cpu", seed: int = DEFAULT_SEED, -) -> tuple[ChunkedKernel, dict[str, Any]]: +) -> tuple[Kernel, dict[str, Any]]: """Create a kernel and compatible input data for testing. Returns: @@ -80,7 +80,7 @@ def make_kernel_and_input_data( **{f"global_vector_{i}": 1 for i in range(n_global_vectors)}, } - kernel = ChunkedKernel( + kernel = Kernel( n_spatial_dims=n_spatial_dims, output_field_ranks=output_field_ranks, source_data_ranks=source_data_ranks, @@ -132,14 +132,13 @@ def make_kernel_and_input_data( global_data_dict, device=device, ), - "chunk_size": None, } return kernel, input_data def evaluate_kernel_with_transform( - kernel: ChunkedKernel, + kernel: Kernel, base_data: dict[str, Any], transform_fn: Callable[[dict[str, Any]], dict[str, Any]], ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: @@ -269,45 +268,6 @@ def test_kernel_forward( assert result[field_name].shape == (11, n_dims) -@device_params -@pytest.mark.parametrize("chunk_size", [20, 40, "auto"]) -def test_kernel_chunking( - device: torch.device, - chunk_size: int | str | None, -): - """Verify that varying chunk sizes produces numerically consistent results.""" - # Build a kernel and common input payload - kernel, input_data = make_kernel_and_input_data( - n_spatial_dims=2, - output_fields={"sfield": "scalar", "vfield": "vector"}, - device=device, - n_source_points=10, - n_target_points=100, - seed=DEFAULT_SEED, - ) - - # Always get the reference answer without chunking for comparison - reference = kernel(**{**input_data, "chunk_size": None}) - - # Skip redundant run for the reference case - if chunk_size is None: - return - - candidate = kernel(**{**input_data, "chunk_size": chunk_size}) - - # Scalars - assert torch.allclose( - reference["sfield"], candidate["sfield"], atol=CHUNKING_ATOL, rtol=0.0 - ), f"Scalar mismatch for {chunk_size=}." - - # Vectors - assert torch.allclose( - reference["vfield"], candidate["vfield"], atol=CHUNKING_ATOL, rtol=0.0 - ), f"Vector mismatch for {chunk_size=}." - - assert reference.batch_size == candidate.batch_size - - @device_params @dims_params def test_kernel_gradient_flow( diff --git a/test/models/globe/test_inference.py b/test/models/globe/test_inference.py index cac7afd0c3..9406f80947 100644 --- a/test/models/globe/test_inference.py +++ b/test/models/globe/test_inference.py @@ -59,13 +59,12 @@ def test_globe_inference(device: str) -> None: prediction_points=prediction_points, boundary_meshes={"no_slip": mesh}, reference_lengths=reference_lengths, - chunk_size=None, ) ### Validate Mesh structure from physicsnemo.mesh import Mesh - assert isinstance(output_mesh, Mesh) + assert isinstance(output_mesh, Mesh[0, 3]) assert output_mesh.points.shape == (N_PREDICTION_POINTS, 3) ### Validate output fields and shapes @@ -79,3 +78,54 @@ def test_globe_inference(device: str) -> None: ### Validate outputs are finite (no NaN or Inf from the forward pass) assert torch.all(torch.isfinite(fields["pressure"])) assert torch.all(torch.isfinite(fields["velocity"])) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_globe_inference_multi_bc(device: str) -> None: + """Inference with two BC types exercises cross-BC interaction plans.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + ### Create model with two BC types + model = GLOBE( + n_spatial_dims=3, + output_field_ranks={"pressure": 0, "velocity": 1}, + boundary_source_data_ranks={"no_slip": {}, "freestream": {}}, + reference_length_names=["test_length"], + reference_area=1.0, + hidden_layer_sizes=[8], + ).to(device) + model.eval() + + ### Use meshes with different face counts to stress-test cross-BC + ### interaction plans (same face count would mask index-range bugs). + mesh_no_slip = lumpy_sphere.load(subdivisions=1, device=device) # 80 faces + mesh_freestream = lumpy_sphere.load(subdivisions=0, device=device) # 20 faces + assert mesh_no_slip.n_cells != mesh_freestream.n_cells + + generator = torch.Generator(device=device).manual_seed(0) + prediction_points = torch.randn( + N_PREDICTION_POINTS, 3, generator=generator, device=device + ) + reference_lengths = { + "test_length": torch.tensor(1.0, dtype=torch.float32, device=device) + } + + ### Run inference + with torch.no_grad(): + output_mesh = model( + prediction_points=prediction_points, + boundary_meshes={ + "no_slip": mesh_no_slip, + "freestream": mesh_freestream, + }, + reference_lengths=reference_lengths, + ) + + ### Validate structure and outputs + fields = output_mesh.point_data + assert set(fields.keys()) == {"pressure", "velocity"} + assert fields["pressure"].shape == (N_PREDICTION_POINTS,) + assert fields["velocity"].shape == (N_PREDICTION_POINTS, 3) + assert torch.all(torch.isfinite(fields["pressure"])) + assert torch.all(torch.isfinite(fields["velocity"]))