Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in Pangu, FengWu attention window shift for asymmetric longitudes
- Fixed a bug in `mesh.sampling.find_nearest_cells`, where a mixup between L2 and L-inf norms
could cause slightly incorrect nearest-neighbor assignments in highly skewed meshes.
- Fixed TensorDict key-ordering bug in GLOBE's Barnes-Hut kernel that caused
incorrect results when `tensordict >= 0.12` reordered leaves during
TensorDict construction from dict literals mixing plain and nested keys.

### Security

Expand Down
31 changes: 21 additions & 10 deletions physicsnemo/experimental/models/globe/field_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,10 @@ def _evaluate_interactions(

basis_vector_components.append(vectors_hat["r"])

for k in vectors.keys(include_nested=True, leaves_only=True):
for k in sorted(
vectors.keys(include_nested=True, leaves_only=True),
key=str,
):
if k == "r":
continue

Expand Down Expand Up @@ -1341,14 +1344,24 @@ def _gather_and_evaluate(
target_positions[tgt_ids] - source_positions[src_ids]
) / reference_length

### Flatten source scalars into one tensor, gather once.
### Flatten source scalars into one tensor, gather once, split back.
# concatenate_leaves: 1 GPU kernel (torch.cat)
# [src_ids]: 1 GPU kernel (aten::index)
# Total: 2 kernels instead of K (one per TensorDict leaf).
# The split-back uses sorted keys matching concatenate_leaves's
# canonical column ordering so position i maps to the correct leaf.
src_scalar_keys = sorted(
source_scalars.keys(include_nested=True, leaves_only=True),
key=str,
)
gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids]
scalars = TensorDict(
{
"source_scalars": gathered_src_scalars,
"source_scalars": TensorDict(
{k: gathered_src_scalars[..., i] for i, k in enumerate(src_scalar_keys)},
batch_size=torch.Size([n_pairs]),
device=device,
),
"global_scalars": global_scalars.expand(
n_pairs, *global_scalars.batch_size
),
Expand All @@ -1362,18 +1375,16 @@ def _gather_and_evaluate(
# each vector leaf separately for magnitude/direction extraction and
# rotationally-equivariant basis construction. Integer indexing
# along the last dimension creates non-contiguous views (zero copies).
src_vector_keys = list(
source_vectors.keys(include_nested=True, leaves_only=True)
# Sorted keys match concatenate_leaves's canonical column ordering.
src_vector_keys = sorted(
source_vectors.keys(include_nested=True, leaves_only=True),
key=str,
)
gathered_src_vectors = concatenate_leaves(source_vectors)[src_ids]
gathered_vector_leaves = {
k: gathered_src_vectors[..., i]
for i, k in enumerate(src_vector_keys)
}
vectors = TensorDict(
{
"source_vectors": TensorDict(
gathered_vector_leaves,
{k: gathered_src_vectors[..., i] for i, k in enumerate(src_vector_keys)},
batch_size=torch.Size([n_pairs, self.n_spatial_dims]),
device=device,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor:
:math:`(*, F_{\text{total}})` where :math:`F_{\text{total}}` is the sum of
flattened features across all leaf tensors.

Leaves are sorted by ``str(key)`` before concatenation, producing a
canonical column ordering that is independent of TensorDict construction
order. This is necessary because ``TensorDict`` iteration order can
differ depending on how the object was constructed (dict literal vs
sequential ``__setitem__`` vs element-wise ops) and can change across
``tensordict`` library versions. Sorting eliminates this as a source
of bugs in any code that relies on positional column layout (e.g. the
MLP input assembly in :meth:`Kernel._evaluate_interactions`).

Parameters
----------
td : TensorDict[str, torch.Tensor]
Expand All @@ -96,14 +105,14 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor:
>>> result.shape
torch.Size([2, 17])
"""
tensors = tuple(td.values(include_nested=True, leaves_only=True))
if len(tensors) == 0:
items = list(td.items(include_nested=True, leaves_only=True))
if len(items) == 0:
return torch.empty(td.batch_size + torch.Size([0]), device=td.device)
else:
return torch.cat(
[t.reshape(td.batch_size + torch.Size([-1])) for t in tensors],
dim=-1,
)
items.sort(key=lambda kv: str(kv[0]))
return torch.cat(
[t.reshape(td.batch_size + torch.Size([-1])) for _, t in items],
dim=-1,
)


class TensorsByRank(dict):
Expand Down
110 changes: 101 additions & 9 deletions test/models/globe/test_barnes_hut_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,10 @@ def test_bh_nested_source_data_keys(n_dims: int):

The aggregation, split_by_leaf_rank, and TensorDict.cat operations must
handle this nesting correctly.

The ``msg`` passed to :func:`torch.testing.assert_close` is a callable
so the default "Greatest absolute/relative difference" diagnostics are
preserved when a failure occurs in CI.
"""
torch.manual_seed(DEFAULT_SEED)
n_src, n_tgt = 30, 15
Expand All @@ -906,16 +910,23 @@ def test_bh_nested_source_data_keys(n_dims: int):

common_kwargs = dict(
n_spatial_dims=n_dims,
output_field_ranks={
k: (0 if v == "scalar" else 1) for k, v in output_field_ranks.items()
},
output_field_ranks=output_field_ranks,
source_data_ranks=source_data_ranks,
hidden_layer_sizes=[16],
)

bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE)
exact_kernel = Kernel(**common_kwargs)
exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False)

### Invariant 1: state_dict transfer is complete and bit-exact.
exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True)
bh_sd, ex_sd = bh_kernel.state_dict(), exact_kernel.state_dict()
mismatched = [k for k in bh_sd if not torch.equal(bh_sd[k], ex_sd[k])]
assert not mismatched, (
f"state_dict value mismatch after load "
f"(torch={torch.__version__}): {mismatched}"
)

bh_kernel.eval()
exact_kernel.eval()

Expand Down Expand Up @@ -953,16 +964,97 @@ def test_bh_nested_source_data_keys(n_dims: int):
"global_data": TensorDict({}, batch_size=[]),
}

exact_result = exact_kernel(**data)
bh_result = bh_kernel(**data, theta=0.01)
### Invariant 2: per-pair pre-aggregation outputs are bit-identical.
# At theta=0.01 all pairs are near-field, so BH and Exact both call
# _evaluate_interactions on the same (target, source) pairs with
# identical weights. Capture the pre-aggregation output from each,
# reindex BH's pair ordering into Exact's row-major (t, s) order,
# and compare tightly. If this fires, there is a genuine algorithmic
# divergence (e.g. tensordict iteration-order change across library
# versions) and the final-sum tolerance is masking a real bug.
captures: dict[str, dict[str, torch.Tensor]] = {}
orig_eval = Kernel._evaluate_interactions

def _capturing_eval(tag: str):
def _patched(self, *, scalars, vectors, device):
out = orig_eval(self, scalars=scalars, vectors=vectors, device=device)
captures[tag] = {k: v.detach().clone() for k, v in out.items()}
return out

return _patched
Comment thread
peterdsharpe marked this conversation as resolved.
Outdated

try:
Kernel._evaluate_interactions = _capturing_eval("exact")
exact_result = exact_kernel(**data)
Kernel._evaluate_interactions = _capturing_eval("bh")
bh_result = bh_kernel(**data, theta=0.01)
finally:
Kernel._evaluate_interactions = orig_eval

src_tree = ClusterTree.from_points(
data["source_points"],
leaf_size=DEFAULT_LEAF_SIZE,
)
tgt_tree = ClusterTree.from_points(
data["target_points"],
leaf_size=DEFAULT_LEAF_SIZE,
)
plan = src_tree.find_dual_interaction_pairs(
target_tree=tgt_tree,
theta=0.01,
)
assert plan.n_near == n_src * n_tgt, (
f"Expected all-near at theta=0.01, got n_near={plan.n_near} "
f"of dense={n_src * n_tgt}"
)

row_of_pair = plan.near_target_ids * n_src + plan.near_source_ids
inv_perm = torch.empty_like(row_of_pair)
inv_perm[row_of_pair] = torch.arange(plan.n_near)

for field_name in output_field_ranks:
ex_pp = captures["exact"][field_name]
bh_pp = captures["bh"][field_name]
ex_flat = ex_pp.reshape(n_tgt * n_src, *ex_pp.shape[2:])
bh_reordered = bh_pp[inv_perm]

torch.testing.assert_close(
bh_reordered,
ex_flat,
atol=1e-6,
rtol=1e-6,
msg=lambda default, f=field_name: (
f"BH/Exact per-pair pre-aggregation {f!r} divergence "
f"(torch={torch.__version__}). BH and Exact paths are "
f"not computing identical per-pair tensors despite "
f"identical inputs and weights.\n{default}"
),
)
Comment thread
peterdsharpe marked this conversation as resolved.
Outdated

### Final aggregation comparison.
# The invariant checks above guarantee that BH and Exact computed
# bit-identical per-pair outputs. The only remaining difference is
# aggregation order: einsum vs scatter_add_. Tolerance matches
# test_bh_convergence_to_exact.
for field_name in output_field_ranks:

def _msg(
default: str,
field: str = field_name,
dims: int = n_dims,
) -> str:
return (
f"Nested keys: {field!r} not close to exact at theta=0.01 "
f"(n_dims={dims}, num_threads={torch.get_num_threads()}, "
f"torch={torch.__version__}).\n{default}"
)

torch.testing.assert_close(
bh_result[field_name],
exact_result[field_name],
atol=1e-3,
rtol=1e-2,
msg=f"Nested keys: {field_name!r} not close to exact at theta=0.01",
atol=1e-4,
rtol=1e-3,
msg=_msg,
)


Expand Down
Loading