diff --git a/CHANGELOG.md b/CHANGELOG.md index 01631d88c8..1fc394bf60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - Add `deterministic` flag to `CollisionPipeline` and `NarrowPhase` for GPU-thread-scheduling-independent contact ordering via radix sort and deterministic fingerprint tiebreaking in contact reduction - Add fast parity-based SDF construction path for watertight meshes in `SDF.create_from_mesh`, using `wp.mesh_query_point_sign_parity` instead of winding numbers; selected via the new `sign_method` argument (`"auto"` — the default — picks parity when `Mesh.is_watertight` is true, or `"parity"` / `"winding"` to force either strategy) - Add `ViewerBase.log_arrows()` for arrow rendering (wide line + arrowhead) in the GL viewer with a dedicated geometry shader +- Add frame-to-frame contact matching via `CollisionPipeline(contact_matching=...)` with modes `"latest"` (populates `contacts.rigid_contact_match_index`) and `"sticky"` (experimental; additionally replays previous-frame contact geometry on matched contacts — the sticky update strategy may change without warning). Optional `contact_report=True` exposes new/broken contact index lists on `Contacts`. - Add `enable_multiccd` parameter to `SolverMuJoCo` for multi-CCD contact generation (up to 4 contact points per geom pair) - Support `` in the MJCF importer, and preserve authored damping, stiffness, and frictionloss when exporting ball joints to MuJoCo specs (previously silently dropped) - Add `ViewerViser.log_scalar()` for live scalar time-series plots via uPlot diff --git a/docs/api/newton_geometry.rst b/docs/api/newton_geometry.rst index 10da70edfa..f13136a76a 100644 --- a/docs/api/newton_geometry.rst +++ b/docs/api/newton_geometry.rst @@ -48,3 +48,15 @@ newton.geometry sdf_plane sdf_sphere transform_inertia + +.. rubric:: Constants + +.. list-table:: + :header-rows: 1 + + * - Name + - Value + * - ``MATCH_BROKEN`` + - ``-2`` + * - ``MATCH_NOT_FOUND`` + - ``-1`` diff --git a/docs/concepts/collisions.rst b/docs/concepts/collisions.rst index 05c483c1c8..b4f54e69a1 100644 --- a/docs/concepts/collisions.rst +++ b/docs/concepts/collisions.rst @@ -1263,6 +1263,19 @@ and is consumed by the solver :meth:`~solvers.SolverBase.step` method for contac - Contact normal, pointing from shape 0 toward shape 1 (world frame). * - ``rigid_contact_margin0``, ``rigid_contact_margin1`` - Per-shape thickness: effective radius + margin (scalar). + * - ``rigid_contact_match_index`` + - Per-contact frame-to-frame match result (int32). ``>= 0``: matched old + index, ``-1``: new, ``-2``: broken. Only allocated when + ``contact_matching`` is not ``"disabled"``. + See :ref:`Contact Matching`. + * - ``rigid_contact_new_indices``, ``rigid_contact_new_count`` + - Compact index list of new contacts in the current sorted buffer (where + ``match_index < 0``). Only allocated when ``contact_report=True``. + See :ref:`Contact Reports`. + * - ``rigid_contact_broken_indices``, ``rigid_contact_broken_count`` + - Compact index list of contacts from the previous frame that no current + contact matched. Only allocated when ``contact_report=True``. + See :ref:`Contact Reports`. **Soft contacts (particle-shape):** @@ -1679,6 +1692,138 @@ fully CUDA-graph-capturable. Hydroelastic contacts are not yet covered by deterministic ordering. +.. _Contact Matching: + +Contact Matching +---------------- + +Contact matching tracks contacts across frames, identifying which contacts +persist, which are new, and which have broken. The ``contact_matching`` +argument on :class:`~CollisionPipeline` selects one of three modes: + +- ``"disabled"`` (default) — no matching, no extra buffers. +- ``"latest"`` — match current contacts against the previous + frame and populate :attr:`Contacts.rigid_contact_match_index`, but keep the + current frame's freshly generated contact geometry in the returned + :class:`Contacts` buffer. +- ``"sticky"`` (experimental) — match like ``"latest"``, then overwrite + each matched contact's body-frame contact points (``point0``/``point1``), + offsets (``offset0``/``offset1``), and world-frame ``normal`` with the + saved previous-frame values. The remaining contact fields + (``shape0``/``shape1``, ``margin0``/``margin1``) are either key-derived + or per-shape constants and so are already identical for a matched + contact — no extra state is kept for them. Unmatched contacts pass + through with their fresh narrow-phase geometry. Useful for stacking + scenarios where small frame-to-frame geometric jitter on persistent + contacts degrades stability. + + .. warning:: + Sticky mode is experimental. The way sticky contacts are updated + across frames may change in the future without warning. + +Any non-disabled mode implies ``deterministic=True``. + +.. testsetup:: contact-matching + + import warp as wp + import newton + + builder = newton.ModelBuilder() + builder.add_ground_plane() + body = builder.add_body(xform=wp.transform((0.0, 0.0, 2.0), wp.quat_identity())) + builder.add_shape_sphere(body, radius=0.5) + model = builder.finalize() + state = model.state() + +.. testcode:: contact-matching + + pipeline = newton.CollisionPipeline( + model, + contact_matching="latest", + contact_matching_pos_threshold=0.005, # metres (default 0.0005) + contact_matching_normal_dot_threshold=0.9, # cos(~25°) + ) + contacts = pipeline.contacts() + + pipeline.collide(state, contacts) + + # Per-contact match index (int32): + # >= 0 : index of the matched contact in the previous frame + # -1 : new contact (no match found) + # -2 : key matched but position/normal thresholds exceeded (broken) + match_idx = contacts.rigid_contact_match_index.numpy() + +Each frame, the matcher binary-searches the current contacts against the +previous frame's sorted keys, then verifies candidates against a world-space +distance threshold and a normal dot-product threshold. The sort key encodes +``(shape_a, shape_b, sub_key)`` so only contacts between the same shape pair +are compared. + +The distance metric is the world-space **contact midpoint** +``0.5 * (world(point0) + world(point1))`` — symmetric in shape 0 and shape 1 +— which means swapping the two shapes of a pair does not change whether a +contact matches. It also means pure changes in penetration depth register +as motion on both sides of the contact, not just one. + +**Thresholds** + +- ``contact_matching_pos_threshold`` — maximum world-space distance [m] + between the previous and current contact midpoints for a match. Contacts + that moved more than this between frames are considered broken. Defaults + to ``0.0005`` m. +- ``contact_matching_normal_dot_threshold`` — minimum dot product between old + and new contact normals. Below this the contact is reported as broken even + if the key and position match. + +**Sticky mode** + +Replay of the matched previous-frame geometry happens after the deterministic +sort, so ``match_index`` already addresses the final sorted layout. Unmatched +rows (``MATCH_NOT_FOUND`` / ``MATCH_BROKEN``) are left untouched, so new and +threshold-broken contacts keep their fresh narrow-phase geometry. Because +matching requires both a position delta below the threshold and a normal dot +product above the threshold, the saved values are guaranteed to be a close +approximation of the current geometry and are safe to reuse. The extra +per-contact buffers (four ``vec3`` columns for the body-frame points and +offsets) are only allocated when the mode is ``"sticky"``; ``"latest"`` and +``"disabled"`` pay zero additional memory and launch no additional kernels. + +.. _Contact Reports: + +Contact Reports +^^^^^^^^^^^^^^^ + +Pass ``contact_report=True`` to also collect compact index lists of new and +broken contacts each frame. ``contact_report=True`` requires a non-disabled +matching mode: + +.. testcode:: contact-matching + + pipeline = newton.CollisionPipeline( + model, + contact_matching="latest", + contact_report=True, + ) + contacts = pipeline.contacts() + pipeline.collide(state, contacts) + + n_new = contacts.rigid_contact_new_count.numpy()[0] + new_indices = contacts.rigid_contact_new_indices.numpy()[:n_new] + + n_broken = contacts.rigid_contact_broken_count.numpy()[0] + broken_indices = contacts.rigid_contact_broken_indices.numpy()[:n_broken] + +``rigid_contact_new_indices`` holds indices into the current frame's sorted +contact buffer for every contact with ``match_index < 0``. This includes both +genuinely new contacts (``MATCH_NOT_FOUND``, ``match_index == -1``) and +threshold-broken contacts whose sort key matched a previous-frame contact but +whose position or normal exceeded the configured thresholds +(``MATCH_BROKEN``, ``match_index == -2``). Inspect +``contacts.rigid_contact_match_index`` to distinguish the two cases. + +``rigid_contact_broken_indices`` holds indices into the *previous* frame's +sorted buffer for contacts that no current contact matched. + .. _Performance: Performance diff --git a/newton/_src/geometry/__init__.py b/newton/_src/geometry/__init__.py index b3b09fc27c..a2cede80bf 100644 --- a/newton/_src/geometry/__init__.py +++ b/newton/_src/geometry/__init__.py @@ -18,6 +18,7 @@ collide_sphere_cylinder, collide_sphere_sphere, ) +from .contact_match import MATCH_BROKEN, MATCH_NOT_FOUND from .flags import ParticleFlags, ShapeFlags from .inertia import compute_inertia_shape, compute_inertia_sphere, transform_inertia from .sdf_utils import SDF @@ -32,6 +33,8 @@ from .utils import compute_shape_radius __all__ = [ + "MATCH_BROKEN", + "MATCH_NOT_FOUND", "SDF", "BroadPhaseAllPairs", "BroadPhaseExplicit", diff --git a/newton/_src/geometry/contact_match.py b/newton/_src/geometry/contact_match.py new file mode 100644 index 0000000000..ff21442253 --- /dev/null +++ b/newton/_src/geometry/contact_match.py @@ -0,0 +1,941 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 + +"""Frame-to-frame contact matching via binary search on sorted contact keys. + +Given the previous frame's sorted contacts (keys, world-space midpoints, +normals) and the current frame's unsorted contacts, this module finds +correspondences using the deterministic sort key from +:func:`~newton._src.geometry.contact_data.make_contact_sort_key`. + +For each new contact the matcher binary-searches the previous frame's +sorted keys for the ``(shape_a, shape_b)`` pair range — ignoring the +``sort_sub_key`` bits — then picks the closest previous contact in that +range whose normal also passes the dot-product threshold. "Closest" is +measured in world space between the contact *midpoints* +``0.5 * (world(point0) + world(point1))``, i.e. symmetric in shape 0 and +shape 1. The result is a per-contact match index: + +- ``>= 0``: index of the matched contact in the previous frame's sorted buffer. +- ``MATCH_NOT_FOUND (-1)``: shape pair has no prior contacts. +- ``MATCH_BROKEN (-2)``: shape pair exists but no contact within + position/normal thresholds, *or* a closer new contact won the same + prev contact in the uniqueness resolve pass. + +Why ignore sort_sub_key +----------------------- +Multi-contact manifolds (e.g. box-box face-face) can rotate the +``sort_sub_key`` assignment across frames when their internal generation +order shifts (e.g. the Sutherland-Hodgman clip's starting vertex moves +by one slot), even though the physical contact points stay essentially +in place. Matching on the full key would mark these contacts broken +every frame. Pair counts are small (a few manifold points per pair), +so the linear scan inside the pair range is cheap. + +One-to-one match via packed atomic_min +-------------------------------------- +A pair-range scan can have multiple new contacts pick the same prev +contact as their closest. To keep the mapping injective without +sorting or CAS retries, the matcher uses a single ``wp.atomic_min`` per +new contact on a per-prev ``int64`` claim word: + + claim = (float_flip(dist_sq) << 32) | tid + +``float_flip`` reinterprets the non-negative ``dist_sq`` as a +sortable ``uint32``, so the high 32 bits order claims by ascending +distance; the low 32 bits hold the new contact index, breaking ties +deterministically (smallest ``tid`` wins). After the match kernel +runs, a small finalize kernel reads ``prev_claim[best_idx]`` and +demotes any new contact whose ``tid`` does not appear in the low bits +to :data:`MATCH_BROKEN`. Losers are *not* re-matched against a +second-closest prev (kept for simplicity and speed). + +Cost: one ``int64[capacity]`` buffer, one ``wp.atomic_min`` per new +contact, and one short finalize kernel launch. No ``atomic_cas``. + +Memory efficiency +----------------- +The matcher reuses the :class:`ContactSorter`'s existing scratch buffers +(:attr:`ContactSorter.scratch_pos_world`, :attr:`ContactSorter.scratch_normal`) +to store previous-frame world-space contact midpoints and normals between +frames. This works for the *match* kernel because matching runs **before** +``ContactSorter.sort_full``, so the scratch buffers still hold the previous +frame's saved data; ``save_sorted_state`` runs **after** sorting and +refreshes them in-place for the next frame. The only additional +per-contact allocation for the non-sticky path is the ``_prev_sorted_keys`` +buffer (8 bytes/contact) since the sorter's key buffer is overwritten by +``_prepare_sort`` each frame. + +Sticky mode needs one extra dedicated buffer (``_prev_normal_sticky``, +12 bytes/contact) because :meth:`replay_matched` runs **after** +``sort_full``, at which point the sorter's ``scratch_normal`` has been +clobbered by the sort's backup pass and no longer contains the previous +frame's sorted normals. The body-frame point/offset columns already use +dedicated sticky buffers for the same reason. + +Per-frame call order (inside :class:`~newton.CollisionPipeline`):: + + matcher.match(...) # before ContactSorter.sort_full() + sorter.sort_full(...) # match_index is permuted with contacts + matcher.replay_matched(...) # sticky-only; overwrite matched rows + matcher.build_report(...) # optional; must precede save_sorted_state + matcher.save_sorted_state(...) # after sorting, replay, and report + +The ordering matters: ``save_sorted_state`` overwrites ``_prev_count`` with +the current frame's count, while ``build_report`` reads the *old* +``_prev_count`` to bound the broken-contact enumeration, and sticky +``replay_matched`` must see the post-sort ``match_index`` and the pre-save +``_prev_*`` buffers it reads from. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from ..core.types import Devicelike +from .contact_sort import SORT_KEY_SENTINEL + +if TYPE_CHECKING: + from .contact_sort import ContactSorter + +MATCH_NOT_FOUND = wp.constant(wp.int32(-1)) +"""Sentinel: no matching key found in last frame's contacts.""" + +MATCH_BROKEN = wp.constant(wp.int32(-2)) +"""Sentinel: key found but position or normal threshold exceeded.""" + + +# ------------------------------------------------------------------ +# Warp helpers +# ------------------------------------------------------------------ + + +# Sentinel value for unclaimed slots in ``_prev_claim``. Larger than +# any packed (flipped_dist << 32 | tid) any kernel will ever produce, +# so the first ``atomic_min`` always wins. +_CLAIM_SENTINEL = wp.constant(wp.int64(0x7FFFFFFFFFFFFFFF)) + + +@wp.func +def _lower_bound_int64( + lower: int, + upper: int, + target: wp.int64, + keys: wp.array[wp.int64], +) -> int: + """First index in ``keys[lower:upper]`` whose value is >= *target*. + + Returns ``upper`` if no such index exists. + """ + left = lower + right = upper + while left < right: + mid = left + (right - left) // 2 + if keys[mid] < target: + left = mid + 1 + else: + right = mid + return left + + +@wp.func_native(""" +uint32_t i = reinterpret_cast(f); +uint32_t mask = (uint32_t)(-(int)(i >> 31)) | 0x80000000u; +return i ^ mask; +""") +def _float_flip(f: float) -> wp.uint32: + """Reinterpret a 32-bit float as a sortable ``uint32`` (Stereopsis trick). + + For non-negative floats this is a strictly monotone encoding, so + comparing the resulting ``uint32`` orders the original floats + correctly. We only ever feed non-negative ``dist_sq`` values into + it, so the negative branch is unused here but kept generic. + """ + ... + + +@wp.func +def _pack_claim(dist_sq: float, tid: int) -> wp.int64: + """Pack ``(dist_sq, tid)`` into a single int64 for ``atomic_min``. + + High 32 bits: ``float_flip(dist_sq)`` — ascending by distance. + Low 32 bits: ``tid`` — deterministic tie-break (smallest wins). + """ + flipped = wp.int64(_float_flip(dist_sq)) + return (flipped << wp.int64(32)) | wp.int64(tid) + + +# ------------------------------------------------------------------ +# Match kernel +# ------------------------------------------------------------------ + + +@wp.struct +class _MatchData: + """Bundled arrays for the contact match kernel.""" + + # Previous frame (sorted) — pos/normal reuse ContactSorter scratch buffers. + # ``prev_pos_world`` holds the world-space *midpoint* between shape 0's and + # shape 1's contact points, saved by the previous frame's save kernel. + prev_keys: wp.array[wp.int64] + prev_pos_world: wp.array[wp.vec3] + prev_normal: wp.array[wp.vec3] + prev_count: wp.array[wp.int32] + + # Current frame (unsorted). + new_keys: wp.array[wp.int64] + new_point0: wp.array[wp.vec3] + new_point1: wp.array[wp.vec3] + new_shape0: wp.array[wp.int32] + new_shape1: wp.array[wp.int32] + new_normal: wp.array[wp.vec3] + new_count: wp.array[wp.int32] + + # Body transforms for world-space conversion + body_q: wp.array[wp.transform] + shape_body: wp.array[wp.int32] + + # Per-prev claim word, packed (float_flip(dist_sq) << 32 | tid). + # Initialised to _CLAIM_SENTINEL each frame; race with atomic_min. + prev_claim: wp.array[wp.int64] + + # Per-new candidate prev index (final value resolved in pass 2). + match_index: wp.array[wp.int32] + + # Thresholds + pos_threshold_sq: float + normal_dot_threshold: float + + +@wp.kernel(enable_backward=False) +def _match_contacts_kernel(data: _MatchData): + """Pass 1: pick each new contact's closest prev candidate and stake + a packed claim on it via ``wp.atomic_min``. + """ + tid = wp.tid() + n_new = data.new_count[0] + if tid >= n_new: + data.match_index[tid] = MATCH_NOT_FOUND + return + + n_old = data.prev_count[0] + if n_old == 0: + data.match_index[tid] = MATCH_NOT_FOUND + return + + target_key = data.new_keys[tid] + + # World-space midpoint of the two contact points (symmetric in shape 0 / + # shape 1). Matches the quantity persisted by ``_save_sorted_state_kernel`` + # for the previous frame, so both sides of ``diff`` below measure the same + # physical quantity. + p0 = data.new_point0[tid] + bid0 = data.shape_body[data.new_shape0[tid]] + if bid0 == -1: + p0w = p0 + else: + p0w = wp.transform_point(data.body_q[bid0], p0) + + p1 = data.new_point1[tid] + bid1 = data.shape_body[data.new_shape1[tid]] + if bid1 == -1: + p1w = p1 + else: + p1w = wp.transform_point(data.body_q[bid1], p1) + + new_pos_w = 0.5 * (p0w + p1w) + new_n = data.new_normal[tid] + + # Binary search the [range_lo, range_hi) interval of prev contacts + # sharing the same (shape_a, shape_b) pair. We ignore sort_sub_key + # because for multi-contact manifolds (e.g. box-box face-face) the + # sub-key assignment is not stable across frames; matching by the + # exact key would spuriously break stable contacts. Pair counts are + # small (<= a few manifold points), so a linear scan inside the + # range is cheap. + pair_prefix = target_key & wp.int64(~0x7FFFFF) + pair_end = pair_prefix + wp.int64(0x800000) # 1 << 23 + range_lo = _lower_bound_int64(0, n_old, pair_prefix, data.prev_keys) + range_hi = _lower_bound_int64(range_lo, n_old, pair_end, data.prev_keys) + + if range_lo >= range_hi: + data.match_index[tid] = MATCH_NOT_FOUND + return + + # Closest-point match within the pair range, gated by normal dot. + best_idx = int(-1) + best_dist_sq = float(data.pos_threshold_sq) + for old_idx in range(range_lo, range_hi): + old_pos = data.prev_pos_world[old_idx] + diff = new_pos_w - old_pos + dist_sq = wp.dot(diff, diff) + old_n = data.prev_normal[old_idx] + ndot = wp.dot(new_n, old_n) + + if dist_sq <= best_dist_sq and ndot >= data.normal_dot_threshold: + best_dist_sq = dist_sq + best_idx = old_idx + + if best_idx >= 0: + data.match_index[tid] = wp.int32(best_idx) + # Race for ownership of prev[best_idx] with a single atomic_min. + # Closest distance wins; ties resolved by lowest tid. + wp.atomic_min(data.prev_claim, best_idx, _pack_claim(best_dist_sq, tid)) + else: + # Pair range exists but no contact within thresholds. + data.match_index[tid] = MATCH_BROKEN + + +@wp.kernel(enable_backward=False) +def _clear_prev_claim_kernel( + prev_claim: wp.array[wp.int64], + prev_count: wp.array[wp.int32], +): + """Reset only the active prefix of the claim buffer to ``_CLAIM_SENTINEL``. + + Launched with ``capacity`` threads so the per-frame launch fits a + static CUDA graph, but each thread guards on ``prev_count[0]`` so we + only touch the (typically much smaller) range of slots that ``match`` + will actually race on. Slots beyond ``prev_count`` are never read + by either kernel, so leaving them stale is safe. + """ + i = wp.tid() + if i < prev_count[0]: + prev_claim[i] = _CLAIM_SENTINEL + + +@wp.kernel(enable_backward=False) +def _resolve_claims_kernel( + match_index: wp.array[wp.int32], + prev_claim: wp.array[wp.int64], + prev_was_matched: wp.array[wp.int32], + new_count: wp.array[wp.int32], + has_report: int, +): + """Pass 2: keep winners, demote losers to :data:`MATCH_BROKEN`. + + The low 32 bits of ``prev_claim[cand]`` identify the winning + ``tid``; everyone else who staked a claim on the same ``cand`` + becomes :data:`MATCH_BROKEN` (no second-closest fallback). + """ + tid = wp.tid() + if tid >= new_count[0]: + return + + cand = match_index[tid] + if cand < wp.int32(0): + return # already MATCH_NOT_FOUND or MATCH_BROKEN + + winner_tid = wp.int32(prev_claim[cand] & wp.int64(0xFFFFFFFF)) + if winner_tid == wp.int32(tid): + if has_report != 0: + prev_was_matched[cand] = wp.int32(1) + else: + match_index[tid] = MATCH_BROKEN + + +# ------------------------------------------------------------------ +# Save sorted state kernel +# ------------------------------------------------------------------ + + +@wp.struct +class _SaveStateData: + """Bundled arrays for the save-sorted-state kernel. + + ``src_point0`` / ``src_point1`` and their shape indices are consumed every + frame to compute the symmetric world-space midpoint written into + ``dst_pos_world``. ``src_offset*`` and the ``dst_*_body`` columns are + only consumed when ``has_sticky != 0``; when sticky is disabled the + matcher passes dummy arrays for those slots and the kernel's + ``if has_sticky`` guard skips the extra writes, so sticky-only columns + need zero allocation in the non-sticky path. + """ + + src_keys: wp.array[wp.int64] + src_point0: wp.array[wp.vec3] + src_point1: wp.array[wp.vec3] + src_offset0: wp.array[wp.vec3] + src_offset1: wp.array[wp.vec3] + src_shape0: wp.array[wp.int32] + src_shape1: wp.array[wp.int32] + src_normal: wp.array[wp.vec3] + src_count: wp.array[wp.int32] + + body_q: wp.array[wp.transform] + shape_body: wp.array[wp.int32] + + dst_keys: wp.array[wp.int64] + dst_pos_world: wp.array[wp.vec3] # world-space midpoint of point0 and point1 + dst_normal: wp.array[wp.vec3] + dst_point0_body: wp.array[wp.vec3] + dst_point1_body: wp.array[wp.vec3] + dst_offset0_body: wp.array[wp.vec3] + dst_offset1_body: wp.array[wp.vec3] + # Dedicated sticky-replay normal buffer. Duplicates ``dst_normal`` content + # but lives in its own allocation so sticky replay (which runs between + # ``sort_full`` and the next ``save_sorted_state``) is not reading the + # sorter's ``scratch_normal`` after the sort has clobbered it. + dst_normal_sticky: wp.array[wp.vec3] + dst_count: wp.array[wp.int32] + + has_sticky: int + + +@wp.kernel(enable_backward=False) +def _save_sorted_state_kernel(data: _SaveStateData): + """Copy sorted contacts into the previous-frame buffers for next-frame matching. + + The persisted ``dst_pos_world`` is the world-space *midpoint* of the two + contact points, so the next frame's match kernel compares a shape-symmetric + quantity. + """ + i = wp.tid() + if i == 0: + data.dst_count[0] = data.src_count[0] + if i < data.src_count[0]: + data.dst_keys[i] = data.src_keys[i] + + p0 = data.src_point0[i] + bid0 = data.shape_body[data.src_shape0[i]] + if bid0 == -1: + p0w = p0 + else: + p0w = wp.transform_point(data.body_q[bid0], p0) + + p1 = data.src_point1[i] + bid1 = data.shape_body[data.src_shape1[i]] + if bid1 == -1: + p1w = p1 + else: + p1w = wp.transform_point(data.body_q[bid1], p1) + + data.dst_pos_world[i] = 0.5 * (p0w + p1w) + data.dst_normal[i] = data.src_normal[i] + + if data.has_sticky != 0: + data.dst_point0_body[i] = p0 + data.dst_point1_body[i] = p1 + data.dst_offset0_body[i] = data.src_offset0[i] + data.dst_offset1_body[i] = data.src_offset1[i] + data.dst_normal_sticky[i] = data.src_normal[i] + + +# ------------------------------------------------------------------ +# Sticky-mode replay (matched rows only) +# ------------------------------------------------------------------ +# +# Sticky mode preserves only the fields that actually change across frames +# for a matched contact: the body-frame contact points (``point0``/``point1``) +# and offsets (``offset0``/``offset1``), plus the world-frame normal (which +# is already persisted for matching in ``prev_normal``, no extra allocation). +# +# Everything else is either key-derived or a per-shape constant that does +# not change between frames, so the new frame's values are already correct: +# +# - ``shape0`` / ``shape1`` : implied by the sort key; identical by +# construction for matched contacts. +# - ``margin0`` / ``margin1``: ``radius_eff + margin``, per-shape constant. +# - ``stiffness`` / ``damping`` / ``friction``: per-shape constants, and +# contact matching for hydroelastic contacts (the only path that writes +# these) is not supported anyway. + + +@wp.struct +class _ReplayData: + """Bundled arrays for the sticky replay kernel.""" + + match_index: wp.array[wp.int32] + contact_count: wp.array[wp.int32] + + prev_point0: wp.array[wp.vec3] + prev_point1: wp.array[wp.vec3] + prev_offset0: wp.array[wp.vec3] + prev_offset1: wp.array[wp.vec3] + prev_normal: wp.array[wp.vec3] + + point0: wp.array[wp.vec3] + point1: wp.array[wp.vec3] + offset0: wp.array[wp.vec3] + offset1: wp.array[wp.vec3] + normal: wp.array[wp.vec3] + + +@wp.kernel(enable_backward=False) +def _replay_matched_kernel(data: _ReplayData): + tid = wp.tid() + if tid >= data.contact_count[0]: + return + idx = data.match_index[tid] + if idx < wp.int32(0): + return # MATCH_NOT_FOUND or MATCH_BROKEN -- keep new-frame data. + data.point0[tid] = data.prev_point0[idx] + data.point1[tid] = data.prev_point1[idx] + data.offset0[tid] = data.prev_offset0[idx] + data.offset1[tid] = data.prev_offset1[idx] + data.normal[tid] = data.prev_normal[idx] + + +# ------------------------------------------------------------------ +# Contact report kernels +# ------------------------------------------------------------------ + + +@wp.kernel(enable_backward=False) +def _collect_new_contacts_kernel( + match_index: wp.array[wp.int32], + contact_count: wp.array[wp.int32], + new_indices: wp.array[wp.int32], + new_count: wp.array[wp.int32], +): + """Collect indices of new or broken contacts (match_index < 0) after sorting.""" + i = wp.tid() + if i >= contact_count[0]: + return + if match_index[i] < wp.int32(0): + slot = wp.atomic_add(new_count, 0, wp.int32(1)) + new_indices[slot] = wp.int32(i) + + +@wp.kernel(enable_backward=False) +def _collect_broken_contacts_kernel( + prev_was_matched: wp.array[wp.int32], + prev_count: wp.array[wp.int32], + broken_indices: wp.array[wp.int32], + broken_count: wp.array[wp.int32], +): + """Collect indices of old contacts that were not matched by any new contact.""" + i = wp.tid() + if i >= prev_count[0]: + return + if prev_was_matched[i] == wp.int32(0): + slot = wp.atomic_add(broken_count, 0, wp.int32(1)) + broken_indices[slot] = wp.int32(i) + + +# ------------------------------------------------------------------ +# ContactMatcher class +# ------------------------------------------------------------------ + + +class ContactMatcher: + """Frame-to-frame contact matching using binary search on sorted keys. + + Internal helper owned by :class:`~newton.CollisionPipeline`. All user-visible + results (match index, new/broken index lists) are surfaced on the + :class:`~newton.Contacts` container; this class only owns cross-frame state. + + Pre-allocates all buffers at construction time for CUDA graph capture + compatibility. See the module docstring for the per-frame call order and + the ordering constraints between :meth:`match`, :meth:`replay_matched`, + :meth:`build_report`, and :meth:`save_sorted_state`. + + Memory is minimised by reusing the sorter's existing scratch buffers for + the previous-frame world-space contact midpoints and normals. The matcher + owns two small per-contact buffers in addition: the sorted key cache + (8 bytes/contact) and the per-prev claim word used by the ``atomic_min`` + race that keeps new→prev injective (8 bytes/contact). When + ``contact_report`` is disabled, the ``prev_was_matched`` flag array is + also skipped. + + .. note:: + Previous-frame state persists across :meth:`~newton.CollisionPipeline.collide` + calls — that is the whole point. In RL-style workflows where the user + resets or teleports bodies between episodes, call :meth:`reset` after + such discontinuities so the next frame starts fresh with all + :data:`MATCH_NOT_FOUND`. + + Args: + capacity: Maximum number of contacts (must match :class:`ContactSorter`). + sorter: The :class:`ContactSorter` whose scratch buffers will be + reused for storing previous-frame positions and normals. + pos_threshold: World-space distance threshold [m] between the + previous and current contact midpoints + ``0.5 * (world(point0) + world(point1))``. Contacts whose midpoint + moved more than this between frames are considered broken. + normal_dot_threshold: Minimum dot product between old and new contact + normals. Below this the contact is considered broken. + contact_report: Allocate the ``prev_was_matched`` flag array needed + to enumerate broken contacts in :meth:`build_report`. + sticky: Allocate five extra per-contact ``wp.vec3`` buffers + (``point0``/``point1``/``offset0``/``offset1`` body-frame, plus a + dedicated ``normal`` buffer) used by :meth:`replay_matched`. The + world-frame normal needs its own allocation because sticky replay + runs after ``ContactSorter.sort_full`` has clobbered the + ``scratch_normal`` alias the match kernel reads pre-sort. When + ``False`` these attributes are ``None`` and no extra kernel + launches are added. + device: Device to allocate on. + """ + + def __init__( + self, + capacity: int, + *, + sorter: ContactSorter, + pos_threshold: float = 0.0005, + normal_dot_threshold: float = 0.995, + contact_report: bool = False, + sticky: bool = False, + device: Devicelike = None, + ): + with wp.ScopedDevice(device): + self._capacity = capacity + self._pos_threshold_sq = pos_threshold * pos_threshold + self._normal_dot_threshold = normal_dot_threshold + self._sorter = sorter + + # Only buffer we must own: sorted keys survive across frames + # (_sort_keys_copy is overwritten by _prepare_sort each frame). + # Init with the sort-key sentinel so a debugger dump of the buffer + # before the first save_sorted_state does not look like valid keys + # for shape_a=0, shape_b=0, sub_key=0. + self._prev_sorted_keys = wp.full(capacity, SORT_KEY_SENTINEL, dtype=wp.int64) + self._prev_count = wp.zeros(1, dtype=wp.int32) + + # Per-prev claim word for the atomic_min race that keeps the + # new→prev mapping injective (see module docstring). Reset + # to _CLAIM_SENTINEL each frame; the low 32 bits of the + # surviving value identify the winning new contact ``tid``. + self._prev_claim = wp.empty(capacity, dtype=wp.int64) + + # Contact report (optional). + self._has_report = contact_report + if contact_report: + self._prev_was_matched = wp.zeros(capacity, dtype=wp.int32) + else: + # Dummy single-element array so the Warp struct is always valid. + self._prev_was_matched = wp.zeros(1, dtype=wp.int32) + + # Sticky-mode buffers. Only the body-frame point/offset pairs + # and the world-frame normal need preserving -- shape indices, + # margins, and per-shape properties are either key-derived or + # per-shape constants and so identical on the next frame for a + # matched contact. The normal cannot reuse the sorter's + # ``scratch_normal`` like the match kernel does, because sticky + # replay runs *after* ``ContactSorter.sort_full`` and by then + # ``scratch_normal`` has been clobbered with the current frame's + # pre-sort normals by the sort's backup pass. + self._sticky = sticky + if sticky: + self._prev_point0 = wp.zeros(capacity, dtype=wp.vec3) + self._prev_point1 = wp.zeros(capacity, dtype=wp.vec3) + self._prev_offset0 = wp.zeros(capacity, dtype=wp.vec3) + self._prev_offset1 = wp.zeros(capacity, dtype=wp.vec3) + self._prev_normal_sticky = wp.zeros(capacity, dtype=wp.vec3) + else: + self._prev_point0 = None + self._prev_point1 = None + self._prev_offset0 = None + self._prev_offset1 = None + self._prev_normal_sticky = None + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def has_report(self) -> bool: + """Whether the contact report buffers are allocated.""" + return self._has_report + + @property + def is_sticky(self) -> bool: + """Whether sticky-mode full-record buffers are allocated.""" + return self._sticky + + @property + def prev_contact_count(self) -> wp.array[wp.int32]: + """Device-side previous frame contact count (single-element int32).""" + return self._prev_count + + def reset(self) -> None: + """Clear cross-frame state so the next frame starts fresh. + + Use this after any discontinuity that invalidates the previous + frame's contacts (RL episode reset, teleported bodies, scene + reload). After ``reset()`` the next :meth:`match` produces all + :data:`MATCH_NOT_FOUND` and :meth:`build_report` reports zero broken + contacts. Zeroing ``_prev_count`` is sufficient because both kernels + gate on it. + """ + self._prev_count.zero_() + + # ------------------------------------------------------------------ + # Public methods + # ------------------------------------------------------------------ + + def match( + self, + sort_keys: wp.array[wp.int64], + contact_count: wp.array[wp.int32], + point0: wp.array[wp.vec3], + point1: wp.array[wp.vec3], + shape0: wp.array[wp.int32], + shape1: wp.array[wp.int32], + normal: wp.array[wp.vec3], + body_q: wp.array[wp.transform], + shape_body: wp.array[wp.int32], + match_index_out: wp.array[wp.int32], + *, + device: Devicelike = None, + ) -> None: + """Match current unsorted contacts against last frame's sorted contacts. + + Must be called **before** :meth:`ContactSorter.sort_full`. + + Distance is measured between world-space contact midpoints + (``0.5 * (world(point0) + world(point1))``) so the metric is symmetric + in shape 0 / shape 1. + + Args: + sort_keys: Current frame's unsorted int64 sort keys. + contact_count: Single-element int array with active contact count. + point0: Body-frame contact points on shape 0 (current frame). + point1: Body-frame contact points on shape 1 (current frame). + shape0: Shape indices for shape 0 (current frame). + shape1: Shape indices for shape 1 (current frame). + normal: Contact normals (current frame). + body_q: Body transforms for the current frame. + shape_body: Shape-to-body index map. + match_index_out: Output int32 array to receive match results. + Written directly (no intermediate copy). + device: Device to launch on. + """ + if self._has_report: + self._prev_was_matched.zero_() + + # Reset only the active prefix of the claim buffer. Launching + # ``capacity`` threads keeps the call shape constant for graph + # capture, but the kernel guards on ``prev_count`` so we touch + # the minimum bytes — important for sparsely-loaded pipelines + # where ``capacity >> prev_count``. + wp.launch( + _clear_prev_claim_kernel, + dim=self._capacity, + inputs=[self._prev_claim, self._prev_count], + device=device, + ) + + data = _MatchData() + data.prev_keys = self._prev_sorted_keys + # Reuse sorter scratch buffers for prev-frame world-space data. + data.prev_pos_world = self._sorter.scratch_pos_world + data.prev_normal = self._sorter.scratch_normal + data.prev_count = self._prev_count + data.new_keys = sort_keys + data.new_point0 = point0 + data.new_point1 = point1 + data.new_shape0 = shape0 + data.new_shape1 = shape1 + data.new_normal = normal + data.new_count = contact_count + data.body_q = body_q + data.shape_body = shape_body + data.match_index = match_index_out + data.prev_claim = self._prev_claim + data.pos_threshold_sq = self._pos_threshold_sq + data.normal_dot_threshold = self._normal_dot_threshold + + wp.launch(_match_contacts_kernel, dim=self._capacity, inputs=[data], device=device) + wp.launch( + _resolve_claims_kernel, + dim=self._capacity, + inputs=[ + match_index_out, + self._prev_claim, + self._prev_was_matched, + contact_count, + 1 if self._has_report else 0, + ], + device=device, + ) + + def save_sorted_state( + self, + sorted_keys: wp.array[wp.int64], + contact_count: wp.array[wp.int32], + sorted_point0: wp.array[wp.vec3], + sorted_point1: wp.array[wp.vec3], + sorted_shape0: wp.array[wp.int32], + sorted_shape1: wp.array[wp.int32], + sorted_normal: wp.array[wp.vec3], + body_q: wp.array[wp.transform], + shape_body: wp.array[wp.int32], + *, + sorted_offset0: wp.array[wp.vec3] | None = None, + sorted_offset1: wp.array[wp.vec3] | None = None, + device: Devicelike = None, + ) -> None: + """Save current frame's sorted contacts for next-frame matching. + + Must be called **after** :meth:`ContactSorter.sort_full`. The + world-space midpoint of ``sorted_point0``/``sorted_point1`` and the + sorted normal are written into the sorter's scratch buffers + (:attr:`ContactSorter.scratch_pos_world` / + :attr:`ContactSorter.scratch_normal`), which are idle between frames. + + When the matcher was built with ``sticky=True``, the body-frame + point/offset columns are also persisted for :meth:`replay_matched` in + the same kernel launch. ``sorted_offset0`` / ``sorted_offset1`` are + required in that case and ignored otherwise. + + Args: + sorted_keys: Sorted int64 keys (use :attr:`ContactSorter.sorted_keys_view`). + contact_count: Single-element int array with active contact count. + sorted_point0: Sorted body-frame contact points on shape 0. + sorted_point1: Sorted body-frame contact points on shape 1. + sorted_shape0: Sorted shape 0 indices. + sorted_shape1: Sorted shape 1 indices. + sorted_normal: Sorted contact normals. + body_q: Body transforms (current frame). + shape_body: Shape-to-body index map. + sorted_offset0, sorted_offset1: Required when sticky is enabled; + ignored otherwise. + device: Device to launch on. + """ + data = _SaveStateData() + data.src_keys = sorted_keys + data.src_point0 = sorted_point0 + data.src_point1 = sorted_point1 + data.src_shape0 = sorted_shape0 + data.src_shape1 = sorted_shape1 + data.src_normal = sorted_normal + data.src_count = contact_count + data.body_q = body_q + data.shape_body = shape_body + data.dst_keys = self._prev_sorted_keys + # Write world-space midpoint and normal into the sorter's scratch buffers. + data.dst_pos_world = self._sorter.scratch_pos_world + data.dst_normal = self._sorter.scratch_normal + data.dst_count = self._prev_count + + if self._sticky: + if sorted_offset0 is None or sorted_offset1 is None: + raise ValueError("save_sorted_state requires sorted_offset0/offset1 when sticky is enabled") + data.src_offset0 = sorted_offset0 + data.src_offset1 = sorted_offset1 + data.dst_point0_body = self._prev_point0 + data.dst_point1_body = self._prev_point1 + data.dst_offset0_body = self._prev_offset0 + data.dst_offset1_body = self._prev_offset1 + data.dst_normal_sticky = self._prev_normal_sticky + data.has_sticky = 1 + else: + # The struct requires a valid array for every field -- the + # kernel guards with has_sticky==0 and never reads/writes them. + data.src_offset0 = sorted_point0 + data.src_offset1 = sorted_point0 + data.dst_point0_body = self._sorter.scratch_pos_world + data.dst_point1_body = self._sorter.scratch_pos_world + data.dst_offset0_body = self._sorter.scratch_pos_world + data.dst_offset1_body = self._sorter.scratch_pos_world + data.dst_normal_sticky = self._sorter.scratch_pos_world + data.has_sticky = 0 + + wp.launch(_save_sorted_state_kernel, dim=self._capacity, inputs=[data], device=device) + + def replay_matched( + self, + contact_count: wp.array[wp.int32], + match_index: wp.array[wp.int32], + *, + point0: wp.array[wp.vec3], + point1: wp.array[wp.vec3], + offset0: wp.array[wp.vec3], + offset1: wp.array[wp.vec3], + normal: wp.array[wp.vec3], + device: Devicelike = None, + ) -> None: + """Overwrite matched rows with the saved previous-frame contact geometry. + + Only valid when the matcher was constructed with ``sticky=True``. Must + run **after** :meth:`ContactSorter.sort_full` and **before** + :meth:`save_sorted_state`. Unmatched rows (``match_index < 0``) are + left untouched so new contacts keep their fresh narrow-phase geometry. + Only ``point0``/``point1``/``offset0``/``offset1``/``normal`` are + restored; other fields (``shape0``/``shape1``, margins, ...) are + already identical for a matched contact. + + Args: + contact_count: Single-element int array with the active contact count. + match_index: Sorted match_index array (from :class:`Contacts`). + point0, point1, offset0, offset1, normal: Current-frame sorted + contact record to be overwritten on matched rows. + device: Device to launch on. + """ + if not self._sticky: + raise ValueError("replay_matched requires the matcher to be constructed with sticky=True") + + data = _ReplayData() + data.match_index = match_index + data.contact_count = contact_count + data.prev_point0 = self._prev_point0 + data.prev_point1 = self._prev_point1 + data.prev_offset0 = self._prev_offset0 + data.prev_offset1 = self._prev_offset1 + # Use the dedicated sticky normal buffer, NOT sorter.scratch_normal: + # replay runs after ``sort_full``, which has clobbered scratch_normal + # with the current frame's pre-sort normals during its backup pass. + data.prev_normal = self._prev_normal_sticky + data.point0 = point0 + data.point1 = point1 + data.offset0 = offset0 + data.offset1 = offset1 + data.normal = normal + + wp.launch(_replay_matched_kernel, dim=self._capacity, inputs=[data], device=device) + + def build_report( + self, + match_index: wp.array[wp.int32], + contact_count: wp.array[wp.int32], + new_indices: wp.array[wp.int32], + new_count: wp.array[wp.int32], + broken_indices: wp.array[wp.int32], + broken_count: wp.array[wp.int32], + *, + device: Devicelike = None, + ) -> None: + """Build new/broken contact index lists (optional, post-sort). + + Must be called **after** :meth:`ContactSorter.sort_full` and **before** + :meth:`save_sorted_state` (``save_sorted_state`` overwrites + ``_prev_count``, which this method reads to bound the broken-contact + enumeration). + + After this call, ``new_indices`` / ``new_count`` hold indices of + contacts in the current sorted buffer that have no prior match + (``match_index < 0``), and ``broken_indices`` / ``broken_count`` hold + indices of old contacts that were not matched by any new contact. + + Args: + match_index: Sorted match_index array (from :class:`Contacts`). + contact_count: Single-element int array with active contact count. + new_indices: Output array to receive new-contact indices. + new_count: Single-element output counter for new contacts. + broken_indices: Output array to receive broken-contact indices + (indexing the previous frame's sorted buffer). + broken_count: Single-element output counter for broken contacts. + device: Device to launch on. + """ + if not self._has_report: + return + + new_count.zero_() + broken_count.zero_() + + wp.launch( + _collect_new_contacts_kernel, + dim=self._capacity, + inputs=[match_index, contact_count, new_indices, new_count], + device=device, + ) + wp.launch( + _collect_broken_contacts_kernel, + dim=self._capacity, + inputs=[self._prev_was_matched, self._prev_count, broken_indices, broken_count], + device=device, + ) diff --git a/newton/_src/geometry/contact_sort.py b/newton/_src/geometry/contact_sort.py index 5e2135c86d..2c5f51644f 100644 --- a/newton/_src/geometry/contact_sort.py +++ b/newton/_src/geometry/contact_sort.py @@ -62,12 +62,15 @@ class _SimpleContactArrays: normal: wp.array[wp.vec3] penetration: wp.array[float] tangent: wp.array[wp.vec3] + match_index: wp.array[wp.int32] pair_buf: wp.array[wp.vec2i] position_buf: wp.array[wp.vec3] normal_buf: wp.array[wp.vec3] penetration_buf: wp.array[float] tangent_buf: wp.array[wp.vec3] + match_index_buf: wp.array[wp.int32] has_tangent: int + has_match_index: int @wp.kernel(enable_backward=False) @@ -82,6 +85,8 @@ def _backup_simple_kernel(data: _SimpleContactArrays, count: wp.array[int]): data.penetration_buf[i] = data.penetration[i] if data.has_tangent != 0: data.tangent_buf[i] = data.tangent[i] + if data.has_match_index != 0: + data.match_index_buf[i] = data.match_index[i] @wp.kernel(enable_backward=False) @@ -97,6 +102,8 @@ def _gather_simple_kernel(data: _SimpleContactArrays, perm: wp.array[wp.int32], data.penetration[i] = data.penetration_buf[p] if data.has_tangent != 0: data.tangent[i] = data.tangent_buf[p] + if data.has_match_index != 0: + data.match_index[i] = data.match_index_buf[p] @wp.struct @@ -116,6 +123,7 @@ class _FullContactArrays: stiffness: wp.array[float] damping: wp.array[float] friction: wp.array[float] + match_index: wp.array[wp.int32] shape0_buf: wp.array[wp.int32] shape1_buf: wp.array[wp.int32] point0_buf: wp.array[wp.vec3] @@ -129,7 +137,9 @@ class _FullContactArrays: stiffness_buf: wp.array[float] damping_buf: wp.array[float] friction_buf: wp.array[float] + match_index_buf: wp.array[wp.int32] has_shape_props: int + has_match_index: int @wp.kernel(enable_backward=False) @@ -152,6 +162,8 @@ def _backup_full_kernel(data: _FullContactArrays, count: wp.array[int]): data.stiffness_buf[i] = data.stiffness[i] data.damping_buf[i] = data.damping[i] data.friction_buf[i] = data.friction[i] + if data.has_match_index != 0: + data.match_index_buf[i] = data.match_index[i] @wp.kernel(enable_backward=False) @@ -175,6 +187,8 @@ def _gather_full_kernel(data: _FullContactArrays, perm: wp.array[wp.int32], coun data.stiffness[i] = data.stiffness_buf[p] data.damping[i] = data.damping_buf[p] data.friction[i] = data.friction_buf[p] + if data.has_match_index != 0: + data.match_index[i] = data.match_index_buf[p] class ContactSorter: @@ -205,6 +219,7 @@ def __init__(self, capacity: int, *, per_contact_shape_properties: bool = False, self._simple_normal_buf = wp.zeros(capacity, dtype=wp.vec3) self._simple_penetration_buf = wp.zeros(capacity, dtype=float) self._simple_tangent_buf = wp.zeros(capacity, dtype=wp.vec3) + self._simple_match_index_buf = wp.zeros(1, dtype=wp.int32) # Scratch buffers for the full gather (CollisionPipeline.collide path). self._full_shape0_buf = wp.zeros(capacity, dtype=wp.int32) @@ -225,6 +240,7 @@ def __init__(self, capacity: int, *, per_contact_shape_properties: bool = False, self._full_stiffness_buf = wp.zeros(0, dtype=float) self._full_damping_buf = wp.zeros(0, dtype=float) self._full_friction_buf = wp.zeros(0, dtype=float) + self._full_match_index_buf = wp.zeros(capacity, dtype=wp.int32) # ------------------------------------------------------------------ # Public API @@ -240,6 +256,7 @@ def sort_simple( contact_normal: wp.array, contact_penetration: wp.array, contact_tangent: wp.array | None = None, + match_index: wp.array | None = None, device: Devicelike = None, ) -> None: """Sort contacts written through the simplified narrow-phase writer. @@ -254,12 +271,16 @@ def sort_simple( contact_normal: vec3 contact normals. contact_penetration: float penetration depths. contact_tangent: Optional vec3 tangent array. + match_index: Optional int32 array of per-contact match indices + from :class:`ContactMatcher`. When provided, the array is + permuted alongside the other contact fields during sorting. device: Device to launch on. """ n = self._capacity self._sort_and_permute(sort_keys, contact_count, device=device) has_tangent = contact_tangent is not None and contact_tangent.shape[0] > 0 + has_match = match_index is not None and match_index.shape[0] > 0 data = _SimpleContactArrays() data.pair = contact_pair @@ -267,12 +288,15 @@ def sort_simple( data.normal = contact_normal data.penetration = contact_penetration data.tangent = contact_tangent if has_tangent else self._simple_tangent_buf + data.match_index = match_index if has_match else self._simple_match_index_buf data.pair_buf = self._simple_pair_buf data.position_buf = self._simple_position_buf data.normal_buf = self._simple_normal_buf data.penetration_buf = self._simple_penetration_buf data.tangent_buf = self._simple_tangent_buf + data.match_index_buf = self._simple_match_index_buf data.has_tangent = 1 if has_tangent else 0 + data.has_match_index = 1 if has_match else 0 wp.launch(_backup_simple_kernel, dim=n, inputs=[data, contact_count], device=device) wp.launch(_gather_simple_kernel, dim=n, inputs=[data, self._sort_indices, contact_count], device=device) @@ -295,6 +319,7 @@ def sort_full( stiffness: wp.array | None = None, damping: wp.array | None = None, friction: wp.array | None = None, + match_index: wp.array | None = None, device: Devicelike = None, ) -> None: """Sort contacts written through the full collide.py writer. @@ -317,12 +342,16 @@ def sort_full( stiffness: Optional float per-contact stiffness. damping: Optional float per-contact damping. friction: Optional float per-contact friction. + match_index: Optional int32 array of per-contact match indices + from :class:`ContactMatcher`. When provided, the array is + permuted alongside the other contact fields during sorting. device: Device to launch on. """ n = self._capacity self._sort_and_permute(sort_keys, contact_count, device=device) has_props = self._has_shape_props + has_match = match_index is not None and match_index.shape[0] > 0 data = _FullContactArrays() data.shape0 = shape0 @@ -342,6 +371,7 @@ def sort_full( data.friction = ( friction if has_props and friction is not None and friction.shape[0] > 0 else self._full_friction_buf ) + data.match_index = match_index if has_match else self._full_match_index_buf data.shape0_buf = self._full_shape0_buf data.shape1_buf = self._full_shape1_buf data.point0_buf = self._full_point0_buf @@ -355,11 +385,50 @@ def sort_full( data.stiffness_buf = self._full_stiffness_buf data.damping_buf = self._full_damping_buf data.friction_buf = self._full_friction_buf + data.match_index_buf = self._full_match_index_buf data.has_shape_props = 1 if has_props else 0 + data.has_match_index = 1 if has_match else 0 wp.launch(_backup_full_kernel, dim=n, inputs=[data, contact_count], device=device) wp.launch(_gather_full_kernel, dim=n, inputs=[data, self._sort_indices, contact_count], device=device) + @property + def sorted_keys_view(self) -> wp.array: + """View of sorted keys (first half of internal buffer). + + Valid only after :meth:`sort_simple` or :meth:`sort_full` returns. + The array has ``capacity`` elements; active entries are + ``sorted_keys_view[:contact_count]``. + """ + return self._sort_keys_copy[: self._capacity] + + @property + def scratch_pos_world(self) -> wp.array: + """Shared scratch buffer for external cross-frame world-space positions. + + Sized ``capacity`` :class:`wp.vec3`. Reserved for use by + :class:`~newton._src.geometry.contact_match.ContactMatcher`, which + repurposes the sorter's unused ``point0`` scratch between frames to + store the previous frame's world-space contact positions. + + .. note:: + The buffer is **only idle between frames** — i.e. between the end + of one :meth:`sort_full` call and the start of the next. Writes + outside that window will corrupt the next sort. Do not write to + this buffer unless you are implementing cross-frame state that + coordinates with the pipeline's per-frame call order. + """ + return self._full_point0_buf + + @property + def scratch_normal(self) -> wp.array: + """Shared scratch buffer for external cross-frame world-space normals. + + Sized ``capacity`` :class:`wp.vec3`. Companion to + :attr:`scratch_pos_world`; see that property for usage constraints. + """ + return self._full_normal_buf + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ diff --git a/newton/_src/sim/collide.py b/newton/_src/sim/collide.py index 953ec6e367..85ea9090bd 100644 --- a/newton/_src/sim/collide.py +++ b/newton/_src/sim/collide.py @@ -12,6 +12,7 @@ from ..geometry.broad_phase_sap import BroadPhaseSAP from ..geometry.collision_core import compute_tight_aabb_from_support from ..geometry.contact_data import ContactData, make_contact_sort_key +from ..geometry.contact_match import ContactMatcher from ..geometry.contact_sort import ContactSorter from ..geometry.differentiable_contacts import launch_differentiable_contact_augment from ..geometry.flags import ShapeFlags @@ -477,6 +478,10 @@ def __init__( narrow_phase: NarrowPhase | None = None, sdf_hydroelastic_config: HydroelasticSDF.Config | None = None, deterministic: bool = False, + contact_matching: Literal["disabled", "latest", "sticky"] = "disabled", + contact_matching_pos_threshold: float = 0.0005, + contact_matching_normal_dot_threshold: float = 0.995, + contact_report: bool = False, verify_buffers: bool = True, ): """ @@ -512,6 +517,23 @@ def __init__( deterministic: Sort contacts after the narrow phase so that results are independent of GPU thread scheduling. Adds a radix sort + gather pass. Hydroelastic contacts are not yet covered. + contact_matching: Frame-to-frame contact matching mode. One of + ``"disabled"``, ``"latest"``, or ``"sticky"``. Any + non-disabled mode implies ``deterministic=True`` and + populates :attr:`Contacts.rigid_contact_match_index`. + Defaults to ``"disabled"``. + contact_matching_pos_threshold: World-space distance threshold [m] + between the previous and current contact midpoints + ``0.5 * (world(point0) + world(point1))``. Contacts whose + midpoint moves more than this are considered broken. Defaults + to ``0.0005``. + contact_matching_normal_dot_threshold: Minimum dot product between + old and new contact normals for a match. + contact_report: Allocate ``rigid_contact_new_indices`` / + ``rigid_contact_new_count`` / ``rigid_contact_broken_indices`` + / ``rigid_contact_broken_count`` on the :class:`Contacts` + container, populated each frame. Requires a non-disabled + ``contact_matching`` mode. verify_buffers: Run a ``dim=[1]`` diagnostic kernel at the end of the narrow phase that prints warnings on any intermediate candidate-pair or final rigid contact buffer overflow; see @@ -525,6 +547,28 @@ def __init__( rigid-contact autodiff via ``rigid_contact_diff_*`` is **experimental**; see :meth:`collide`. """ + if contact_matching not in ("disabled", "latest", "sticky"): + raise ValueError( + f"contact_matching must be one of 'disabled', 'latest', 'sticky', got {contact_matching!r}" + ) + + if contact_matching_pos_threshold < 0.0: + raise ValueError( + f"contact_matching_pos_threshold must be non-negative, got {contact_matching_pos_threshold}" + ) + if not -1.0 <= contact_matching_normal_dot_threshold <= 1.0: + raise ValueError( + f"contact_matching_normal_dot_threshold must be in [-1, 1], got {contact_matching_normal_dot_threshold}" + ) + matching_enabled = contact_matching != "disabled" + matching_sticky = contact_matching == "sticky" + if contact_report and not matching_enabled: + raise ValueError('contact_report=True requires contact_matching != "disabled"') + + # Any non-disabled matching mode implies deterministic sorting. + if matching_enabled: + deterministic = True + mode_from_broad_phase: str | None = None broad_phase_instance: BroadPhaseAllPairs | BroadPhaseSAP | BroadPhaseExplicit | None = None if broad_phase is not None: @@ -727,8 +771,8 @@ def __init__( self._soft_contact_max = soft_contact_max self.requires_grad = requires_grad self.deterministic = deterministic + per_contact_props = self.narrow_phase.hydroelastic_sdf is not None if deterministic: - per_contact_props = self.narrow_phase.hydroelastic_sdf is not None with wp.ScopedDevice(device): self._sort_key_array = wp.zeros(rigid_contact_max, dtype=wp.int64, device=device) self._contact_sorter = ContactSorter( @@ -738,6 +782,23 @@ def __init__( self._sort_key_array = wp.zeros(0, dtype=wp.int64, device=device) self._contact_sorter = None + self.contact_matching = contact_matching + self._matching_enabled = matching_enabled + self._matching_sticky = matching_sticky + self.contact_report = contact_report + if matching_enabled: + self._contact_matcher = ContactMatcher( + rigid_contact_max, + sorter=self._contact_sorter, + pos_threshold=contact_matching_pos_threshold, + normal_dot_threshold=contact_matching_normal_dot_threshold, + contact_report=contact_report, + sticky=matching_sticky, + device=device, + ) + else: + self._contact_matcher = None + @property def rigid_contact_max(self) -> int: """Maximum rigid contact buffer capacity used by this pipeline.""" @@ -770,6 +831,8 @@ def contacts(self) -> Contacts: device=self.model.device, per_contact_shape_properties=self.narrow_phase.hydroelastic_sdf is not None, requested_attributes=self.model.get_requested_contact_attributes(), + contact_matching=self._matching_enabled, + contact_report=self.contact_report, ) # attach custom attributes with assignment==CONTACT @@ -971,6 +1034,28 @@ def collide( device=self.device, ) + # Match contacts against previous frame before sorting. + if self._contact_matcher is not None: + if contacts.rigid_contact_match_index is None: + raise ValueError( + "CollisionPipeline has contact_matching enabled but the " + "Contacts buffer was created without contact_matching. " + "Use pipeline.contacts() to create a compatible buffer." + ) + self._contact_matcher.match( + sort_keys=self._sort_key_array, + contact_count=contacts.rigid_contact_count, + point0=contacts.rigid_contact_point0, + point1=contacts.rigid_contact_point1, + shape0=contacts.rigid_contact_shape0, + shape1=contacts.rigid_contact_shape1, + normal=contacts.rigid_contact_normal, + body_q=state.body_q, + shape_body=model.shape_body, + match_index_out=contacts.rigid_contact_match_index, + device=self.device, + ) + if self.deterministic and self._contact_sorter is not None: self._contact_sorter.sort_full( self._sort_key_array, @@ -988,7 +1073,66 @@ def collide( stiffness=contacts.rigid_contact_stiffness, damping=contacts.rigid_contact_damping, friction=contacts.rigid_contact_friction, + match_index=contacts.rigid_contact_match_index, + device=self.device, + ) + + # Sticky mode: overwrite matched rows with the saved previous-frame + # contact geometry. Must run after sort_full (so match_index points at + # the sorted prev-frame layout *and* we target the final sorted rows) + # and before save_sorted_state (we save the record we actually used + # this frame, carrying the sticky history forward). + if self._matching_sticky: + self._contact_matcher.replay_matched( + contact_count=contacts.rigid_contact_count, + match_index=contacts.rigid_contact_match_index, + point0=contacts.rigid_contact_point0, + point1=contacts.rigid_contact_point1, + offset0=contacts.rigid_contact_offset0, + offset1=contacts.rigid_contact_offset1, + normal=contacts.rigid_contact_normal, + device=self.device, + ) + + # Build the contact report before saving state, because save + # overwrites _prev_count and the report needs the old value. + if self._contact_matcher is not None: + if self._contact_matcher.has_report: + if contacts.rigid_contact_new_indices is None: + raise ValueError( + "CollisionPipeline has contact_report enabled but the Contacts " + "buffer was created without contact_report=True. " + "Use pipeline.contacts() to create a compatible buffer." + ) + self._contact_matcher.build_report( + contacts.rigid_contact_match_index, + contacts.rigid_contact_count, + contacts.rigid_contact_new_indices, + contacts.rigid_contact_new_count, + contacts.rigid_contact_broken_indices, + contacts.rigid_contact_broken_count, + device=self.device, + ) + sticky_offsets: dict[str, wp.array] = ( + { + "sorted_offset0": contacts.rigid_contact_offset0, + "sorted_offset1": contacts.rigid_contact_offset1, + } + if self._matching_sticky + else {} + ) + self._contact_matcher.save_sorted_state( + sorted_keys=self._contact_sorter.sorted_keys_view, + contact_count=contacts.rigid_contact_count, + sorted_point0=contacts.rigid_contact_point0, + sorted_point1=contacts.rigid_contact_point1, + sorted_shape0=contacts.rigid_contact_shape0, + sorted_shape1=contacts.rigid_contact_shape1, + sorted_normal=contacts.rigid_contact_normal, + body_q=state.body_q, + shape_body=model.shape_body, device=self.device, + **sticky_offsets, ) # Differentiable contact augmentation: reconstruct world-space contact diff --git a/newton/_src/sim/contacts.py b/newton/_src/sim/contacts.py index b913d15d2e..56e89d977b 100644 --- a/newton/_src/sim/contacts.py +++ b/newton/_src/sim/contacts.py @@ -93,6 +93,8 @@ def __init__( per_contact_shape_properties: bool = False, clear_buffers: bool = False, requested_attributes: set[str] | None = None, + contact_matching: bool = False, + contact_report: bool = False, ): """ Initialize Contacts storage. @@ -115,11 +117,22 @@ def __init__( than the conservative path and safe since solvers only read up to contact_count. requested_attributes: Set of extended contact attribute names to allocate. See :attr:`EXTENDED_ATTRIBUTES` for available options. + contact_matching: Allocate a per-contact match index array + (:attr:`rigid_contact_match_index`) that stores frame-to-frame + contact correspondences filled by the collision pipeline. + contact_report: Allocate compact index lists of new and broken + contacts (:attr:`rigid_contact_new_indices`, + :attr:`rigid_contact_new_count`, + :attr:`rigid_contact_broken_indices`, + :attr:`rigid_contact_broken_count`) populated each frame by + the collision pipeline. Requires ``contact_matching=True``. .. note:: The ``rigid_contact_diff_*`` arrays allocated when ``requires_grad=True`` are **experimental**; see :meth:`newton.CollisionPipeline.collide`. """ + if contact_report and not contact_matching: + raise ValueError("contact_report=True requires contact_matching=True") self.per_contact_shape_properties = per_contact_shape_properties self.clear_buffers = clear_buffers with wp.ScopedDevice(device): @@ -210,6 +223,43 @@ def __init__( self.rigid_contact_friction = None """Per-contact friction coefficient [dimensionless], shape (rigid_contact_max,), dtype float.""" + # Contact matching index — filled by the collision pipeline when + # contact_matching is enabled. + self.contact_matching = contact_matching + self.contact_report = contact_report + if contact_matching: + self.rigid_contact_match_index = wp.full(rigid_contact_max, -1, dtype=wp.int32) + """Per-contact match index from frame-to-frame matching. + + Values: ``>= 0`` matched old contact index; + :data:`newton.geometry.MATCH_NOT_FOUND` (``-1``) new contact; + :data:`newton.geometry.MATCH_BROKEN` (``-2``) key matched but + position/normal thresholds exceeded. + Shape (rigid_contact_max,), dtype int32.""" + else: + self.rigid_contact_match_index = None + + if contact_report: + self.rigid_contact_new_indices = wp.zeros(rigid_contact_max, dtype=wp.int32) + """Indices of new contacts in the current sorted buffer (where ``match_index < 0``). + + Valid after the collision pipeline runs. + Shape (rigid_contact_max,), dtype int32.""" + self.rigid_contact_new_count = wp.zeros(1, dtype=wp.int32) + """Device-side count of new contacts (single-element int32).""" + self.rigid_contact_broken_indices = wp.zeros(rigid_contact_max, dtype=wp.int32) + """Indices of broken contacts in the previous frame's sorted buffer. + + Valid after the collision pipeline runs. + Shape (rigid_contact_max,), dtype int32.""" + self.rigid_contact_broken_count = wp.zeros(1, dtype=wp.int32) + """Device-side count of broken contacts (single-element int32).""" + else: + self.rigid_contact_new_indices = None + self.rigid_contact_new_count = None + self.rigid_contact_broken_indices = None + self.rigid_contact_broken_count = None + # soft contacts — requires_grad flows through here for differentiable simulation self.soft_contact_count = self.contact_counters[1:2] self.soft_contact_particle = wp.full(soft_contact_max, -1, dtype=int) @@ -290,6 +340,9 @@ def clear(self, bump_generation: bool = True): self.rigid_contact_damping.zero_() self.rigid_contact_friction.zero_() + if self.rigid_contact_match_index is not None: + self.rigid_contact_match_index.fill_(-1) + self.soft_contact_particle.fill_(-1) self.soft_contact_shape.fill_(-1) self.soft_contact_tids.fill_(-1) diff --git a/newton/geometry.py b/newton/geometry.py index 327f6c782e..4ab443a97b 100644 --- a/newton/geometry.py +++ b/newton/geometry.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from ._src.geometry import ( + MATCH_BROKEN, + MATCH_NOT_FOUND, BroadPhaseAllPairs, BroadPhaseExplicit, BroadPhaseSAP, @@ -25,6 +27,8 @@ from ._src.geometry.sdf_utils import compute_offset_mesh, create_empty_sdf_data __all__ = [ + "MATCH_BROKEN", + "MATCH_NOT_FOUND", "BroadPhaseAllPairs", "BroadPhaseExplicit", "BroadPhaseSAP", diff --git a/newton/tests/test_contact_matching.py b/newton/tests/test_contact_matching.py new file mode 100644 index 0000000000..d3131e59a5 --- /dev/null +++ b/newton/tests/test_contact_matching.py @@ -0,0 +1,741 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for frame-to-frame contact matching.""" + +import unittest + +import numpy as np +import warp as wp + +import newton +from newton.tests.unittest_utils import add_function_test, get_test_devices + + +class TestContactMatching(unittest.TestCase): + pass + + +class TestContactMatchingSticky(unittest.TestCase): + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_simple_scene(device): + """Build a scene with 3 spheres resting on a ground plane. + + Returns (model, state). Spheres at x = -0.5, 0.0, 0.5, all at z = radius + so they touch the plane. + """ + builder = newton.ModelBuilder() + builder.add_ground_plane() + + for x in (-0.5, 0.0, 0.5): + b = builder.add_body(xform=wp.transform(wp.vec3(x, 0.0, 0.1))) + builder.add_shape_sphere(body=b, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + return model, state + + +def _collide_once(pipeline, state, contacts): + """Clear and collide, returning the contact count on host.""" + contacts.clear() + pipeline.collide(state, contacts) + return contacts.rigid_contact_count.numpy()[0] + + +# --------------------------------------------------------------------------- +# Test functions +# --------------------------------------------------------------------------- + + +def test_first_frame_all_not_found(test, device): + """First frame: prev_count is 0, so every contact must get MATCH_NOT_FOUND.""" + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + count = _collide_once(pipeline, state, contacts) + test.assertGreater(count, 0, "Expected contacts between spheres and ground plane") + + match_idx = contacts.rigid_contact_match_index.numpy()[:count] + test.assertTrue( + np.all(match_idx == -1), + f"First frame should have all MATCH_NOT_FOUND, got unique values: {np.unique(match_idx)}", + ) + + +def test_stable_scene_identity_match(test, device): + """Stable scene: deterministic sort + identical state means match_index[i] == i. + + This is the strongest possible invariant: each sorted contact maps to the + same sorted position in the previous frame. It verifies binary search, + position/normal threshold acceptance, sort permutation of match_index, + and the save-then-match round-trip through the sorter's scratch buffers. + """ + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + # Frame 1: populate previous-frame data. + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Frame 2: identical state. + count2 = _collide_once(pipeline, state, contacts) + test.assertEqual(count1, count2, "Contact count must be stable between identical frames") + + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + expected = np.arange(count2, dtype=np.int32) + np.testing.assert_array_equal( + match_idx, + expected, + err_msg="Stable scene: match_index[i] must be i (identity mapping)", + ) + + +def test_stable_scene_identity_across_three_frames(test, device): + """Identity match must hold across 3+ frames, not just the first pair.""" + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + _collide_once(pipeline, state, contacts) # frame 1 + for frame in range(2, 5): + count = _collide_once(pipeline, state, contacts) + match_idx = contacts.rigid_contact_match_index.numpy()[:count] + expected = np.arange(count, dtype=np.int32) + np.testing.assert_array_equal( + match_idx, + expected, + err_msg=f"Frame {frame}: match_index must be identity", + ) + + +def test_new_contact_detection(test, device): + """A new sphere that enters the scene produces MATCH_NOT_FOUND, + while existing contacts keep their identity match. + """ + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + builder.add_ground_plane() + for x in (-0.5, 0.5): + b = builder.add_body(xform=wp.transform(wp.vec3(x, 0.0, 0.1))) + builder.add_shape_sphere(body=b, radius=0.1) + # Third sphere far away — no contacts in frame 1. + b3 = builder.add_body(xform=wp.transform(wp.vec3(0.0, 0.0, 10.0))) + builder.add_shape_sphere(body=b3, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + # Frame 1: 2 sphere-plane contacts. + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Move third sphere to ground for frame 2. + q = state.body_q.numpy() + q[2][0:3] = [0.0, 0.0, 0.1] + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + test.assertGreater(count2, count1, "More contacts expected with third sphere on ground") + + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + + n_new = np.sum(match_idx == -1) + n_matched = np.sum(match_idx >= 0) + test.assertGreater(n_new, 0, "New sphere should produce MATCH_NOT_FOUND contacts") + test.assertEqual(n_matched, count1, f"All {count1} old contacts should still match, got {n_matched}") + + # Matched indices must be unique (no two new contacts claim the same old). + matched_vals = match_idx[match_idx >= 0] + test.assertEqual(len(np.unique(matched_vals)), len(matched_vals), "Matched indices must be unique") + + +def test_broken_pos_threshold_all_contacts(test, device): + """Moving all spheres beyond pos_threshold must break ALL contacts (not just some). + + Uses the default :attr:`CollisionPipeline.contact_matching_pos_threshold` so + the test follows any future retune of the default. ``contact_report=True`` + lets us close the loop and verify each broken new contact has a matching + entry in ``rigid_contact_broken_indices`` (the old contact was also + reported as broken — broken-on-both-sides). + """ + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline( + model, + broad_phase="nxn", + contact_matching="latest", + contact_report=True, + ) + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Shift all dynamic bodies along x by 0.2 m — well above the default + # (0.0005 m) pos_threshold but small enough to keep them on the plane. + q = state.body_q.numpy() + for i in range(len(q)): + q[i][0] += 0.2 + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + + # Every new contact should be MATCH_BROKEN (-2): key matches but + # position drifted beyond threshold. + test.assertTrue( + np.all(match_idx == newton.geometry.MATCH_BROKEN), + f"All contacts should be MATCH_BROKEN. Unique values: {np.unique(match_idx)}", + ) + + # And every old contact should appear in broken_contact_indices: + # if the new side is broken, the old side must also be broken + # (nothing matched it). + broken_count = contacts.rigid_contact_broken_count.numpy()[0] + test.assertEqual( + broken_count, + count1, + f"All {count1} old contacts should be reported as broken, got {broken_count}", + ) + broken_indices = contacts.rigid_contact_broken_indices.numpy()[:broken_count] + np.testing.assert_array_equal( + np.sort(broken_indices), + np.arange(count1, dtype=np.int32), + err_msg="broken_contact_indices must enumerate every old contact", + ) + + +def test_within_pos_threshold_still_matches(test, device): + """Moving spheres less than pos_threshold must still produce matches. + + Uses the default :attr:`CollisionPipeline.contact_matching_pos_threshold` + (0.0005 m) so the test follows any future retune of the default. + """ + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline( + model, + broad_phase="nxn", + contact_matching="latest", + ) + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Shift all dynamic bodies along x by 0.0002 m — below the default + # (0.0005 m) pos_threshold. + q = state.body_q.numpy() + for i in range(len(q)): + q[i][0] += 0.0002 + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + + test.assertTrue( + np.all(match_idx >= 0), + f"All contacts should match within default threshold. Unique: {np.unique(match_idx)}", + ) + + +def test_broken_normal_threshold(test, device): + """Moving a sphere so the contact normal direction changes beyond threshold + produces MATCH_BROKEN. + + Two spheres (radius 0.1) overlap in frame 1 along x-axis (normal ≈ (1,0,0)). + In frame 2, sphere B moves so they overlap along y-axis (normal ≈ (0,1,0)). + Same shape pair / sub_key, generous pos_threshold, but dot((1,0,0), (0,1,0)) = 0 + which is below any reasonable normal_dot_threshold → MATCH_BROKEN. + """ + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + # Two spheres overlapping along x-axis. + ba = builder.add_body(xform=wp.transform(wp.vec3(0.0, 0.0, 0.0))) + builder.add_shape_sphere(body=ba, radius=0.1) + bb = builder.add_body(xform=wp.transform(wp.vec3(0.19, 0.0, 0.0))) + builder.add_shape_sphere(body=bb, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + + pipeline = newton.CollisionPipeline( + model, + broad_phase="nxn", + contact_matching="latest", + contact_matching_pos_threshold=10.0, # very generous — ignore position + contact_matching_normal_dot_threshold=0.5, # cos(60°) — perpendicular normals break + ) + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0, "Overlapping spheres must produce contacts") + + # Move sphere B so they overlap along y-axis instead. + q = state.body_q.numpy() + q[1][0:3] = [0.0, 0.19, 0.0] + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + test.assertGreater(count2, 0, "Repositioned spheres must still produce contacts") + + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + test.assertTrue( + np.all(match_idx == -2), + f"Normal changed ~90°, all should be MATCH_BROKEN. Unique: {np.unique(match_idx)}", + ) + + +def test_contact_report_indices_correct(test, device): + """Contact report indices must be consistent with match_index values.""" + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest", contact_report=True) + contacts = pipeline.contacts() + + # Frame 1: all contacts are new. + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + new_count1 = contacts.rigid_contact_new_count.numpy()[0] + test.assertEqual(new_count1, count1, "First frame: all contacts should be new") + + # Verify new_contact_indices point to valid sorted positions. + new_indices1 = contacts.rigid_contact_new_indices.numpy()[:new_count1] + test.assertTrue(np.all(new_indices1 >= 0) and np.all(new_indices1 < count1)) + + # Verify new_contact_indices match the actual -1 positions in match_index. + match_idx1 = contacts.rigid_contact_match_index.numpy()[:count1] + expected_new = np.where(match_idx1 < 0)[0].astype(np.int32) + np.testing.assert_array_equal( + np.sort(new_indices1), + np.sort(expected_new), + err_msg="rigid_contact_new_indices must match positions where match_index < 0", + ) + + # Frame 2: stable scene — no new, no broken. + _collide_once(pipeline, state, contacts) + test.assertEqual(contacts.rigid_contact_new_count.numpy()[0], 0) + test.assertEqual(contacts.rigid_contact_broken_count.numpy()[0], 0) + + +def test_contact_report_broken_indices(test, device): + """Broken contact report must list old contacts that disappeared.""" + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + builder.add_ground_plane() + for x in (-0.5, 0.5): + b = builder.add_body(xform=wp.transform(wp.vec3(x, 0.0, 0.1))) + builder.add_shape_sphere(body=b, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest", contact_report=True) + contacts = pipeline.contacts() + + # Frame 1: 2 sphere-plane contacts. + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Frame 2: move one sphere far away so its contact disappears. + q = state.body_q.numpy() + q[1][0:3] = [0.5, 0.0, 10.0] # second sphere flies away + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + test.assertLess(count2, count1, "Fewer contacts after removing a sphere") + + broken_count = contacts.rigid_contact_broken_count.numpy()[0] + test.assertGreater(broken_count, 0, "Should have broken contacts from the removed sphere") + + # Broken indices must be valid positions in the OLD sorted buffer. + broken_indices = contacts.rigid_contact_broken_indices.numpy()[:broken_count] + test.assertTrue( + np.all(broken_indices >= 0) and np.all(broken_indices < count1), + f"Broken indices must be in [0, {count1}), got: {broken_indices}", + ) + + +def test_deterministic_implied(test, device): + """Any non-disabled contact_matching mode should imply deterministic=True.""" + with wp.ScopedDevice(device): + model, _state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + test.assertTrue(pipeline.deterministic) + test.assertEqual(pipeline.contact_matching, "latest") + + +def test_matching_disabled_no_allocation(test, device): + """DISABLED mode: match_index and report arrays should be None.""" + with wp.ScopedDevice(device): + model, _state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", deterministic=True) + contacts = pipeline.contacts() + test.assertIsNone(contacts.rigid_contact_match_index) + test.assertIsNone(contacts.rigid_contact_new_indices) + test.assertIsNone(contacts.rigid_contact_broken_indices) + test.assertEqual(pipeline.contact_matching, "disabled") + + +def test_match_index_valid_after_sort(test, device): + """After sorting, match indices must be in valid range and unique.""" + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + _collide_once(pipeline, state, contacts) # frame 1 + count = _collide_once(pipeline, state, contacts) # frame 2 + + match_idx = contacts.rigid_contact_match_index.numpy()[:count] + matched = match_idx[match_idx >= 0] + + test.assertTrue(np.all(matched < count), f"Indices must be < {count}, max: {matched.max()}") + test.assertEqual(len(np.unique(matched)), len(matched), "Matched indices must be unique") + + +def test_dynamic_body_world_transform(test, device): + """Two dynamic spheres (no ground plane) must produce identity match. + + This exercises the ``body_q[bid]`` world-space transform path in both the + match and save kernels (bid != -1), which the ground-plane tests skip. + """ + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + ba = builder.add_body(xform=wp.transform(wp.vec3(0.0, 0.0, 0.0))) + builder.add_shape_sphere(body=ba, radius=0.1) + bb = builder.add_body(xform=wp.transform(wp.vec3(0.19, 0.0, 0.0))) + builder.add_shape_sphere(body=bb, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + + # Verify shape0 is a dynamic body (not ground). + sb = model.shape_body.numpy() + test.assertNotEqual(sb[0], -1, "shape0 should be a dynamic body in this test") + + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Frame 2: identical state → identity match. + count2 = _collide_once(pipeline, state, contacts) + test.assertEqual(count1, count2) + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + np.testing.assert_array_equal( + match_idx, + np.arange(count2, dtype=np.int32), + err_msg="Dynamic-body stable scene must produce identity match", + ) + + +def test_box_on_plane_multiple_contacts(test, device): + """A box on a plane produces multiple contacts per shape pair (sub_keys 0-3). + + This verifies matching works when a single shape pair generates several + contacts with distinct sort sub-keys, and that the identity invariant + holds for all of them. + """ + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + builder.add_ground_plane() + b = builder.add_body(xform=wp.transform(wp.vec3(0.0, 0.0, 0.15))) + builder.add_shape_box(body=b, hx=0.1, hy=0.1, hz=0.1) + + model = builder.finalize(device=device) + state = model.state() + + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 1, "Box on plane should produce multiple contacts") + + # Frame 2: identical state → identity match for all contacts. + count2 = _collide_once(pipeline, state, contacts) + test.assertEqual(count1, count2) + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + np.testing.assert_array_equal( + match_idx, + np.arange(count2, dtype=np.int32), + err_msg="Box multi-contact stable scene must produce identity match", + ) + + +def test_invalid_mode_raises(test, device): + """Invalid contact_matching values must raise ValueError.""" + with wp.ScopedDevice(device): + model, _state = _build_simple_scene(device) + + with test.assertRaises(ValueError): + newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="bogus") + + with test.assertRaises(ValueError): + # Booleans no longer accepted. + newton.CollisionPipeline(model, broad_phase="nxn", contact_matching=True) + + +def test_contact_report_requires_matching(test, device): + """contact_report=True requires a non-disabled matching mode.""" + with wp.ScopedDevice(device): + model, _state = _build_simple_scene(device) + with test.assertRaises(ValueError): + newton.CollisionPipeline( + model, + broad_phase="nxn", + contact_matching="disabled", + contact_report=True, + ) + + +# --------------------------------------------------------------------------- +# Sticky mode tests +# --------------------------------------------------------------------------- + + +def test_sticky_matched_rows_replayed(test, device): + """STICKY mode: matched rows carry exact previous-frame geometry even when + the narrow phase's fresh output differs on a perturbed second frame. + + Frame 2 perturbs the bodies slightly (less than the match threshold) so + the narrow phase produces a different-but-close contact record. Sticky + replay must overwrite ``point0``/``point1``/``offset0``/``offset1`` with + the previous frame's values, so after frame 2 those columns equal the + frame-1 snapshot even though the narrow phase would have produced + something slightly different. + """ + with wp.ScopedDevice(device): + model, state = _build_simple_scene(device) + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="sticky") + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + snap_point0 = contacts.rigid_contact_point0.numpy()[:count1].copy() + snap_point1 = contacts.rigid_contact_point1.numpy()[:count1].copy() + snap_offset0 = contacts.rigid_contact_offset0.numpy()[:count1].copy() + snap_offset1 = contacts.rigid_contact_offset1.numpy()[:count1].copy() + snap_normal = contacts.rigid_contact_normal.numpy()[:count1].copy() + + # Perturb every body by 0.1 mm in x -- well below the 0.5 mm default + # pos threshold so every contact still matches, but enough for the + # narrow phase to produce a detectably different fresh record. + q = state.body_q.numpy() + for i in range(len(q)): + q[i][0] += 0.0001 + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + # Also run the narrow phase on a fresh (non-sticky) pipeline with + # the same state, so we can confirm the fresh contact values really + # differ from frame 1 -- otherwise the sticky assertion below would + # pass trivially. + pipeline_fresh = newton.CollisionPipeline(model, broad_phase="nxn") + contacts_fresh = pipeline_fresh.contacts() + _collide_once(pipeline_fresh, state, contacts_fresh) + fresh_point0 = contacts_fresh.rigid_contact_point0.numpy()[:count1] + + count2 = _collide_once(pipeline, state, contacts) + test.assertEqual(count1, count2) + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + test.assertTrue( + np.all(match_idx >= 0), + f"All perturbed contacts should still match. Unique: {np.unique(match_idx)}", + ) + + # Sanity: fresh narrow phase really did produce different point0 values + # on the perturbed frame, so the sticky assertion below is non-trivial. + test.assertFalse( + np.array_equal(fresh_point0, snap_point0), + "Precondition: perturbation must change fresh narrow-phase point0", + ) + + # Sticky contract: replayed fields equal the frame-1 snapshot. + for field, prev in ( + ("point0", snap_point0), + ("point1", snap_point1), + ("offset0", snap_offset0), + ("offset1", snap_offset1), + ("normal", snap_normal), + ): + current = getattr(contacts, f"rigid_contact_{field}").numpy()[:count2] + np.testing.assert_array_equal( + current, + prev, + err_msg=f"Sticky mode: matched rows must carry prev-frame {field} byte-for-byte", + ) + + +def test_sticky_unmatched_rows_pass_through(test, device): + """STICKY mode: unmatched rows keep the current frame's narrow-phase data. + + Add a new sphere to the scene in frame 2. Its contacts have + match_index < 0, so sticky replay must NOT overwrite them — their + shape indices must reflect the newly added shape. + """ + with wp.ScopedDevice(device): + builder = newton.ModelBuilder() + builder.add_ground_plane() + for x in (-0.5, 0.5): + b = builder.add_body(xform=wp.transform(wp.vec3(x, 0.0, 0.1))) + builder.add_shape_sphere(body=b, radius=0.1) + # Third sphere parked out of the way for frame 1. + b3 = builder.add_body(xform=wp.transform(wp.vec3(0.0, 0.0, 10.0))) + new_shape = builder.add_shape_sphere(body=b3, radius=0.1) + + model = builder.finalize(device=device) + state = model.state() + pipeline = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="sticky") + contacts = pipeline.contacts() + + count1 = _collide_once(pipeline, state, contacts) + test.assertGreater(count1, 0) + + # Bring the third sphere down onto the ground. + q = state.body_q.numpy() + q[2][0:3] = [0.0, 0.0, 0.1] + state.body_q = wp.array(q, dtype=wp.transform, device=device) + + count2 = _collide_once(pipeline, state, contacts) + test.assertGreater(count2, count1) + + match_idx = contacts.rigid_contact_match_index.numpy()[:count2] + shape0 = contacts.rigid_contact_shape0.numpy()[:count2] + shape1 = contacts.rigid_contact_shape1.numpy()[:count2] + + unmatched = match_idx < 0 + test.assertTrue(unmatched.any(), "Frame 2 must introduce at least one unmatched contact") + + # At least one unmatched row must reference the newly added shape, + # proving sticky replay did not overwrite new contacts with stale data. + involves_new = (shape0 == new_shape) | (shape1 == new_shape) + test.assertTrue( + (involves_new & unmatched).any(), + "Unmatched rows must pass through the new narrow-phase contacts for the new shape", + ) + + # Sanity: matched rows still carry valid shape indices (not -1 from + # the default-fill sentinel). + matched_mask = match_idx >= 0 + test.assertTrue( + np.all(shape0[matched_mask] >= 0) and np.all(shape1[matched_mask] >= 0), + "Matched rows must have non-sentinel shape indices after replay", + ) + + +def test_sticky_disabled_no_sticky_buffers(test, device): + """LATEST and DISABLED modes must not allocate sticky buffers.""" + with wp.ScopedDevice(device): + model, _state = _build_simple_scene(device) + + p_latest = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="latest") + test.assertIsNotNone(p_latest._contact_matcher) + test.assertFalse(p_latest._contact_matcher.is_sticky) + test.assertIsNone(p_latest._contact_matcher._prev_point0) + test.assertIsNone(p_latest._contact_matcher._prev_point1) + test.assertIsNone(p_latest._contact_matcher._prev_offset0) + test.assertIsNone(p_latest._contact_matcher._prev_offset1) + + p_off = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="disabled") + test.assertIsNone(p_off._contact_matcher) + + p_sticky = newton.CollisionPipeline(model, broad_phase="nxn", contact_matching="sticky") + test.assertTrue(p_sticky._contact_matcher.is_sticky) + test.assertIsNotNone(p_sticky._contact_matcher._prev_point0) + test.assertIsNotNone(p_sticky._contact_matcher._prev_point1) + test.assertIsNotNone(p_sticky._contact_matcher._prev_offset0) + test.assertIsNotNone(p_sticky._contact_matcher._prev_offset1) + + +# --------------------------------------------------------------------------- +# Register tests +# --------------------------------------------------------------------------- + +devices = get_test_devices() + +add_function_test( + TestContactMatching, "test_first_frame_all_not_found", test_first_frame_all_not_found, devices=devices +) +add_function_test( + TestContactMatching, "test_stable_scene_identity_match", test_stable_scene_identity_match, devices=devices +) +add_function_test( + TestContactMatching, + "test_stable_scene_identity_across_three_frames", + test_stable_scene_identity_across_three_frames, + devices=devices, +) +add_function_test(TestContactMatching, "test_new_contact_detection", test_new_contact_detection, devices=devices) +add_function_test( + TestContactMatching, + "test_broken_pos_threshold_all_contacts", + test_broken_pos_threshold_all_contacts, + devices=devices, +) +add_function_test( + TestContactMatching, + "test_within_pos_threshold_still_matches", + test_within_pos_threshold_still_matches, + devices=devices, +) +add_function_test(TestContactMatching, "test_broken_normal_threshold", test_broken_normal_threshold, devices=devices) +add_function_test( + TestContactMatching, "test_contact_report_indices_correct", test_contact_report_indices_correct, devices=devices +) +add_function_test( + TestContactMatching, "test_contact_report_broken_indices", test_contact_report_broken_indices, devices=devices +) +add_function_test(TestContactMatching, "test_deterministic_implied", test_deterministic_implied, devices=devices) +add_function_test( + TestContactMatching, "test_matching_disabled_no_allocation", test_matching_disabled_no_allocation, devices=devices +) +add_function_test( + TestContactMatching, "test_match_index_valid_after_sort", test_match_index_valid_after_sort, devices=devices +) +add_function_test( + TestContactMatching, "test_dynamic_body_world_transform", test_dynamic_body_world_transform, devices=devices +) +add_function_test( + TestContactMatching, "test_box_on_plane_multiple_contacts", test_box_on_plane_multiple_contacts, devices=devices +) +add_function_test(TestContactMatching, "test_invalid_mode_raises", test_invalid_mode_raises, devices=devices) +add_function_test( + TestContactMatching, "test_contact_report_requires_matching", test_contact_report_requires_matching, devices=devices +) + +add_function_test( + TestContactMatchingSticky, "test_sticky_matched_rows_replayed", test_sticky_matched_rows_replayed, devices=devices +) +add_function_test( + TestContactMatchingSticky, + "test_sticky_unmatched_rows_pass_through", + test_sticky_unmatched_rows_pass_through, + devices=devices, +) +add_function_test( + TestContactMatchingSticky, + "test_sticky_disabled_no_sticky_buffers", + test_sticky_disabled_no_sticky_buffers, + devices=devices, +) + +if __name__ == "__main__": + wp.clear_kernel_cache() + unittest.main(verbosity=2)