Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
JointType,
Model,
ModelBuilder,
SpeculativeContactConfig,
State,
eval_fk,
eval_ik,
Expand All @@ -73,6 +74,7 @@
"JointType",
"Model",
"ModelBuilder",
"SpeculativeContactConfig",
"State",
"eval_fk",
"eval_ik",
Expand Down
64 changes: 60 additions & 4 deletions newton/_src/geometry/narrow_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def write_contact_simple(
writer_data.contact_tangent[index] = wp.normalize(world_x - wp.dot(world_x, normal) * normal)


def create_narrow_phase_primitive_kernel(writer_func: Any):
def create_narrow_phase_primitive_kernel(writer_func: Any, speculative: bool = False):
"""
Create a kernel for fast analytical collision detection of primitive shapes.

Expand All @@ -136,6 +136,8 @@ def create_narrow_phase_primitive_kernel(writer_func: Any):

Args:
writer_func: Contact writer function (e.g., write_contact_simple)
speculative: When True, the kernel reads per-shape velocity arrays and
extends ``gap_sum`` by a scalar speculative margin.

Returns:
A warp kernel for primitive collision detection
Expand All @@ -153,6 +155,10 @@ def narrow_phase_primitive_kernel(
shape_flags: wp.array[wp.int32],
writer_data: Any,
total_num_threads: int,
shape_lin_vel: wp.array[wp.vec3],
shape_ang_speed_bound: wp.array[float],
speculative_dt: float,
max_speculative_extension: float,
# Output: pairs that need GJK/MPR processing
gjk_candidate_pairs: wp.array[wp.vec2i],
gjk_candidate_pairs_count: wp.array[int],
Expand Down Expand Up @@ -233,6 +239,11 @@ def narrow_phase_primitive_kernel(
gap_b = shape_gap[shape_b]
gap_sum = gap_a + gap_b

if wp.static(speculative):
vel_rel = shape_lin_vel[shape_b] - shape_lin_vel[shape_a]
rel_speed = wp.length(vel_rel) + shape_ang_speed_bound[shape_a] + shape_ang_speed_bound[shape_b]
gap_sum = gap_sum + wp.min(rel_speed * speculative_dt, max_speculative_extension)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# =====================================================================
# Route heightfield pairs.
# Heightfield-vs-mesh and heightfield-vs-heightfield go through the
Expand Down Expand Up @@ -593,7 +604,11 @@ def narrow_phase_primitive_kernel(


def create_narrow_phase_kernel_gjk_mpr(
external_aabb: bool, writer_func: Any, support_func: Any = None, post_process_contact: Any = None
external_aabb: bool,
writer_func: Any,
support_func: Any = None,
post_process_contact: Any = None,
speculative: bool = False,
):
"""
Create a GJK/MPR narrow phase kernel for complex convex shape collisions.
Expand All @@ -606,6 +621,10 @@ def create_narrow_phase_kernel_gjk_mpr(

The remaining pairs are complex convex-convex (plane-box, plane-cylinder,
plane-cone, box-box, cylinder-cylinder, etc.) that need GJK/MPR.

Args:
speculative: When True, extends ``gap_sum`` by a scalar speculative
margin derived from per-shape velocity arrays.
"""

@wp.kernel(enable_backward=False, module="unique")
Expand All @@ -622,6 +641,10 @@ def narrow_phase_kernel_gjk_mpr(
shape_aabb_upper: wp.array[wp.vec3],
writer_data: Any,
total_num_threads: int,
shape_lin_vel: wp.array[wp.vec3],
shape_ang_speed_bound: wp.array[float],
speculative_dt: float,
max_speculative_extension: float,
):
"""
GJK/MPR collision detection for complex convex pairs.
Expand Down Expand Up @@ -737,6 +760,11 @@ def narrow_phase_kernel_gjk_mpr(
gap_b = shape_gap[shape_b]
gap_sum = gap_a + gap_b

if wp.static(speculative):
vel_rel = shape_lin_vel[shape_b] - shape_lin_vel[shape_a]
rel_speed = wp.length(vel_rel) + shape_ang_speed_bound[shape_a] + shape_ang_speed_bound[shape_b]
gap_sum = gap_sum + wp.min(rel_speed * speculative_dt, max_speculative_extension)

# Find and write contacts using GJK/MPR
wp.static(
create_find_contacts(writer_func, support_func=support_func, post_process_contact=post_process_contact)
Expand Down Expand Up @@ -1390,6 +1418,7 @@ def __init__(
has_meshes: bool = True,
has_heightfields: bool = False,
use_lean_gjk_mpr: bool = False,
speculative: bool = False,
) -> None:
"""
Initialize NarrowPhase with pre-allocated buffers.
Expand All @@ -1413,6 +1442,10 @@ def __init__(
Defaults to True for safety. Set to False when constructing from a model with no meshes.
has_heightfields: Whether the scene contains any heightfield shapes (GeoType.HFIELD). When True,
heightfield collision buffers and kernels are allocated. Defaults to False.
speculative: Enable speculative contact support in narrow-phase kernels.
When True, kernel variants that read per-shape velocity arrays and
extend gap thresholds are compiled. When False (default), the
speculative code paths are eliminated at compile time.
"""
self.max_candidate_pairs = max_candidate_pairs
self.max_triangle_pairs = max_triangle_pairs
Expand Down Expand Up @@ -1460,9 +1493,11 @@ def __init__(
self.tile_size_mesh_plane = 512
self.block_dim = 128

self.speculative = speculative

# Create the appropriate kernel variants
# Primitive kernel handles lightweight primitives and routes remaining pairs
self.primitive_kernel = create_narrow_phase_primitive_kernel(writer_func)
self.primitive_kernel = create_narrow_phase_primitive_kernel(writer_func, speculative=speculative)
# GJK/MPR kernel handles remaining convex-convex pairs
if use_lean_gjk_mpr:
# Use lean support function (CONVEX_MESH, BOX, SPHERE only) and lean post-processing
Expand All @@ -1472,9 +1507,12 @@ def __init__(
writer_func,
support_func=support_map_lean,
post_process_contact=post_process_minkowski_only,
speculative=speculative,
)
else:
self.narrow_phase_kernel = create_narrow_phase_kernel_gjk_mpr(self.external_aabb, writer_func)
self.narrow_phase_kernel = create_narrow_phase_kernel_gjk_mpr(
self.external_aabb, writer_func, speculative=speculative
)
# Create triangle contacts kernel when meshes or heightfields are present
if has_meshes or has_heightfields:
self.mesh_triangle_contacts_kernel = create_narrow_phase_process_mesh_triangle_contacts_kernel(writer_func)
Expand Down Expand Up @@ -1657,6 +1695,10 @@ def launch_custom_write(
shape_edge_range: wp.array[wp.vec2i] | None = None,
writer_data: Any,
device: Devicelike | None = None, # Device to launch on
shape_lin_vel: wp.array[wp.vec3] | None = None,
shape_ang_speed_bound: wp.array[wp.float32] | None = None,
speculative_dt: float = 0.0,
max_speculative_extension: float = 0.0,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> None:
"""
Launch narrow phase collision detection with a custom contact writer struct.
Expand Down Expand Up @@ -1690,6 +1732,12 @@ def launch_custom_write(
# Clear all counters with a single kernel launch (consolidated counter array)
self._counter_array.zero_()

# Resolve speculative velocity arrays (empty when disabled)
_empty_vec3 = wp.empty(0, dtype=wp.vec3, device=device)
_empty_float = wp.empty(0, dtype=wp.float32, device=device)
_slv = shape_lin_vel if shape_lin_vel is not None else _empty_vec3
_sasb = shape_ang_speed_bound if shape_ang_speed_bound is not None else _empty_float

# Stage 1: Launch primitive kernel for fast analytical collisions
# This handles sphere-sphere, sphere-capsule, capsule-capsule, plane-sphere, plane-capsule
# and routes remaining pairs to gjk_candidate_pairs and mesh buffers
Expand All @@ -1707,6 +1755,10 @@ def launch_custom_write(
shape_flags,
writer_data,
self.total_num_threads,
_slv,
_sasb,
speculative_dt,
max_speculative_extension,
],
outputs=[
self.gjk_candidate_pairs,
Expand Down Expand Up @@ -1746,6 +1798,10 @@ def launch_custom_write(
self.shape_aabb_upper,
writer_data,
self.total_num_threads,
_slv,
_sasb,
speculative_dt,
max_speculative_extension,
],
device=device,
block_dim=self.block_dim,
Expand Down
3 changes: 2 additions & 1 deletion newton/_src/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .articulation import eval_fk, eval_ik, eval_jacobian, eval_mass_matrix
from .builder import ModelBuilder
from .collide import CollisionPipeline
from .collide import CollisionPipeline, SpeculativeContactConfig
from .contacts import Contacts
from .control import Control
from .enums import (
Expand All @@ -25,6 +25,7 @@
"JointType",
"Model",
"ModelBuilder",
"SpeculativeContactConfig",
"State",
"eval_fk",
"eval_ik",
Expand Down
Loading
Loading