Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions examples/minimal/inference/torch_trt_warp_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example: Optimized Physics-AI Inference with TensorRT and Warp
--------------------------------------------------------------
This example demonstrates a hybrid inference pipeline that leverages:
1. NVIDIA Warp for high-performance geometric processing (neighbor search).
2. TensorRT (via Torch-TensorRT) for accelerated neural network execution.

The model is a simplified point-cloud processor that finds neighbors using Warp
and then processes the local geometry using a TensorRT-optimized MLP.
"""

import time
import torch
import torch.nn as nn
import warp as wp
import numpy as np
from physicsnemo.utils.inference import compile_to_trt, is_trt_available
from physicsnemo.models.figconvnet.warp_neighbor_search import radius_search_warp

# 1. Define a Simple Neural Network Module
class GeometryProcessor(nn.Module):
def __init__(self, input_dim=3, hidden_dim=64, output_dim=32):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)

def forward(self, x):
return self.net(x)

def run_example():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu":
print("CUDA is not available. This example requires a GPU for Warp and TensorRT.")
return

# Initialize Warp
wp.init()

# 2. Setup Data
num_points = 10000
points = torch.randn(num_points, 3, device=device)
queries = torch.randn(1000, 3, device=device)
radius = 0.1

# 3. Geometric Processing with Warp (Neighbor Search)
Comment on lines +63 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 radius_search_warp does not accept a device keyword argument

The function signature in physicsnemo/models/figconvnet/warp_neighbor_search.py is radius_search_warp(points, queries, radius, grid_dim=...) — it derives the device directly from the input tensors and has no device parameter. Passing device=device.type will raise TypeError: radius_search_warp() got an unexpected keyword argument 'device' at runtime.

Suggested change
queries = torch.randn(1000, 3, device=device)
radius = 0.1
# 3. Geometric Processing with Warp (Neighbor Search)
neighbor_index, neighbor_dist, neighbor_offset = radius_search_warp(
points, queries, radius
)

print(f"Finding neighbors for {queries.shape[0]} queries in {points.shape[0]} points using Warp...")
start_time = time.time()
# neighbor_index: [total_neighbors], neighbor_dist: [total_neighbors], neighbor_offset: [num_queries + 1]
neighbor_index, neighbor_dist, neighbor_offset = radius_search_warp(
points, queries, radius, device=device.type
)
wp_time = time.time() - start_time
print(f"Warp neighbor search took: {wp_time:.4f}s")

# 4. Neural Network Optimization with TensorRT
model = GeometryProcessor(input_dim=3).to(device).eval()

# Example input for TensorRT compilation signature
# We'll process one neighbor at a time or in batch.
# For simplicity, let's assume we process the relative coordinates of all neighbors.
example_input = torch.randn(1, 3, device=device)

if is_trt_available():
print("Compiling GeometryProcessor to TensorRT...")
try:
trt_model = compile_to_trt(
model,
input_signature=[example_input],
enabled_precisions={torch.float32}
)
process_func = trt_model
except Exception as e:
print(f"TensorRT compilation failed, falling back to eager mode: {e}")
process_func = model
else:
print("Torch-TensorRT not found, using eager PyTorch.")
process_func = model

# 5. Hybrid Inference Loop
print("Running hybrid inference...")
# For the sake of the example, we just process the first query's neighbors
q_idx = 0
start_idx = neighbor_offset[q_idx].item()
end_idx = neighbor_offset[q_idx+1].item()

neighbor_indices = neighbor_index[start_idx:end_idx]
if len(neighbor_indices) > 0:
neighbor_coords = points[neighbor_indices]
relative_coords = neighbor_coords - queries[q_idx]

# Inference using TensorRT-optimized module
with torch.no_grad():
output = process_func(relative_coords)

print(f"Query {q_idx} has {len(neighbor_indices)} neighbors.")
print(f"Output features shape: {output.shape}")
else:
print(f"Query {q_idx} has no neighbors within radius {radius}.")

print("Example completed successfully!")

if __name__ == "__main__":
run_example()
34 changes: 17 additions & 17 deletions physicsnemo/mesh/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ neighbors of mesh elements (i.e., based on the mesh connectivity,as opposed to
Note that these use an efficient sparse (`indices`, `offsets`) encoding of the
adjacency relationships, which is used internally for all computations. (See the
dedicated
[`physicsnemo.mesh.neighbors._adjacency.py`](physicsnemo/mesh/neighbors/_adjacency.py)
[`physicsnemo.mesh.neighbors._adjacency.py`](./neighbors/_adjacency.py)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these unrelated changes, or bring them in with a separate PR.

module.) You can convert these to a typical ragged list-of-lists representation
with `.to_list()`, which is useful for debugging or interoperability, at the
cost of performance:
Expand Down Expand Up @@ -540,35 +540,35 @@ Key design decisions enable these principles:

## Documentation & Resources

- **Examples**: See [`examples/`](examples/) directory for runnable demonstrations
- **Tests**: See [`test/`](test/) directory for comprehensive test suite showing usage
- **Examples**: See [`examples/`](../../examples/) directory for runnable demonstrations
- **Tests**: See [`test/`](../../test/mesh/) directory for comprehensive test suite showing usage
patterns
- **Source**: Explore [`physicsnemo/mesh/`](physicsnemo/mesh/) for implementation details
- **Source**: Explore [`physicsnemo/mesh/`](./) for implementation details

**Module Organization:**

- [`physicsnemo.mesh.calculus`](physicsnemo/mesh/calculus/) - Discrete differential
- [`physicsnemo.mesh.calculus`](./calculus/) - Discrete differential
operators
- [`physicsnemo.mesh.curvature`](physicsnemo/mesh/curvature/) - Gaussian and mean
- [`physicsnemo.mesh.curvature`](./curvature/) - Gaussian and mean
curvature
- [`physicsnemo.mesh.subdivision`](physicsnemo/mesh/subdivision/) - Mesh refinement
- [`physicsnemo.mesh.subdivision`](./subdivision/) - Mesh refinement
schemes
- [`physicsnemo.mesh.boundaries`](physicsnemo/mesh/boundaries/) - Boundary detection
- [`physicsnemo.mesh.boundaries`](./boundaries/) - Boundary detection
and facet extraction
- [`physicsnemo.mesh.neighbors`](physicsnemo/mesh/neighbors/) - Adjacency computations
- [`physicsnemo.mesh.spatial`](physicsnemo/mesh/spatial/) - BVH and spatial queries
- [`physicsnemo.mesh.sampling`](physicsnemo/mesh/sampling/) - Point sampling and
- [`physicsnemo.mesh.neighbors`](./neighbors/) - Adjacency computations
- [`physicsnemo.mesh.spatial`](./spatial/) - BVH and spatial queries
- [`physicsnemo.mesh.sampling`](./sampling/) - Point sampling and
interpolation
- [`physicsnemo.mesh.transformations`](physicsnemo/mesh/transformations/) - Geometric
- [`physicsnemo.mesh.transformations`](./transformations/) - Geometric
operations
- [`physicsnemo.mesh.repair`](physicsnemo/mesh/repair/) - Mesh cleaning and topology
- [`physicsnemo.mesh.repair`](./repair/) - Mesh cleaning and topology
repair
- [`physicsnemo.mesh.validation`](physicsnemo/mesh/validation/) - Quality metrics
- [`physicsnemo.mesh.validation`](./validation/) - Quality metrics
and statistics
- [`physicsnemo.mesh.visualization`](physicsnemo/mesh/visualization/) - Matplotlib
- [`physicsnemo.mesh.visualization`](./visualization/) - Matplotlib
and PyVista backends
- [`physicsnemo.mesh.io`](physicsnemo/mesh/io/) - PyVista import/export
- [`physicsnemo.mesh.examples`](physicsnemo/mesh/examples/) - Example mesh generators
- [`physicsnemo.mesh.io`](./io/) - PyVista import/export
- [`physicsnemo.mesh.examples`](./examples/) - Example mesh generators

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def count_neighbors(
grid: wp.HashGrid,
wp_points: wp.array(dtype=wp.vec3),
wp_queries: wp.array(dtype=wp.vec3),
wp_launch_device: wp.context.Device | None,
wp_launch_device: wp.Device | None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea, and fixes a deprecation warning, but is unrelated to the main PR here - please bring these changes in a separate PR.

wp_launch_stream: wp.Stream | None,
radius: float,
N_queries: int,
Expand All @@ -65,7 +65,7 @@ def count_neighbors(
grid (wp.HashGrid): The hash grid to use for the search.
wp_points (wp.array): The points to search in, as a warp array.
wp_queries (wp.array): The queries to search for, as a warp array.
wp_launch_device (wp.context.Device | None): The device to launch the kernel on.
wp_launch_device (wp.Device | None): The device to launch the kernel on.
wp_launch_stream (wp.Stream | None): The stream to launch the kernel on.
radius (float): The radius that bounds the search.
N_queries (int): Total number of query points.
Expand Down Expand Up @@ -114,7 +114,7 @@ def gather_neighbors(
wp_points: wp.array(dtype=wp.vec3),
wp_queries: wp.array(dtype=wp.vec3),
wp_offset: wp.array(dtype=wp.int32),
wp_launch_device: wp.context.Device | None,
wp_launch_device: wp.Device | None,
wp_launch_stream: wp.Stream | None,
radius: float,
N_queries: int,
Expand All @@ -131,7 +131,7 @@ def gather_neighbors(
wp_points (wp.array): The points to search in, as a warp array.
wp_queries (wp.array): The queries to search for, as a warp array.
wp_offset (wp.array): The offset in output for each input point, as a warp array.
wp_launch_device (wp.context.Device | None): The device to launch the kernel on.
wp_launch_device (wp.Device | None): The device to launch the kernel on.
wp_launch_stream (wp.Stream | None): The stream to launch the kernel on.
radius (float): The radius that bounds the search.
N_queries (int): Total number of query points.
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
)
from .logging import LaunchLogger, PythonLogger, RankZeroLoggingWrapper
from .profiling import Profiler
from .inference import compile_to_trt, is_trt_available
113 changes: 113 additions & 0 deletions physicsnemo/utils/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, List, Optional, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Missing Set import causes NameError at import time

Set is used in the type annotation on line 34 (Optional[Set[torch.dtype]]) but is not imported from typing. Because Python evaluates function annotations eagerly at definition time (without from __future__ import annotations), importing this module will immediately raise NameError: name 'Set' is not defined. Union is also imported but never used.

Suggested change
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Set

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused Union import

Union is imported but never referenced in this file. It can be removed to keep the imports clean.


import torch

try:
import torch_tensorrt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates repo-wide optional import conventions; will trip up importlinter

_TORCH_TRT_AVAILABLE = True
except ImportError:
_TORCH_TRT_AVAILABLE = False

logger = logging.getLogger(__name__)


def compile_to_trt(
model: torch.nn.Module,
input_signature: List[torch.Tensor],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_signature: List[torch.Tensor],
input_signature: list[torch.Tensor],

enabled_precisions: Optional[Set[torch.dtype]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enabled_precisions: Optional[Set[torch.dtype]] = None,
enabled_precisions: Set[torch.dtype] | None = None,

Modernizes type-hint syntax

workspace_size: int = 1 << 30,
min_block_size: int = 3,
**kwargs: Any,
) -> torch.nn.Module:
"""Compile a PyTorch module to TensorRT for optimized inference.

