diff --git a/benchmarks/physicsnemo/nn/functional/registry.py b/benchmarks/physicsnemo/nn/functional/registry.py index 0e959f90a2..7222b5d4c5 100644 --- a/benchmarks/physicsnemo/nn/functional/registry.py +++ b/benchmarks/physicsnemo/nn/functional/registry.py @@ -16,6 +16,13 @@ """Registry of FunctionSpec classes to benchmark with ASV.""" +from physicsnemo.nn.functional.equivariant_ops import ( + LegendrePolynomials, + PolarAndDipoleBasis, + SmoothLog, + SphericalBasis, + VectorProject, +) from physicsnemo.nn.functional.fourier_spectral import ( IRFFT, IRFFT2, @@ -32,6 +39,7 @@ DropPath, WeightFact, ) +from physicsnemo.nn.functional.transformer import NA1D, NA2D, NA3D # FunctionSpec classes listed here must implement ``make_inputs_forward`` for ASV. # ``make_inputs_backward`` is optional and only used when backward benchmarks run. @@ -39,6 +47,16 @@ # Regularization / parameterization. DropPath, WeightFact, + # Equivariant ops. + SmoothLog, + LegendrePolynomials, + VectorProject, + PolarAndDipoleBasis, + SphericalBasis, + # Neighborhood attention. + NA1D, + NA2D, + NA3D, # Neighbor queries. KNN, RadiusSearch, diff --git a/docs/api/nn/functionals/equivariant_ops.rst b/docs/api/nn/functionals/equivariant_ops.rst new file mode 100644 index 0000000000..12af175ff5 --- /dev/null +++ b/docs/api/nn/functionals/equivariant_ops.rst @@ -0,0 +1,12 @@ +Equivariant Ops Functionals +=========================== + +.. autofunction:: physicsnemo.nn.functional.smooth_log + +.. autofunction:: physicsnemo.nn.functional.legendre_polynomials + +.. autofunction:: physicsnemo.nn.functional.vector_project + +.. autofunction:: physicsnemo.nn.functional.polar_and_dipole_basis + +.. autofunction:: physicsnemo.nn.functional.spherical_basis diff --git a/docs/api/physicsnemo.nn.functionals.rst b/docs/api/physicsnemo.nn.functionals.rst index 6608736d45..e0b1dd637c 100644 --- a/docs/api/physicsnemo.nn.functionals.rst +++ b/docs/api/physicsnemo.nn.functionals.rst @@ -19,6 +19,7 @@ in the documentation for performance comparisons. :name: PhysicsNeMo Functionals nn/functionals/neighbors + nn/functionals/equivariant_ops nn/functionals/geometry nn/functionals/fourier_spectral nn/functionals/regularization_parameterization diff --git a/physicsnemo/datapipes/benchmarks/darcy.py b/physicsnemo/datapipes/benchmarks/darcy.py index 0a11ce19ce..72ea468c72 100644 --- a/physicsnemo/datapipes/benchmarks/darcy.py +++ b/physicsnemo/datapipes/benchmarks/darcy.py @@ -273,6 +273,7 @@ def generate_batch(self) -> None: ], device=self.device, ) + def __iter__(self) -> Tuple[Tensor, Tensor]: """ diff --git a/physicsnemo/domain_parallel/shard_utils/natten_patches.py b/physicsnemo/domain_parallel/shard_utils/natten_patches.py index 5c3edcf669..e67bc37394 100644 --- a/physicsnemo/domain_parallel/shard_utils/natten_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/natten_patches.py @@ -32,7 +32,7 @@ MissingShardPatch, UndeterminedShardingError, ) -from physicsnemo.nn.functional.natten import na1d, na2d, na3d +from physicsnemo.nn.functional.transformer.natten import na1d, na2d, na3d _natten = OptionalImport("natten") _raw_func_map = { @@ -221,9 +221,9 @@ def _natten_wrapper( r"""Shared wrapper for natten functions to support sharded tensors. Registered with :meth:`ShardTensor.register_function_handler` so that calls - to :func:`~physicsnemo.nn.functional.natten.na1d`, - :func:`~physicsnemo.nn.functional.natten.na2d`, or - :func:`~physicsnemo.nn.functional.natten.na3d` automatically route through + to :func:`~physicsnemo.nn.functional.transformer.natten.na1d`, + :func:`~physicsnemo.nn.functional.transformer.natten.na2d`, or + :func:`~physicsnemo.nn.functional.transformer.natten.na3d` automatically route through this handler when any argument is a :class:`ShardTensor`. Parameters diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py index fe3b64cb46..cf0a822c78 100644 --- a/physicsnemo/nn/functional/__init__.py +++ b/physicsnemo/nn/functional/__init__.py @@ -24,9 +24,9 @@ from .fourier_spectral import imag, irfft, irfft2, real, rfft, rfft2, view_as_complex from .geometry import signed_distance_field from .interpolation import interpolation -from .natten import na1d, na2d, na3d from .neighbors import knn, radius_search from .regularization_parameterization import drop_path, weight_fact +from .transformer import na1d, na2d, na3d __all__ = [ "irfft", diff --git a/physicsnemo/nn/functional/equivariant_ops.py b/physicsnemo/nn/functional/equivariant_ops.py deleted file mode 100644 index 1b8f16369e..0000000000 --- a/physicsnemo/nn/functional/equivariant_ops.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import overload - -import torch -from jaxtyping import Float -from tensordict import TensorDict - - -@overload -def smooth_log(x: Float[torch.Tensor, "..."]) -> Float[torch.Tensor, "..."]: ... - - -@overload -def smooth_log(x: TensorDict) -> TensorDict: ... - - -def smooth_log( - x: Float[torch.Tensor, "..."] | TensorDict, -) -> Float[torch.Tensor, "..."] | TensorDict: - r"""Performs an elementwise operation on ``x`` with the following properties: - - - ``f(x) -> 0`` as ``x -> 0`` - - ``f(x) -> ln(x)`` as ``x -> infinity`` - - ``f(x)`` is smooth (``C_infty`` continuous) for all ``x >= 0`` - - ``f(x)`` is monotonically increasing for ``x > 0`` - - Has "nicely-behaved" higher-order derivatives for all ``x >= 0`` - - Function is "intended" to be used with the domain ``x`` in ``[0, inf)``; technically it - remains well-defined for ``(-1, inf)``. - - Parameters - ---------- - x : Float[torch.Tensor, "..."] or TensorDict - Input tensor or TensorDict with non-negative values. - - Returns - ------- - Float[torch.Tensor, "..."] or TensorDict - Result of the smooth log operation, same type and shape as ``x``. - """ - return (-x).expm1().neg() * x.log1p() - - -@overload -def legendre_polynomials( - x: Float[torch.Tensor, "..."], n: int -) -> list[Float[torch.Tensor, "..."]]: ... - - -@overload -def legendre_polynomials(x: TensorDict, n: int) -> list[TensorDict]: ... - - -def legendre_polynomials( - x: Float[torch.Tensor, "..."] | TensorDict, n: int -) -> list[Float[torch.Tensor, "..."] | TensorDict]: - r"""Computes the first ``n`` Legendre polynomials evaluated at ``x``. - - Acts elementwise on all entries of ``x``. - - Uses the recurrence relation for efficiency:: - - P_0(x) = 1 - P_1(x) = x - (n+1)*P_{n+1}(x) = (2n+1)*x*P_n(x) - n*P_{n-1}(x) - - Parameters - ---------- - x : Float[torch.Tensor, "..."] or TensorDict - Input tensor of any shape. - n : int - Number of Legendre polynomials to compute (must be >= 0). - Returns ``P_0`` through ``P_{n-1}``. - - Returns - ------- - list[Float[torch.Tensor, "..."] or TensorDict] - List of ``n`` tensors, where the i-th tensor is ``P_i(x)`` with the same - shape as ``x``. Returns an empty list when ``n = 0``. - - Raises - ------ - ValueError - If ``n`` is negative. - - Examples - -------- - >>> x = torch.tensor([0.0, 0.5, 1.0]) - >>> polys = legendre_polynomials(x, 4) - >>> # polys[0] is P_0(x) = 1 - >>> # polys[1] is P_1(x) = x - >>> # polys[2] is P_2(x) = (3x^2 - 1)/2 - >>> # polys[3] is P_3(x) = (5x^3 - 3x)/2 - """ - if n < 0: - raise ValueError(f"n must be non-negative, got {n=}") - if n == 0: - return [] - - ### Seed with the two base cases; slice to handle n=1 - polynomials: list[Float[torch.Tensor, "..."] | TensorDict] = [ - torch.ones_like(x), # P_0(x) = 1 # type: ignore[invalid-argument-type] - x, # P_1(x) = x - ][:n] - - ### Recurrence relation for P_2 and beyond - for i in range(2, n): - # (i)*P_i(x) = (2i-1)*x*P_{i-1}(x) - (i-1)*P_{i-2}(x) - p_i = ((2 * i - 1) * x * polynomials[i - 1] - (i - 1) * polynomials[i - 2]) / i # type: ignore[operator] - polynomials.append(p_i) - - return polynomials - - -def vector_project( - v: Float[torch.Tensor, "... n_dims"], - n_hat: Float[torch.Tensor, "... n_dims"], -) -> Float[torch.Tensor, "... n_dims"]: - r"""Projects vector ``v`` onto the plane orthogonal to unit vector ``n_hat``. - - Uses the Gram-Schmidt orthogonalization: - - .. math:: - - v_{\perp} = v - (v \cdot \hat{n}) \hat{n} - - Parameters - ---------- - v : Float[torch.Tensor, "... n_dims"] - Input vectors to project, with shape :math:`(*, D)`. - n_hat : Float[torch.Tensor, "... n_dims"] - Unit normal vectors defining the projection plane, with shape :math:`(*, D)`. - - Returns - ------- - Float[torch.Tensor, "... n_dims"] - Projected vectors with shape :math:`(*, D)`. - """ - # Below are two equivalent implementations; on my machine the second is faster, but - # in general einsums can be optimized more due to limiting intermediate allocations. - # return v - torch.einsum("...i,...i->...", v, n_hat)[..., None] * n_hat - return v - (v * n_hat).sum(dim=-1, keepdim=True) * n_hat - - -def polar_and_dipole_basis( - r_hat: Float[torch.Tensor, "... 2"], - n_hat: Float[torch.Tensor, "... 2"], - normalize_basis_vectors: bool = True, -) -> tuple[ - Float[torch.Tensor, "... 2"], - Float[torch.Tensor, "... 2"], - Float[torch.Tensor, "... 2"], -]: - r"""Computes a local vector basis for 2D vectors that is rotation-invariant - w.r.t. ``n_hat``. - - Notably, this isn't a true vector basis, as it has 3 vectors, not the - required 2. The basis is essentially a combination of a polar basis (r, - theta) and an additional dipole-like direction (kappa) for the third vector. - The axis for the dipole direction is set by ``n_hat``. - - Parameters - ---------- - r_hat : Float[torch.Tensor, "... 2"] - Unit direction vectors with shape :math:`(*, 2)`, assumed to be - normalized. - n_hat : Float[torch.Tensor, "... 2"] - Axis vectors with shape :math:`(*, 2)`, assumed to be unit vectors. - normalize_basis_vectors : bool, optional, default=True - Whether to normalize ``e_kappa`` to be unit length - (``e_r`` and ``e_theta`` are always unit). If ``False``, ``e_kappa`` is essentially - multiplied by ``sin(theta)``. This gives the sometimes-useful property that - ``e_kappa`` smoothly changes on the surface of a unit circle. - - Returns - ------- - tuple[Float[torch.Tensor, "... 2"], Float[torch.Tensor, "... 2"], Float[torch.Tensor, "... 2"]] - A tuple of 3 vectors, each of shape :math:`(*, 2)`: - - - ``e_r``: The radial direction, aligned with ``r_hat``. This corresponds to the - influence field direction associated with a point source (i.e., - outwards from the origin). - - - ``e_theta``: The tangential direction, orthogonal to ``e_r``. This corresponds - to the vortex field direction associated with a point vortex (i.e., - the direction of circulation around the source). - - - ``e_kappa``: A dipole-like direction. Notably, this is orthogonal to ``e_r``, - but exactly parallel to ``e_theta`` - if you need to construct a full-rank - basis, this is the one to drop. - - Note - ---- - Edge Cases (even if ``normalize_basis_vectors`` is ``True``): - - - If ``r_hat`` is a zero vector, all basis vectors will be zero vectors. - - If ``r_hat`` is aligned with ``n_hat``, ``e_kappa`` will be a zero vector. - """ - # Validate input shapes - if not torch.compiler.is_compiling(): - shape_validations = { - "n_hat": (n_hat.shape[-1], 2), - "r_hat": (r_hat.shape[-1], 2), - } - for name, (actual, expected) in shape_validations.items(): - if actual != expected: - raise ValueError( - f"Expected {name} to have shape (..., {expected}), got shape {actual}." - ) - - # e_r is simply the input unit vector - e_r = r_hat - - # Compute e_theta, the basis vector in the tangential direction - e_theta = torch.stack([-r_hat[..., 1], r_hat[..., 0]], dim=-1) - - # Compute e_kappa, the basis vector in the dipole direction - e_kappa = vector_project(-n_hat, r_hat) - r_hat_is_zero = torch.all(r_hat == 0.0, dim=-1) - e_kappa[r_hat_is_zero] = 0.0 - if normalize_basis_vectors: - norm = torch.linalg.norm(e_kappa, dim=-1) - e_kappa = e_kappa / norm[..., None] - e_kappa[norm == 0] = 0.0 # Overwrites any NaNs with zero vectors - - return e_r, e_theta, e_kappa - - -def spherical_basis( - r_hat: Float[torch.Tensor, "... 3"], - n_hat: Float[torch.Tensor, "... 3"], - normalize_basis_vectors: bool = True, -) -> tuple[ - Float[torch.Tensor, "... 3"], - Float[torch.Tensor, "... 3"], - Float[torch.Tensor, "... 3"], -]: - r"""Computes a local vector basis for 3D vectors that is rotation-invariant - w.r.t. ``n_hat``. - - The basis is essentially a spherical coordinate system, with the axis set by - ``n_hat``. - - Parameters - ---------- - r_hat : Float[torch.Tensor, "... 3"] - Unit direction vectors with shape :math:`(*, 3)`, assumed to be - normalized. - n_hat : Float[torch.Tensor, "... 3"] - Axis vectors with shape :math:`(*, 3)`, assumed to be unit vectors. - normalize_basis_vectors : bool, optional, default=True - Whether to normalize ``e_theta`` and ``e_phi`` to unit - length (``e_r`` is always unit). If ``False``, ``e_theta`` and ``e_phi`` are - essentially multiplied by ``sin(theta)``. This gives the - sometimes-useful property that the basis vectors smoothly change on - the surface of a unit sphere. (If ``e_theta`` and ``e_phi`` are normalized, - then there is provably no possible way for these to smoothly vary on - the surface of a sphere, as shown by the Hairy Ball theorem.) - - Returns - ------- - tuple[Float[torch.Tensor, "... 3"], Float[torch.Tensor, "... 3"], Float[torch.Tensor, "... 3"]] - A tuple of 3 vectors, each of shape :math:`(*, 3)`: - - - ``e_r``: The radial direction, pointing outward from the origin. This - corresponds to the influence field direction associated with a point - source. - - - ``e_theta``: The polar direction, orthogonal to both ``e_r`` and ``n_hat``. This - corresponds to the meridional direction in spherical coordinates. - - - ``e_phi``: The azimuthal direction, orthogonal to both ``e_r`` and ``e_theta``. - This corresponds to the circumferential direction in spherical - coordinates. - - Note - ---- - Edge Cases (even if ``normalize_basis_vectors`` is ``True``): - - - If ``r_hat`` is a zero vector, all basis vectors will be zero vectors. - - If ``r_hat`` is aligned with ``n_hat``, ``e_theta`` and ``e_phi`` will be zero vectors. - """ - # Validate input shapes - if not torch.compiler.is_compiling(): - shape_validations = { - "n_hat": (n_hat.shape[-1], 3), - "r_hat": (r_hat.shape[-1], 3), - } - for name, (actual, expected) in shape_validations.items(): - if actual != expected: - raise ValueError( - f"Expected {name} to have shape (..., {expected}), got shape {actual}." - ) - - # e_r is simply the input unit vector - e_r = r_hat - - # Compute e_theta, the basis vector in the polar direction - e_theta = vector_project(-n_hat, r_hat) - r_hat_is_zero = torch.all(r_hat == 0.0, dim=-1) - e_theta[r_hat_is_zero] = 0.0 - if normalize_basis_vectors: - norm = torch.linalg.norm(e_theta, dim=-1) - e_theta = e_theta / norm[..., None] - e_theta[norm == 0] = 0.0 # Overwrites any NaNs with zero vectors - - # Compute e_phi, the basis vector in the azimuthal direction - e_phi = torch.cross(e_r, e_theta, dim=-1) - - return e_r, e_theta, e_phi diff --git a/physicsnemo/nn/functional/equivariant_ops/__init__.py b/physicsnemo/nn/functional/equivariant_ops/__init__.py new file mode 100644 index 0000000000..f9f2023610 --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .legendre_polynomials import LegendrePolynomials, legendre_polynomials +from .polar_and_dipole_basis import PolarAndDipoleBasis, polar_and_dipole_basis +from .smooth_log import SmoothLog, smooth_log +from .spherical_basis import SphericalBasis, spherical_basis +from .vector_project import VectorProject, vector_project + +__all__ = [ + "SmoothLog", + "LegendrePolynomials", + "VectorProject", + "PolarAndDipoleBasis", + "SphericalBasis", + "smooth_log", + "legendre_polynomials", + "vector_project", + "polar_and_dipole_basis", + "spherical_basis", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/_common.py b/physicsnemo/nn/functional/equivariant_ops/_common.py new file mode 100644 index 0000000000..41a8de4c69 --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/_common.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +import torch +from jaxtyping import Float +from tensordict import TensorDict + +TensorLike: TypeAlias = Float[torch.Tensor, "..."] | TensorDict + + +def _validate_last_dim(x: torch.Tensor, *, dim: int, name: str) -> None: + """Validate the final dimension for basis-vector inputs.""" + if torch.compiler.is_compiling(): + return + if x.shape[-1] != dim: + raise ValueError( + f"Expected {name} to have shape (..., {dim}), got shape {tuple(x.shape)}." + ) + + +def _safe_normalize( + x: Float[torch.Tensor, "... n_dims"], +) -> Float[torch.Tensor, "... n_dims"]: + """Normalize vectors and keep exact zeros for zero-length inputs.""" + norm = torch.linalg.norm(x, dim=-1, keepdim=True) + normalized = x / norm.clamp_min(torch.finfo(x.dtype).eps) + return torch.where(norm > 0, normalized, torch.zeros_like(x)) + + +def _make_tensordict_input(num_elements: int, device: torch.device) -> TensorDict: + """Create a simple TensorDict input used for benchmarks.""" + return TensorDict( + { + "a": torch.rand(num_elements, device=device), + "b": torch.rand(num_elements, device=device), + }, + batch_size=[num_elements], + ) + + +def _vector_project_impl( + v: Float[torch.Tensor, "... n_dims"], + n_hat: Float[torch.Tensor, "... n_dims"], +) -> Float[torch.Tensor, "... n_dims"]: + """Project vectors onto the plane orthogonal to ``n_hat``.""" + return v - (v * n_hat).sum(dim=-1, keepdim=True) * n_hat + + +__all__ = [ + "TensorLike", + "_validate_last_dim", + "_safe_normalize", + "_make_tensordict_input", + "_vector_project_impl", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/legendre_polynomials.py b/physicsnemo/nn/functional/equivariant_ops/legendre_polynomials.py new file mode 100644 index 0000000000..967919542d --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/legendre_polynomials.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from jaxtyping import Float +from tensordict import TensorDict + +from physicsnemo.core.function_spec import FunctionSpec + +from ._common import TensorLike, _make_tensordict_input + + +def _legendre_polynomials_impl(x: TensorLike, n: int) -> list[TensorLike]: + if n < 0: + raise ValueError(f"n must be non-negative, got {n=}") + if n == 0: + return [] + + if isinstance(x, TensorDict): + polynomials: list[TensorDict] = [x.apply(torch.ones_like), x][:n] + for i in range(2, n): + p_i = ( + (2 * i - 1) * x * polynomials[i - 1] - (i - 1) * polynomials[i - 2] + ) / i + polynomials.append(p_i) + return polynomials + + polynomials_t: list[Float[torch.Tensor, "..."]] = [torch.ones_like(x), x][:n] + for i in range(2, n): + p_i = ( + (2 * i - 1) * x * polynomials_t[i - 1] - (i - 1) * polynomials_t[i - 2] + ) / i + polynomials_t.append(p_i) + + return polynomials_t + + +class LegendrePolynomials(FunctionSpec): + r"""Compute Legendre polynomials ``P_0`` through ``P_{n-1}`` at ``x``. + + Parameters + ---------- + x : Float[torch.Tensor, "..."] or TensorDict + Input tensor-like values. + n : int + Number of Legendre polynomials to evaluate. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward(x: TensorLike, n: int) -> list[TensorLike]: + return _legendre_polynomials_impl(x, n) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + yield ("tensor-n1024-p8", (torch.rand(1024, device=device), 8), {}) + yield ( + "tensordict-n512-p6", + (_make_tensordict_input(512, device), 6), + {}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + yield ( + "tensor-grad-n1024-p8", + (torch.rand(1024, device=device, requires_grad=True), 8), + {}, + ) + + +legendre_polynomials = LegendrePolynomials.make_function("legendre_polynomials") + + +__all__ = [ + "LegendrePolynomials", + "legendre_polynomials", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/polar_and_dipole_basis.py b/physicsnemo/nn/functional/equivariant_ops/polar_and_dipole_basis.py new file mode 100644 index 0000000000..eb0cae4346 --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/polar_and_dipole_basis.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +from jaxtyping import Float + +from physicsnemo.core.function_spec import FunctionSpec + +from ._common import _safe_normalize, _validate_last_dim, _vector_project_impl + + +def _polar_and_dipole_basis_impl( + r_hat: Float[torch.Tensor, "... 2"], + n_hat: Float[torch.Tensor, "... 2"], + normalize_basis_vectors: bool, +) -> tuple[ + Float[torch.Tensor, "... 2"], + Float[torch.Tensor, "... 2"], + Float[torch.Tensor, "... 2"], +]: + _validate_last_dim(r_hat, dim=2, name="r_hat") + _validate_last_dim(n_hat, dim=2, name="n_hat") + + e_r = r_hat + e_theta = torch.stack((-r_hat[..., 1], r_hat[..., 0]), dim=-1) + + e_kappa = _vector_project_impl(-n_hat, r_hat) + r_hat_is_zero = torch.all(r_hat == 0.0, dim=-1, keepdim=True) + e_kappa = torch.where(r_hat_is_zero, torch.zeros_like(e_kappa), e_kappa) + + if normalize_basis_vectors: + e_kappa = _safe_normalize(e_kappa) + + return e_r, e_theta, e_kappa + + +class PolarAndDipoleBasis(FunctionSpec): + r"""Compute a local 2D basis aligned with ``r_hat`` and conditioned on ``n_hat``. + + Parameters + ---------- + r_hat : Float[torch.Tensor, "... 2"] + Unit direction vectors. + n_hat : Float[torch.Tensor, "... 2"] + Axis vectors. + normalize_basis_vectors : bool, optional + Whether to normalize ``e_kappa`` to unit length. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward( + r_hat: Float[torch.Tensor, "... 2"], + n_hat: Float[torch.Tensor, "... 2"], + normalize_basis_vectors: bool = True, + ) -> tuple[ + Float[torch.Tensor, "... 2"], + Float[torch.Tensor, "... 2"], + Float[torch.Tensor, "... 2"], + ]: + return _polar_and_dipole_basis_impl(r_hat, n_hat, normalize_basis_vectors) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + theta = torch.linspace(0.0, 2.0 * math.pi, 2048, device=device) + r_hat = torch.stack((torch.cos(theta), torch.sin(theta)), dim=-1) + n_hat = torch.tensor([1.0, 0.0], device=device).repeat(2048, 1) + yield ( + "basis-n2048", + (r_hat, n_hat), + {"normalize_basis_vectors": True}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + theta = torch.linspace(0.0, 2.0 * math.pi, 2048, device=device) + r_hat = torch.stack((torch.cos(theta), torch.sin(theta)), dim=-1) + n_hat = torch.tensor([1.0, 0.0], device=device).repeat(2048, 1) + yield ( + "basis-grad-n2048", + ( + r_hat.requires_grad_(True), + n_hat.requires_grad_(True), + ), + {"normalize_basis_vectors": True}, + ) + + +polar_and_dipole_basis = PolarAndDipoleBasis.make_function("polar_and_dipole_basis") + + +__all__ = [ + "PolarAndDipoleBasis", + "polar_and_dipole_basis", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/smooth_log.py b/physicsnemo/nn/functional/equivariant_ops/smooth_log.py new file mode 100644 index 0000000000..3d7c854450 --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/smooth_log.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._common import TensorLike, _make_tensordict_input + + +def _smooth_log_impl(x: TensorLike) -> TensorLike: + return (-x).expm1().neg() * x.log1p() + + +class SmoothLog(FunctionSpec): + r"""Apply a smooth logarithm-like map elementwise. + + Parameters + ---------- + x : Float[torch.Tensor, "..."] or TensorDict + Input tensor-like values. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward(x: TensorLike) -> TensorLike: + return _smooth_log_impl(x) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + yield ("tensor-n1024", (torch.rand(1024, device=device),), {}) + yield ("tensordict-n512", (_make_tensordict_input(512, device),), {}) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + yield ( + "tensor-grad-n1024", + (torch.rand(1024, device=device, requires_grad=True),), + {}, + ) + + +smooth_log = SmoothLog.make_function("smooth_log") + + +__all__ = [ + "SmoothLog", + "smooth_log", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/spherical_basis.py b/physicsnemo/nn/functional/equivariant_ops/spherical_basis.py new file mode 100644 index 0000000000..8bace15f9c --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/spherical_basis.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from jaxtyping import Float + +from physicsnemo.core.function_spec import FunctionSpec + +from ._common import _safe_normalize, _validate_last_dim, _vector_project_impl + + +def _spherical_basis_impl( + r_hat: Float[torch.Tensor, "... 3"], + n_hat: Float[torch.Tensor, "... 3"], + normalize_basis_vectors: bool, +) -> tuple[ + Float[torch.Tensor, "... 3"], + Float[torch.Tensor, "... 3"], + Float[torch.Tensor, "... 3"], +]: + _validate_last_dim(r_hat, dim=3, name="r_hat") + _validate_last_dim(n_hat, dim=3, name="n_hat") + + e_r = r_hat + + e_theta = _vector_project_impl(-n_hat, r_hat) + r_hat_is_zero = torch.all(r_hat == 0.0, dim=-1, keepdim=True) + e_theta = torch.where(r_hat_is_zero, torch.zeros_like(e_theta), e_theta) + + if normalize_basis_vectors: + e_theta = _safe_normalize(e_theta) + + e_phi = torch.cross(e_r, e_theta, dim=-1) + return e_r, e_theta, e_phi + + +class SphericalBasis(FunctionSpec): + r"""Compute a local spherical-like 3D basis aligned with ``n_hat``. + + Parameters + ---------- + r_hat : Float[torch.Tensor, "... 3"] + Unit radial direction vectors. + n_hat : Float[torch.Tensor, "... 3"] + Axis vectors. + normalize_basis_vectors : bool, optional + Whether to normalize ``e_theta`` and ``e_phi``. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward( + r_hat: Float[torch.Tensor, "... 3"], + n_hat: Float[torch.Tensor, "... 3"], + normalize_basis_vectors: bool = True, + ) -> tuple[ + Float[torch.Tensor, "... 3"], + Float[torch.Tensor, "... 3"], + Float[torch.Tensor, "... 3"], + ]: + return _spherical_basis_impl(r_hat, n_hat, normalize_basis_vectors) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + r_hat = torch.randn(1024, 3, device=device) + r_hat = _safe_normalize(r_hat) + n_hat = torch.tensor([0.0, 0.0, 1.0], device=device).repeat(1024, 1) + yield ( + "basis-n1024", + (r_hat, n_hat), + {"normalize_basis_vectors": True}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + r_hat = torch.randn(1024, 3, device=device) + r_hat = _safe_normalize(r_hat) + n_hat = torch.tensor([0.0, 0.0, 1.0], device=device).repeat(1024, 1) + yield ( + "basis-grad-n1024", + ( + r_hat.requires_grad_(True), + n_hat.requires_grad_(True), + ), + {"normalize_basis_vectors": True}, + ) + + +spherical_basis = SphericalBasis.make_function("spherical_basis") + + +__all__ = [ + "SphericalBasis", + "spherical_basis", +] diff --git a/physicsnemo/nn/functional/equivariant_ops/vector_project.py b/physicsnemo/nn/functional/equivariant_ops/vector_project.py new file mode 100644 index 0000000000..f4ddea43bc --- /dev/null +++ b/physicsnemo/nn/functional/equivariant_ops/vector_project.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from jaxtyping import Float + +from physicsnemo.core.function_spec import FunctionSpec + +from ._common import _vector_project_impl + + +class VectorProject(FunctionSpec): + r"""Project vectors onto the plane orthogonal to a normal vector. + + Parameters + ---------- + v : Float[torch.Tensor, "... n_dims"] + Input vectors to project. + n_hat : Float[torch.Tensor, "... n_dims"] + Unit normal vectors defining the projection plane. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward( + v: Float[torch.Tensor, "... n_dims"], + n_hat: Float[torch.Tensor, "... n_dims"], + ) -> Float[torch.Tensor, "... n_dims"]: + return _vector_project_impl(v, n_hat) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + v = torch.randn(2048, 3, device=device) + n_hat = torch.randn(2048, 3, device=device) + n_hat = n_hat / torch.linalg.norm(n_hat, dim=-1, keepdim=True) + yield ("vectors-n2048-d3", (v, n_hat), {}) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + v = torch.randn(2048, 3, device=device, requires_grad=True) + n_hat = torch.randn(2048, 3, device=device) + n_hat = n_hat / torch.linalg.norm(n_hat, dim=-1, keepdim=True) + n_hat = n_hat.requires_grad_(True) + yield ("vectors-grad-n2048-d3", (v, n_hat), {}) + + +vector_project = VectorProject.make_function("vector_project") + + +__all__ = [ + "VectorProject", + "vector_project", +] diff --git a/physicsnemo/nn/functional/natten.py b/physicsnemo/nn/functional/natten.py deleted file mode 100644 index 225acf33d3..0000000000 --- a/physicsnemo/nn/functional/natten.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any - -import torch -from torch.overrides import handle_torch_function, has_torch_function - -from physicsnemo.core.version_check import OptionalImport - -_natten = OptionalImport("natten") - - -def na1d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 1D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na1d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, L, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, L, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, L, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na1d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na1d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na1d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -def na2d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 2D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na2d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, H, W, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, H, W, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, H, W, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na2d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na2d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na2d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -def na3d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 3D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na3d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na3d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na3d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na3d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -__all__ = ["na1d", "na2d", "na3d"] diff --git a/physicsnemo/nn/functional/transformer/__init__.py b/physicsnemo/nn/functional/transformer/__init__.py new file mode 100644 index 0000000000..67e407a715 --- /dev/null +++ b/physicsnemo/nn/functional/transformer/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .natten import NA1D, NA2D, NA3D, na1d, na2d, na3d + +__all__ = [ + "NA1D", + "NA2D", + "NA3D", + "na1d", + "na2d", + "na3d", +] diff --git a/physicsnemo/nn/functional/transformer/natten.py b/physicsnemo/nn/functional/transformer/natten.py new file mode 100644 index 0000000000..e8955d47ed --- /dev/null +++ b/physicsnemo/nn/functional/transformer/natten.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Callable + +import torch +from torch.overrides import handle_torch_function, has_torch_function + +from physicsnemo.core.function_spec import FunctionSpec +from physicsnemo.core.version_check import OptionalImport + +_natten = OptionalImport("natten") + + +class _NeighborhoodAttentionBase(FunctionSpec): + """Shared FunctionSpec behavior for neighborhood attention wrappers.""" + + _public_function: Callable[..., torch.Tensor] | None = None + _BENCHMARK_CASES: tuple[tuple[str, tuple[int, ...], int, int, int], ...] = () + + @classmethod + def dispatch( + cls, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + """Dispatch with ``__torch_function__`` interception for tensor subclasses.""" + implementation = kwargs.pop("implementation", None) + + if has_torch_function((q, k, v)): + if cls._public_function is None: + raise RuntimeError( + f"{cls.__name__} public function is not configured for dispatch" + ) + return handle_torch_function( + cls._public_function, + (q, k, v), + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + + if implementation is not None: + kwargs["implementation"] = implementation + + return super().dispatch( + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + for label, spatial_shape, num_heads, head_dim, kernel_size in cls._BENCHMARK_CASES: + shape = (1, *spatial_shape, num_heads, head_dim) + q = torch.randn(shape, device=device) + k = torch.randn_like(q) + v = torch.randn_like(q) + yield ( + label, + (q, k, v, kernel_size), + {"dilation": 1}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + device = torch.device(device) + for label, spatial_shape, num_heads, head_dim, kernel_size in cls._BENCHMARK_CASES: + shape = (1, *spatial_shape, num_heads, head_dim) + q = torch.randn(shape, device=device, requires_grad=True) + k = torch.randn(shape, device=device, requires_grad=True) + v = torch.randn(shape, device=device, requires_grad=True) + yield ( + f"{label}-grad", + (q, k, v, kernel_size), + {"dilation": 1}, + ) + + +class NA1D(_NeighborhoodAttentionBase): + r"""Compute 1D neighborhood attention. + + This is a wrapper around :func:`natten.functional.na1d` with support for + ``__torch_function__`` dispatch, which enables tensor subclasses (for + example ``ShardTensor``) to intercept the call. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, L, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, L, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, L, \text{heads}, D)`. + kernel_size : int + Attention kernel size. + dilation : int, optional + Kernel dilation factor. + implementation : {"natten"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + _BENCHMARK_CASES = ( + ("small-l64-h2-d16-k3", (64,), 2, 16, 3), + ("medium-l128-h4-d16-k5", (128,), 4, 16, 5), + ) + + @FunctionSpec.register( + name="natten", + required_imports=("natten>=0.21.5",), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + return _natten.functional.na1d( + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + + +class NA2D(_NeighborhoodAttentionBase): + r"""Compute 2D neighborhood attention. + + This is a wrapper around :func:`natten.functional.na2d` with support for + ``__torch_function__`` dispatch, which enables tensor subclasses (for + example ``ShardTensor``) to intercept the call. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, H, W, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, H, W, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, H, W, \text{heads}, D)`. + kernel_size : int + Attention kernel size. + dilation : int, optional + Kernel dilation factor. + implementation : {"natten"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + _BENCHMARK_CASES = ( + ("small-h16-w16-h2-d16-k3", (16, 16), 2, 16, 3), + ("medium-h32-w32-h4-d16-k5", (32, 32), 4, 16, 5), + ) + + @FunctionSpec.register( + name="natten", + required_imports=("natten>=0.21.5",), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + return _natten.functional.na2d( + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + + +class NA3D(_NeighborhoodAttentionBase): + r"""Compute 3D neighborhood attention. + + This is a wrapper around :func:`natten.functional.na3d` with support for + ``__torch_function__`` dispatch, which enables tensor subclasses (for + example ``ShardTensor``) to intercept the call. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + kernel_size : int + Attention kernel size. + dilation : int, optional + Kernel dilation factor. + implementation : {"natten"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + """ + + _BENCHMARK_CASES = ( + ("small-x8-y8-z8-h2-d8-k3", (8, 8, 8), 2, 8, 3), + ("medium-x12-y12-z12-h2-d8-k5", (12, 12, 12), 2, 8, 5), + ) + + @FunctionSpec.register( + name="natten", + required_imports=("natten>=0.21.5",), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + return _natten.functional.na3d( + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + + +na1d = NA1D.make_function("na1d") +NA1D._public_function = na1d + +na2d = NA2D.make_function("na2d") +NA2D._public_function = na2d + +na3d = NA3D.make_function("na3d") +NA3D._public_function = na3d + + +__all__ = [ + "NA1D", + "NA2D", + "NA3D", + "na1d", + "na2d", + "na3d", +] diff --git a/physicsnemo/nn/module/dit_layers.py b/physicsnemo/nn/module/dit_layers.py index c069c04025..e9e30b7f21 100644 --- a/physicsnemo/nn/module/dit_layers.py +++ b/physicsnemo/nn/module/dit_layers.py @@ -27,7 +27,7 @@ from physicsnemo.core import Module from physicsnemo.core.version_check import OptionalImport, check_version_spec -from physicsnemo.nn.functional.natten import na2d as _na2d_func +from physicsnemo.nn.functional.transformer.natten import na2d as _na2d_func from physicsnemo.nn.module.drop import DropPath from physicsnemo.nn.module.hpx.tokenizer import ( HEALPixPatchDetokenizer, diff --git a/test/domain_parallel/ops/test_natten.py b/test/domain_parallel/ops/test_natten.py index f54fe2fcce..4f7d6108eb 100644 --- a/test/domain_parallel/ops/test_natten.py +++ b/test/domain_parallel/ops/test_natten.py @@ -16,9 +16,9 @@ r"""Tests for 1D, 2D, and 3D neighborhood attention on sharded tensors. -This module validates the correctness of :func:`physicsnemo.nn.functional.natten.na1d`, -:func:`physicsnemo.nn.functional.natten.na2d`, and -:func:`physicsnemo.nn.functional.natten.na3d` over sharded inputs, covering both +This module validates the correctness of :func:`physicsnemo.nn.functional.transformer.natten.na1d`, +:func:`physicsnemo.nn.functional.transformer.natten.na2d`, and +:func:`physicsnemo.nn.functional.transformer.natten.na3d` over sharded inputs, covering both forward and backward passes. Sharding is performed over spatial dimensions which correspond to ``Shard(1)``, ``Shard(2)``, etc. in the natten heads-last layout. """ @@ -115,7 +115,7 @@ class TestNA1D: def test_na1d_shard_l( self, distributed_mesh, L, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d _run_natten_check( na1d, @@ -148,7 +148,7 @@ class TestNA2D: def test_na2d_shard_h( self, distributed_mesh, H, W, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d _run_natten_check( na2d, @@ -171,7 +171,7 @@ def test_na2d_shard_h( def test_na2d_shard_w( self, distributed_mesh, H, W, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d _run_natten_check( na2d, @@ -205,7 +205,7 @@ class TestNA3D: def test_na3d_shard_x( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d _run_natten_check( na3d, @@ -229,7 +229,7 @@ def test_na3d_shard_x( def test_na3d_shard_y( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d _run_natten_check( na3d, @@ -253,7 +253,7 @@ def test_na3d_shard_y( def test_na3d_shard_z( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d _run_natten_check( na3d, diff --git a/test/nn/functional/test_equivariant_ops.py b/test/nn/functional/test_equivariant_ops.py new file mode 100644 index 0000000000..61087ad1fa --- /dev/null +++ b/test/nn/functional/test_equivariant_ops.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from tensordict import TensorDict + +from physicsnemo.nn.functional.equivariant_ops import ( + polar_and_dipole_basis, + smooth_log, + spherical_basis, + vector_project, +) + + +def test_smooth_log_tensordict_matches_leafwise() -> None: + """TensorDict outputs should match applying smooth_log to each tensor leaf.""" + x = TensorDict( + { + "a": torch.tensor([0.0, 1.0, 2.0], dtype=torch.float32), + "b": torch.tensor([3.0, 4.0, 5.0], dtype=torch.float32), + }, + batch_size=[3], + ) + + y = smooth_log(x) + + assert isinstance(y, TensorDict) + torch.testing.assert_close(y["a"], smooth_log(x["a"])) + torch.testing.assert_close(y["b"], smooth_log(x["b"])) + + +@pytest.mark.parametrize("normalize_basis_vectors", [False, True]) +def test_polar_basis_zero_vectors_stay_zero( + normalize_basis_vectors: bool, +) -> None: + """Zero 2D direction vectors should produce finite zero basis vectors.""" + r_hat = torch.zeros(5, 2, dtype=torch.float32) + n_hat = torch.tensor([1.0, 0.0], dtype=torch.float32).repeat(5, 1) + + e_r, e_theta, e_kappa = polar_and_dipole_basis( + r_hat, + n_hat, + normalize_basis_vectors=normalize_basis_vectors, + ) + + torch.testing.assert_close(e_r, torch.zeros_like(e_r)) + torch.testing.assert_close(e_theta, torch.zeros_like(e_theta)) + torch.testing.assert_close(e_kappa, torch.zeros_like(e_kappa)) + assert not torch.isnan(e_kappa).any() + + +@pytest.mark.parametrize("normalize_basis_vectors", [False, True]) +def test_spherical_basis_zero_vectors_stay_zero( + normalize_basis_vectors: bool, +) -> None: + """Zero 3D direction vectors should produce finite zero basis vectors.""" + r_hat = torch.zeros(4, 3, dtype=torch.float32) + n_hat = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32).repeat(4, 1) + + e_r, e_theta, e_phi = spherical_basis( + r_hat, + n_hat, + normalize_basis_vectors=normalize_basis_vectors, + ) + + torch.testing.assert_close(e_r, torch.zeros_like(e_r)) + torch.testing.assert_close(e_theta, torch.zeros_like(e_theta)) + torch.testing.assert_close(e_phi, torch.zeros_like(e_phi)) + assert not torch.isnan(e_theta).any() + assert not torch.isnan(e_phi).any() + + +def test_basis_functions_validate_last_dimension() -> None: + """Basis builders should fail fast on incorrect vector dimensions.""" + with pytest.raises(ValueError, match="r_hat"): + polar_and_dipole_basis(torch.randn(2, 3), torch.randn(2, 2)) + + with pytest.raises(ValueError, match="n_hat"): + spherical_basis(torch.randn(2, 3), torch.randn(2, 4)) + + +def test_vector_project_is_orthogonal_to_normal() -> None: + """Projected vectors should be orthogonal to the projection normal.""" + v = torch.randn(8, 3, dtype=torch.float32) + n_hat = torch.randn(8, 3, dtype=torch.float32) + n_hat = n_hat / torch.linalg.norm(n_hat, dim=-1, keepdim=True) + + v_projected = vector_project(v, n_hat) + dot = (v_projected * n_hat).sum(dim=-1) + + torch.testing.assert_close(dot, torch.zeros_like(dot), atol=1e-6, rtol=1e-6) diff --git a/test/nn/functional/test_natten.py b/test/nn/functional/test_natten.py index eaeeaf91e6..6328e35cb5 100644 --- a/test/nn/functional/test_natten.py +++ b/test/nn/functional/test_natten.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -r"""Lightweight unit tests for :mod:`physicsnemo.nn.functional.natten`. +r"""Lightweight unit tests for :mod:`physicsnemo.nn.functional.transformer.natten`. Validates that the ``na1d``, ``na2d``, and ``na3d`` wrappers: @@ -79,7 +79,7 @@ def _sdpa_reference(q, k, v): @requires_module("natten") class TestNA1D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na1d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.transformer.natten.na1d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) @pytest.mark.parametrize("dilation", [1, 2]) @@ -87,7 +87,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): """Wrapper output must be identical to ``natten.functional.na1d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d B, L, H, D = 2, 16, 4, 8 q = torch.randn(B, L, H, D, device=device) @@ -101,7 +101,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d B, L, H, D = 1, 12, 2, 16 q = torch.randn(B, L, H, D, device=device) @@ -110,7 +110,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d B, L, H, D = 1, 12, 2, 8 q = torch.randn(B, L, H, D, device=device, requires_grad=True) @@ -126,7 +126,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d B, L, H, D = 1, 8, 2, 8 q = torch.randn(B, L, H, D, device=device).as_subclass(_DispatchRecorder) @@ -140,7 +140,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the entire sequence, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.transformer.natten import na1d B, L, H, D = 2, 7, 2, 8 q = torch.randn(B, L, H, D, device=device, dtype=torch.float32) @@ -161,7 +161,7 @@ def test_full_window_matches_sdpa(self, device): @requires_module("natten") class TestNA2D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na2d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.transformer.natten.na2d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) @pytest.mark.parametrize("dilation", [1, 2]) @@ -169,7 +169,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): """Wrapper output must be identical to ``natten.functional.na2d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d B, Ht, W, H, D = 2, 16, 16, 4, 8 q = torch.randn(B, Ht, W, H, D, device=device) @@ -183,7 +183,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d B, Ht, W, H, D = 1, 6, 6, 2, 16 q = torch.randn(B, Ht, W, H, D, device=device) @@ -192,7 +192,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d B, Ht, W, H, D = 1, 6, 6, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device, requires_grad=True) @@ -208,7 +208,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d B, Ht, W, H, D = 1, 4, 4, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device).as_subclass(_DispatchRecorder) @@ -222,7 +222,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the full spatial extent, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.transformer.natten import na2d B, Ht, W, H, D = 2, 5, 5, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device, dtype=torch.float32) @@ -242,14 +242,14 @@ def test_full_window_matches_sdpa(self, device): @requires_module("natten") class TestNA3D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na3d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.transformer.natten.na3d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) def test_matches_natten_directly(self, device, kernel_size): """Wrapper output must be identical to ``natten.functional.na3d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d B, X, Y, Z, H, D = 1, 16, 16, 16, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device) @@ -263,7 +263,7 @@ def test_matches_natten_directly(self, device, kernel_size): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device) @@ -272,7 +272,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device, requires_grad=True) @@ -288,7 +288,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device).as_subclass(_DispatchRecorder) @@ -302,7 +302,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the full spatial extent, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.transformer.natten import na3d B, X, Y, Z, H, D = 1, 5, 5, 5, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device, dtype=torch.float32)