Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/physicsnemo/nn/functional/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

"""Registry of FunctionSpec classes to benchmark with ASV."""

from physicsnemo.nn.functional.equivariant_ops import (
LegendrePolynomials,
PolarAndDipoleBasis,
SmoothLog,
SphericalBasis,
VectorProject,
)
from physicsnemo.nn.functional.fourier_spectral import (
IRFFT,
IRFFT2,
Expand All @@ -32,13 +39,24 @@
DropPath,
WeightFact,
)
from physicsnemo.nn.functional.transformer import NA1D, NA2D, NA3D

# FunctionSpec classes listed here must implement ``make_inputs_forward`` for ASV.
# ``make_inputs_backward`` is optional and only used when backward benchmarks run.
FUNCTIONAL_SPECS = (
# Regularization / parameterization.
DropPath,
WeightFact,
# Equivariant ops.
SmoothLog,
LegendrePolynomials,
VectorProject,
PolarAndDipoleBasis,
SphericalBasis,
# Neighborhood attention.
NA1D,
NA2D,
NA3D,
# Neighbor queries.
KNN,
RadiusSearch,
Expand Down
12 changes: 12 additions & 0 deletions docs/api/nn/functionals/equivariant_ops.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Equivariant Ops Functionals
===========================

.. autofunction:: physicsnemo.nn.functional.smooth_log

.. autofunction:: physicsnemo.nn.functional.legendre_polynomials

.. autofunction:: physicsnemo.nn.functional.vector_project

.. autofunction:: physicsnemo.nn.functional.polar_and_dipole_basis

.. autofunction:: physicsnemo.nn.functional.spherical_basis
1 change: 1 addition & 0 deletions docs/api/physicsnemo.nn.functionals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ in the documentation for performance comparisons.
:name: PhysicsNeMo Functionals

nn/functionals/neighbors
nn/functionals/equivariant_ops
nn/functionals/geometry
nn/functionals/fourier_spectral
nn/functionals/regularization_parameterization
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/datapipes/benchmarks/darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def generate_batch(self) -> None:
],
device=self.device,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Trailing whitespace on blank line

A blank line with a trailing space was introduced here. It is unrelated to the rest of the PR and may fail linting checks.

Suggested change

→ should simply be an empty line with no trailing whitespace.


def __iter__(self) -> Tuple[Tensor, Tensor]:
"""
Expand Down
8 changes: 4 additions & 4 deletions physicsnemo/domain_parallel/shard_utils/natten_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
MissingShardPatch,
UndeterminedShardingError,
)
from physicsnemo.nn.functional.natten import na1d, na2d, na3d
from physicsnemo.nn.functional.transformer.natten import na1d, na2d, na3d

_natten = OptionalImport("natten")
_raw_func_map = {
Expand Down Expand Up @@ -221,9 +221,9 @@ def _natten_wrapper(
r"""Shared wrapper for natten functions to support sharded tensors.

Registered with :meth:`ShardTensor.register_function_handler` so that calls
to :func:`~physicsnemo.nn.functional.natten.na1d`,
:func:`~physicsnemo.nn.functional.natten.na2d`, or
:func:`~physicsnemo.nn.functional.natten.na3d` automatically route through
to :func:`~physicsnemo.nn.functional.transformer.natten.na1d`,
:func:`~physicsnemo.nn.functional.transformer.natten.na2d`, or
:func:`~physicsnemo.nn.functional.transformer.natten.na3d` automatically route through
this handler when any argument is a :class:`ShardTensor`.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion physicsnemo/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from .fourier_spectral import imag, irfft, irfft2, real, rfft, rfft2, view_as_complex
from .geometry import signed_distance_field
from .interpolation import interpolation
from .natten import na1d, na2d, na3d
from .neighbors import knn, radius_search
from .regularization_parameterization import drop_path, weight_fact
from .transformer import na1d, na2d, na3d

__all__ = [
"irfft",
Expand Down
Loading
Loading