Skip to content

Cleanup/equivariant ops transformer natten#1540

Open
loliverhennigh wants to merge 15 commits intoNVIDIA:mainfrom
loliverhennigh:cleanup/equivariant-ops-transformer-natten
Open

Cleanup/equivariant ops transformer natten#1540
loliverhennigh wants to merge 15 commits intoNVIDIA:mainfrom
loliverhennigh:cleanup/equivariant-ops-transformer-natten

Conversation

@loliverhennigh
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Cleanup of ops

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 30, 2026

Greptile Summary

This PR reorganises physicsnemo.nn.functional by splitting the monolithic equivariant_ops.py into a proper sub-package (equivariant_ops/) and moving the natten neighborhood-attention wrappers into a new transformer/ sub-package. Both old public import paths (physicsnemo.nn.functional.na1d, physicsnemo.nn.functional.smooth_log, etc.) are preserved via the top-level __init__.py.

Highlights

  • The equivariant ops now follow the FunctionSpec pattern (consistent with the rest of the functional module), enabling pluggable implementations and ASV benchmarks.
  • _safe_normalize replaces the old in-place e_kappa[norm == 0] = 0.0 pattern with a torch.where-based approach that is differentiable and avoids NaN in the backward pass.
  • _legendre_polynomials_impl correctly branches on TensorDict and uses x.apply(torch.ones_like) instead of the old torch.ones_like(x) call that required a # type: ignore.
  • All __torch_function__ dispatch wiring for ShardTensor is preserved; _public_function is set on each NA* class after make_function, keeping function-object identity consistent with the handlers registered in natten_patches.py.

Issues found

  • legendre_polynomials has no test coverage in the new test_equivariant_ops.py despite being one of the five exported ops (see inline comment).
  • NA2D _BENCHMARK_CASES labels overload h for both height and num-heads (see inline comment).
  • A spurious trailing-whitespace blank line was introduced in darcy.py (see inline comment).

Important Files Changed

Filename Overview
physicsnemo/nn/functional/equivariant_ops/_common.py Shared helpers (_validate_last_dim, _safe_normalize, _vector_project_impl) extracted from the old monolithic module; _safe_normalize replaces the old in-place NaN-overwrite pattern with a differentiable torch.where approach.
physicsnemo/nn/functional/equivariant_ops/legendre_polynomials.py Refactored into FunctionSpec; correctly handles TensorDict via x.apply(torch.ones_like) instead of the old ignored-type-error torch.ones_like(x) call. No test coverage for legendre_polynomials in the new test file.
physicsnemo/nn/functional/transformer/natten.py natten functions wrapped in FunctionSpec with torch_function dispatch; _public_function set after make_function call. NA2D benchmark label overloads 'h' for both height and num_heads, which is confusing.
test/nn/functional/test_equivariant_ops.py New test file covering smooth_log (TensorDict), polar_and_dipole_basis, spherical_basis, and vector_project. legendre_polynomials is not tested despite being part of the new package.
physicsnemo/datapipes/benchmarks/darcy.py Only change is a spurious blank line with a trailing space added at line 276; no functional change.

Reviews (1): Last reviewed commit: "Refactor equivariant ops package and mov..." | Re-trigger Greptile

Comment on lines +20 to +27

from physicsnemo.nn.functional.equivariant_ops import (
polar_and_dipole_basis,
smooth_log,
spherical_basis,
vector_project,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing test coverage for legendre_polynomials

The new test file covers smooth_log, polar_and_dipole_basis, spherical_basis, and vector_project, but legendre_polynomials (one of the five ops in the new package) has no test at all. Given that the TensorDict branch was notably reworked (replacing torch.ones_like(x) with x.apply(torch.ones_like)), a basic correctness test would be valuable. For example, you could assert the recurrence relation values for a known input:

def test_legendre_polynomials_values() -> None:
    """P_0 through P_3 should match known values at x=0.5."""
    from physicsnemo.nn.functional.equivariant_ops import legendre_polynomials

    x = torch.tensor([0.5], dtype=torch.float32)
    polys = legendre_polynomials(x, 4)

    assert len(polys) == 4
    torch.testing.assert_close(polys[0], torch.ones_like(x))        # P_0 = 1
    torch.testing.assert_close(polys[1], x)                          # P_1 = x
    torch.testing.assert_close(polys[2], (3 * x**2 - 1) / 2)        # P_2
    torch.testing.assert_close(polys[3], (5 * x**3 - 3 * x) / 2)   # P_3

You may also want a test for the TensorDict path now that the implementation explicitly branches on isinstance(x, TensorDict).

Comment on lines +135 to +138
@FunctionSpec.register(
name="natten",
required_imports=("natten>=0.21.5",),
rank=0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Ambiguous use of h in NA2D benchmark labels

The label "small-h16-w16-h2-d16-k3" uses h to mean both height (h16) and number of heads (h2). This makes it hard to interpret benchmark reports at a glance. Consider a cleaner separator, e.g. nh for num-heads:

Suggested change
@FunctionSpec.register(
name="natten",
required_imports=("natten>=0.21.5",),
rank=0,
_BENCHMARK_CASES = (
("small-h16-w16-nh2-d16-k3", (16, 16), 2, 16, 3),
("medium-h32-w32-nh4-d16-k5", (32, 32), 4, 16, 5),
)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

],
device=self.device,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Trailing whitespace on blank line

A blank line with a trailing space was introduced here. It is unrelated to the rest of the PR and may fail linting checks.

Suggested change

→ should simply be an empty line with no trailing whitespace.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant