Cleanup/equivariant ops transformer natten#1540
Cleanup/equivariant ops transformer natten#1540loliverhennigh wants to merge 15 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR reorganises Highlights
Issues found
Important Files Changed
Reviews (1): Last reviewed commit: "Refactor equivariant ops package and mov..." | Re-trigger Greptile |
|
|
||
| from physicsnemo.nn.functional.equivariant_ops import ( | ||
| polar_and_dipole_basis, | ||
| smooth_log, | ||
| spherical_basis, | ||
| vector_project, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Missing test coverage for
legendre_polynomials
The new test file covers smooth_log, polar_and_dipole_basis, spherical_basis, and vector_project, but legendre_polynomials (one of the five ops in the new package) has no test at all. Given that the TensorDict branch was notably reworked (replacing torch.ones_like(x) with x.apply(torch.ones_like)), a basic correctness test would be valuable. For example, you could assert the recurrence relation values for a known input:
def test_legendre_polynomials_values() -> None:
"""P_0 through P_3 should match known values at x=0.5."""
from physicsnemo.nn.functional.equivariant_ops import legendre_polynomials
x = torch.tensor([0.5], dtype=torch.float32)
polys = legendre_polynomials(x, 4)
assert len(polys) == 4
torch.testing.assert_close(polys[0], torch.ones_like(x)) # P_0 = 1
torch.testing.assert_close(polys[1], x) # P_1 = x
torch.testing.assert_close(polys[2], (3 * x**2 - 1) / 2) # P_2
torch.testing.assert_close(polys[3], (5 * x**3 - 3 * x) / 2) # P_3You may also want a test for the TensorDict path now that the implementation explicitly branches on isinstance(x, TensorDict).
| @FunctionSpec.register( | ||
| name="natten", | ||
| required_imports=("natten>=0.21.5",), | ||
| rank=0, |
There was a problem hiding this comment.
Ambiguous use of
h in NA2D benchmark labels
The label "small-h16-w16-h2-d16-k3" uses h to mean both height (h16) and number of heads (h2). This makes it hard to interpret benchmark reports at a glance. Consider a cleaner separator, e.g. nh for num-heads:
| @FunctionSpec.register( | |
| name="natten", | |
| required_imports=("natten>=0.21.5",), | |
| rank=0, | |
| _BENCHMARK_CASES = ( | |
| ("small-h16-w16-nh2-d16-k3", (16, 16), 2, 16, 3), | |
| ("medium-h32-w32-nh4-d16-k5", (32, 32), 4, 16, 5), | |
| ) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| ], | ||
| device=self.device, | ||
| ) | ||
|
|
PhysicsNeMo Pull Request
Description
Cleanup of ops