This utility provides a high-level wrapper around Torch-TensorRT to optimize
PhysicsNeMo models. It handles standard compilation parameters and provides
graceful fallbacks.

Parameters
----------
model : torch.nn.Module
The PyTorch model to compile.
input_signature : List[torch.Tensor]
A list of example input tensors that define the input shapes and types.
enabled_precisions : Set[torch.dtype], optional
Set of precisions to enable (e.g., {torch.float32, torch.float16}).
Defaults to {torch.float32}.
workspace_size : int, optional
Maximum workspace size for TensorRT in bytes, by default 1GB.
min_block_size : int, optional
Minimum number of operators in a sub-graph to be converted to TensorRT,
by default 3.
**kwargs : Any
Additional arguments passed to torch_tensorrt.compile.

Returns
-------
torch.nn.Module
The compiled TensorRT-optimized model.

Raises
------
ImportError
If torch_tensorrt is not installed.
"""
if not _TORCH_TRT_AVAILABLE:
raise ImportError(
"torch_tensorrt is required for TensorRT compilation. "
"Please install it using 'pip install torch-tensorrt'."
)

if enabled_precisions is None:
enabled_precisions = {torch.float32}

logger.info(f"Compiling model {model.__class__.__name__} to TensorRT...")

# Set up compilation arguments for Torch-TRT
compile_spec = {
"inputs": input_signature,
"enabled_precisions": enabled_precisions,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
**kwargs,
}

try:
# Use torch.compile with tensorrt backend if using PyTorch 2.x style
# or fall back to torch_tensorrt.compile for explicit conversion.
# Here we prefer the explicit torch_tensorrt.compile for better control
# over the conversion process in static inference scenarios.
trt_model = torch_tensorrt.compile(model, **compile_spec)
logger.info("TensorRT compilation successful.")
return trt_model
except Exception as e:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend allowing this to fail, rather than re-raising here. If we do choose to re-raise, perhaps we can narrow scope to tighter than a bare Exception?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree.

logger.error(f"TensorRT compilation failed: {e}")
raise e


def is_trt_available() -> bool:
"""Check if TensorRT support is available in the current environment.

Returns
-------
bool
True if torch_tensorrt is installed, False otherwise.
"""
return _TORCH_TRT_AVAILABLE