Skip to content

[Stacked] Adds GLOBE Model updates for 3D + Dual Tree Traversal Acceleration#1494

Open
peterdsharpe wants to merge 29 commits intoNVIDIA:mainfrom
peterdsharpe:psharpe/stacked-add-GLOBE-3D-DTT-model-changes
Open

[Stacked] Adds GLOBE Model updates for 3D + Dual Tree Traversal Acceleration#1494
peterdsharpe wants to merge 29 commits intoNVIDIA:mainfrom
peterdsharpe:psharpe/stacked-add-GLOBE-3D-DTT-model-changes

Conversation

@peterdsharpe
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

This PR includes:

  • GLOBE model architecture modifications that fundamentally break the previous O(n^2) bottleneck, enabling scaling to large industrial problems. Scaling is now (both theoretically and empirically verified) to be O(n), which is very exciting to see for an all-to-all interaction problem. This is performed by doing a dual-tree variant of classical Barnes-Hut (or Fast-Multipole-Method-like) hierarchical acceleration by exploiting spatial locality. Normally Barnes-Hut would get you to O(n log n), but it turns out that by putting both the sources AND the targets in trees (a "dual tree traversal" approach), you can make it so that the interactions are no longer the bottleneck at all. This leaves O(n) operations as the bottleneck.
    • This dual-tree algorithm is implemented using a variant of a Linear Bounding Volume Hierarchy (LBVH), which is a nice way to make this more GPU-friendly (basically, this allows each layer of the tree (or in the case of dual tree traversal, each combination of layers from both trees) to coalesce into a single kernel launch, rather than doing a full tree traversal for every interaction.

Adds a theory doc about the dual tree traversal approach at ./physicsnemo/experimental/models/globe/hierarchical_acceleration.md.

This PR is:

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

…uations

- Introduced a new `ClusterTree` class for spatial decomposition, enabling efficient dual-tree traversal.
- Updated `GLOBE` model to utilize the new clustering mechanism, significantly reducing kernel evaluation complexity.
- Enhanced `CHANGELOG.md` to reflect the addition of the dual-tree algorithm and its impact on performance.
- Added comprehensive tests for the `BarnesHutKernel` and `ClusterTree` functionalities to ensure correctness and performance.
- Refactored existing kernel evaluation methods to integrate the new dual-tree approach, improving overall efficiency.

This update is crucial for handling large mesh scales effectively, particularly in scenarios with 800k+ faces.
commit fb4f159
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Mar 11 23:11:04 2026 -0400

    Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D (NVIDIA#1483)

    * Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D

    * Fixes docstring example for compute_cell_normals to reflect correct normal vector output in 2D case.

    * Refactor compute_cell_areas and compute_cell_normals functions to use match-case syntax for improved readability and maintainability.

commit 219aca3
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Wed Mar 11 17:23:16 2026 -0700

    Fix window shift in pangu, fengwu (NVIDIA#1492)

    * Fix window shift in pangu, fengwu

    * changelog

commit 26fcdce
Author: Kaustubh Tangsali <ktangsali@nvidia.com>
Date:   Wed Mar 11 20:20:33 2026 +0000

    fix linting issues

commit ca15f47
Author: Charlelie Laurent <claurent@nvidia.com>
Date:   Wed Mar 11 12:08:12 2026 -0700

    Resolved conflicts in checkpoint.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 95eca0c
Author: Kaustubh Tangsali <ktangsali@nvidia.com>
Date:   Tue Mar 10 23:27:40 2026 +0000

    remove the conflict block

commit 33d9a7e
Author: Kaustubh Tangsali <ktangsali@nvidia.com>
Date:   Tue Mar 10 23:22:22 2026 +0000

    update versioins

commit fbfb896
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Mon Mar 9 16:42:38 2026 -0700

    Improved docs for module.py + multiple cleanups in docs (NVIDIA#1478)

    * Improved docs for module.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Fix in save_checkpoint and load_checkpoint docstrings

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Addressed PR comments

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Improvements in docs

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Moved down section about static capture in physicsnemo.utils.rst

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    ---------

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 13c6fb4
Author: Corey adams <6619961+coreyjadams@users.noreply.github.com>
Date:   Mon Mar 9 11:25:44 2026 -0500

    Update Datapipes API (NVIDIA#1468)

    * Trying again with datapipes check in

    * Update docs/api/datapipes/physicsnemo.datapipes.cae.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.cae.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Update docs/api/datapipes/physicsnemo.datapipes.transforms.rst

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

    * Resolve api commits for datapipes

    * Remove old datapipes api

    * Add link to the datapipe docs.

    ---------

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

commit 32f2261
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Fri Mar 6 15:31:53 2026 -0800

    Fix unresolved conflict (NVIDIA#1477)

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 9a32517
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Fri Mar 6 15:09:50 2026 -0800

    Diffusion API docs (NVIDIA#1473)

    * New API docs for diffusion

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Some fixes in nested API references

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Revert some changes

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Some fixes

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Some clarifications in introduction.rst

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Some clarification in diffusion models.rst

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Fixed note sections in preconditioners.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Fix some broken short-form refs in losses.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Updated DPSScorePredictor class name in samplers.rst

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Enhance clarity and structure of introduction section

    Refactor introduction section for clarity and readability. Improve formatting and organization of key concepts related to the PhysicsNeMo diffusion framework.

    * Refactor metrics.rst for improved clarity and formatting

    Reformatted the description of the module to use bullet points for clarity. Adjusted wording for consistency and readability.

    * Fix punctuation and enhance clarity in models.rst

    Corrected punctuation and improved clarity in the documentation.

    * Addressed PR comments

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    ---------

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
    Co-authored-by: megnvidia <mmiranda@nvidia.com>

commit d723767
Author: Kaustubh Tangsali <71059996+ktangsali@users.noreply.github.com>
Date:   Thu Mar 5 18:56:46 2026 -0800

    Minor edits to the install guide (NVIDIA#1470)

    * minor edits to the install guide

    * add more details

    * minor doc fix

    * add transolver to the api index

commit 6bb6d04
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Thu Mar 5 12:38:13 2026 -0800

    Fixes and renaming in dps_guidance.py (NVIDIA#1471)

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 1e14a56
Author: Corey adams <6619961+coreyjadams@users.noreply.github.com>
Date:   Wed Mar 4 15:46:15 2026 -0600

    Update API docs and structure. (NVIDIA#1337)

    * Update API docs and structure.

    * clean-up and re-organization of docs

    * fix based on new api

    * remove unused sections

    * update image paths

    * Update docs/api/models/diffusion.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/models/operators.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/models/weather.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.core.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.diffusion.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.diffusion.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Update docs/api/physicsnemo.utils.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * fix formatting

    ---------

    Co-authored-by: Kaustubh Tangsali <71059996+ktangsali@users.noreply.github.com>
    Co-authored-by: Kaustubh Tangsali <ktangsali@nvidia.com>
    Co-authored-by: megnvidia <mmiranda@nvidia.com>

commit f22cfbf
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Mar 2 17:56:32 2026 -0500

    Adds PhysicsNeMo-Mesh API Docs (NVIDIA#1461)

    * Adds mesh docs on top of RC branch

    * Update docs/mesh/boundaries.rst

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Apply suggestions from code review

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

    * Refine documentation for mesh geometry functions to clarify usage and API exposure for advanced cases.

    * better subdivision descriptions

    * clearer docs

    ---------

    Co-authored-by: megnvidia <mmiranda@nvidia.com>

commit 1e9fdd0
Author: Corey adams <6619961+coreyjadams@users.noreply.github.com>
Date:   Fri Feb 27 12:12:02 2026 -0600

    GeoTransolver: Fix attention and turn off feature broadcasting. (NVIDIA#1415)

    * Fix attention and turn off feature broadcasting.

    * Fix scalar loading shapes

    * Update volume.yaml

    Ensure the volume example works out of the box.

    * Fix Geotransolver inference tests

commit 4f0a3cb
Author: Kaustubh Tangsali <71059996+ktangsali@users.noreply.github.com>
Date:   Thu Feb 26 15:00:42 2026 -0800

    Wandb fixes (NVIDIA#1458)

    * Add wandb to requirements

    * Modify requirements for trimesh and add wandb

    Updated trimesh version constraint and added wandb.

commit 082cd36
Author: Kaustubh Tangsali <71059996+ktangsali@users.noreply.github.com>
Date:   Thu Feb 26 15:00:16 2026 -0800

    Fixes from SFB builds / testing (NVIDIA#1459)

    * Remove 'perf' extra from physicsnemo installation because NGC containers already include transformer-engine

    * update deterministic settings

commit f6ca818
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Wed Feb 25 13:20:46 2026 -0800

    Fix broken cross-ref links in docstrings (NVIDIA#1454)

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 47b8ff0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Mar 11 15:21:31 2026 -0400

    Deprecates `physicsnemo.utils.mesh.py` (NVIDIA#1487)

    * Adds DeprecationWarning on module.

    * Changelog update for deprecations

commit 74c91f9
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Mar 11 14:20:25 2026 -0400

    Adds docstrings to CombinedOptimizer tests. (NVIDIA#1486)
@peterdsharpe peterdsharpe marked this pull request as ready for review March 12, 2026 14:51
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This PR introduces a dual-tree Barnes-Hut acceleration for the GLOBE model, reducing kernel evaluation complexity from O(N²) to O(N) by building spatial cluster trees over both source and target points and classifying interactions into four quadrants (near-near, far-far, near-far, far-near). The implementation uses a Morton-code LBVH for GPU-friendly tree construction and a bottom-up propagation pass for aggregate centroids.

Key changes:

  • cluster_tree.py (new): LBVH-based ClusterTree tensorclass + DualInteractionPlan dataclass + find_dual_interaction_pairs traversal. The two-stage leaf-pair expansion in _expand_dual_leaf_hits correctly classifies every interaction without double-counting.
  • field_kernel.py: New BarnesHutKernel (four-phase kernel evaluation with chunked near-field and node-broadcast far-field) and MultiscaleKernel (one BarnesHutKernel branch per reference length, shared trees/plans). One P1 issue: MultiscaleKernel.forward() mutates the caller's global_data TensorDict in-place (line 1760) by injecting log_reference_length_ratios, which leaks into the output Mesh.global_data and modifies the user's original object — contrary to the "passed through from input" semantics documented in the forward signature.
  • model.py: GLOBE updated to precompute trees and plans once per forward pass, reuse them across communication layers, and explicitly del comm_plans after the communication phase to release memory at scale.
  • _ragged.py (new): Clean, GPU-friendly _ragged_arange helper.
  • Test coverage is thorough: source-coverage invariants, convergence to exact at θ→0, and all four equivariance properties are verified.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/globe/cluster_tree.py New file implementing a GPU-compatible LBVH cluster tree and dual-tree traversal for Barnes-Hut acceleration. Core algorithm is well-structured with clean separation between tree construction (O(log N) iterations), bottom-up propagation, and dual-tree traversal. The two-stage leaf-pair expansion (_expand_dual_leaf_hits) correctly classifies interactions into four quadrants without double-counting. Minor: stable sort on Morton codes is conservative but safe.
physicsnemo/experimental/models/globe/field_kernel.py Major update adding BarnesHutKernel and MultiscaleKernel classes. The four-phase interaction evaluation (near-near, far-far, near-far, far-near) is logically correct, with gradient-checkpointing and auto-chunking for memory efficiency. Key issue: MultiscaleKernel.forward() mutates the caller's global_data TensorDict (line 1760) by adding log_reference_length_ratios, which leaks into the output Mesh and modifies the user's original object. The _compiled_evaluate_interactions pattern via object.__setattr__ is functional but lacks a comment explaining the bypass.
physicsnemo/experimental/models/globe/model.py GLOBE model updated to use dual-tree Barnes-Hut acceleration via ClusterTree and DualInteractionPlan. Architecture cleanly separates tree/plan construction from kernel evaluation, and correctly reuses plans across communication layers. comm_plans are explicitly deleted after use to free memory (important at 800k+ faces). The _build_source_data_ranks correctly reflects the latent feature structure between layers.
physicsnemo/mesh/spatial/_ragged.py New helper _ragged_arange for expanding ragged (variable-length) segment descriptors into flat index arrays. Implementation is clean and GPU-friendly (no Python loops, uses repeat_interleave + cumsum). Correctly handles non-monotone starts, zero-count segments, and empty inputs.
test/models/globe/test_barnes_hut_kernel.py Comprehensive test suite covering ClusterTree construction, DualInteractionPlan source coverage invariants, BarnesHutKernel convergence to exact (theta→0), translation/rotation/parity equivariance, all four interaction quadrants, and a GLOBE-like production configuration. Good coverage of the new dual-tree algorithm.

Last reviewed commit: "Merge branch 'main' ..."

Comment thread physicsnemo/experimental/models/globe/field_kernel.py Outdated
Comment thread physicsnemo/experimental/models/globe/field_kernel.py
Comment thread CHANGELOG.md
…rs, ensuring compatibility with AMP autocast settings.
…improved compatibility with torch.compile and dynamic input handling.
@peterdsharpe
Copy link
Copy Markdown
Collaborator Author

peterdsharpe commented Mar 12, 2026

Greptile's find of:

Critical logic bug — _evaluate_communication_hyperlayer in model.py passes target_trees=cluster_trees and dual_plans=comm_plans to evaluate_hyperlayer, where both dicts are keyed by source BC type. For cross-BC-type interactions (e.g., evaluating freestream sources against no_slip targets), the kernel receives the source BC's self-interaction plan and tree as the target, causing wrong target centroids, incorrect broadcast indices, and potential out-of-bounds scatter_add when the two BC types have different face counts. The existing test_inference.py only exercises a single-BC-type scenario and therefore does not catch this regression.

Is actually valid. Thankfully, it doesn't affect any of our current usage (AirFRANS or DrivAerML) since these use a single BC type (no_slip) currently, but I'll get this fixed before merging.

UPDATE: fixed

…ions. Updated documentation to reflect changes in self-interaction and cross-BC interaction handling. Modified the `GLOBE` class to compute dual interaction plans for all (source BC, destination BC) pairs, improving efficiency in communication layers. Added tests for multi-BC inference to validate functionality.
- Introduced variables for output name and directory to enhance flexibility.
- Updated the AIRFRANS_DATA_DIR path for consistency with dataset location.
- Set OMP_NUM_THREADS to 1 to prevent thread oversubscription during data loading.
- Simplified head node retrieval for multi-node training setup.
…ndling

- Added internal_level_ids and internal_level_offsets to ClusterTree for efficient storage of internal node IDs in CSR-packed level order.
- Introduced internal_nodes_per_level property to retrieve internal node IDs grouped by tree depth.
- Updated _propagate_centroids_bottom_up to utilize cached internal node levels, improving performance.
- Modified BarnesHutKernel to leverage cached level ordering for bottom-up propagation, enhancing efficiency in node strength calculations.
- Adjusted lazy compilation settings for MLP and evaluation pipeline to optimize performance during execution.
@peterdsharpe peterdsharpe self-assigned this Mar 17, 2026
Comment thread physicsnemo/experimental/models/globe/field_kernel.py
Comment thread physicsnemo/experimental/models/globe/field_kernel.py Outdated
Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice work as usual Peter - for the GLOBE scaling challenges, it's clearly a winner. I'd like to request some changes, to make it worse, so that GeoTransolver can remain the winning 3d model a little longer! No just kidding :).

I do have some changes I want to propose to you. I think the structure and architecture of the Dual Tree Traversal for this problem would me a lot of sense to be independent of GLOBE - if it was a semi-standalone component of physicsnemo enabling arbitrary spatial computation of of some sort of 1/R^k interaction, across point clouds, and it's differentiable, there could be really nice applications in many modeling fields. I'm thinking it would be immediately applicable in quantum chemistry, also maybe electromagetism (though we don't do a ton of that, it's definitely a possible use case).

Another reason to do this is that GLOBE is probably not the final answer, and we don't want to couple too tightly to it. So the next models you build with even better design could still use these accelerations maybe without having to refactor globe to get it out.

A final reason: You mentioned the need to benchmark and consolidate some of this into kernels that are in warp or other languages. IT will be much easier to do that with standalone benchmarks and synthetic codes that we can use for testing, development, and benchmarking. Plus we want this integrated into Oliver's benchmarking suite.

At the end of the day, this is killer work and I am mostly asking, "Hey, can we package this up more amazingly with it's own bow on top, and let GLOBE use it, rather than build it so tightly with GLOBE?" - what do you think?

Comment thread physicsnemo/mesh/spatial/_ragged.py
Comment on lines +333 to +338
The key memory optimization: each chunk's gather and evaluate steps are
wrapped together in a single `torch.utils.checkpoint.checkpoint` call. The
checkpoint boundary is drawn so that autograd saves only the compact int64
index arrays (~8 bytes/pair) and references to the shared source data (O(1)),
rather than the gathered float data (~300 bytes/pair). This is a ~37x
reduction in checkpoint-saved memory per chunk.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't follow this entirely. Can you add a section on the backwards pass before this to give some perspective on how the basic backwards pass might work?

Copy link
Copy Markdown
Collaborator Author

@peterdsharpe peterdsharpe Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The basic logic is:

  1. Doing this forward pass without chunking requires a ton of memory. Without DTT, we would see 300 bytes/pair * (80e3 faces)^2 ~= 1920 GB per hyperlayer. With DTT, it gets way better, as we might only have ~2e7 pairs. But that's still ~6 GB per hyperlayer in the forward pass, and the backward pass is typically 2x-5x larger peak memory - so call it 20 GB per hyperlayer. With ~3 hyperlayers + other overhead, this can quickly OOM a H100.
  2. So, we want to do chunking to mitigate memory usage. The problem is, if we just use chunking alone, we don't actually save any memory! This is because the same intermediate quantities need to be materialized simultaneously, they're just split across multiple tensors (i.e., per-chunk) rather than 1 tensor.
  3. To actually save memory, we need to do chunking + checkpointing. That way the memory usage isn't sum(size of each chunk), but rather max(size of each chunk).
  4. At the same time, we need to save something about the other chunks in order to rematerialize them - so the natural thing to save is the point in the computational graph that is the most-compressed latent representation of the chunks.
  5. That most-compressed representation is between hyperlayers, hence why we checkpoint there. This is where you get the ~8 bytes per pair from, since we can actually reduce this to just an index.

I can update this document to add something to this effect.

Comment thread physicsnemo/experimental/models/globe/hierarchical_acceleration.md
Comment thread physicsnemo/experimental/models/globe/cluster_tree.py
@peterdsharpe
Copy link
Copy Markdown
Collaborator Author

Thank you so much for the review @coreyjadams!

Summarizing offline discussion:

  • Resolved that the cluster_tree core logic for the DTT can go inside physicsnemo.experimental.nn(.geometric?).

@peterdsharpe
Copy link
Copy Markdown
Collaborator Author

Thanks for the review @coreyjadams ! Documenting our offline chat:

Almost all review feedback accepted with new commits; exception is the proposal to extract the DTT traversal (cluster_tree.py) into a standalone place for re-use by other modules.

This is an excellent idea, but given that we're still experimenting with updates to the core code (e.g., 0th-order vs. 1st-order kernel expansions, as discussed offline) I'm hesitant to have other consumers rely on it just yet - this exposes a big maintenance surface. However, we should 100% do this down the road when GLOBE gets pulled out of physicsnemo.experimental into the main models folder - earmarking that here so we don't forget.

Thanks again for the review!

Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterdsharpe I think this is great work as always. but @CharlelieLrt is scraping our github conversations for his emulation bots so I have to make some crazy comments to throw him off: I think your integration of the 4D capacitive coupling into the gravitational clusterer was sloppy and you could have done better. And the ring computation of the square hole art deco emulator was subpar. You should have turned the synergy setting to 7. Otherwise, lackluster progress. Try harder!

(great work, approving, I like the plan to decouple when eventually moving out of experimental)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants