Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in Pangu, FengWu attention window shift for asymmetric longitudes
- Fixed a bug in `mesh.sampling.find_nearest_cells`, where a mixup between L2 and L-inf norms
could cause slightly incorrect nearest-neighbor assignments in highly skewed meshes.
- Fixed TensorDict key-ordering bug in GLOBE's Barnes-Hut kernel that caused
incorrect results when `tensordict >= 0.12` reordered leaves during
TensorDict construction from dict literals mixing plain and nested keys.

### Security

Expand Down
31 changes: 21 additions & 10 deletions physicsnemo/experimental/models/globe/field_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,10 @@ def _evaluate_interactions(

basis_vector_components.append(vectors_hat["r"])

for k in vectors.keys(include_nested=True, leaves_only=True):
for k in sorted(
vectors.keys(include_nested=True, leaves_only=True),
key=str,
):
if k == "r":
continue

Expand Down Expand Up @@ -1341,14 +1344,24 @@ def _gather_and_evaluate(
target_positions[tgt_ids] - source_positions[src_ids]
) / reference_length

### Flatten source scalars into one tensor, gather once.
### Flatten source scalars into one tensor, gather once, split back.
# concatenate_leaves: 1 GPU kernel (torch.cat)
# [src_ids]: 1 GPU kernel (aten::index)
# Total: 2 kernels instead of K (one per TensorDict leaf).
# The split-back uses sorted keys matching concatenate_leaves's
# canonical column ordering so position i maps to the correct leaf.
src_scalar_keys = sorted(
source_scalars.keys(include_nested=True, leaves_only=True),
key=str,
)
gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids]
scalars = TensorDict(
{
"source_scalars": gathered_src_scalars,
"source_scalars": TensorDict(
{k: gathered_src_scalars[..., i] for i, k in enumerate(src_scalar_keys)},
batch_size=torch.Size([n_pairs]),
device=device,
),
"global_scalars": global_scalars.expand(
n_pairs, *global_scalars.batch_size
),
Expand All @@ -1362,18 +1375,16 @@ def _gather_and_evaluate(
# each vector leaf separately for magnitude/direction extraction and
# rotationally-equivariant basis construction. Integer indexing
# along the last dimension creates non-contiguous views (zero copies).
src_vector_keys = list(
source_vectors.keys(include_nested=True, leaves_only=True)
# Sorted keys match concatenate_leaves's canonical column ordering.
src_vector_keys = sorted(
source_vectors.keys(include_nested=True, leaves_only=True),
key=str,
)
gathered_src_vectors = concatenate_leaves(source_vectors)[src_ids]
gathered_vector_leaves = {
k: gathered_src_vectors[..., i]
for i, k in enumerate(src_vector_keys)
}
vectors = TensorDict(
{
"source_vectors": TensorDict(
gathered_vector_leaves,
{k: gathered_src_vectors[..., i] for i, k in enumerate(src_vector_keys)},
batch_size=torch.Size([n_pairs, self.n_spatial_dims]),
device=device,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor:
:math:`(*, F_{\text{total}})` where :math:`F_{\text{total}}` is the sum of
flattened features across all leaf tensors.

Leaves are sorted by ``str(key)`` before concatenation, producing a
canonical column ordering that is independent of TensorDict construction
order. This is necessary because ``TensorDict`` iteration order can
differ depending on how the object was constructed (dict literal vs
sequential ``__setitem__`` vs element-wise ops) and can change across
``tensordict`` library versions. Sorting eliminates this as a source
of bugs in any code that relies on positional column layout (e.g. the
MLP input assembly in :meth:`Kernel._evaluate_interactions`).

Parameters
----------
td : TensorDict[str, torch.Tensor]
Expand All @@ -96,14 +105,14 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor:
>>> result.shape
torch.Size([2, 17])
"""
tensors = tuple(td.values(include_nested=True, leaves_only=True))
if len(tensors) == 0:
items = list(td.items(include_nested=True, leaves_only=True))
if len(items) == 0:
return torch.empty(td.batch_size + torch.Size([0]), device=td.device)
else:
return torch.cat(
[t.reshape(td.batch_size + torch.Size([-1])) for t in tensors],
dim=-1,
)
items.sort(key=lambda kv: str(kv[0]))
return torch.cat(
[t.reshape(td.batch_size + torch.Size([-1])) for _, t in items],
dim=-1,
)


class TensorsByRank(dict):
Expand Down
15 changes: 8 additions & 7 deletions test/models/globe/test_barnes_hut_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,16 +906,14 @@ def test_bh_nested_source_data_keys(n_dims: int):

common_kwargs = dict(
n_spatial_dims=n_dims,
output_field_ranks={
k: (0 if v == "scalar" else 1) for k, v in output_field_ranks.items()
},
output_field_ranks=output_field_ranks,
source_data_ranks=source_data_ranks,
hidden_layer_sizes=[16],
)

bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE)
exact_kernel = Kernel(**common_kwargs)
exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False)
exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True)
bh_kernel.eval()
exact_kernel.eval()

Expand Down Expand Up @@ -960,9 +958,12 @@ def test_bh_nested_source_data_keys(n_dims: int):
torch.testing.assert_close(
bh_result[field_name],
exact_result[field_name],
atol=1e-3,
rtol=1e-2,
msg=f"Nested keys: {field_name!r} not close to exact at theta=0.01",
atol=1e-4,
rtol=1e-3,
msg=lambda default, f=field_name: (
f"Nested keys: {f!r} not close to exact at theta=0.01 "
f"(n_dims={n_dims}, torch={torch.__version__}).\n{default}"
),
)


Expand Down
Loading