diff --git a/CHANGELOG.md b/CHANGELOG.md index e7ac7f2659..c1c799ee75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in Pangu, FengWu attention window shift for asymmetric longitudes - Fixed a bug in `mesh.sampling.find_nearest_cells`, where a mixup between L2 and L-inf norms could cause slightly incorrect nearest-neighbor assignments in highly skewed meshes. +- Fixed TensorDict key-ordering bug in GLOBE's Barnes-Hut kernel that caused + incorrect results when `tensordict >= 0.12` reordered leaves during + TensorDict construction from dict literals mixing plain and nested keys. ### Security diff --git a/physicsnemo/experimental/models/globe/field_kernel.py b/physicsnemo/experimental/models/globe/field_kernel.py index 7f7cda1d36..e6a11f1916 100644 --- a/physicsnemo/experimental/models/globe/field_kernel.py +++ b/physicsnemo/experimental/models/globe/field_kernel.py @@ -680,7 +680,10 @@ def _evaluate_interactions( basis_vector_components.append(vectors_hat["r"]) - for k in vectors.keys(include_nested=True, leaves_only=True): + for k in sorted( + vectors.keys(include_nested=True, leaves_only=True), + key=str, + ): if k == "r": continue @@ -1341,14 +1344,24 @@ def _gather_and_evaluate( target_positions[tgt_ids] - source_positions[src_ids] ) / reference_length - ### Flatten source scalars into one tensor, gather once. + ### Flatten source scalars into one tensor, gather once, split back. # concatenate_leaves: 1 GPU kernel (torch.cat) # [src_ids]: 1 GPU kernel (aten::index) # Total: 2 kernels instead of K (one per TensorDict leaf). + # The split-back uses sorted keys matching concatenate_leaves's + # canonical column ordering so position i maps to the correct leaf. + src_scalar_keys = sorted( + source_scalars.keys(include_nested=True, leaves_only=True), + key=str, + ) gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids] scalars = TensorDict( { - "source_scalars": gathered_src_scalars, + "source_scalars": TensorDict( + {k: gathered_src_scalars[..., i] for i, k in enumerate(src_scalar_keys)}, + batch_size=torch.Size([n_pairs]), + device=device, + ), "global_scalars": global_scalars.expand( n_pairs, *global_scalars.batch_size ), @@ -1362,18 +1375,16 @@ def _gather_and_evaluate( # each vector leaf separately for magnitude/direction extraction and # rotationally-equivariant basis construction. Integer indexing # along the last dimension creates non-contiguous views (zero copies). - src_vector_keys = list( - source_vectors.keys(include_nested=True, leaves_only=True) + # Sorted keys match concatenate_leaves's canonical column ordering. + src_vector_keys = sorted( + source_vectors.keys(include_nested=True, leaves_only=True), + key=str, ) gathered_src_vectors = concatenate_leaves(source_vectors)[src_ids] - gathered_vector_leaves = { - k: gathered_src_vectors[..., i] - for i, k in enumerate(src_vector_keys) - } vectors = TensorDict( { "source_vectors": TensorDict( - gathered_vector_leaves, + {k: gathered_src_vectors[..., i] for i, k in enumerate(src_vector_keys)}, batch_size=torch.Size([n_pairs, self.n_spatial_dims]), device=device, ), diff --git a/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py b/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py index e3ddb7c15e..8e939ed975 100644 --- a/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py +++ b/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py @@ -73,6 +73,15 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor: :math:`(*, F_{\text{total}})` where :math:`F_{\text{total}}` is the sum of flattened features across all leaf tensors. + Leaves are sorted by ``str(key)`` before concatenation, producing a + canonical column ordering that is independent of TensorDict construction + order. This is necessary because ``TensorDict`` iteration order can + differ depending on how the object was constructed (dict literal vs + sequential ``__setitem__`` vs element-wise ops) and can change across + ``tensordict`` library versions. Sorting eliminates this as a source + of bugs in any code that relies on positional column layout (e.g. the + MLP input assembly in :meth:`Kernel._evaluate_interactions`). + Parameters ---------- td : TensorDict[str, torch.Tensor] @@ -96,14 +105,14 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor: >>> result.shape torch.Size([2, 17]) """ - tensors = tuple(td.values(include_nested=True, leaves_only=True)) - if len(tensors) == 0: + items = list(td.items(include_nested=True, leaves_only=True)) + if len(items) == 0: return torch.empty(td.batch_size + torch.Size([0]), device=td.device) - else: - return torch.cat( - [t.reshape(td.batch_size + torch.Size([-1])) for t in tensors], - dim=-1, - ) + items.sort(key=lambda kv: str(kv[0])) + return torch.cat( + [t.reshape(td.batch_size + torch.Size([-1])) for _, t in items], + dim=-1, + ) class TensorsByRank(dict): diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 7b8d3a7d9c..8ccfa0c61a 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -906,16 +906,14 @@ def test_bh_nested_source_data_keys(n_dims: int): common_kwargs = dict( n_spatial_dims=n_dims, - output_field_ranks={ - k: (0 if v == "scalar" else 1) for k, v in output_field_ranks.items() - }, + output_field_ranks=output_field_ranks, source_data_ranks=source_data_ranks, hidden_layer_sizes=[16], ) bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) exact_kernel = Kernel(**common_kwargs) - exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True) bh_kernel.eval() exact_kernel.eval() @@ -960,9 +958,12 @@ def test_bh_nested_source_data_keys(n_dims: int): torch.testing.assert_close( bh_result[field_name], exact_result[field_name], - atol=1e-3, - rtol=1e-2, - msg=f"Nested keys: {field_name!r} not close to exact at theta=0.01", + atol=1e-4, + rtol=1e-3, + msg=lambda default, f=field_name: ( + f"Nested keys: {f!r} not close to exact at theta=0.01 " + f"(n_dims={n_dims}, torch={torch.__version__}).\n{default}" + ), )