-
Notifications
You must be signed in to change notification settings - Fork 647
🚀[FEA]: Add TensorRT compilation utility and hybrid Warp example #1565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Example: Optimized Physics-AI Inference with TensorRT and Warp | ||
| -------------------------------------------------------------- | ||
| This example demonstrates a hybrid inference pipeline that leverages: | ||
| 1. NVIDIA Warp for high-performance geometric processing (neighbor search). | ||
| 2. TensorRT (via Torch-TensorRT) for accelerated neural network execution. | ||
|
|
||
| The model is a simplified point-cloud processor that finds neighbors using Warp | ||
| and then processes the local geometry using a TensorRT-optimized MLP. | ||
| """ | ||
|
|
||
| import time | ||
| import torch | ||
| import torch.nn as nn | ||
| import warp as wp | ||
| import numpy as np | ||
| from physicsnemo.utils.inference import compile_to_trt, is_trt_available | ||
| from physicsnemo.models.figconvnet.warp_neighbor_search import radius_search_warp | ||
|
|
||
| # 1. Define a Simple Neural Network Module | ||
| class GeometryProcessor(nn.Module): | ||
| def __init__(self, input_dim=3, hidden_dim=64, output_dim=32): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(input_dim, hidden_dim), | ||
| nn.ReLU(), | ||
| nn.Linear(hidden_dim, hidden_dim), | ||
| nn.ReLU(), | ||
| nn.Linear(hidden_dim, output_dim) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.net(x) | ||
|
|
||
| def run_example(): | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| if device.type == "cpu": | ||
| print("CUDA is not available. This example requires a GPU for Warp and TensorRT.") | ||
| return | ||
|
|
||
| # Initialize Warp | ||
| wp.init() | ||
|
|
||
| # 2. Setup Data | ||
| num_points = 10000 | ||
| points = torch.randn(num_points, 3, device=device) | ||
| queries = torch.randn(1000, 3, device=device) | ||
| radius = 0.1 | ||
|
|
||
| # 3. Geometric Processing with Warp (Neighbor Search) | ||
| print(f"Finding neighbors for {queries.shape[0]} queries in {points.shape[0]} points using Warp...") | ||
| start_time = time.time() | ||
| # neighbor_index: [total_neighbors], neighbor_dist: [total_neighbors], neighbor_offset: [num_queries + 1] | ||
| neighbor_index, neighbor_dist, neighbor_offset = radius_search_warp( | ||
| points, queries, radius, device=device.type | ||
| ) | ||
| wp_time = time.time() - start_time | ||
| print(f"Warp neighbor search took: {wp_time:.4f}s") | ||
|
|
||
| # 4. Neural Network Optimization with TensorRT | ||
| model = GeometryProcessor(input_dim=3).to(device).eval() | ||
|
|
||
| # Example input for TensorRT compilation signature | ||
| # We'll process one neighbor at a time or in batch. | ||
| # For simplicity, let's assume we process the relative coordinates of all neighbors. | ||
| example_input = torch.randn(1, 3, device=device) | ||
|
|
||
| if is_trt_available(): | ||
| print("Compiling GeometryProcessor to TensorRT...") | ||
| try: | ||
| trt_model = compile_to_trt( | ||
| model, | ||
| input_signature=[example_input], | ||
| enabled_precisions={torch.float32} | ||
| ) | ||
| process_func = trt_model | ||
| except Exception as e: | ||
| print(f"TensorRT compilation failed, falling back to eager mode: {e}") | ||
| process_func = model | ||
| else: | ||
| print("Torch-TensorRT not found, using eager PyTorch.") | ||
| process_func = model | ||
|
|
||
| # 5. Hybrid Inference Loop | ||
| print("Running hybrid inference...") | ||
| # For the sake of the example, we just process the first query's neighbors | ||
| q_idx = 0 | ||
| start_idx = neighbor_offset[q_idx].item() | ||
| end_idx = neighbor_offset[q_idx+1].item() | ||
|
|
||
| neighbor_indices = neighbor_index[start_idx:end_idx] | ||
| if len(neighbor_indices) > 0: | ||
| neighbor_coords = points[neighbor_indices] | ||
| relative_coords = neighbor_coords - queries[q_idx] | ||
|
|
||
| # Inference using TensorRT-optimized module | ||
| with torch.no_grad(): | ||
| output = process_func(relative_coords) | ||
|
|
||
| print(f"Query {q_idx} has {len(neighbor_indices)} neighbors.") | ||
| print(f"Output features shape: {output.shape}") | ||
| else: | ||
| print(f"Query {q_idx} has no neighbors within radius {radius}.") | ||
|
|
||
| print("Example completed successfully!") | ||
|
|
||
| if __name__ == "__main__": | ||
| run_example() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -422,7 +422,7 @@ neighbors of mesh elements (i.e., based on the mesh connectivity,as opposed to | |
| Note that these use an efficient sparse (`indices`, `offsets`) encoding of the | ||
| adjacency relationships, which is used internally for all computations. (See the | ||
| dedicated | ||
| [`physicsnemo.mesh.neighbors._adjacency.py`](physicsnemo/mesh/neighbors/_adjacency.py) | ||
| [`physicsnemo.mesh.neighbors._adjacency.py`](./neighbors/_adjacency.py) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove these unrelated changes, or bring them in with a separate PR. |
||
| module.) You can convert these to a typical ragged list-of-lists representation | ||
| with `.to_list()`, which is useful for debugging or interoperability, at the | ||
| cost of performance: | ||
|
|
@@ -540,35 +540,35 @@ Key design decisions enable these principles: | |
|
|
||
| ## Documentation & Resources | ||
|
|
||
| - **Examples**: See [`examples/`](examples/) directory for runnable demonstrations | ||
| - **Tests**: See [`test/`](test/) directory for comprehensive test suite showing usage | ||
| - **Examples**: See [`examples/`](../../examples/) directory for runnable demonstrations | ||
| - **Tests**: See [`test/`](../../test/mesh/) directory for comprehensive test suite showing usage | ||
| patterns | ||
| - **Source**: Explore [`physicsnemo/mesh/`](physicsnemo/mesh/) for implementation details | ||
| - **Source**: Explore [`physicsnemo/mesh/`](./) for implementation details | ||
|
|
||
| **Module Organization:** | ||
|
|
||
| - [`physicsnemo.mesh.calculus`](physicsnemo/mesh/calculus/) - Discrete differential | ||
| - [`physicsnemo.mesh.calculus`](./calculus/) - Discrete differential | ||
| operators | ||
| - [`physicsnemo.mesh.curvature`](physicsnemo/mesh/curvature/) - Gaussian and mean | ||
| - [`physicsnemo.mesh.curvature`](./curvature/) - Gaussian and mean | ||
| curvature | ||
| - [`physicsnemo.mesh.subdivision`](physicsnemo/mesh/subdivision/) - Mesh refinement | ||
| - [`physicsnemo.mesh.subdivision`](./subdivision/) - Mesh refinement | ||
| schemes | ||
| - [`physicsnemo.mesh.boundaries`](physicsnemo/mesh/boundaries/) - Boundary detection | ||
| - [`physicsnemo.mesh.boundaries`](./boundaries/) - Boundary detection | ||
| and facet extraction | ||
| - [`physicsnemo.mesh.neighbors`](physicsnemo/mesh/neighbors/) - Adjacency computations | ||
| - [`physicsnemo.mesh.spatial`](physicsnemo/mesh/spatial/) - BVH and spatial queries | ||
| - [`physicsnemo.mesh.sampling`](physicsnemo/mesh/sampling/) - Point sampling and | ||
| - [`physicsnemo.mesh.neighbors`](./neighbors/) - Adjacency computations | ||
| - [`physicsnemo.mesh.spatial`](./spatial/) - BVH and spatial queries | ||
| - [`physicsnemo.mesh.sampling`](./sampling/) - Point sampling and | ||
| interpolation | ||
| - [`physicsnemo.mesh.transformations`](physicsnemo/mesh/transformations/) - Geometric | ||
| - [`physicsnemo.mesh.transformations`](./transformations/) - Geometric | ||
| operations | ||
| - [`physicsnemo.mesh.repair`](physicsnemo/mesh/repair/) - Mesh cleaning and topology | ||
| - [`physicsnemo.mesh.repair`](./repair/) - Mesh cleaning and topology | ||
| repair | ||
| - [`physicsnemo.mesh.validation`](physicsnemo/mesh/validation/) - Quality metrics | ||
| - [`physicsnemo.mesh.validation`](./validation/) - Quality metrics | ||
| and statistics | ||
| - [`physicsnemo.mesh.visualization`](physicsnemo/mesh/visualization/) - Matplotlib | ||
| - [`physicsnemo.mesh.visualization`](./visualization/) - Matplotlib | ||
| and PyVista backends | ||
| - [`physicsnemo.mesh.io`](physicsnemo/mesh/io/) - PyVista import/export | ||
| - [`physicsnemo.mesh.examples`](physicsnemo/mesh/examples/) - Example mesh generators | ||
| - [`physicsnemo.mesh.io`](./io/) - PyVista import/export | ||
| - [`physicsnemo.mesh.examples`](./examples/) - Example mesh generators | ||
|
|
||
| --- | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,7 @@ def count_neighbors( | |
| grid: wp.HashGrid, | ||
| wp_points: wp.array(dtype=wp.vec3), | ||
| wp_queries: wp.array(dtype=wp.vec3), | ||
| wp_launch_device: wp.context.Device | None, | ||
| wp_launch_device: wp.Device | None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good idea, and fixes a deprecation warning, but is unrelated to the main PR here - please bring these changes in a separate PR. |
||
| wp_launch_stream: wp.Stream | None, | ||
| radius: float, | ||
| N_queries: int, | ||
|
|
@@ -65,7 +65,7 @@ def count_neighbors( | |
| grid (wp.HashGrid): The hash grid to use for the search. | ||
| wp_points (wp.array): The points to search in, as a warp array. | ||
| wp_queries (wp.array): The queries to search for, as a warp array. | ||
| wp_launch_device (wp.context.Device | None): The device to launch the kernel on. | ||
| wp_launch_device (wp.Device | None): The device to launch the kernel on. | ||
| wp_launch_stream (wp.Stream | None): The stream to launch the kernel on. | ||
| radius (float): The radius that bounds the search. | ||
| N_queries (int): Total number of query points. | ||
|
|
@@ -114,7 +114,7 @@ def gather_neighbors( | |
| wp_points: wp.array(dtype=wp.vec3), | ||
| wp_queries: wp.array(dtype=wp.vec3), | ||
| wp_offset: wp.array(dtype=wp.int32), | ||
| wp_launch_device: wp.context.Device | None, | ||
| wp_launch_device: wp.Device | None, | ||
| wp_launch_stream: wp.Stream | None, | ||
| radius: float, | ||
| N_queries: int, | ||
|
|
@@ -131,7 +131,7 @@ def gather_neighbors( | |
| wp_points (wp.array): The points to search in, as a warp array. | ||
| wp_queries (wp.array): The queries to search for, as a warp array. | ||
| wp_offset (wp.array): The offset in output for each input point, as a warp array. | ||
| wp_launch_device (wp.context.Device | None): The device to launch the kernel on. | ||
| wp_launch_device (wp.Device | None): The device to launch the kernel on. | ||
| wp_launch_stream (wp.Stream | None): The stream to launch the kernel on. | ||
| radius (float): The radius that bounds the search. | ||
| N_queries (int): Total number of query points. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,113 @@ | ||||||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||||||
| # SPDX-FileCopyrightText: All rights reserved. | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| import logging | ||||||
| from typing import Any, List, Optional, Union | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| try: | ||||||
| import torch_tensorrt | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. violates repo-wide optional import conventions; will trip up |
||||||
| _TORCH_TRT_AVAILABLE = True | ||||||
| except ImportError: | ||||||
| _TORCH_TRT_AVAILABLE = False | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
||||||
| def compile_to_trt( | ||||||
| model: torch.nn.Module, | ||||||
| input_signature: List[torch.Tensor], | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| enabled_precisions: Optional[Set[torch.dtype]] = None, | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Modernizes type-hint syntax |
||||||
| workspace_size: int = 1 << 30, | ||||||
| min_block_size: int = 3, | ||||||
| **kwargs: Any, | ||||||
| ) -> torch.nn.Module: | ||||||
| """Compile a PyTorch module to TensorRT for optimized inference. | ||||||
|
|
||||||
| This utility provides a high-level wrapper around Torch-TensorRT to optimize | ||||||
| PhysicsNeMo models. It handles standard compilation parameters and provides | ||||||
| graceful fallbacks. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| model : torch.nn.Module | ||||||
| The PyTorch model to compile. | ||||||
| input_signature : List[torch.Tensor] | ||||||
| A list of example input tensors that define the input shapes and types. | ||||||
| enabled_precisions : Set[torch.dtype], optional | ||||||
| Set of precisions to enable (e.g., {torch.float32, torch.float16}). | ||||||
| Defaults to {torch.float32}. | ||||||
| workspace_size : int, optional | ||||||
| Maximum workspace size for TensorRT in bytes, by default 1GB. | ||||||
| min_block_size : int, optional | ||||||
| Minimum number of operators in a sub-graph to be converted to TensorRT, | ||||||
| by default 3. | ||||||
| **kwargs : Any | ||||||
| Additional arguments passed to torch_tensorrt.compile. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| torch.nn.Module | ||||||
| The compiled TensorRT-optimized model. | ||||||
|
|
||||||
| Raises | ||||||
| ------ | ||||||
| ImportError | ||||||
| If torch_tensorrt is not installed. | ||||||
| """ | ||||||
| if not _TORCH_TRT_AVAILABLE: | ||||||
| raise ImportError( | ||||||
| "torch_tensorrt is required for TensorRT compilation. " | ||||||
| "Please install it using 'pip install torch-tensorrt'." | ||||||
| ) | ||||||
|
|
||||||
| if enabled_precisions is None: | ||||||
| enabled_precisions = {torch.float32} | ||||||
|
|
||||||
| logger.info(f"Compiling model {model.__class__.__name__} to TensorRT...") | ||||||
|
|
||||||
| # Set up compilation arguments for Torch-TRT | ||||||
| compile_spec = { | ||||||
| "inputs": input_signature, | ||||||
| "enabled_precisions": enabled_precisions, | ||||||
| "workspace_size": workspace_size, | ||||||
| "min_block_size": min_block_size, | ||||||
| **kwargs, | ||||||
| } | ||||||
|
|
||||||
| try: | ||||||
| # Use torch.compile with tensorrt backend if using PyTorch 2.x style | ||||||
| # or fall back to torch_tensorrt.compile for explicit conversion. | ||||||
| # Here we prefer the explicit torch_tensorrt.compile for better control | ||||||
| # over the conversion process in static inference scenarios. | ||||||
| trt_model = torch_tensorrt.compile(model, **compile_spec) | ||||||
| logger.info("TensorRT compilation successful.") | ||||||
| return trt_model | ||||||
| except Exception as e: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend allowing this to fail, rather than re-raising here. If we do choose to re-raise, perhaps we can narrow scope to tighter than a bare
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, agree. |
||||||
| logger.error(f"TensorRT compilation failed: {e}") | ||||||
| raise e | ||||||
|
|
||||||
|
|
||||||
| def is_trt_available() -> bool: | ||||||
| """Check if TensorRT support is available in the current environment. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| bool | ||||||
| True if torch_tensorrt is installed, False otherwise. | ||||||
| """ | ||||||
| return _TORCH_TRT_AVAILABLE | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
radius_search_warpdoes not accept adevicekeyword argumentThe function signature in
physicsnemo/models/figconvnet/warp_neighbor_search.pyisradius_search_warp(points, queries, radius, grid_dim=...)— it derives the device directly from the input tensors and has nodeviceparameter. Passingdevice=device.typewill raiseTypeError: radius_search_warp() got an unexpected keyword argument 'device'at runtime.