diff --git a/examples/minimal/inference/torch_trt_warp_inference.py b/examples/minimal/inference/torch_trt_warp_inference.py new file mode 100644 index 0000000000..f9aaed0561 --- /dev/null +++ b/examples/minimal/inference/torch_trt_warp_inference.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Optimized Physics-AI Inference with TensorRT and Warp +-------------------------------------------------------------- +This example demonstrates a hybrid inference pipeline that leverages: +1. NVIDIA Warp for high-performance geometric processing (neighbor search). +2. TensorRT (via Torch-TensorRT) for accelerated neural network execution. + +The model is a simplified point-cloud processor that finds neighbors using Warp +and then processes the local geometry using a TensorRT-optimized MLP. +""" + +import time +import torch +import torch.nn as nn +import warp as wp +import numpy as np +from physicsnemo.utils.inference import compile_to_trt, is_trt_available +from physicsnemo.models.figconvnet.warp_neighbor_search import radius_search_warp + +# 1. Define a Simple Neural Network Module +class GeometryProcessor(nn.Module): + def __init__(self, input_dim=3, hidden_dim=64, output_dim=32): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + return self.net(x) + +def run_example(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print("CUDA is not available. This example requires a GPU for Warp and TensorRT.") + return + + # Initialize Warp + wp.init() + + # 2. Setup Data + num_points = 10000 + points = torch.randn(num_points, 3, device=device) + queries = torch.randn(1000, 3, device=device) + radius = 0.1 + + # 3. Geometric Processing with Warp (Neighbor Search) + print(f"Finding neighbors for {queries.shape[0]} queries in {points.shape[0]} points using Warp...") + start_time = time.time() + # neighbor_index: [total_neighbors], neighbor_dist: [total_neighbors], neighbor_offset: [num_queries + 1] + neighbor_index, neighbor_dist, neighbor_offset = radius_search_warp( + points, queries, radius, device=device.type + ) + wp_time = time.time() - start_time + print(f"Warp neighbor search took: {wp_time:.4f}s") + + # 4. Neural Network Optimization with TensorRT + model = GeometryProcessor(input_dim=3).to(device).eval() + + # Example input for TensorRT compilation signature + # We'll process one neighbor at a time or in batch. + # For simplicity, let's assume we process the relative coordinates of all neighbors. + example_input = torch.randn(1, 3, device=device) + + if is_trt_available(): + print("Compiling GeometryProcessor to TensorRT...") + try: + trt_model = compile_to_trt( + model, + input_signature=[example_input], + enabled_precisions={torch.float32} + ) + process_func = trt_model + except Exception as e: + print(f"TensorRT compilation failed, falling back to eager mode: {e}") + process_func = model + else: + print("Torch-TensorRT not found, using eager PyTorch.") + process_func = model + + # 5. Hybrid Inference Loop + print("Running hybrid inference...") + # For the sake of the example, we just process the first query's neighbors + q_idx = 0 + start_idx = neighbor_offset[q_idx].item() + end_idx = neighbor_offset[q_idx+1].item() + + neighbor_indices = neighbor_index[start_idx:end_idx] + if len(neighbor_indices) > 0: + neighbor_coords = points[neighbor_indices] + relative_coords = neighbor_coords - queries[q_idx] + + # Inference using TensorRT-optimized module + with torch.no_grad(): + output = process_func(relative_coords) + + print(f"Query {q_idx} has {len(neighbor_indices)} neighbors.") + print(f"Output features shape: {output.shape}") + else: + print(f"Query {q_idx} has no neighbors within radius {radius}.") + + print("Example completed successfully!") + +if __name__ == "__main__": + run_example() diff --git a/physicsnemo/mesh/README.md b/physicsnemo/mesh/README.md index 601e18753c..5b555eaa3b 100644 --- a/physicsnemo/mesh/README.md +++ b/physicsnemo/mesh/README.md @@ -422,7 +422,7 @@ neighbors of mesh elements (i.e., based on the mesh connectivity,as opposed to Note that these use an efficient sparse (`indices`, `offsets`) encoding of the adjacency relationships, which is used internally for all computations. (See the dedicated -[`physicsnemo.mesh.neighbors._adjacency.py`](physicsnemo/mesh/neighbors/_adjacency.py) +[`physicsnemo.mesh.neighbors._adjacency.py`](./neighbors/_adjacency.py) module.) You can convert these to a typical ragged list-of-lists representation with `.to_list()`, which is useful for debugging or interoperability, at the cost of performance: @@ -540,35 +540,35 @@ Key design decisions enable these principles: ## Documentation & Resources -- **Examples**: See [`examples/`](examples/) directory for runnable demonstrations -- **Tests**: See [`test/`](test/) directory for comprehensive test suite showing usage +- **Examples**: See [`examples/`](../../examples/) directory for runnable demonstrations +- **Tests**: See [`test/`](../../test/mesh/) directory for comprehensive test suite showing usage patterns -- **Source**: Explore [`physicsnemo/mesh/`](physicsnemo/mesh/) for implementation details +- **Source**: Explore [`physicsnemo/mesh/`](./) for implementation details **Module Organization:** -- [`physicsnemo.mesh.calculus`](physicsnemo/mesh/calculus/) - Discrete differential +- [`physicsnemo.mesh.calculus`](./calculus/) - Discrete differential operators -- [`physicsnemo.mesh.curvature`](physicsnemo/mesh/curvature/) - Gaussian and mean +- [`physicsnemo.mesh.curvature`](./curvature/) - Gaussian and mean curvature -- [`physicsnemo.mesh.subdivision`](physicsnemo/mesh/subdivision/) - Mesh refinement +- [`physicsnemo.mesh.subdivision`](./subdivision/) - Mesh refinement schemes -- [`physicsnemo.mesh.boundaries`](physicsnemo/mesh/boundaries/) - Boundary detection +- [`physicsnemo.mesh.boundaries`](./boundaries/) - Boundary detection and facet extraction -- [`physicsnemo.mesh.neighbors`](physicsnemo/mesh/neighbors/) - Adjacency computations -- [`physicsnemo.mesh.spatial`](physicsnemo/mesh/spatial/) - BVH and spatial queries -- [`physicsnemo.mesh.sampling`](physicsnemo/mesh/sampling/) - Point sampling and +- [`physicsnemo.mesh.neighbors`](./neighbors/) - Adjacency computations +- [`physicsnemo.mesh.spatial`](./spatial/) - BVH and spatial queries +- [`physicsnemo.mesh.sampling`](./sampling/) - Point sampling and interpolation -- [`physicsnemo.mesh.transformations`](physicsnemo/mesh/transformations/) - Geometric +- [`physicsnemo.mesh.transformations`](./transformations/) - Geometric operations -- [`physicsnemo.mesh.repair`](physicsnemo/mesh/repair/) - Mesh cleaning and topology +- [`physicsnemo.mesh.repair`](./repair/) - Mesh cleaning and topology repair -- [`physicsnemo.mesh.validation`](physicsnemo/mesh/validation/) - Quality metrics +- [`physicsnemo.mesh.validation`](./validation/) - Quality metrics and statistics -- [`physicsnemo.mesh.visualization`](physicsnemo/mesh/visualization/) - Matplotlib +- [`physicsnemo.mesh.visualization`](./visualization/) - Matplotlib and PyVista backends -- [`physicsnemo.mesh.io`](physicsnemo/mesh/io/) - PyVista import/export -- [`physicsnemo.mesh.examples`](physicsnemo/mesh/examples/) - Example mesh generators +- [`physicsnemo.mesh.io`](./io/) - PyVista import/export +- [`physicsnemo.mesh.examples`](./examples/) - Example mesh generators --- diff --git a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py index 6698be028d..4afc17913f 100644 --- a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py @@ -53,7 +53,7 @@ def count_neighbors( grid: wp.HashGrid, wp_points: wp.array(dtype=wp.vec3), wp_queries: wp.array(dtype=wp.vec3), - wp_launch_device: wp.context.Device | None, + wp_launch_device: wp.Device | None, wp_launch_stream: wp.Stream | None, radius: float, N_queries: int, @@ -65,7 +65,7 @@ def count_neighbors( grid (wp.HashGrid): The hash grid to use for the search. wp_points (wp.array): The points to search in, as a warp array. wp_queries (wp.array): The queries to search for, as a warp array. - wp_launch_device (wp.context.Device | None): The device to launch the kernel on. + wp_launch_device (wp.Device | None): The device to launch the kernel on. wp_launch_stream (wp.Stream | None): The stream to launch the kernel on. radius (float): The radius that bounds the search. N_queries (int): Total number of query points. @@ -114,7 +114,7 @@ def gather_neighbors( wp_points: wp.array(dtype=wp.vec3), wp_queries: wp.array(dtype=wp.vec3), wp_offset: wp.array(dtype=wp.int32), - wp_launch_device: wp.context.Device | None, + wp_launch_device: wp.Device | None, wp_launch_stream: wp.Stream | None, radius: float, N_queries: int, @@ -131,7 +131,7 @@ def gather_neighbors( wp_points (wp.array): The points to search in, as a warp array. wp_queries (wp.array): The queries to search for, as a warp array. wp_offset (wp.array): The offset in output for each input point, as a warp array. - wp_launch_device (wp.context.Device | None): The device to launch the kernel on. + wp_launch_device (wp.Device | None): The device to launch the kernel on. wp_launch_stream (wp.Stream | None): The stream to launch the kernel on. radius (float): The radius that bounds the search. N_queries (int): Total number of query points. diff --git a/physicsnemo/utils/__init__.py b/physicsnemo/utils/__init__.py index 6a490d4fc3..57cd6b1791 100644 --- a/physicsnemo/utils/__init__.py +++ b/physicsnemo/utils/__init__.py @@ -26,3 +26,4 @@ ) from .logging import LaunchLogger, PythonLogger, RankZeroLoggingWrapper from .profiling import Profiler +from .inference import compile_to_trt, is_trt_available diff --git a/physicsnemo/utils/inference.py b/physicsnemo/utils/inference.py new file mode 100644 index 0000000000..1a190d949d --- /dev/null +++ b/physicsnemo/utils/inference.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional, Union + +import torch + +try: + import torch_tensorrt + _TORCH_TRT_AVAILABLE = True +except ImportError: + _TORCH_TRT_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +def compile_to_trt( + model: torch.nn.Module, + input_signature: List[torch.Tensor], + enabled_precisions: Optional[Set[torch.dtype]] = None, + workspace_size: int = 1 << 30, + min_block_size: int = 3, + **kwargs: Any, +) -> torch.nn.Module: + """Compile a PyTorch module to TensorRT for optimized inference. + + This utility provides a high-level wrapper around Torch-TensorRT to optimize + PhysicsNeMo models. It handles standard compilation parameters and provides + graceful fallbacks. + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to compile. + input_signature : List[torch.Tensor] + A list of example input tensors that define the input shapes and types. + enabled_precisions : Set[torch.dtype], optional + Set of precisions to enable (e.g., {torch.float32, torch.float16}). + Defaults to {torch.float32}. + workspace_size : int, optional + Maximum workspace size for TensorRT in bytes, by default 1GB. + min_block_size : int, optional + Minimum number of operators in a sub-graph to be converted to TensorRT, + by default 3. + **kwargs : Any + Additional arguments passed to torch_tensorrt.compile. + + Returns + ------- + torch.nn.Module + The compiled TensorRT-optimized model. + + Raises + ------ + ImportError + If torch_tensorrt is not installed. + """ + if not _TORCH_TRT_AVAILABLE: + raise ImportError( + "torch_tensorrt is required for TensorRT compilation. " + "Please install it using 'pip install torch-tensorrt'." + ) + + if enabled_precisions is None: + enabled_precisions = {torch.float32} + + logger.info(f"Compiling model {model.__class__.__name__} to TensorRT...") + + # Set up compilation arguments for Torch-TRT + compile_spec = { + "inputs": input_signature, + "enabled_precisions": enabled_precisions, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + **kwargs, + } + + try: + # Use torch.compile with tensorrt backend if using PyTorch 2.x style + # or fall back to torch_tensorrt.compile for explicit conversion. + # Here we prefer the explicit torch_tensorrt.compile for better control + # over the conversion process in static inference scenarios. + trt_model = torch_tensorrt.compile(model, **compile_spec) + logger.info("TensorRT compilation successful.") + return trt_model + except Exception as e: + logger.error(f"TensorRT compilation failed: {e}") + raise e + + +def is_trt_available() -> bool: + """Check if TensorRT support is available in the current environment. + + Returns + ------- + bool + True if torch_tensorrt is installed, False otherwise. + """ + return _TORCH_TRT_AVAILABLE