From 0ac716d8b8a195ab21da848f134990260cfcafc2 Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 6 Apr 2026 19:35:48 +0000 Subject: [PATCH 1/7] Add private FeatherPGS solver implementation --- newton/_src/solvers/feather_pgs/__init__.py | 20 + newton/_src/solvers/feather_pgs/kernels.py | 4146 ++++++++++++++ .../solvers/feather_pgs/solver_feather_pgs.py | 4821 +++++++++++++++++ newton/_src/solvers/featherstone/kernels.py | 58 +- 4 files changed, 9020 insertions(+), 25 deletions(-) create mode 100644 newton/_src/solvers/feather_pgs/__init__.py create mode 100644 newton/_src/solvers/feather_pgs/kernels.py create mode 100644 newton/_src/solvers/feather_pgs/solver_feather_pgs.py diff --git a/newton/_src/solvers/feather_pgs/__init__.py b/newton/_src/solvers/feather_pgs/__init__.py new file mode 100644 index 0000000000..3163cab7f0 --- /dev/null +++ b/newton/_src/solvers/feather_pgs/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .solver_feather_pgs import SolverFeatherPGS + +__all__ = [ + "SolverFeatherPGS", +] diff --git a/newton/_src/solvers/feather_pgs/kernels.py b/newton/_src/solvers/feather_pgs/kernels.py new file mode 100644 index 0000000000..5389545566 --- /dev/null +++ b/newton/_src/solvers/feather_pgs/kernels.py @@ -0,0 +1,4146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warp as wp + +from ...math.spatial import transform_twist +from ...sim import JointType +from ...sim.articulation import ( + compute_2d_rotational_dofs, + compute_3d_rotational_dofs, +) + +PGS_CONSTRAINT_TYPE_CONTACT = 0 +PGS_CONSTRAINT_TYPE_JOINT_TARGET = 1 +PGS_CONSTRAINT_TYPE_FRICTION = 2 +PGS_CONSTRAINT_TYPE_JOINT_LIMIT = 3 + + +@wp.kernel +def copy_int_array_masked( + src: wp.array(dtype=int), + mask: wp.array(dtype=int), + # outputs + dst: wp.array(dtype=int), +): + tid = wp.tid() + if mask[tid] != 0: + dst[tid] = src[tid] + + +@wp.kernel +def compute_spatial_inertia( + body_inertia: wp.array(dtype=wp.mat33), + body_mass: wp.array(dtype=float), + # outputs + body_I_m: wp.array(dtype=wp.spatial_matrix), +): + tid = wp.tid() + I = body_inertia[tid] + m = body_mass[tid] + # fmt: off + body_I_m[tid] = wp.spatial_matrix( + m, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, m, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, m, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, I[0, 0], I[0, 1], I[0, 2], + 0.0, 0.0, 0.0, I[1, 0], I[1, 1], I[1, 2], + 0.0, 0.0, 0.0, I[2, 0], I[2, 1], I[2, 2], + ) + # fmt: on + + +@wp.kernel +def compute_com_transforms( + body_com: wp.array(dtype=wp.vec3), + # outputs + body_X_com: wp.array(dtype=wp.transform), +): + tid = wp.tid() + com = body_com[tid] + body_X_com[tid] = wp.transform(com, wp.quat_identity()) + + +@wp.kernel +def update_articulation_origins( + articulation_start: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + body_q: wp.array(dtype=wp.transform), + body_com: wp.array(dtype=wp.vec3), + # outputs + articulation_origin: wp.array(dtype=wp.vec3), +): + art = wp.tid() + + start = articulation_start[art] + end = articulation_start[art + 1] + + if start >= end: + articulation_origin[art] = wp.vec3() + return + + root_body = joint_child[start] + if root_body >= 0: + # Store the absolute world-space COM position of the articulation root body. + articulation_origin[art] = wp.transform_point(body_q[root_body], body_com[root_body]) + else: + articulation_origin[art] = wp.vec3() + + +@wp.kernel +def update_articulation_root_com_offsets( + articulation_start: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + body_q: wp.array(dtype=wp.transform), + body_com: wp.array(dtype=wp.vec3), + # outputs + articulation_root_com_offset: wp.array(dtype=wp.vec3), +): + # NOTE: This helper keeps the rotated root COM offset in world orientation. + # FeatherPGS currently uses update_articulation_origins() instead, which + # stores the absolute root COM world position for its free-root convention. + art = wp.tid() + + start = articulation_start[art] + end = articulation_start[art + 1] + + if start >= end: + articulation_root_com_offset[art] = wp.vec3() + return + + root_body = joint_child[start] + if root_body >= 0: + rot = wp.transform_get_rotation(body_q[root_body]) + articulation_root_com_offset[art] = wp.quat_rotate(rot, body_com[root_body]) + else: + articulation_root_com_offset[art] = wp.vec3() + + +@wp.kernel +def convert_root_free_qd_world_to_local( + articulation_root_is_free: wp.array(dtype=int), + articulation_root_dof_start: wp.array(dtype=int), + articulation_root_com_offset: wp.array(dtype=wp.vec3), + # in/out + qd: wp.array(dtype=float), +): + art = wp.tid() + if articulation_root_is_free[art] == 0: + return + + ds = articulation_root_dof_start[art] + v_com = wp.vec3(qd[ds + 0], qd[ds + 1], qd[ds + 2]) + w = wp.vec3(qd[ds + 3], qd[ds + 4], qd[ds + 5]) + com_offset = articulation_root_com_offset[art] + + # Shift linear velocity from the public CoM convention to the internal + # root-body-origin linear term used by FeatherPGS integration/ID state. + v_local = v_com - wp.cross(w, com_offset) + + qd[ds + 0] = v_local[0] + qd[ds + 1] = v_local[1] + qd[ds + 2] = v_local[2] + + +@wp.kernel +def convert_root_free_qd_local_to_world( + articulation_root_is_free: wp.array(dtype=int), + articulation_root_dof_start: wp.array(dtype=int), + articulation_root_com_offset: wp.array(dtype=wp.vec3), + # in/out + qd: wp.array(dtype=float), +): + art = wp.tid() + if articulation_root_is_free[art] == 0: + return + + ds = articulation_root_dof_start[art] + v_local = wp.vec3(qd[ds + 0], qd[ds + 1], qd[ds + 2]) + w = wp.vec3(qd[ds + 3], qd[ds + 4], qd[ds + 5]) + com_offset = articulation_root_com_offset[art] + + # Convert the internal root-body-origin linear term back to the public CoM convention. + v_com = v_local + wp.cross(w, com_offset) + + qd[ds + 0] = v_com[0] + qd[ds + 1] = v_com[1] + qd[ds + 2] = v_com[2] + + +@wp.func +def transform_spatial_inertia(t: wp.transform, I: wp.spatial_matrix): + """ + Transform a spatial inertia tensor to a new coordinate frame. + + This computes the change of coordinates for a spatial inertia tensor under a rigid-body + transformation `t`. The result is mathematically equivalent to: + + adj_t^-T * I * adj_t^-1 + + where `adj_t` is the adjoint transformation matrix of `t`, and `I` is the spatial inertia + tensor in the original frame. This operation is described in Frank & Park, "Modern Robotics", + Section 8.2.3 (pg. 290). + + Args: + t (wp.transform): The rigid-body transform (destination ← source). + I (wp.spatial_matrix): The spatial inertia tensor in the source frame. + + Returns: + wp.spatial_matrix: The spatial inertia tensor expressed in the destination frame. + """ + t_inv = wp.transform_inverse(t) + + q = wp.transform_get_rotation(t_inv) + p = wp.transform_get_translation(t_inv) + + r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0)) + r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0)) + r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0)) + + R = wp.matrix_from_cols(r1, r2, r3) + S = wp.skew(p) @ R + + T = wp.spatial_matrix( + R[0, 0], + R[0, 1], + R[0, 2], + S[0, 0], + S[0, 1], + S[0, 2], + R[1, 0], + R[1, 1], + R[1, 2], + S[1, 0], + S[1, 1], + S[1, 2], + R[2, 0], + R[2, 1], + R[2, 2], + S[2, 0], + S[2, 1], + S[2, 2], + 0.0, + 0.0, + 0.0, + R[0, 0], + R[0, 1], + R[0, 2], + 0.0, + 0.0, + 0.0, + R[1, 0], + R[1, 1], + R[1, 2], + 0.0, + 0.0, + 0.0, + R[2, 0], + R[2, 1], + R[2, 2], + ) + + return wp.mul(wp.mul(wp.transpose(T), I), T) + + +# compute transform across a joint +@wp.func +def jcalc_transform( + type: int, + joint_axis: wp.array(dtype=wp.vec3), + axis_start: int, + lin_axis_count: int, + ang_axis_count: int, + joint_q: wp.array(dtype=float), + q_start: int, +): + if type == JointType.PRISMATIC: + q = joint_q[q_start] + axis = joint_axis[axis_start] + X_jc = wp.transform(axis * q, wp.quat_identity()) + return X_jc + + if type == JointType.REVOLUTE: + q = joint_q[q_start] + axis = joint_axis[axis_start] + X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q)) + return X_jc + + if type == JointType.BALL: + qx = joint_q[q_start + 0] + qy = joint_q[q_start + 1] + qz = joint_q[q_start + 2] + qw = joint_q[q_start + 3] + + X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw)) + return X_jc + + if type == JointType.FIXED: + X_jc = wp.transform_identity() + return X_jc + + if type == JointType.FREE or type == JointType.DISTANCE: + px = joint_q[q_start + 0] + py = joint_q[q_start + 1] + pz = joint_q[q_start + 2] + + qx = joint_q[q_start + 3] + qy = joint_q[q_start + 4] + qz = joint_q[q_start + 5] + qw = joint_q[q_start + 6] + + X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw)) + return X_jc + + if type == JointType.D6: + pos = wp.vec3(0.0) + rot = wp.quat_identity() + + # unroll for loop to ensure joint actions remain differentiable + # (since differentiating through a for loop that updates a local variable is not supported) + + if lin_axis_count > 0: + axis = joint_axis[axis_start + 0] + pos += axis * joint_q[q_start + 0] + if lin_axis_count > 1: + axis = joint_axis[axis_start + 1] + pos += axis * joint_q[q_start + 1] + if lin_axis_count > 2: + axis = joint_axis[axis_start + 2] + pos += axis * joint_q[q_start + 2] + + ia = axis_start + lin_axis_count + iq = q_start + lin_axis_count + if ang_axis_count == 1: + axis = joint_axis[ia] + rot = wp.quat_from_axis_angle(axis, joint_q[iq]) + if ang_axis_count == 2: + rot, _ = compute_2d_rotational_dofs( + joint_axis[ia + 0], + joint_axis[ia + 1], + joint_q[iq + 0], + joint_q[iq + 1], + 0.0, + 0.0, + ) + if ang_axis_count == 3: + rot, _ = compute_3d_rotational_dofs( + joint_axis[ia + 0], + joint_axis[ia + 1], + joint_axis[ia + 2], + joint_q[iq + 0], + joint_q[iq + 1], + joint_q[iq + 2], + 0.0, + 0.0, + 0.0, + ) + + X_jc = wp.transform(pos, rot) + return X_jc + + # default case + return wp.transform_identity() + + +# compute motion subspace and velocity for a joint +@wp.func +def jcalc_motion( + type: int, + joint_axis: wp.array(dtype=wp.vec3), + lin_axis_count: int, + ang_axis_count: int, + X_sc: wp.transform, + joint_qd: wp.array(dtype=float), + qd_start: int, + # outputs + joint_S_s: wp.array(dtype=wp.spatial_vector), +): + if type == JointType.PRISMATIC: + axis = joint_axis[qd_start] + S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3())) + v_j_s = S_s * joint_qd[qd_start] + joint_S_s[qd_start] = S_s + return v_j_s + + if type == JointType.REVOLUTE: + axis = joint_axis[qd_start] + S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis)) + v_j_s = S_s * joint_qd[qd_start] + joint_S_s[qd_start] = S_s + return v_j_s + + if type == JointType.D6: + v_j_s = wp.spatial_vector() + if lin_axis_count > 0: + axis = joint_axis[qd_start + 0] + S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3())) + v_j_s += S_s * joint_qd[qd_start + 0] + joint_S_s[qd_start + 0] = S_s + if lin_axis_count > 1: + axis = joint_axis[qd_start + 1] + S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3())) + v_j_s += S_s * joint_qd[qd_start + 1] + joint_S_s[qd_start + 1] = S_s + if lin_axis_count > 2: + axis = joint_axis[qd_start + 2] + S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3())) + v_j_s += S_s * joint_qd[qd_start + 2] + joint_S_s[qd_start + 2] = S_s + if ang_axis_count > 0: + axis = joint_axis[qd_start + lin_axis_count + 0] + S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis)) + v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0] + joint_S_s[qd_start + lin_axis_count + 0] = S_s + if ang_axis_count > 1: + axis = joint_axis[qd_start + lin_axis_count + 1] + S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis)) + v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1] + joint_S_s[qd_start + lin_axis_count + 1] = S_s + if ang_axis_count > 2: + axis = joint_axis[qd_start + lin_axis_count + 2] + S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis)) + v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2] + joint_S_s[qd_start + lin_axis_count + 2] = S_s + + return v_j_s + + if type == JointType.BALL: + S_0 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0)) + S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)) + S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0)) + + joint_S_s[qd_start + 0] = S_0 + joint_S_s[qd_start + 1] = S_1 + joint_S_s[qd_start + 2] = S_2 + + return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2] + + if type == JointType.FIXED: + return wp.spatial_vector() + + if type == JointType.FREE or type == JointType.DISTANCE: + # For FREE/DISTANCE joints we treat linear/angular velocity components as + # referenced at the root COM world point to avoid world-origin conditioning. + q_sc = wp.transform_get_rotation(X_sc) + + v_local = wp.vec3(joint_qd[qd_start + 0], joint_qd[qd_start + 1], joint_qd[qd_start + 2]) + w_local = wp.vec3(joint_qd[qd_start + 3], joint_qd[qd_start + 4], joint_qd[qd_start + 5]) + v_j_s = wp.spatial_vector(wp.quat_rotate(q_sc, v_local), wp.quat_rotate(q_sc, w_local)) + + ex = wp.quat_rotate(q_sc, wp.vec3(1.0, 0.0, 0.0)) + ey = wp.quat_rotate(q_sc, wp.vec3(0.0, 1.0, 0.0)) + ez = wp.quat_rotate(q_sc, wp.vec3(0.0, 0.0, 1.0)) + + joint_S_s[qd_start + 0] = wp.spatial_vector(ex, wp.vec3()) + joint_S_s[qd_start + 1] = wp.spatial_vector(ey, wp.vec3()) + joint_S_s[qd_start + 2] = wp.spatial_vector(ez, wp.vec3()) + joint_S_s[qd_start + 3] = wp.spatial_vector(wp.vec3(), ex) + joint_S_s[qd_start + 4] = wp.spatial_vector(wp.vec3(), ey) + joint_S_s[qd_start + 5] = wp.spatial_vector(wp.vec3(), ez) + + return v_j_s + + wp.printf("jcalc_motion not implemented for joint type %d\n", type) + + # default case + return wp.spatial_vector() + + +# computes joint space forces/torques in tau +@wp.func +def jcalc_tau( + type: int, + joint_S_s: wp.array(dtype=wp.spatial_vector), + joint_f: wp.array(dtype=float), + dof_start: int, + lin_axis_count: int, + ang_axis_count: int, + body_f_s: wp.spatial_vector, + # outputs + tau: wp.array(dtype=float), +): + if type == JointType.BALL: + # target_ke = joint_target_ke[dof_start] + # target_kd = joint_target_kd[dof_start] + + for i in range(3): + S_s = joint_S_s[dof_start + i] + + # w = joint_qd[dof_start + i] + # r = joint_q[coord_start + i] + + tau[dof_start + i] = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] + # tau -= w * target_kd - r * target_ke + + return + + if type == JointType.FREE or type == JointType.DISTANCE: + for i in range(6): + S_s = joint_S_s[dof_start + i] + tau[dof_start + i] = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] + + return + + if type == JointType.PRISMATIC or type == JointType.REVOLUTE or type == JointType.D6: + axis_count = lin_axis_count + ang_axis_count + + for i in range(axis_count): + j = dof_start + i + S_s = joint_S_s[j] + # total torque / force on the joint (drive forces handled via augmented mass) + tau[j] = -wp.dot(S_s, body_f_s) + joint_f[j] + + return + + +@wp.func +def jcalc_integrate( + type: int, + child: int, + body_com: wp.array(dtype=wp.vec3), + joint_q: wp.array(dtype=float), + joint_qd: wp.array(dtype=float), + joint_qdd: wp.array(dtype=float), + coord_start: int, + dof_start: int, + lin_axis_count: int, + ang_axis_count: int, + dt: float, + # outputs + joint_q_new: wp.array(dtype=float), + joint_qd_new: wp.array(dtype=float), +): + if type == JointType.FIXED: + return + + # prismatic / revolute + if type == JointType.PRISMATIC or type == JointType.REVOLUTE: + qdd = joint_qdd[dof_start] + qd = joint_qd[dof_start] + q = joint_q[coord_start] + + qd_new = qd + qdd * dt + q_new = q + qd_new * dt + + joint_qd_new[dof_start] = qd_new + joint_q_new[coord_start] = q_new + + return + + # ball + if type == JointType.BALL: + m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2]) + w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2]) + + r_j = wp.quat( + joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3] + ) + + # symplectic Euler + w_j_new = w_j + m_j * dt + + drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5 + + # new orientation (normalized) + r_j_new = wp.normalize(r_j + drdt_j * dt) + + # update joint coords + joint_q_new[coord_start + 0] = r_j_new[0] + joint_q_new[coord_start + 1] = r_j_new[1] + joint_q_new[coord_start + 2] = r_j_new[2] + joint_q_new[coord_start + 3] = r_j_new[3] + + # update joint vel + joint_qd_new[dof_start + 0] = w_j_new[0] + joint_qd_new[dof_start + 1] = w_j_new[1] + joint_qd_new[dof_start + 2] = w_j_new[2] + + return + + if type == JointType.FREE or type == JointType.DISTANCE: + a_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2]) + m_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5]) + + v_com = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2]) + w_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5]) + + # symplectic Euler + w_s = w_s + m_s * dt + v_com = v_com + a_s * dt + + p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2]) + + r_s = wp.quat( + joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6] + ) + com_offset_world = wp.quat_rotate(r_s, body_com[child]) + dpdt_s = v_com - wp.cross(w_s, com_offset_world) + + drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5 + + # new orientation (normalized) + p_s_new = p_s + dpdt_s * dt + r_s_new = wp.normalize(r_s + drdt_s * dt) + + # update transform + joint_q_new[coord_start + 0] = p_s_new[0] + joint_q_new[coord_start + 1] = p_s_new[1] + joint_q_new[coord_start + 2] = p_s_new[2] + + joint_q_new[coord_start + 3] = r_s_new[0] + joint_q_new[coord_start + 4] = r_s_new[1] + joint_q_new[coord_start + 5] = r_s_new[2] + joint_q_new[coord_start + 6] = r_s_new[3] + + joint_qd_new[dof_start + 0] = v_com[0] + joint_qd_new[dof_start + 1] = v_com[1] + joint_qd_new[dof_start + 2] = v_com[2] + joint_qd_new[dof_start + 3] = w_s[0] + joint_qd_new[dof_start + 4] = w_s[1] + joint_qd_new[dof_start + 5] = w_s[2] + + return + + # other joint types (compound, universal, D6) + if type == JointType.D6: + axis_count = lin_axis_count + ang_axis_count + + for i in range(axis_count): + qdd = joint_qdd[dof_start + i] + qd = joint_qd[dof_start + i] + q = joint_q[coord_start + i] + + qd_new = qd + qdd * dt + q_new = q + qd_new * dt + + joint_qd_new[dof_start + i] = qd_new + joint_q_new[coord_start + i] = q_new + + return + + +@wp.func +def compute_link_transform( + i: int, + joint_type: wp.array(dtype=int), + joint_parent: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_q: wp.array(dtype=float), + joint_X_p: wp.array(dtype=wp.transform), + joint_X_c: wp.array(dtype=wp.transform), + body_X_com: wp.array(dtype=wp.transform), + joint_axis: wp.array(dtype=wp.vec3), + joint_dof_dim: wp.array(dtype=int, ndim=2), + # outputs + body_q: wp.array(dtype=wp.transform), + body_q_com: wp.array(dtype=wp.transform), +): + # parent transform + parent = joint_parent[i] + child = joint_child[i] + + # parent transform in spatial coordinates + X_pj = joint_X_p[i] + X_cj = joint_X_c[i] + # parent anchor frame in world space + X_wpj = X_pj + if parent >= 0: + X_wp = body_q[parent] + X_wpj = X_wp * X_wpj + + type = joint_type[i] + qd_start = joint_qd_start[i] + lin_axis_count = joint_dof_dim[i, 0] + ang_axis_count = joint_dof_dim[i, 1] + coord_start = joint_q_start[i] + + # compute transform across joint + X_j = jcalc_transform(type, joint_axis, qd_start, lin_axis_count, ang_axis_count, joint_q, coord_start) + + # transform from world to joint anchor frame at child body + X_wcj = X_wpj * X_j + # transform from world to child body frame + X_wc = X_wcj * wp.transform_inverse(X_cj) + + # compute transform of center of mass + X_cm = body_X_com[child] + X_sm = X_wc * X_cm + + # store geometry transforms + body_q[child] = X_wc + body_q_com[child] = X_sm + + +@wp.kernel +def eval_rigid_fk( + articulation_start: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_parent: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_q: wp.array(dtype=float), + joint_X_p: wp.array(dtype=wp.transform), + joint_X_c: wp.array(dtype=wp.transform), + body_X_com: wp.array(dtype=wp.transform), + joint_axis: wp.array(dtype=wp.vec3), + joint_dof_dim: wp.array(dtype=int, ndim=2), + # outputs + body_q: wp.array(dtype=wp.transform), + body_q_com: wp.array(dtype=wp.transform), +): + # one thread per-articulation + index = wp.tid() + + start = articulation_start[index] + end = articulation_start[index + 1] + + for i in range(start, end): + compute_link_transform( + i, + joint_type, + joint_parent, + joint_child, + joint_q_start, + joint_qd_start, + joint_q, + joint_X_p, + joint_X_c, + body_X_com, + joint_axis, + joint_dof_dim, + body_q, + body_q_com, + ) + + +@wp.func +def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector): + w_a = wp.spatial_bottom(a) + v_a = wp.spatial_top(a) + + w_b = wp.spatial_bottom(b) + v_b = wp.spatial_top(b) + + w = wp.cross(w_a, w_b) + v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b) + + return wp.spatial_vector(v, w) + + +@wp.func +def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector): + w_a = wp.spatial_bottom(a) + v_a = wp.spatial_top(a) + + w_b = wp.spatial_bottom(b) + v_b = wp.spatial_top(b) + + w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b) + v = wp.cross(w_a, v_b) + + return wp.spatial_vector(v, w) + + +@wp.func +def dense_index(stride: int, i: int, j: int): + return i * stride + j + + +@wp.func +def compute_link_velocity( + i: int, + joint_type: wp.array(dtype=int), + joint_parent: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_articulation: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_qd: wp.array(dtype=float), + joint_axis: wp.array(dtype=wp.vec3), + joint_dof_dim: wp.array(dtype=int, ndim=2), + body_I_m: wp.array(dtype=wp.spatial_matrix), + body_q: wp.array(dtype=wp.transform), + body_q_com: wp.array(dtype=wp.transform), + joint_X_p: wp.array(dtype=wp.transform), + articulation_origin: wp.array(dtype=wp.vec3), + gravity: wp.array(dtype=wp.vec3), + # outputs + joint_S_s: wp.array(dtype=wp.spatial_vector), + body_I_s: wp.array(dtype=wp.spatial_matrix), + body_v_s: wp.array(dtype=wp.spatial_vector), + body_f_s: wp.array(dtype=wp.spatial_vector), + body_a_s: wp.array(dtype=wp.spatial_vector), +): + type = joint_type[i] + child = joint_child[i] + parent = joint_parent[i] + articulation = joint_articulation[i] + qd_start = joint_qd_start[i] + origin = wp.vec3() + if articulation >= 0: + origin = articulation_origin[articulation] + + X_pj = joint_X_p[i] + # X_cj = joint_X_c[i] + + # parent anchor frame in world space + X_wpj = X_pj + if parent >= 0: + X_wp = body_q[parent] + X_wpj = X_wp * X_wpj + X_wpj_local = wp.transform( + wp.transform_get_translation(X_wpj) - origin, + wp.transform_get_rotation(X_wpj), + ) + + # compute motion subspace and velocity across the joint (also stores S_s to global memory) + lin_axis_count = joint_dof_dim[i, 0] + ang_axis_count = joint_dof_dim[i, 1] + v_j_s = jcalc_motion( + type, + joint_axis, + lin_axis_count, + ang_axis_count, + X_wpj_local, + joint_qd, + qd_start, + joint_S_s, + ) + + # parent velocity + v_parent_s = wp.spatial_vector() + a_parent_s = wp.spatial_vector() + + if parent >= 0: + v_parent_s = body_v_s[parent] + a_parent_s = body_a_s[parent] + + # body velocity, acceleration + v_s = v_parent_s + v_j_s + a_s = a_parent_s + spatial_cross(v_s, v_j_s) + + # compute body forces + X_sm = body_q_com[child] + X_sm_local = wp.transform( + wp.transform_get_translation(X_sm) - origin, + wp.transform_get_rotation(X_sm), + ) + I_m = body_I_m[child] + + # gravity and external forces (expressed in frame aligned with s but centered at body mass) + m = I_m[0, 0] + + f_g = m * gravity[0] + r_com = wp.transform_get_translation(X_sm_local) + f_g_s = wp.spatial_vector(f_g, wp.cross(r_com, f_g)) + + # body forces + I_s = transform_spatial_inertia(X_sm_local, I_m) + + f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s) + + body_v_s[child] = v_s + body_a_s[child] = a_s + body_f_s[child] = f_b_s - f_g_s + body_I_s[child] = I_s + + +# Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1) +@wp.kernel +def eval_rigid_id( + articulation_start: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_parent: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_articulation: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_qd: wp.array(dtype=float), + joint_axis: wp.array(dtype=wp.vec3), + joint_dof_dim: wp.array(dtype=int, ndim=2), + body_I_m: wp.array(dtype=wp.spatial_matrix), + body_q: wp.array(dtype=wp.transform), + body_q_com: wp.array(dtype=wp.transform), + joint_X_p: wp.array(dtype=wp.transform), + articulation_origin: wp.array(dtype=wp.vec3), + gravity: wp.array(dtype=wp.vec3), + # outputs + joint_S_s: wp.array(dtype=wp.spatial_vector), + body_I_s: wp.array(dtype=wp.spatial_matrix), + body_v_s: wp.array(dtype=wp.spatial_vector), + body_f_s: wp.array(dtype=wp.spatial_vector), + body_a_s: wp.array(dtype=wp.spatial_vector), +): + # one thread per-articulation + index = wp.tid() + + start = articulation_start[index] + end = articulation_start[index + 1] + + # compute link velocities and coriolis forces + for i in range(start, end): + compute_link_velocity( + i, + joint_type, + joint_parent, + joint_child, + joint_articulation, + joint_qd_start, + joint_qd, + joint_axis, + joint_dof_dim, + body_I_m, + body_q, + body_q_com, + joint_X_p, + articulation_origin, + gravity, + joint_S_s, + body_I_s, + body_v_s, + body_f_s, + body_a_s, + ) + + +@wp.kernel +def eval_rigid_tau( + articulation_start: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_parent: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_articulation: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_f: wp.array(dtype=float), + joint_S_s: wp.array(dtype=wp.spatial_vector), + body_fb_s: wp.array(dtype=wp.spatial_vector), + body_f_ext: wp.array(dtype=wp.spatial_vector), + body_q: wp.array(dtype=wp.transform), + body_com: wp.array(dtype=wp.vec3), + articulation_origin: wp.array(dtype=wp.vec3), + # outputs + body_ft_s: wp.array(dtype=wp.spatial_vector), + tau: wp.array(dtype=float), +): + # one thread per-articulation + index = wp.tid() + + start = articulation_start[index] + end = articulation_start[index + 1] + count = end - start + + # compute joint forces + for offset in range(count): + # for backwards traversal + i = end - offset - 1 + + type = joint_type[i] + parent = joint_parent[i] + child = joint_child[i] + articulation = joint_articulation[i] + dof_start = joint_qd_start[i] + lin_axis_count = joint_dof_dim[i, 0] + ang_axis_count = joint_dof_dim[i, 1] + origin = wp.vec3() + if articulation >= 0: + origin = articulation_origin[articulation] + + # body forces in Featherstone frame (origin) + f_b_s = body_fb_s[child] + f_t_s = body_ft_s[child] + + # external wrench is provided at COM in world frame; shift torque to origin + f_ext_com = body_f_ext[child] + f_ext_f = wp.spatial_bottom(f_ext_com) + f_ext_t = wp.spatial_top(f_ext_com) + + X_wb = body_q[child] + com_local = body_com[child] + com_world = wp.transform_point(X_wb, com_local) + com_rel = com_world - origin + tau_origin = f_ext_f + wp.cross(com_rel, f_ext_t) + f_ext_origin = wp.spatial_vector(f_ext_t, tau_origin) + + # subtract external wrench to get net wrench on body + f_s = f_b_s + f_t_s - f_ext_origin + + # compute joint-space forces, writes out tau + jcalc_tau( + type, + joint_S_s, + joint_f, + dof_start, + lin_axis_count, + ang_axis_count, + f_s, + tau, + ) + + if parent >= 0: + # update parent forces, todo: check that this is valid for the backwards pass + wp.atomic_add(body_ft_s, parent, f_s) + + +@wp.kernel +def eval_rigid_mass( + articulation_start: wp.array(dtype=int), + articulation_M_start: wp.array(dtype=int), + mass_update_mask: wp.array(dtype=int), + body_I_s: wp.array(dtype=wp.spatial_matrix), + # outputs + M_blocks: wp.array(dtype=float), +): + # one thread per-articulation + index = wp.tid() + + if mass_update_mask[index] == 0: + return + + joint_start = articulation_start[index] + joint_end = articulation_start[index + 1] + joint_count = joint_end - joint_start + + M_offset = articulation_M_start[index] + + for l in range(joint_count): + I = body_I_s[joint_start + l] + block = M_offset + l * 36 + for row in range(6): + for col in range(6): + M_blocks[block + row * 6 + col] = I[row, col] + + +@wp.kernel +def compute_composite_inertia( + articulation_start: wp.array(dtype=int), + mass_update_mask: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + body_I_s: wp.array(dtype=wp.spatial_matrix), + # outputs + body_I_c: wp.array(dtype=wp.spatial_matrix), +): + art_idx = wp.tid() + + if mass_update_mask[art_idx] == 0: + return + + start = articulation_start[art_idx] + end = articulation_start[art_idx + 1] + count = end - start + + for i in range(count): + idx = start + i + body_I_c[idx] = body_I_s[idx] + + for i in range(count - 1, -1, -1): + child_idx = start + i + parent_idx = joint_ancestor[child_idx] + + if parent_idx >= start: + body_I_c[parent_idx] += body_I_c[child_idx] + + +@wp.func +def dense_cholesky( + n: int, + A: wp.array(dtype=float), + R: wp.array(dtype=float), + A_start: int, + R_start: int, + # outputs + L: wp.array(dtype=float), +): + # compute the Cholesky factorization of A = L L^T with diagonal regularization R + for j in range(n): + s = A[A_start + dense_index(n, j, j)] + R[R_start + j] + + for k in range(j): + r = L[A_start + dense_index(n, j, k)] + s -= r * r + + s = wp.sqrt(s) + invS = 1.0 / s + + L[A_start + dense_index(n, j, j)] = s + + for i in range(j + 1, n): + s = A[A_start + dense_index(n, i, j)] + + for k in range(j): + s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)] + + L[A_start + dense_index(n, i, j)] = s * invS + + +@wp.kernel +def cholesky_loop( + H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + R_group: wp.array2d(dtype=float), # [n_arts, n_dofs] + group_to_art: wp.array(dtype=int), + mass_update_mask: wp.array(dtype=int), + n_dofs: int, + # output + L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] +): + """Non-tiled Cholesky for grouped articulation storage. + + One thread per articulation, loop-based Cholesky decomposition. + Efficient for small articulations where tile overhead dominates. + """ + group_idx = wp.tid() + art_idx = group_to_art[group_idx] + + if mass_update_mask[art_idx] == 0: + return + + # Cholesky decomposition with regularization: L L^T = H + diag(R) + for j in range(n_dofs): + # Compute diagonal element L[j,j] + s = H_group[group_idx, j, j] + R_group[group_idx, j] + + for k in range(j): + r = L_group[group_idx, j, k] + s -= r * r + + s = wp.sqrt(s) + inv_s = 1.0 / s + L_group[group_idx, j, j] = s + + # Compute off-diagonal elements L[i,j] for i > j + for i in range(j + 1, n_dofs): + s = H_group[group_idx, i, j] + + for k in range(j): + s -= L_group[group_idx, i, k] * L_group[group_idx, j, k] + + L_group[group_idx, i, j] = s * inv_s + + +@wp.func +def dense_subs( + n: int, + L_start: int, + b_start: int, + L: wp.array(dtype=float), + b: wp.array(dtype=float), + # outputs + x: wp.array(dtype=float), +): + # Solves (L L^T) x = b for x given the Cholesky factor L + # forward substitution solves the lower triangular system L y = b for y + for i in range(n): + s = b[b_start + i] + + for j in range(i): + s -= L[L_start + dense_index(n, i, j)] * x[b_start + j] + + x[b_start + i] = s / L[L_start + dense_index(n, i, i)] + + # backward substitution solves the upper triangular system L^T x = y for x + for i in range(n - 1, -1, -1): + s = x[b_start + i] + + for j in range(i + 1, n): + s -= L[L_start + dense_index(n, j, i)] * x[b_start + j] + + x[b_start + i] = s / L[L_start + dense_index(n, i, i)] + + +@wp.func +def dense_solve( + n: int, + L_start: int, + b_start: int, + A: wp.array(dtype=float), + L: wp.array(dtype=float), + b: wp.array(dtype=float), + # outputs + x: wp.array(dtype=float), + tmp: wp.array(dtype=float), +): + # helper function to include tmp argument for backward pass + dense_subs(n, L_start, b_start, L, b, x) + + +@wp.kernel +def integrate_generalized_joints( + joint_type: wp.array(dtype=int), + joint_child: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + body_com: wp.array(dtype=wp.vec3), + joint_q: wp.array(dtype=float), + joint_qd: wp.array(dtype=float), + joint_qdd: wp.array(dtype=float), + dt: float, + # outputs + joint_q_new: wp.array(dtype=float), + joint_qd_new: wp.array(dtype=float), +): + # one thread per-articulation + index = wp.tid() + + type = joint_type[index] + child = joint_child[index] + coord_start = joint_q_start[index] + dof_start = joint_qd_start[index] + lin_axis_count = joint_dof_dim[index, 0] + ang_axis_count = joint_dof_dim[index, 1] + + jcalc_integrate( + type, + child, + body_com, + joint_q, + joint_qd, + joint_qdd, + coord_start, + dof_start, + lin_axis_count, + ang_axis_count, + dt, + joint_q_new, + joint_qd_new, + ) + + +@wp.kernel +def compute_velocity_predictor( + joint_qd: wp.array(dtype=float), + joint_qdd: wp.array(dtype=float), + dt: float, + # outputs + v_hat: wp.array(dtype=float), +): + tid = wp.tid() + v_hat[tid] = joint_qd[tid] + joint_qdd[tid] * dt + + +@wp.kernel +def update_qdd_from_velocity( + joint_qd: wp.array(dtype=float), + v_new: wp.array(dtype=float), + inv_dt: float, + # outputs + joint_qdd: wp.array(dtype=float), +): + tid = wp.tid() + joint_qdd[tid] = (v_new[tid] - joint_qd[tid]) * inv_dt + + +@wp.func +def contact_tangent_basis(n: wp.vec3): + # pick an arbitrary perpendicular vector and orthonormalize + tangent0 = wp.cross(n, wp.vec3(1.0, 0.0, 0.0)) + if wp.length_sq(tangent0) < 1.0e-12: + tangent0 = wp.cross(n, wp.vec3(0.0, 1.0, 0.0)) + tangent0 = wp.normalize(tangent0) + tangent1 = wp.normalize(wp.cross(n, tangent0)) + return tangent0, tangent1 + + +@wp.kernel +def compute_contact_linear_force_from_impulses( + contact_count: wp.array(dtype=wp.int32), + contact_normal: wp.array(dtype=wp.vec3), + contact_world: wp.array(dtype=wp.int32), + contact_slot: wp.array(dtype=wp.int32), + contact_path: wp.array(dtype=wp.int32), + world_impulses: wp.array2d(dtype=wp.float32), + mf_impulses: wp.array2d(dtype=wp.float32), + enable_friction: int, + inv_dt: float, + # outputs + rigid_contact_force: wp.array(dtype=wp.vec3), +): + """Convert solved FeatherPGS contact impulses into world-frame forces.""" + c = wp.tid() + total_contacts = contact_count[0] + if c >= total_contacts: + return + + force = wp.vec3(0.0) + slot = contact_slot[c] + path = contact_path[c] + + if slot >= 0 and path >= 0 and inv_dt > 0.0: + world = contact_world[c] + # Contacts store normals from shape 0 toward shape 1 (A-to-B). FeatherPGS + # solves along the opposite direction internally, which corresponds to the + # force on shape/body 0 from shape/body 1. + normal = -contact_normal[c] + + lam_n = 0.0 + lam_t0 = 0.0 + lam_t1 = 0.0 + if path == 0: + lam_n = world_impulses[world, slot] + if enable_friction != 0: + lam_t0 = world_impulses[world, slot + 1] + lam_t1 = world_impulses[world, slot + 2] + elif path == 1: + lam_n = mf_impulses[world, slot] + if enable_friction != 0: + lam_t0 = mf_impulses[world, slot + 1] + lam_t1 = mf_impulses[world, slot + 2] + + force = lam_n * normal + if enable_friction != 0: + tangent0, tangent1 = contact_tangent_basis(normal) + force += lam_t0 * tangent0 + lam_t1 * tangent1 + force *= inv_dt + + rigid_contact_force[c] = force + + +@wp.kernel +def pack_contact_linear_force_as_spatial( + contact_count: wp.array(dtype=wp.int32), + rigid_contact_force: wp.array(dtype=wp.vec3), + # outputs + contact_force: wp.array(dtype=wp.spatial_vector), +): + """Pack linear contact forces into Newton's spatial-force contact buffer.""" + c = wp.tid() + total_contacts = contact_count[0] + if c >= total_contacts: + return + + contact_force[c] = wp.spatial_vector(rigid_contact_force[c], wp.vec3(0.0)) + + +# Computes J*v contribution on the fly by walking the tree +# This keeps the S vectors in L2 cache and avoids reading the large J matrix. +@wp.func +def accumulate_contact_jacobian_matrix_free( + articulation: int, + body_index: int, + weight: float, + point_world: wp.vec3, + n_vec: wp.vec3, + body_to_joint: wp.array(dtype=int), + body_to_articulation: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_S_s: wp.array(dtype=wp.spatial_vector), + articulation_origin: wp.array(dtype=wp.vec3), + articulation_dof_start: int, + # Outputs + row_base_index: int, + Jc_out: wp.array(dtype=float), +): + if body_index < 0: + return + + origin = articulation_origin[articulation] + point_rel = point_world - origin + + curr_joint = body_to_joint[body_index] + + while curr_joint != -1: + dof_start = joint_qd_start[curr_joint] + dof_end = joint_qd_start[curr_joint + 1] + dof_count = dof_end - dof_start + + for k in range(dof_count): + global_dof = dof_start + k + + S = joint_S_s[global_dof] + + linear = wp.vec3(S[0], S[1], S[2]) + angular = wp.vec3(S[3], S[4], S[5]) + + lin_vel_at_point = linear + wp.cross(angular, point_rel) + proj = wp.dot(n_vec, lin_vel_at_point) + + local_dof = global_dof - articulation_dof_start + + Jc_out[row_base_index + local_dof] += weight * proj + + curr_joint = joint_ancestor[curr_joint] + + +@wp.kernel +def build_contact_rows_normal( + contact_count: wp.array(dtype=int), + contact_point0: wp.array(dtype=wp.vec3), + contact_point1: wp.array(dtype=wp.vec3), + contact_normal: wp.array(dtype=wp.vec3), + contact_shape0: wp.array(dtype=int), + contact_shape1: wp.array(dtype=int), + contact_thickness0: wp.array(dtype=float), + contact_thickness1: wp.array(dtype=float), + shape_body: wp.array(dtype=int), + body_q: wp.array(dtype=wp.transform), + shape_transform: wp.array(dtype=wp.transform), + shape_material_mu: wp.array(dtype=float), + articulation_start: wp.array(dtype=int), + articulation_H_rows: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + body_to_joint: wp.array(dtype=int), + body_to_articulation: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_S_s: wp.array(dtype=wp.spatial_vector), + articulation_origin: wp.array(dtype=wp.vec3), + max_constraints: int, + max_dofs: int, + contact_beta: float, + contact_cfm: float, + enable_friction: int, + # Outputs + constraint_counts: wp.array(dtype=int), + Jc_out: wp.array(dtype=float), + phi_out: wp.array(dtype=float), + row_beta: wp.array(dtype=float), + row_cfm: wp.array(dtype=float), + row_types: wp.array(dtype=int), + target_velocity: wp.array(dtype=float), + row_parent: wp.array(dtype=int), + row_mu: wp.array(dtype=float), +): + tid = wp.tid() + total_contacts = contact_count[0] + if tid >= total_contacts: + return + + # contact normal stored as A-to-B; negate to get B-to-A used internally + n = -contact_normal[tid] + shape_a = contact_shape0[tid] + shape_b = contact_shape1[tid] + + body_a = -1 + body_b = -1 + if shape_a >= 0: + body_a = shape_body[shape_a] + if shape_b >= 0: + body_b = shape_body[shape_b] + + articulation_a = -1 + articulation_b = -1 + if body_a >= 0: + articulation_a = body_to_articulation[body_a] + if body_b >= 0: + articulation_b = body_to_articulation[body_b] + + articulation = articulation_a + if articulation < 0: + articulation = articulation_b + elif articulation_b >= 0 and articulation_b != articulation: + return + if articulation < 0: + return + + thickness_a = contact_thickness0[tid] + thickness_b = contact_thickness1[tid] + mu = 0.0 + mat_count = 0 + if shape_a >= 0: + mu += shape_material_mu[shape_a] + mat_count += 1 + if shape_b >= 0: + mu += shape_material_mu[shape_b] + mat_count += 1 + if mat_count > 0: + mu /= float(mat_count) + + point_a_local = contact_point0[tid] + point_b_local = contact_point1[tid] + point_a_world = wp.vec3(0.0) + point_b_world = wp.vec3(0.0) + + if body_a >= 0: + X_wb_a = body_q[body_a] # World-from-Body transform + # Contact points are stored in body frame by collision detection + point_a_world = wp.transform_point(X_wb_a, point_a_local) - thickness_a * n + else: + point_a_world = point_a_local - thickness_a * n + + if body_b >= 0: + X_wb_b = body_q[body_b] # World-from-Body transform + # Contact points are stored in body frame by collision detection + point_b_world = wp.transform_point(X_wb_b, point_b_local) + thickness_b * n + else: + point_b_world = point_b_local + thickness_b * n + + phi = wp.dot(n, point_a_world - point_b_world) + + if phi > 0.001: + return + + # Determine upfront if we'll add friction rows (needed for atomic slot allocation) + dof_count = articulation_H_rows[articulation] + will_add_friction = enable_friction != 0 and mu > 0.0 and dof_count > 0 + + # Allocate all slots (normal + 2 friction) in a single atomic operation + # This guarantees contiguous layout: [normal, friction1, friction2] + slots_needed = 3 if will_add_friction else 1 + base_slot = wp.atomic_add(constraint_counts, articulation, slots_needed) + + # Check for overflow (all slots must fit) + if base_slot + slots_needed > max_constraints: + return + + art_dof_start = articulation_dof_start[articulation] + + # --- Normal contact row at base_slot --- + phi_index = articulation * max_constraints + base_slot + phi_out[phi_index] = phi + row_beta[phi_index] = contact_beta + row_cfm[phi_index] = contact_cfm + row_types[phi_index] = PGS_CONSTRAINT_TYPE_CONTACT + target_velocity[phi_index] = 0.0 + row_parent[phi_index] = -1 + row_mu[phi_index] = mu + + row_base = phi_index * max_dofs + for col in range(max_dofs): + Jc_out[row_base + col] = 0.0 + + accumulate_contact_jacobian_matrix_free( + articulation, + body_a, + 1.0, + point_a_world, + n, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + row_base, + Jc_out, + ) + + accumulate_contact_jacobian_matrix_free( + articulation, + body_b, + -1.0, + point_b_world, + n, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + row_base, + Jc_out, + ) + + # --- Friction rows at base_slot + 1 and base_slot + 2 --- + if will_add_friction: + t0, t1 = contact_tangent_basis(n) + + # Friction row 1 at base_slot + 1 + row_index_1 = articulation * max_constraints + base_slot + 1 + tangent_base_1 = row_index_1 * max_dofs + + for col in range(max_dofs): + Jc_out[tangent_base_1 + col] = 0.0 + + row_beta[row_index_1] = 0.0 + row_cfm[row_index_1] = contact_cfm + row_types[row_index_1] = PGS_CONSTRAINT_TYPE_FRICTION + target_velocity[row_index_1] = 0.0 + phi_out[row_index_1] = 0.0 + row_parent[row_index_1] = phi_index + row_mu[row_index_1] = mu + + accumulate_contact_jacobian_matrix_free( + articulation, + body_a, + 1.0, + point_a_world, + t0, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + tangent_base_1, + Jc_out, + ) + + accumulate_contact_jacobian_matrix_free( + articulation, + body_b, + -1.0, + point_b_world, + t0, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + tangent_base_1, + Jc_out, + ) + + # Friction row 2 at base_slot + 2 + row_index_2 = articulation * max_constraints + base_slot + 2 + tangent_base_2 = row_index_2 * max_dofs + + for col in range(max_dofs): + Jc_out[tangent_base_2 + col] = 0.0 + + row_beta[row_index_2] = 0.0 + row_cfm[row_index_2] = contact_cfm + row_types[row_index_2] = PGS_CONSTRAINT_TYPE_FRICTION + target_velocity[row_index_2] = 0.0 + phi_out[row_index_2] = 0.0 + row_parent[row_index_2] = phi_index + row_mu[row_index_2] = mu + + accumulate_contact_jacobian_matrix_free( + articulation, + body_a, + 1.0, + point_a_world, + t1, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + tangent_base_2, + Jc_out, + ) + + accumulate_contact_jacobian_matrix_free( + articulation, + body_b, + -1.0, + point_b_world, + t1, + body_to_joint, + body_to_articulation, + joint_ancestor, + joint_qd_start, + joint_S_s, + articulation_origin, + art_dof_start, + tangent_base_2, + Jc_out, + ) + + +@wp.kernel +def build_augmented_joint_rows( + articulation_start: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + articulation_H_rows: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_target_ke: wp.array(dtype=float), + joint_target_kd: wp.array(dtype=float), + joint_q: wp.array(dtype=float), + joint_qd: wp.array(dtype=float), + joint_target_pos: wp.array(dtype=float), + joint_target_vel: wp.array(dtype=float), + max_dofs: int, + dt: float, + # outputs + row_counts: wp.array(dtype=int), + row_dof_index: wp.array(dtype=int), + row_K: wp.array(dtype=float), + row_u0: wp.array(dtype=float), + limit_counts: wp.array(dtype=int), +): + articulation = wp.tid() + if max_dofs == 0: + row_counts[articulation] = 0 + limit_counts[articulation] = 0 + return + + dof_count = articulation_H_rows[articulation] + if dof_count == 0: + row_counts[articulation] = 0 + limit_counts[articulation] = 0 + return + + joint_start = articulation_start[articulation] + joint_end = articulation_start[articulation + 1] + + slot = int(0) + limit_counts[articulation] = 0 + + for joint_index in range(joint_start, joint_end): + type = joint_type[joint_index] + if type != JointType.PRISMATIC and type != JointType.REVOLUTE and type != JointType.D6: + continue + + lin_axis_count = joint_dof_dim[joint_index, 0] + ang_axis_count = joint_dof_dim[joint_index, 1] + axis_count = lin_axis_count + ang_axis_count + + qd_start = joint_qd_start[joint_index] + coord_start = joint_q_start[joint_index] + + for axis in range(axis_count): + if slot >= max_dofs: + break + dof_index = qd_start + axis + coord_index = coord_start + axis + + ke = joint_target_ke[dof_index] + kd = joint_target_kd[dof_index] + if ke <= 0.0 and kd <= 0.0: + continue + + K = ke * dt * dt + kd * dt + if K <= 0.0: + continue + + row_index = articulation * max_dofs + slot + row_dof_index[row_index] = dof_index + q = joint_q[coord_index] + qd_val = joint_qd[dof_index] + target_pos = joint_target_pos[dof_index] + target_vel = joint_target_vel[dof_index] + u0 = -(ke * (q - target_pos + dt * qd_val) + kd * (qd_val - target_vel)) + row_K[row_index] = K + row_u0[row_index] = u0 + + slot += 1 + if slot >= max_dofs: + break + + row_counts[articulation] = slot + limit_counts[articulation] = 0 + + +@wp.kernel +def detect_limit_count_changes( + limit_counts: wp.array(dtype=int), + prev_limit_counts: wp.array(dtype=int), + # outputs + limit_change_mask: wp.array(dtype=int), +): + tid = wp.tid() + change = 1 if limit_counts[tid] != prev_limit_counts[tid] else 0 + limit_change_mask[tid] = change + + +@wp.kernel +def build_mass_update_mask( + global_flag: int, + limit_change_mask: wp.array(dtype=int), + # outputs + mass_update_mask: wp.array(dtype=int), +): + tid = wp.tid() + flag = 1 if global_flag != 0 else 0 + if limit_change_mask[tid] != 0: + flag = 1 + mass_update_mask[tid] = flag + + +# ============================================================================= +# Joint Limit Constraint Kernels +# ============================================================================= + + +@wp.kernel +def allocate_joint_limit_slots( + articulation_start: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + articulation_H_rows: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_limit_lower: wp.array(dtype=float), + joint_limit_upper: wp.array(dtype=float), + joint_q: wp.array(dtype=float), + art_to_world: wp.array(dtype=int), + max_constraints: int, + # outputs + limit_slot: wp.array(dtype=int), + limit_sign: wp.array(dtype=float), + world_slot_counter: wp.array(dtype=int), +): + """Allocate constraint slots for violated joint limits. + + For each articulation, checks all DOFs of PRISMATIC, REVOLUTE, and D6 + joints against their limits. When a DOF violates its lower or upper + limit, a single constraint slot is atomically reserved in the world's + slot counter (the same counter used by contacts). + + Outputs per-DOF arrays ``limit_slot`` (world-constraint row, or -1) and + ``limit_sign`` (+1 for lower-limit violation, -1 for upper). + """ + art = wp.tid() + world = art_to_world[art] + + # Initialize all DOFs of this articulation to "no limit active" + dof_base = articulation_dof_start[art] + dof_count = articulation_H_rows[art] + for d in range(dof_count): + limit_slot[dof_base + d] = -1 + limit_sign[dof_base + d] = 0.0 + + joint_start = articulation_start[art] + joint_end = articulation_start[art + 1] + + for j in range(joint_start, joint_end): + jtype = joint_type[j] + if jtype != JointType.PRISMATIC and jtype != JointType.REVOLUTE and jtype != JointType.D6: + continue + + lin_count = joint_dof_dim[j, 0] + ang_count = joint_dof_dim[j, 1] + axis_count = lin_count + ang_count + qd_start = joint_qd_start[j] + q_start = joint_q_start[j] + + for axis in range(axis_count): + dof = qd_start + axis + q_val = joint_q[q_start + axis] + lower = joint_limit_lower[dof] + upper = joint_limit_upper[dof] + + # Lower limit violation (q < lower) + if q_val < lower: + slot = wp.atomic_add(world_slot_counter, world, 1) + if slot < max_constraints: + limit_slot[dof] = slot + limit_sign[dof] = 1.0 + # Upper limit violation (q > upper) + elif q_val > upper: + slot = wp.atomic_add(world_slot_counter, world, 1) + if slot < max_constraints: + limit_slot[dof] = slot + limit_sign[dof] = -1.0 + + +@wp.kernel +def populate_joint_limit_J_for_size( + articulation_start: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + joint_type: wp.array(dtype=int), + joint_q_start: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_limit_lower: wp.array(dtype=float), + joint_limit_upper: wp.array(dtype=float), + joint_q: wp.array(dtype=float), + art_to_world: wp.array(dtype=int), + limit_slot: wp.array(dtype=int), + limit_sign: wp.array(dtype=float), + group_to_art: wp.array(dtype=int), + pgs_beta: float, + pgs_cfm: float, + # outputs + J_group: wp.array3d(dtype=float), + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + world_row_beta: wp.array2d(dtype=float), + world_row_cfm: wp.array2d(dtype=float), + world_phi: wp.array2d(dtype=float), + world_target_velocity: wp.array2d(dtype=float), +): + """Populate Jacobian and metadata for joint limit constraints. + + Launched once per size group with ``dim = n_arts_of_size``. Each thread + walks the joints of one articulation and, for every DOF whose + ``limit_slot`` is non-negative (i.e. the limit was activated by + :func:`allocate_joint_limit_slots`), writes: + + * A single ±1 entry in the Jacobian at the DOF's local column. + * Constraint metadata (type, phi, beta, cfm, etc.). + """ + group_idx = wp.tid() + art = group_to_art[group_idx] + world = art_to_world[art] + dof_start = articulation_dof_start[art] + + joint_start = articulation_start[art] + joint_end = articulation_start[art + 1] + + for j in range(joint_start, joint_end): + jtype = joint_type[j] + if jtype != JointType.PRISMATIC and jtype != JointType.REVOLUTE and jtype != JointType.D6: + continue + + lin_count = joint_dof_dim[j, 0] + ang_count = joint_dof_dim[j, 1] + axis_count = lin_count + ang_count + qd_start = joint_qd_start[j] + q_start = joint_q_start[j] + + for axis in range(axis_count): + dof = qd_start + axis + slot = limit_slot[dof] + if slot < 0: + continue + + sign = limit_sign[dof] + q_val = joint_q[q_start + axis] + lower = joint_limit_lower[dof] + upper = joint_limit_upper[dof] + + # phi is negative when violating + phi = 0.0 + if sign > 0.0: + phi = q_val - lower + else: + phi = upper - q_val + + # Jacobian: single ±1 entry at the local DOF column + local_dof = dof - dof_start + J_group[group_idx, slot, local_dof] = sign + + # Constraint metadata + world_row_type[world, slot] = PGS_CONSTRAINT_TYPE_JOINT_LIMIT + world_row_parent[world, slot] = -1 + world_row_mu[world, slot] = 0.0 + world_row_beta[world, slot] = pgs_beta + world_row_cfm[world, slot] = pgs_cfm + world_phi[world, slot] = phi + world_target_velocity[world, slot] = 0.0 + + +# ============================================================================= +# Multi-Articulation Contact Building Kernels +# ============================================================================= +# These kernels enable contacts between multiple articulations within the same +# world. The constraint system becomes world-level instead of per-articulation. + + +@wp.kernel +def allocate_world_contact_slots( + contact_count: wp.array(dtype=int), + contact_shape0: wp.array(dtype=int), + contact_shape1: wp.array(dtype=int), + contact_point0: wp.array(dtype=wp.vec3), + contact_point1: wp.array(dtype=wp.vec3), + contact_normal: wp.array(dtype=wp.vec3), + contact_thickness0: wp.array(dtype=float), + contact_thickness1: wp.array(dtype=float), + body_q: wp.array(dtype=wp.transform), + shape_transform: wp.array(dtype=wp.transform), + shape_body: wp.array(dtype=int), + body_to_articulation: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + is_free_rigid: wp.array(dtype=int), + has_free_rigid: int, + max_constraints: int, + mf_max_constraints: int, + enable_friction: int, + # outputs + contact_world: wp.array(dtype=int), + contact_slot: wp.array(dtype=int), + contact_art_a: wp.array(dtype=int), + contact_art_b: wp.array(dtype=int), + world_slot_counter: wp.array(dtype=int), + contact_path: wp.array(dtype=int), + mf_slot_counter: wp.array(dtype=int), +): + """ + Phase 1 of multi-articulation contact building. + + Allocates world-level constraint slots for each contact and records + which articulations are involved. Contacts where both sides are free + rigid bodies (or ground) are routed to the matrix-free path. + + Each contact reserves 3 slots (normal + 2 friction) in its world's constraint buffer. + """ + c = wp.tid() + total_contacts = contact_count[0] + if c >= total_contacts: + contact_slot[c] = -1 + contact_path[c] = -1 + return + + shape_a = contact_shape0[c] + shape_b = contact_shape1[c] + + # Get bodies and articulations + body_a = -1 + body_b = -1 + if shape_a >= 0: + body_a = shape_body[shape_a] + if shape_b >= 0: + body_b = shape_body[shape_b] + + art_a = -1 + art_b = -1 + if body_a >= 0: + art_a = body_to_articulation[body_a] + if body_b >= 0: + art_b = body_to_articulation[body_b] + + # Determine world (both bodies must be in same world, or one is ground) + world = -1 + if art_a >= 0: + world = art_to_world[art_a] + if art_b >= 0: + world_b = art_to_world[art_b] + if world >= 0 and world_b != world: + # Cross-world contact - shouldn't happen, skip + contact_slot[c] = -1 + contact_path[c] = -1 + return + world = world_b + + if world < 0: + # No articulation involved (ground-ground?) + contact_slot[c] = -1 + contact_path[c] = -1 + return + + # Compute phi (same logic as populate_world_J_for_size) + # contact normal stored as A-to-B; negate to get B-to-A used internally + normal = -contact_normal[c] + point_a_local = contact_point0[c] + point_b_local = contact_point1[c] + thickness_a = contact_thickness0[c] + thickness_b = contact_thickness1[c] + + point_a_world = wp.vec3(0.0) + point_b_world = wp.vec3(0.0) + + if body_a >= 0: + X_wb_a = body_q[body_a] + # Contact points are stored in body frame by collision detection + point_a_world = wp.transform_point(X_wb_a, point_a_local) - thickness_a * normal + else: + point_a_world = point_a_local - thickness_a * normal + + if body_b >= 0: + X_wb_b = body_q[body_b] + # Contact points are stored in body frame by collision detection + point_b_world = wp.transform_point(X_wb_b, point_b_local) + thickness_b * normal + else: + point_b_world = point_b_local + thickness_b * normal + phi = wp.dot(normal, point_a_world - point_b_world) + + # Gate on margin + if phi >= 0.001: + contact_slot[c] = -1 + contact_path[c] = -1 + return + + # Classify: MF path if both sides are free rigid or ground + is_mf = 0 + if has_free_rigid != 0: + a_is_free_or_ground = (art_a < 0) or (is_free_rigid[art_a] != 0) + b_is_free_or_ground = (art_b < 0) or (is_free_rigid[art_b] != 0) + if a_is_free_or_ground and b_is_free_or_ground: + is_mf = 1 + + # Allocate slots (1 normal + 2 friction) + slots_needed = 1 + if enable_friction != 0: + slots_needed = 3 + + if is_mf != 0: + # Matrix-free path + slot = wp.atomic_add(mf_slot_counter, world, slots_needed) + if slot + slots_needed > mf_max_constraints: + # Roll back the counter so finalize sees only filled slots + wp.atomic_add(mf_slot_counter, world, -slots_needed) + contact_slot[c] = -1 + contact_path[c] = -1 + return + contact_world[c] = world + contact_slot[c] = slot + contact_art_a[c] = art_a + contact_art_b[c] = art_b + contact_path[c] = 1 + else: + # Dense path + slot = wp.atomic_add(world_slot_counter, world, slots_needed) + if slot + slots_needed > max_constraints: + # Roll back the counter so finalize sees only filled slots + wp.atomic_add(world_slot_counter, world, -slots_needed) + contact_slot[c] = -1 + contact_path[c] = -1 + return + contact_world[c] = world + contact_slot[c] = slot + contact_art_a[c] = art_a + contact_art_b[c] = art_b + contact_path[c] = 0 + + +@wp.func +def accumulate_jacobian_row_world( + body_index: int, + sign: float, + point_world: wp.vec3, + origin: wp.vec3, + direction: wp.vec3, + body_to_joint: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_S_s: wp.array(dtype=wp.spatial_vector), + art_dof_start: int, + n_dofs: int, + group_idx: int, + row: int, + J_group: wp.array3d(dtype=float), +): + """Accumulate Jacobian contributions by walking up the kinematic tree.""" + if body_index < 0: + return + + point_rel = point_world - origin + curr_joint = body_to_joint[body_index] + + while curr_joint >= 0: + dof_start = joint_qd_start[curr_joint] + dof_end = joint_qd_start[curr_joint + 1] + + for global_dof in range(dof_start, dof_end): + S = joint_S_s[global_dof] + lin = wp.vec3(S[0], S[1], S[2]) + ang = wp.vec3(S[3], S[4], S[5]) + + # Velocity at contact point from this joint + v = lin + wp.cross(ang, point_rel) + proj = wp.dot(direction, v) + + local_dof = global_dof - art_dof_start + if local_dof >= 0 and local_dof < n_dofs: + J_group[group_idx, row, local_dof] += sign * proj + + curr_joint = joint_ancestor[curr_joint] + + +@wp.kernel +def populate_world_J_for_size( + contact_count: wp.array(dtype=int), + contact_point0: wp.array(dtype=wp.vec3), + contact_point1: wp.array(dtype=wp.vec3), + contact_normal: wp.array(dtype=wp.vec3), + contact_shape0: wp.array(dtype=int), + contact_shape1: wp.array(dtype=int), + contact_thickness0: wp.array(dtype=float), + contact_thickness1: wp.array(dtype=float), + contact_world: wp.array(dtype=int), + contact_slot: wp.array(dtype=int), + contact_art_a: wp.array(dtype=int), + contact_art_b: wp.array(dtype=int), + contact_path: wp.array(dtype=int), + target_size: int, + art_size: wp.array(dtype=int), + art_group_idx: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + articulation_origin: wp.array(dtype=wp.vec3), + body_to_joint: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_S_s: wp.array(dtype=wp.spatial_vector), + shape_body: wp.array(dtype=int), + body_q: wp.array(dtype=wp.transform), + shape_transform: wp.array(dtype=wp.transform), + shape_material_mu: wp.array(dtype=float), + enable_friction: int, + pgs_beta: float, + pgs_cfm: float, + # outputs + J_group: wp.array3d(dtype=float), + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + world_row_beta: wp.array2d(dtype=float), + world_row_cfm: wp.array2d(dtype=float), + world_phi: wp.array2d(dtype=float), + world_target_velocity: wp.array2d(dtype=float), +): + """ + Phase 2 of multi-articulation contact building (per size group). + + Populates the Jacobian matrix for articulations of a specific DOF size. + Each contact may contribute to multiple articulations' J matrices. + Contacts routed to the matrix-free path (contact_path==1) are skipped. + """ + c = wp.tid() + total_contacts = contact_count[0] + if c >= total_contacts: + return + + # Skip contacts routed to MF path + if contact_path[c] != 0: + return + + slot = contact_slot[c] + if slot < 0: + return + + world = contact_world[c] + art_a = contact_art_a[c] + art_b = contact_art_b[c] + + # Get contact geometry + # contact normal stored as A-to-B; negate to get B-to-A used internally + normal = -contact_normal[c] + shape_a = contact_shape0[c] + shape_b = contact_shape1[c] + + body_a = -1 + body_b = -1 + if shape_a >= 0: + body_a = shape_body[shape_a] + if shape_b >= 0: + body_b = shape_body[shape_b] + + thickness_a = contact_thickness0[c] + thickness_b = contact_thickness1[c] + + # Compute contact points in world frame + # Contact points are stored in body frame by collision detection + point_a_local = contact_point0[c] + point_b_local = contact_point1[c] + point_a_world = wp.vec3(0.0) + point_b_world = wp.vec3(0.0) + + if body_a >= 0: + X_wb_a = body_q[body_a] + point_a_world = wp.transform_point(X_wb_a, point_a_local) - thickness_a * normal + else: + point_a_world = point_a_local - thickness_a * normal + + if body_b >= 0: + X_wb_b = body_q[body_b] + point_b_world = wp.transform_point(X_wb_b, point_b_local) + thickness_b * normal + else: + point_b_world = point_b_local + thickness_b * normal + + # Compute penetration depth + phi = wp.dot(normal, point_a_world - point_b_world) + + # Compute friction coefficient + mu = 0.0 + mat_count = 0 + if shape_a >= 0: + mu += shape_material_mu[shape_a] + mat_count += 1 + if shape_b >= 0: + mu += shape_material_mu[shape_b] + mat_count += 1 + if mat_count > 0: + mu /= float(mat_count) + + # Compute tangent basis for friction + t0, t1 = contact_tangent_basis(normal) + + # Handle articulation A if it matches target size + if art_a >= 0 and art_size[art_a] == target_size: + group_idx_a = art_group_idx[art_a] + dof_start_a = art_dof_start[art_a] + origin_a = articulation_origin[art_a] + + # Normal row (slot + 0) + accumulate_jacobian_row_world( + body_a, + 1.0, + point_a_world, + origin_a, + normal, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_a, + target_size, + group_idx_a, + slot, + J_group, + ) + + if enable_friction != 0: + # Friction row 1 (slot + 1) + accumulate_jacobian_row_world( + body_a, + 1.0, + point_a_world, + origin_a, + t0, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_a, + target_size, + group_idx_a, + slot + 1, + J_group, + ) + # Friction row 2 (slot + 2) + accumulate_jacobian_row_world( + body_a, + 1.0, + point_a_world, + origin_a, + t1, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_a, + target_size, + group_idx_a, + slot + 2, + J_group, + ) + + # Handle articulation B if it matches target size + if art_b >= 0 and art_size[art_b] == target_size: + group_idx_b = art_group_idx[art_b] + dof_start_b = art_dof_start[art_b] + origin_b = articulation_origin[art_b] + + # Opposite sign for body B + accumulate_jacobian_row_world( + body_b, + -1.0, + point_b_world, + origin_b, + normal, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_b, + target_size, + group_idx_b, + slot, + J_group, + ) + + if enable_friction != 0: + accumulate_jacobian_row_world( + body_b, + -1.0, + point_b_world, + origin_b, + t0, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_b, + target_size, + group_idx_b, + slot + 1, + J_group, + ) + accumulate_jacobian_row_world( + body_b, + -1.0, + point_b_world, + origin_b, + t1, + body_to_joint, + joint_ancestor, + joint_qd_start, + joint_S_s, + dof_start_b, + target_size, + group_idx_b, + slot + 2, + J_group, + ) + + # Set row metadata (only once per contact, from whichever articulation runs first) + # Use art_a preferentially to avoid double-writes + if art_a >= 0 and art_size[art_a] == target_size: + # Normal contact row + world_row_type[world, slot] = PGS_CONSTRAINT_TYPE_CONTACT + world_row_parent[world, slot] = -1 + world_row_mu[world, slot] = mu + world_row_beta[world, slot] = pgs_beta + world_row_cfm[world, slot] = pgs_cfm + world_phi[world, slot] = phi + world_target_velocity[world, slot] = 0.0 + + if enable_friction != 0: + # Friction row 1 + world_row_type[world, slot + 1] = PGS_CONSTRAINT_TYPE_FRICTION + world_row_parent[world, slot + 1] = slot + world_row_mu[world, slot + 1] = mu + world_row_beta[world, slot + 1] = 0.0 + world_row_cfm[world, slot + 1] = pgs_cfm + world_phi[world, slot + 1] = 0.0 + world_target_velocity[world, slot + 1] = 0.0 + + # Friction row 2 + world_row_type[world, slot + 2] = PGS_CONSTRAINT_TYPE_FRICTION + world_row_parent[world, slot + 2] = slot + world_row_mu[world, slot + 2] = mu + world_row_beta[world, slot + 2] = 0.0 + world_row_cfm[world, slot + 2] = pgs_cfm + world_phi[world, slot + 2] = 0.0 + world_target_velocity[world, slot + 2] = 0.0 + + elif art_b >= 0 and art_size[art_b] == target_size: + # Only write metadata from art_b if art_a didn't match this size + world_row_type[world, slot] = PGS_CONSTRAINT_TYPE_CONTACT + world_row_parent[world, slot] = -1 + world_row_mu[world, slot] = mu + world_row_beta[world, slot] = pgs_beta + world_row_cfm[world, slot] = pgs_cfm + world_phi[world, slot] = phi + world_target_velocity[world, slot] = 0.0 + + if enable_friction != 0: + world_row_type[world, slot + 1] = PGS_CONSTRAINT_TYPE_FRICTION + world_row_parent[world, slot + 1] = slot + world_row_mu[world, slot + 1] = mu + world_row_beta[world, slot + 1] = 0.0 + world_row_cfm[world, slot + 1] = pgs_cfm + world_phi[world, slot + 1] = 0.0 + world_target_velocity[world, slot + 1] = 0.0 + + world_row_type[world, slot + 2] = PGS_CONSTRAINT_TYPE_FRICTION + world_row_parent[world, slot + 2] = slot + world_row_mu[world, slot + 2] = mu + world_row_beta[world, slot + 2] = 0.0 + world_row_cfm[world, slot + 2] = pgs_cfm + world_phi[world, slot + 2] = 0.0 + world_target_velocity[world, slot + 2] = 0.0 + + +@wp.kernel +def finalize_world_constraint_counts( + world_slot_counter: wp.array(dtype=int), + max_constraints: int, + slots_per_contact: int, + # outputs + world_constraint_count: wp.array(dtype=int), +): + """Copy and clamp the slot counter to constraint counts. + + When the atomic slot counter exceeds ``max_constraints``, clamping can + leave "gap" slots that were reserved by a rejected contact but never + written. Those gap slots have zero Jacobians and will be harmlessly + skipped by PGS (zero diagonal → ``continue``). + + The ``slots_per_contact`` argument is accepted for backwards + compatibility but is no longer used for rounding, because the + constraint buffer may now contain a mix of 3-row contact groups and + single-row joint-limit constraints. + """ + world = wp.tid() + count = world_slot_counter[world] + if count > max_constraints: + count = max_constraints + world_constraint_count[world] = count + + +@wp.kernel +def clamp_contact_counts( + constraint_counts: wp.array(dtype=int), + max_constraints: int, +): + articulation = wp.tid() + count = constraint_counts[articulation] + if count > max_constraints: + constraint_counts[articulation] = max_constraints + + +@wp.kernel +def apply_augmented_mass_diagonal( + articulation_H_start: wp.array(dtype=int), + articulation_H_rows: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + max_dofs: int, + mass_update_mask: wp.array(dtype=int), + row_counts: wp.array(dtype=int), + row_dof_index: wp.array(dtype=int), + row_K: wp.array(dtype=float), + # outputs + H: wp.array(dtype=float), +): + articulation = wp.tid() + if mass_update_mask[articulation] == 0: + return + + n = articulation_H_rows[articulation] + if n == 0 or max_dofs == 0: + return + + count = row_counts[articulation] + if count == 0: + return + + H_start = articulation_H_start[articulation] + dof_start = articulation_dof_start[articulation] + + for i in range(count): + row_index = articulation * max_dofs + i + dof = row_dof_index[row_index] + local = dof - dof_start + if local < 0 or local >= n: + continue + + K = row_K[row_index] + if K <= 0.0: + continue + + diag_index = H_start + dense_index(n, local, local) + H[diag_index] += K + + +@wp.kernel +def apply_augmented_mass_diagonal_grouped( + group_to_art: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + n_dofs: int, + max_dofs: int, + mass_update_mask: wp.array(dtype=int), + row_counts: wp.array(dtype=int), + row_dof_index: wp.array(dtype=int), + row_K: wp.array(dtype=float), + # outputs + H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] +): + """Apply augmented mass diagonal for grouped H storage.""" + idx = wp.tid() + articulation = group_to_art[idx] + + if mass_update_mask[articulation] == 0: + return + + count = row_counts[articulation] + if count == 0: + return + + dof_start = articulation_dof_start[articulation] + + for i in range(count): + row_index = articulation * max_dofs + i + dof = row_dof_index[row_index] + local = dof - dof_start + if local < 0 or local >= n_dofs: + continue + + K = row_K[row_index] + if K <= 0.0: + continue + + H_group[idx, local, local] += K + + +@wp.kernel +def apply_augmented_joint_tau( + max_dofs: int, + row_counts: wp.array(dtype=int), + row_dof_index: wp.array(dtype=int), + row_u0: wp.array(dtype=float), + # outputs + joint_tau: wp.array(dtype=float), +): + articulation = wp.tid() + if max_dofs == 0: + return + + count = row_counts[articulation] + if count == 0: + return + + for i in range(count): + row_index = articulation * max_dofs + i + dof = row_dof_index[row_index] + u0 = row_u0[row_index] + if u0 == 0.0: + continue + + wp.atomic_add(joint_tau, dof, u0) + + +@wp.kernel +def prepare_impulses( + constraint_counts: wp.array(dtype=int), + max_constraints: int, + warmstart: int, + # outputs + impulses: wp.array(dtype=float), +): + articulation = wp.tid() + m = constraint_counts[articulation] + base = articulation * max_constraints + + for i in range(max_constraints): + if warmstart == 0 or i >= m: + impulses[base + i] = 0.0 + + +@wp.kernel +def clamp_joint_tau( + joint_tau: wp.array(dtype=float), + joint_effort_limit: wp.array(dtype=float), +): + tid = wp.tid() + + # Per-DoF effort limit (same convention as MuJoCo actuators) + limit = joint_effort_limit[tid] + + # If limit <= 0, treat as unlimited + if limit <= 0.0: + return + + t = joint_tau[tid] + + if t > limit: + t = limit + elif t < -limit: + t = -limit + + joint_tau[tid] = t + + +# --- Tile configuration for contact system build --- +# Kernel naming: {op}_{parallelism} +# parallelism: tiled | loop | par_row | par_row_col | par_dof + +# Max generalized dofs per articulation we support in the tiled path. +# joint_dof_count per articulation must be <= TILE_DOF or we use fall back +TILE_DOF = wp.constant(49) + +# Max constraints per articulation we support in the tiled path. +# dense_max_constraints must be <= TILE_CONSTRAINTS or we use fall back +TILE_CONSTRAINTS = wp.constant(128) + +# Threads per tile/block for tile kernels +TILE_THREADS = 64 + + +@wp.kernel +def update_body_qd_from_featherstone( + body_v_s: wp.array(dtype=wp.spatial_vector), + body_q: wp.array(dtype=wp.transform), + body_com: wp.array(dtype=wp.vec3), + body_to_articulation: wp.array(dtype=int), + articulation_origin: wp.array(dtype=wp.vec3), + body_qd_out: wp.array(dtype=wp.spatial_vector), +): + tid = wp.tid() + + twist = body_v_s[tid] # spatial twist about origin + v0 = wp.spatial_top(twist) + w = wp.spatial_bottom(twist) + + X_wb = body_q[tid] + com_local = body_com[tid] + com_world = wp.transform_point(X_wb, com_local) + art = body_to_articulation[tid] + origin = wp.vec3() + if art >= 0: + origin = articulation_origin[art] + com_rel = com_world - origin + + v_com = v0 + wp.cross(w, com_rel) + + body_qd_out[tid] = wp.spatial_vector(v_com, w) + + +# ============================================================================= +# World-Level PGS and Velocity Kernels for Multi-Articulation +# ============================================================================= + + +@wp.kernel +def compute_world_contact_bias( + world_constraint_count: wp.array(dtype=int), + max_constraints: int, + world_phi: wp.array2d(dtype=float), + world_row_beta: wp.array2d(dtype=float), + world_row_type: wp.array2d(dtype=int), + world_target_velocity: wp.array2d(dtype=float), + dt: float, + # outputs + world_rhs: wp.array2d(dtype=float), +): + """Compute the RHS bias term for world-level PGS solve. + + The RHS follows the convention: rhs = J*v + stabilization + For contacts with penetration (phi < 0): rhs = J*v + beta * phi / dt (negative) + This leads to positive impulses when resolved by PGS. + """ + world = wp.tid() + m = world_constraint_count[world] + + inv_dt = 1.0 / dt + + for i in range(m): + phi = world_phi[world, i] + beta = world_row_beta[world, i] + row_type = world_row_type[world, i] + target_vel = world_target_velocity[world, i] + + # Initialize with -target_velocity (will add J*v later) + rhs = -target_vel + + # For contacts and joint limits: add Baumgarte stabilization when violating + if row_type == PGS_CONSTRAINT_TYPE_CONTACT or row_type == PGS_CONSTRAINT_TYPE_JOINT_LIMIT: + if phi < 0.0: + rhs += beta * phi * inv_dt # Negative for penetration / violation + elif row_type == PGS_CONSTRAINT_TYPE_JOINT_TARGET: + rhs += beta * phi * inv_dt + + world_rhs[world, i] = rhs + + +@wp.kernel +def rhs_accum_world_par_art( + world_constraint_count: wp.array(dtype=int), + max_constraints: int, + art_to_world: wp.array(dtype=int), + art_size: wp.array(dtype=int), + art_group_idx: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + v_hat: wp.array(dtype=float), + group_to_art: wp.array(dtype=int), + J_group: wp.array3d(dtype=float), + n_dofs: int, + # outputs + world_rhs: wp.array2d(dtype=float), +): + """ + Accumulate J*v_hat into world RHS for a single size group. + + RHS = J*v + stabilization (already includes stabilization from compute_world_contact_bias) + This kernel is launched once per size group to accumulate velocity contributions. + """ + idx = wp.tid() + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + + if n_constraints == 0: + return + + dof_start = art_dof_start[art] + + for c in range(n_constraints): + jv = float(0.0) + for d in range(n_dofs): + jv += J_group[idx, c, d] * v_hat[dof_start + d] + wp.atomic_add(world_rhs, world, c, jv) # Add J*v (positive) + + +@wp.kernel +def prepare_world_impulses( + world_constraint_count: wp.array(dtype=int), + max_constraints: int, + warmstart: int, + # in/out + world_impulses: wp.array2d(dtype=float), +): + """Initialize world impulses (zero or warmstart).""" + world = wp.tid() + m = world_constraint_count[world] + + for i in range(max_constraints): + if warmstart == 0 or i >= m: + world_impulses[world, i] = 0.0 + + +# ============================================================================= +# Fully Matrix-Free PGS Kernels (velocity-space Jacobi) +# ============================================================================= + + +@wp.kernel +def diag_from_JY_par_art( + J_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] + Y_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + n_dofs: int, + max_constraints: int, + n_arts: int, + # output + world_diag: wp.array2d(dtype=float), +): + """Compute diagonal of Delassus from J and Y without assembling the full matrix. + + diag[w,c] += sum_k J[idx,c,k] * Y[idx,c,k]. Thread dim: n_arts * max_constraints. + """ + tid = wp.tid() + c = tid % max_constraints + idx = tid // max_constraints + if idx >= n_arts: + return + art = group_to_art[idx] + world = art_to_world[art] + if c >= world_constraint_count[world]: + return + val = float(0.0) + for k in range(n_dofs): + val += J_group[idx, c, k] * Y_group[idx, c, k] + if val != 0.0: + wp.atomic_add(world_diag, world, c, val) + + +@wp.kernel +def gather_JY_to_world( + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + world_dof_start: wp.array(dtype=int), + J_group: wp.array3d(dtype=float), + Y_group: wp.array3d(dtype=float), + n_dofs: int, + max_constraints: int, + n_arts: int, + # outputs + J_world: wp.array3d(dtype=float), + Y_world: wp.array3d(dtype=float), +): + """Gather per-size-group J/Y into world-indexed arrays. + + Thread dim: n_arts * max_constraints * n_dofs. + """ + tid = wp.tid() + d = tid % n_dofs + remainder = tid // n_dofs + c = remainder % max_constraints + idx = remainder // max_constraints + if idx >= n_arts: + return + art = group_to_art[idx] + world = art_to_world[art] + if c >= world_constraint_count[world]: + return + dof_start = art_dof_start[art] + w_dof_start = world_dof_start[world] + local_d = (dof_start - w_dof_start) + d + # Write unconditionally (including zeros) so J_world/Y_world don't need pre-zeroing + J_world[world, c, local_d] = J_group[idx, c, d] + Y_world[world, c, local_d] = Y_group[idx, c, d] + + +# ============================================================================= +# Matrix-Free PGS Kernels for Free Rigid Bodies +# ============================================================================= + + +@wp.kernel +def build_mf_contact_rows( + contact_count: wp.array(dtype=int), + contact_point0: wp.array(dtype=wp.vec3), + contact_point1: wp.array(dtype=wp.vec3), + contact_normal: wp.array(dtype=wp.vec3), + contact_shape0: wp.array(dtype=int), + contact_shape1: wp.array(dtype=int), + contact_thickness0: wp.array(dtype=float), + contact_thickness1: wp.array(dtype=float), + contact_world: wp.array(dtype=int), + contact_slot: wp.array(dtype=int), + contact_path: wp.array(dtype=int), + contact_art_a: wp.array(dtype=int), + contact_art_b: wp.array(dtype=int), + articulation_origin: wp.array(dtype=wp.vec3), + shape_body: wp.array(dtype=int), + body_q: wp.array(dtype=wp.transform), + shape_material_mu: wp.array(dtype=float), + enable_friction: int, + pgs_beta: float, + # outputs + mf_body_a: wp.array2d(dtype=int), + mf_body_b: wp.array2d(dtype=int), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_row_mu: wp.array2d(dtype=float), + mf_phi: wp.array2d(dtype=float), +): + """Build MF constraint rows for contacts between free rigid bodies / ground. + + For root free joints, the internal qd used here stores the COM-point linear + term and angular velocity, i.e. `qd = [v_com_point, omega]`, where the COM + point is the root body's world-space COM position. The MF contact Jacobian + uses contact position relative to that point: + J = [d, r x d] (r = p_contact - p_com_world) + """ + c = wp.tid() + total_contacts = contact_count[0] + if c >= total_contacts: + return + + if contact_path[c] != 1: + return + + slot = contact_slot[c] + if slot < 0: + return + + world = contact_world[c] + shape_a = contact_shape0[c] + shape_b = contact_shape1[c] + # contact normal stored as A-to-B; negate to get B-to-A used internally + normal = -contact_normal[c] + + body_a = -1 + body_b = -1 + if shape_a >= 0: + body_a = shape_body[shape_a] + if shape_b >= 0: + body_b = shape_body[shape_b] + + thickness_a = contact_thickness0[c] + thickness_b = contact_thickness1[c] + + # Compute contact points in world frame + point_a_local = contact_point0[c] + point_b_local = contact_point1[c] + point_a_world = wp.vec3(0.0) + point_b_world = wp.vec3(0.0) + + if body_a >= 0: + X_wb_a = body_q[body_a] + point_a_world = wp.transform_point(X_wb_a, point_a_local) - thickness_a * normal + else: + point_a_world = point_a_local - thickness_a * normal + + if body_b >= 0: + X_wb_b = body_q[body_b] + point_b_world = wp.transform_point(X_wb_b, point_b_local) + thickness_b * normal + else: + point_b_world = point_b_local + thickness_b * normal + + phi = wp.dot(normal, point_a_world - point_b_world) + + # Friction coefficient + mu = 0.0 + mat_count = 0 + if shape_a >= 0: + mu += shape_material_mu[shape_a] + mat_count += 1 + if shape_b >= 0: + mu += shape_material_mu[shape_b] + mat_count += 1 + if mat_count > 0: + mu /= float(mat_count) + + # Tangent basis + t0, t1 = contact_tangent_basis(normal) + + # Write rows for normal + friction + for row_offset in range(3): + if row_offset > 0 and enable_friction == 0: + break + + row_idx = slot + row_offset + + if row_offset == 0: + d = normal + elif row_offset == 1: + d = t0 + else: + d = t1 + + # Body A Jacobian in articulation-local frame: J = [d, r_a x d], where + # r_a is the contact point relative to articulation A's fixed origin. + if body_a >= 0: + art_a = contact_art_a[c] + origin_a = articulation_origin[art_a] + r_a = point_a_world - origin_a + ang_a = wp.cross(r_a, d) + mf_J_a[world, row_idx, 0] = d[0] + mf_J_a[world, row_idx, 1] = d[1] + mf_J_a[world, row_idx, 2] = d[2] + mf_J_a[world, row_idx, 3] = ang_a[0] + mf_J_a[world, row_idx, 4] = ang_a[1] + mf_J_a[world, row_idx, 5] = ang_a[2] + + # Body B Jacobian in articulation-local frame (opposite sign). + if body_b >= 0: + art_b = contact_art_b[c] + origin_b = articulation_origin[art_b] + r_b = point_b_world - origin_b + ang_b = wp.cross(r_b, d) + mf_J_b[world, row_idx, 0] = -d[0] + mf_J_b[world, row_idx, 1] = -d[1] + mf_J_b[world, row_idx, 2] = -d[2] + mf_J_b[world, row_idx, 3] = -ang_b[0] + mf_J_b[world, row_idx, 4] = -ang_b[1] + mf_J_b[world, row_idx, 5] = -ang_b[2] + + mf_body_a[world, row_idx] = body_a + mf_body_b[world, row_idx] = body_b + + if row_offset == 0: + mf_row_type[world, row_idx] = PGS_CONSTRAINT_TYPE_CONTACT + mf_row_parent[world, row_idx] = -1 + mf_phi[world, row_idx] = phi + else: + mf_row_type[world, row_idx] = PGS_CONSTRAINT_TYPE_FRICTION + mf_row_parent[world, row_idx] = slot + mf_phi[world, row_idx] = 0.0 + mf_row_mu[world, row_idx] = mu + + +@wp.func +def spatial_matrix_block_inverse(M: wp.spatial_matrix): + """Invert a 6x6 spatial matrix using 3x3 block inversion. + + Partition M = [A B; C D] into 3x3 blocks, then: + S = D - C * A^-1 * B (Schur complement) + M^-1 = [A^-1 + A^-1*B*S^-1*C*A^-1, -A^-1*B*S^-1] + [-S^-1*C*A^-1, S^-1] + """ + A = wp.mat33( + M[0, 0], + M[0, 1], + M[0, 2], + M[1, 0], + M[1, 1], + M[1, 2], + M[2, 0], + M[2, 1], + M[2, 2], + ) + B = wp.mat33( + M[0, 3], + M[0, 4], + M[0, 5], + M[1, 3], + M[1, 4], + M[1, 5], + M[2, 3], + M[2, 4], + M[2, 5], + ) + C = wp.mat33( + M[3, 0], + M[3, 1], + M[3, 2], + M[4, 0], + M[4, 1], + M[4, 2], + M[5, 0], + M[5, 1], + M[5, 2], + ) + D = wp.mat33( + M[3, 3], + M[3, 4], + M[3, 5], + M[4, 3], + M[4, 4], + M[4, 5], + M[5, 3], + M[5, 4], + M[5, 5], + ) + + Ainv = wp.inverse(A) + AinvB = Ainv * B + S = D - C * AinvB + Sinv = wp.inverse(S) + SinvCAinv = Sinv * C * Ainv + + # Top-left: Ainv + AinvB * SinvCAinv + TL = Ainv + AinvB * SinvCAinv + # Top-right: -AinvB * Sinv + TR = -AinvB * Sinv + # Bottom-left: -SinvCAinv + BL = -SinvCAinv + # Bottom-right: Sinv + BR = Sinv + + return wp.spatial_matrix( + TL[0, 0], + TL[0, 1], + TL[0, 2], + TR[0, 0], + TR[0, 1], + TR[0, 2], + TL[1, 0], + TL[1, 1], + TL[1, 2], + TR[1, 0], + TR[1, 1], + TR[1, 2], + TL[2, 0], + TL[2, 1], + TL[2, 2], + TR[2, 0], + TR[2, 1], + TR[2, 2], + BL[0, 0], + BL[0, 1], + BL[0, 2], + BR[0, 0], + BR[0, 1], + BR[0, 2], + BL[1, 0], + BL[1, 1], + BL[1, 2], + BR[1, 0], + BR[1, 1], + BR[1, 2], + BL[2, 0], + BL[2, 1], + BL[2, 2], + BR[2, 0], + BR[2, 1], + BR[2, 2], + ) + + +@wp.kernel +def compute_mf_body_Hinv( + body_I_s: wp.array(dtype=wp.spatial_matrix), + is_free_rigid: wp.array(dtype=int), + body_to_articulation: wp.array(dtype=int), + # outputs + mf_body_Hinv: wp.array(dtype=wp.spatial_matrix), +): + """Compute H^-1 = inverse(body_I_s) for free rigid bodies. + + For root free joints, H = body_I_s in articulation-local coordinates. + This remains a full 6x6 matrix for bodies with non-zero CoM offsets. + """ + b = wp.tid() + art = body_to_articulation[b] + if art < 0: + return + if is_free_rigid[art] == 0: + return + + mf_body_Hinv[b] = spatial_matrix_block_inverse(body_I_s[b]) + + +@wp.kernel +def compute_mf_effective_mass_and_rhs( + mf_constraint_count: wp.array(dtype=int), + mf_body_a: wp.array2d(dtype=int), + mf_body_b: wp.array2d(dtype=int), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_body_Hinv: wp.array(dtype=wp.spatial_matrix), + mf_phi: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + pgs_cfm: float, + pgs_beta: float, + dt: float, + mf_max_constraints: int, + # outputs + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_rhs: wp.array2d(dtype=float), +): + """Compute effective mass diagonal, H^-1*J^T, and RHS bias for MF constraints. + + The effective mass for constraint i is: + d_ii = J_a^T * H_a_inv * J_a + J_b^T * H_b_inv * J_b + cfm + + H_inv is the full 6x6 inverse of the spatial inertia in articulation-local + coordinates for each free rigid articulation. + + RHS stores only the stabilization bias (not J*v), since the MF PGS + recomputes J*v each iteration from the live velocity array. + """ + tid = wp.tid() + world = tid // mf_max_constraints + i = tid % mf_max_constraints + if i >= mf_constraint_count[world]: + return + + ba = mf_body_a[world, i] + bb = mf_body_b[world, i] + + # Load Jacobian as spatial_vector + Ja = wp.spatial_vector( + mf_J_a[world, i, 0], + mf_J_a[world, i, 1], + mf_J_a[world, i, 2], + mf_J_a[world, i, 3], + mf_J_a[world, i, 4], + mf_J_a[world, i, 5], + ) + Jb = wp.spatial_vector( + mf_J_b[world, i, 0], + mf_J_b[world, i, 1], + mf_J_b[world, i, 2], + mf_J_b[world, i, 3], + mf_J_b[world, i, 4], + mf_J_b[world, i, 5], + ) + + d = pgs_cfm + + # Side A: MiJt_a = H_a_inv * J_a, d += J_a^T * MiJt_a + if ba >= 0: + Hinv_a = mf_body_Hinv[ba] + MiJt_a = Hinv_a * Ja + d += wp.dot(Ja, MiJt_a) + mf_MiJt_a[world, i, 0] = MiJt_a[0] + mf_MiJt_a[world, i, 1] = MiJt_a[1] + mf_MiJt_a[world, i, 2] = MiJt_a[2] + mf_MiJt_a[world, i, 3] = MiJt_a[3] + mf_MiJt_a[world, i, 4] = MiJt_a[4] + mf_MiJt_a[world, i, 5] = MiJt_a[5] + + # Side B + if bb >= 0: + Hinv_b = mf_body_Hinv[bb] + MiJt_b = Hinv_b * Jb + d += wp.dot(Jb, MiJt_b) + mf_MiJt_b[world, i, 0] = MiJt_b[0] + mf_MiJt_b[world, i, 1] = MiJt_b[1] + mf_MiJt_b[world, i, 2] = MiJt_b[2] + mf_MiJt_b[world, i, 3] = MiJt_b[3] + mf_MiJt_b[world, i, 4] = MiJt_b[4] + mf_MiJt_b[world, i, 5] = MiJt_b[5] + + if d > 0.0: + mf_eff_mass_inv[world, i] = 1.0 / d + else: + mf_eff_mass_inv[world, i] = 0.0 + + # Baumgarte stabilization bias only (not J*v -- recomputed each PGS iter) + bias = float(0.0) + rtype = mf_row_type[world, i] + if rtype == PGS_CONSTRAINT_TYPE_CONTACT: + phi_val = mf_phi[world, i] + bias = pgs_beta / dt * wp.min(phi_val, 0.0) + + mf_rhs[world, i] = bias + + +@wp.kernel +def pgs_solve_mf_loop( + mf_constraint_count: wp.array(dtype=int), + mf_body_a: wp.array2d(dtype=int), + mf_body_b: wp.array2d(dtype=int), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_rhs: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_row_mu: wp.array2d(dtype=float), + body_to_articulation: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + iterations: int, + omega: float, + # in/out + mf_impulses: wp.array2d(dtype=float), + v_out: wp.array(dtype=float), +): + """Matrix-free PGS solver for free rigid body contacts. + + Operates directly on body velocities stored in v_out (generalized coordinates). + Each iteration recomputes J*v from v_out and applies velocity corrections + immediately (Gauss-Seidel style). + """ + world = wp.tid() + m_count = mf_constraint_count[world] + if m_count == 0: + return + + for _ in range(iterations): + for i in range(m_count): + eff_inv = mf_eff_mass_inv[world, i] + if eff_inv <= 0.0: + continue + + ba = mf_body_a[world, i] + bb = mf_body_b[world, i] + + # Compute current J * v + jv = float(0.0) + if ba >= 0: + art_a = body_to_articulation[ba] + ds_a = art_dof_start[art_a] + for k in range(6): + jv += mf_J_a[world, i, k] * v_out[ds_a + k] + if bb >= 0: + art_b = body_to_articulation[bb] + ds_b = art_dof_start[art_b] + for k in range(6): + jv += mf_J_b[world, i, k] * v_out[ds_b + k] + + # PGS update: delta = -(J*v_current + bias) / d_ii + delta = -(jv + mf_rhs[world, i]) * eff_inv + new_impulse = mf_impulses[world, i] + omega * delta + old_impulse = mf_impulses[world, i] + + row_type = mf_row_type[world, i] + + # Project + if row_type == PGS_CONSTRAINT_TYPE_CONTACT: + if new_impulse < 0.0: + new_impulse = 0.0 + elif row_type == PGS_CONSTRAINT_TYPE_FRICTION: + parent_idx = mf_row_parent[world, i] + lambda_n = mf_impulses[world, parent_idx] + mu_val = mf_row_mu[world, i] + radius = wp.max(mu_val * lambda_n, 0.0) + + if radius <= 0.0: + new_impulse = 0.0 + else: + # Sibling friction row + if i == parent_idx + 1: + sib = parent_idx + 2 + else: + sib = parent_idx + 1 + + mf_impulses[world, i] = new_impulse + a = new_impulse + b = mf_impulses[world, sib] + mag = wp.sqrt(a * a + b * b) + if mag > radius: + scale = radius / mag + new_impulse = a * scale + mf_impulses[world, sib] = b * scale + # Apply sibling correction to velocities + sib_delta = b * scale - b + sib_ba = mf_body_a[world, sib] + sib_bb = mf_body_b[world, sib] + if sib_ba >= 0: + sib_art_a = body_to_articulation[sib_ba] + sib_ds_a = art_dof_start[sib_art_a] + for k in range(6): + v_out[sib_ds_a + k] = v_out[sib_ds_a + k] + mf_MiJt_a[world, sib, k] * sib_delta + if sib_bb >= 0: + sib_art_b = body_to_articulation[sib_bb] + sib_ds_b = art_dof_start[sib_art_b] + for k in range(6): + v_out[sib_ds_b + k] = v_out[sib_ds_b + k] + mf_MiJt_b[world, sib, k] * sib_delta + + delta_impulse = new_impulse - old_impulse + mf_impulses[world, i] = new_impulse + + # Apply velocity correction: v += M_inv * J^T * delta_impulse + if ba >= 0: + art_a2 = body_to_articulation[ba] + ds_a2 = art_dof_start[art_a2] + for k in range(6): + v_out[ds_a2 + k] = v_out[ds_a2 + k] + mf_MiJt_a[world, i, k] * delta_impulse + if bb >= 0: + art_b2 = body_to_articulation[bb] + ds_b2 = art_dof_start[art_b2] + for k in range(6): + v_out[ds_b2 + k] = v_out[ds_b2 + k] + mf_MiJt_b[world, i, k] * delta_impulse + + +@wp.kernel +def finalize_mf_constraint_counts( + mf_slot_counter: wp.array(dtype=int), + mf_max_constraints: int, + slots_per_contact: int, + # outputs + mf_constraint_count: wp.array(dtype=int), +): + """Clamp MF slot counter to max and store as constraint count. + + See :func:`finalize_world_constraint_counts` for the gap-avoidance + rationale behind the ``slots_per_contact`` rounding. + """ + world = wp.tid() + count = mf_slot_counter[world] + if count > mf_max_constraints: + count = mf_max_constraints + count = (count // slots_per_contact) * slots_per_contact + mf_constraint_count[world] = count + + +@wp.kernel +def build_mf_body_map( + mf_constraint_count: wp.array(dtype=int), + mf_body_a: wp.array2d(dtype=int), + mf_body_b: wp.array2d(dtype=int), + body_to_articulation: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + max_mf_bodies: int, + # outputs + mf_body_list: wp.array2d(dtype=int), + mf_body_dof_start: wp.array2d(dtype=int), + mf_body_count: wp.array(dtype=int), + mf_local_body_a: wp.array2d(dtype=int), + mf_local_body_b: wp.array2d(dtype=int), +): + """Build per-world compact body table and local body index mapping. + + Scans all MF constraint body indices, builds a unique body list per world, + and maps each constraint's body indices to local (compact) indices. + """ + world = wp.tid() + m = mf_constraint_count[world] + if m == 0: + mf_body_count[world] = 0 + return + + n_bodies = int(0) + + for i in range(m): + # Process body A + ba = mf_body_a[world, i] + if ba >= 0: + # Search for ba in body_list + found_a = int(-1) + for b in range(n_bodies): + if mf_body_list[world, b] == ba: + found_a = b + break + if found_a < 0 and n_bodies < max_mf_bodies: + found_a = n_bodies + mf_body_list[world, n_bodies] = ba + art_a = body_to_articulation[ba] + mf_body_dof_start[world, n_bodies] = art_dof_start[art_a] + n_bodies += 1 + mf_local_body_a[world, i] = found_a + else: + mf_local_body_a[world, i] = -1 + + # Process body B + bb = mf_body_b[world, i] + if bb >= 0: + found_b = int(-1) + for b in range(n_bodies): + if mf_body_list[world, b] == bb: + found_b = b + break + if found_b < 0 and n_bodies < max_mf_bodies: + found_b = n_bodies + mf_body_list[world, n_bodies] = bb + mf_body_dof_start[world, n_bodies] = art_dof_start[body_to_articulation[bb]] + n_bodies += 1 + mf_local_body_b[world, i] = found_b + else: + mf_local_body_b[world, i] = -1 + + mf_body_count[world] = n_bodies + + +@wp.kernel +def compute_mf_world_dof_offsets( + mf_constraint_count: wp.array(dtype=int), + mf_body_a: wp.array2d(dtype=int), + mf_body_b: wp.array2d(dtype=int), + body_to_articulation: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + world_dof_start: wp.array(dtype=int), + mf_max_constraints: int, + # outputs + mf_dof_a: wp.array2d(dtype=int), + mf_dof_b: wp.array2d(dtype=int), +): + """Compute world-relative DOF offsets for each MF contact body. + + For each MF constraint, stores the articulation DOF start minus the + world DOF start for body A and B. The two-phase GS kernel uses + these offsets to index into the shared velocity vector. + """ + tid = wp.tid() + world = tid // mf_max_constraints + c = tid % mf_max_constraints + if c >= mf_constraint_count[world]: + return + w_dof = world_dof_start[world] + ba = mf_body_a[world, c] + bb = mf_body_b[world, c] + if ba >= 0: + mf_dof_a[world, c] = art_dof_start[body_to_articulation[ba]] - w_dof + else: + mf_dof_a[world, c] = -1 + if bb >= 0: + mf_dof_b[world, c] = art_dof_start[body_to_articulation[bb]] - w_dof + else: + mf_dof_b[world, c] = -1 + + +@wp.kernel +def pgs_solve_loop( + world_constraint_count: wp.array(dtype=int), + max_constraints: int, + world_diag: wp.array2d(dtype=float), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), +): + """ + World-level Projected Gauss-Seidel solver. + + Similar to pgs_solve_contacts but operates on 2D world-indexed arrays. + """ + world = wp.tid() + m = world_constraint_count[world] + + if m == 0: + return + + for _ in range(iterations): + for i in range(m): + # Compute residual: w = rhs_i + sum_j C_ij * lambda_j + w = world_rhs[world, i] + for j in range(m): + w += world_C[world, i, j] * world_impulses[world, j] + + denom = world_diag[world, i] + if denom <= 0.0: + continue + + delta = -w / denom + new_impulse = world_impulses[world, i] + omega * delta + row_type = world_row_type[world, i] + + # --- Normal contact or joint limit: lambda_n >= 0 --- + if row_type == PGS_CONSTRAINT_TYPE_CONTACT or row_type == PGS_CONSTRAINT_TYPE_JOINT_LIMIT: + if new_impulse < 0.0: + new_impulse = 0.0 + world_impulses[world, i] = new_impulse + + # --- Friction: isotropic Coulomb --- + elif row_type == PGS_CONSTRAINT_TYPE_FRICTION: + parent_idx = world_row_parent[world, i] + lambda_n = world_impulses[world, parent_idx] + mu = world_row_mu[world, i] + radius = wp.max(mu * lambda_n, 0.0) + + if radius <= 0.0: + world_impulses[world, i] = 0.0 + continue + + world_impulses[world, i] = new_impulse + + # Sibling friction row: constraints are laid out as [normal, friction1, friction2] + # so friction rows are at parent_idx+1 and parent_idx+2 + if i == parent_idx + 1: + sib = parent_idx + 2 + else: + sib = parent_idx + 1 + + # Project tangent impulses onto friction disk + a = world_impulses[world, i] + b = world_impulses[world, sib] + + mag = wp.sqrt(a * a + b * b) + if mag > radius: + scale = radius / mag + world_impulses[world, i] = a * scale + world_impulses[world, sib] = b * scale + + else: + world_impulses[world, i] = new_impulse + + +@wp.kernel +def apply_impulses_world_par_dof( + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + art_dof_start: wp.array(dtype=int), + n_dofs: int, + n_arts: int, + world_constraint_count: wp.array(dtype=int), + max_constraints: int, + Y_group: wp.array3d(dtype=float), + world_impulses: wp.array2d(dtype=float), + v_hat: wp.array(dtype=float), + # outputs + v_out: wp.array(dtype=float), +): + """ + Accumulate velocity changes from world impulses for a single size group. + DOF-parallelized: each thread handles one (articulation, DOF) pair. + + v_out = v_hat + Y * impulses + """ + tid = wp.tid() + + # Decode thread index + local_dof = tid % n_dofs + idx = tid // n_dofs # group index + + if idx >= n_arts: + return + + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + dof_start = art_dof_start[art] + + # Inner loop only over constraints + delta_v = float(0.0) + for c in range(n_constraints): + delta_v += Y_group[idx, c, local_dof] * world_impulses[world, c] + + global_dof = dof_start + local_dof + v_out[global_dof] = v_hat[global_dof] + delta_v + + +@wp.kernel +def finalize_world_diag_cfm( + world_constraint_count: wp.array(dtype=int), + world_row_cfm: wp.array2d(dtype=float), + # in/out + world_diag: wp.array2d(dtype=float), +): + """Add CFM to world diagonal after Delassus accumulation.""" + world = wp.tid() + m = world_constraint_count[world] + + for i in range(m): + world_diag[world, i] += world_row_cfm[world, i] + + +@wp.kernel +def add_dense_contact_compliance_to_diag( + world_constraint_count: wp.array(dtype=int), + world_row_type: wp.array2d(dtype=int), + contact_alpha: float, + # in/out + world_diag: wp.array2d(dtype=float), +): + """Add normal-contact compliance to the dense PGS diagonal. + + The dense articulated contact path uses a Delassus diagonal in impulse + space. A compliance ``alpha = compliance / dt^2`` contributes an additional + diagonal term for normal contact rows only, yielding a softer normal + response without changing friction or joint-limit rows. + """ + world = wp.tid() + m = world_constraint_count[world] + + for i in range(m): + if world_row_type[world, i] == PGS_CONSTRAINT_TYPE_CONTACT: + world_diag[world, i] += contact_alpha + + +# ============================================================================= +# Parallelized Non-Tiled Kernels for Heterogeneous Multi-Articulation +# ============================================================================= +# These kernels parallelize across constraints (and constraint pairs) to achieve +# much better GPU utilization than the single-thread-per-articulation versions. + + +@wp.kernel +def hinv_jt_par_row( + # Grouped Cholesky factor storage [n_arts, n_dofs, n_dofs] + L_group: wp.array3d(dtype=float), + # Size-grouped Jacobian [n_arts_of_size, max_constraints, n_dofs] + J_group: wp.array3d(dtype=float), + # Indirection arrays + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + # Size parameters + n_dofs: int, + max_constraints: int, + n_arts: int, + # Output: Y = H^-1 * J^T [n_arts_of_size, max_constraints, n_dofs] + Y_group: wp.array3d(dtype=float), +): + """ + Compute Y = H^-1 * J^T for one size group using forward/backward substitution. + + Uses L_group (3D array) grouped by DOF size. + Efficient for small articulations where tile overhead dominates. + + Each thread handles one (articulation, constraint) pair. + + For each articulation in the group, solves: + L * L^T * Y = J^T + Using: + 1. Forward substitution: L * Z = J^T + 2. Backward substitution: L^T * Y = Z + + Thread dimension: n_arts_of_size * max_constraints + """ + tid = wp.tid() + + # Decode thread index + c = tid % max_constraints # constraint index + idx = tid // max_constraints # group index (articulation within size group) + + # Bounds check for articulation + if idx >= n_arts: + return + + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + + # Early exit if this constraint is beyond the actual count + if c >= n_constraints: + return + + # ---------------------------------------------------------------- + # Forward substitution: L * z = j + # L is lower triangular, so solve from top to bottom + # ---------------------------------------------------------------- + for i in range(n_dofs): + # z[i] = (j[i] - sum_{ki} L[k,i] * y[k]) / L[i,i] + # Note: L^T[i,k] = L[k,i], so we read L[k,i] for k > i + val = Y_group[idx, c, i] # This is z[i] from forward pass + + for k in range(i + 1, n_dofs): + val -= L_group[idx, k, i] * Y_group[idx, c, k] + + L_ii = L_group[idx, i, i] + if L_ii != 0.0: + Y_group[idx, c, i] = val / L_ii + else: + Y_group[idx, c, i] = 0.0 + + +@wp.kernel +def delassus_par_row_col( + # Size-grouped arrays + J_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] + Y_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] + # Indirection arrays + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + # Size parameters + n_dofs: int, + max_constraints: int, + n_arts: int, + # Output: Delassus matrix C and diagonal (accumulated via atomics) + world_C: wp.array3d(dtype=float), # [world_count, max_constraints, max_constraints] + world_diag: wp.array2d(dtype=float), # [world_count, max_constraints] +): + """ + Accumulate Delassus matrix contribution C += J * Y^T from one size group. + + PARALLELIZED VERSION: Each thread handles one (articulation, i, j) triplet. + + The Delassus matrix is: C = sum_art J_art * H_art^-1 * J_art^T = sum_art J_art * Y_art^T + + Since Y is stored as [constraint, dof], we compute: + C[i,j] = sum_k J[i,k] * Y[j,k] + + Thread dimension: n_arts_of_size * max_constraints * max_constraints + """ + tid = wp.tid() + + # Decode thread index + j = tid % max_constraints + i = (tid // max_constraints) % max_constraints + idx = tid // (max_constraints * max_constraints) + + # Bounds check for articulation + if idx >= n_arts: + return + + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + + # Early exit if this (i, j) is beyond the actual constraint count + if i >= n_constraints or j >= n_constraints: + return + + # Compute C[i,j] = sum_k J[i,k] * Y[j,k] + val = float(0.0) + for k in range(n_dofs): + val += J_group[idx, i, k] * Y_group[idx, j, k] + + if val != 0.0: + wp.atomic_add(world_C, world, i, j, val) + + # Also accumulate diagonal separately (only when i == j) + if i == j and val != 0.0: + wp.atomic_add(world_diag, world, i, val) + + +# ============================================================================= +# Tiled kernels for homogenous multi-articulation support +# ============================================================================= + + +@wp.kernel +def crba_fill_par_dof( + articulation_start: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + mass_update_mask: wp.array(dtype=int), + joint_ancestor: wp.array(dtype=int), + joint_qd_start: wp.array(dtype=int), + joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_S_s: wp.array(dtype=wp.spatial_vector), + body_I_c: wp.array(dtype=wp.spatial_matrix), + # Size-group parameters + group_to_art: wp.array(dtype=int), + n_dofs: int, # = TILE_DOF for tiled path + # outputs + H_group: wp.array3d(dtype=float), # [n_arts_of_size, n_dofs, n_dofs] +): + """ + CRBA fill kernel that writes directly to size-grouped H storage. + + Thread dimension: n_arts_of_size * n_dofs (one thread per articulation-column pair) + + This version is for homogenous multi-articulation where all articulations have + the same DOF count equal to TILE_DOF. + """ + tid = wp.tid() + + group_idx = tid // n_dofs + col_idx = tid % n_dofs + + art_idx = group_to_art[group_idx] + + if mass_update_mask[art_idx] == 0: + return + + # All articulations in this group have exactly n_dofs DOFs + if col_idx >= n_dofs: + return + + global_dof_start = articulation_dof_start[art_idx] + target_dof_global = global_dof_start + col_idx + + joint_start = articulation_start[art_idx] + joint_end = articulation_start[art_idx + 1] + + # Find the joint that owns this DOF + pivot_joint = int(-1) + for j in range(joint_start, joint_end): + q_start = joint_qd_start[j] + q_end = joint_qd_start[j + 1] + if target_dof_global >= q_start and target_dof_global < q_end: + pivot_joint = j + break + + if pivot_joint == -1: + return + + # Compute Force F = I_c[pivot] * S[column] + S_col = joint_S_s[target_dof_global] + I_comp = body_I_c[pivot_joint] + F = I_comp * S_col + + # Walk up the tree and project F onto ancestors + # H[row, col] = S[row] * F + curr = pivot_joint + + while curr != -1: + if curr < joint_start: + break + + q_start = joint_qd_start[curr] + q_dim = joint_dof_dim[curr] + count = q_dim[0] + q_dim[1] + + dof_offset_local = q_start - global_dof_start + + for k in range(count): + row_idx = dof_offset_local + k + + S_row = joint_S_s[q_start + k] + val = wp.dot(S_row, F) + + # Write to grouped 3D array + H_group[group_idx, row_idx, col_idx] = val + H_group[group_idx, col_idx, row_idx] = val + + curr = joint_ancestor[curr] + + +@wp.kernel +def trisolve_loop( + L_group: wp.array3d(dtype=float), # [n_arts_of_size, n_dofs, n_dofs] + group_to_art: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + n_dofs: int, + joint_tau: wp.array(dtype=float), # [total_dofs] + # output + joint_qdd: wp.array(dtype=float), # [total_dofs] +): + """ + Solve L * L^T * qdd = tau for grouped articulations using forward/backward substitution. + + Thread dimension: n_arts_of_size (one thread per articulation in this size group) + """ + idx = wp.tid() + art = group_to_art[idx] + dof_start = articulation_dof_start[art] + + # Forward substitution: L * z = tau + # z is stored temporarily in joint_qdd + for i in range(n_dofs): + val = joint_tau[dof_start + i] + for k in range(i): + L_ik = L_group[idx, i, k] + val -= L_ik * joint_qdd[dof_start + k] + + L_ii = L_group[idx, i, i] + if L_ii != 0.0: + joint_qdd[dof_start + i] = val / L_ii + else: + joint_qdd[dof_start + i] = 0.0 + + # Backward substitution: L^T * qdd = z + for i_rev in range(n_dofs): + i = n_dofs - 1 - i_rev + + val = joint_qdd[dof_start + i] + for k in range(i + 1, n_dofs): + L_ki = L_group[idx, k, i] + val -= L_ki * joint_qdd[dof_start + k] + + L_ii = L_group[idx, i, i] + if L_ii != 0.0: + joint_qdd[dof_start + i] = val / L_ii + else: + joint_qdd[dof_start + i] = 0.0 + + +@wp.kernel +def gather_tau_to_groups( + joint_tau: wp.array(dtype=float), # [total_dofs] + group_to_art: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + n_dofs: int, + tau_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] +): + """Gather joint_tau from 1D array into grouped 3D buffer for tiled solve. + + Thread dimension: n_arts_of_size (one thread per articulation in this size group) + """ + idx = wp.tid() + art = group_to_art[idx] + dof_start = articulation_dof_start[art] + for i in range(n_dofs): + tau_group[idx, i, 0] = joint_tau[dof_start + i] + + +@wp.kernel +def scatter_qdd_from_groups( + qdd_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] + group_to_art: wp.array(dtype=int), + articulation_dof_start: wp.array(dtype=int), + n_dofs: int, + joint_qdd: wp.array(dtype=float), # [total_dofs] +): + """Scatter qdd from grouped 3D buffer back to 1D array after tiled solve. + + Thread dimension: n_arts_of_size (one thread per articulation in this size group) + """ + idx = wp.tid() + art = group_to_art[idx] + dof_start = articulation_dof_start[art] + for i in range(n_dofs): + joint_qdd[dof_start + i] = qdd_group[idx, i, 0] + + +@wp.kernel +def vector_add_inplace(a: wp.array(dtype=float), b: wp.array(dtype=float)): + """a[i] += b[i]""" + i = wp.tid() + a[i] = a[i] + b[i] + + +@wp.kernel +def compute_delta_and_accumulate( + v_out: wp.array(dtype=float), + v_snap: wp.array(dtype=float), + v_accum: wp.array(dtype=float), +): + """delta = v_out - v_snap; v_accum += delta; v_snap = delta (reuse buffer for rhs_accum input)""" + i = wp.tid() + delta = v_out[i] - v_snap[i] + v_accum[i] = v_accum[i] + delta + v_snap[i] = delta + + +# ============================================================================= +# PGS Convergence Diagnostic Kernel (velocity-space mode) +# ============================================================================= + + +@wp.kernel +def pgs_convergence_diagnostic_velocity( + # Dense constraints + constraint_count: wp.array(dtype=int), + world_dof_start: wp.array(dtype=int), + rhs: wp.array2d(dtype=float), + impulses: wp.array2d(dtype=float), + prev_impulses: wp.array2d(dtype=float), + row_type: wp.array2d(dtype=int), + row_parent: wp.array2d(dtype=int), + row_mu: wp.array2d(dtype=float), + J_world: wp.array3d(dtype=float), + max_constraints: int, + max_world_dofs: int, + # MF constraints + mf_constraint_count: wp.array(dtype=int), + mf_rhs: wp.array2d(dtype=float), + mf_impulses: wp.array2d(dtype=float), + prev_mf_impulses: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_row_mu: wp.array2d(dtype=float), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_dof_a: wp.array2d(dtype=int), + mf_dof_b: wp.array2d(dtype=int), + mf_max_constraints: int, + # Velocity + v_out: wp.array(dtype=float), + # Output: [worlds, 4] + metrics: wp.array2d(dtype=float), +): + """Compute per-world PGS convergence metrics for velocity-space mode. + + Metrics: + [0] max|delta_lambda| across all constraint rows + [1] sum(lambda_n * residual_n) for normal contacts (complementarity gap) + [2] sum(residual_t^2) for sticking friction contacts (tangent residual energy) + [3] sum(FB(lambda_n, residual_n)^2) for normal contacts (Fischer-Burmeister) + """ + world = wp.tid() + + m_dense = constraint_count[world] + m_mf = mf_constraint_count[world] + w_dof_start = world_dof_start[world] + + max_dl = float(0.0) + comp_gap = float(0.0) + tang_res = float(0.0) + fb_merit = float(0.0) + + # --- Dense constraints --- + for i in range(m_dense): + lam = impulses[world, i] + prev_lam = prev_impulses[world, i] + dl = wp.abs(lam - prev_lam) + if dl > max_dl: + max_dl = dl + + # Compute residual: r_i = J_i * v + bias_i + jv = float(0.0) + for d in range(max_world_dofs): + jv += J_world[world, i, d] * v_out[w_dof_start + d] + residual = jv + rhs[world, i] + + rt = row_type[world, i] + if rt == PGS_CONSTRAINT_TYPE_CONTACT: + # Normal: complementarity gap and FB + comp_gap += lam * residual + fb_val = wp.sqrt(lam * lam + residual * residual) - lam - residual + fb_merit += fb_val * fb_val + elif rt == PGS_CONSTRAINT_TYPE_FRICTION: + # Friction: tangent residual for sticking contacts + parent_idx = row_parent[world, i] + lambda_n = impulses[world, parent_idx] + mu = row_mu[world, i] + radius = mu * lambda_n + if radius > 0.0: + # Check if sticking: |lambda_t| < mu * lambda_n + # Get sibling + if i == parent_idx + 1: + sib = parent_idx + 2 + else: + sib = parent_idx + 1 + lam_t1 = impulses[world, i] + lam_t2 = impulses[world, sib] + t_mag = wp.sqrt(lam_t1 * lam_t1 + lam_t2 * lam_t2) + if t_mag < radius * 0.999: # sticking (with small tolerance) + tang_res += residual * residual + + # --- MF constraints --- + for i in range(m_mf): + lam = mf_impulses[world, i] + prev_lam = prev_mf_impulses[world, i] + dl = wp.abs(lam - prev_lam) + if dl > max_dl: + max_dl = dl + + # Compute residual: r = J_a * v_a + J_b * v_b + bias + dof_a = mf_dof_a[world, i] + dof_b = mf_dof_b[world, i] + jv = float(0.0) + if dof_a >= 0: + for k in range(6): + jv += mf_J_a[world, i, k] * v_out[dof_a + k] + if dof_b >= 0: + for k in range(6): + jv += mf_J_b[world, i, k] * v_out[dof_b + k] + residual = jv + mf_rhs[world, i] + + rt = mf_row_type[world, i] + if rt == PGS_CONSTRAINT_TYPE_CONTACT: + comp_gap += lam * residual + fb_val = wp.sqrt(lam * lam + residual * residual) - lam - residual + fb_merit += fb_val * fb_val + elif rt == PGS_CONSTRAINT_TYPE_FRICTION: + parent_idx = mf_row_parent[world, i] + lambda_n = mf_impulses[world, parent_idx] + mu = mf_row_mu[world, i] + radius = mu * lambda_n + if radius > 0.0: + if i == parent_idx + 1: + sib = parent_idx + 2 + else: + sib = parent_idx + 1 + lam_t1 = mf_impulses[world, i] + lam_t2 = mf_impulses[world, sib] + t_mag = wp.sqrt(lam_t1 * lam_t1 + lam_t2 * lam_t2) + if t_mag < radius * 0.999: + tang_res += residual * residual + + metrics[world, 0] = max_dl + metrics[world, 1] = comp_gap + metrics[world, 2] = tang_res + metrics[world, 3] = fb_merit diff --git a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py new file mode 100644 index 0000000000..ed58597ac4 --- /dev/null +++ b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py @@ -0,0 +1,4821 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import ClassVar + +import numpy as np +import warp as wp + +from ...core.types import override +from ...sim import Contacts, Control, JointType, Model, State +from ..featherstone.kernels import eval_fk_with_velocity_conversion +from ..semi_implicit.kernels_contact import ( + eval_particle_body_contact_forces, + eval_particle_contact_forces, +) +from ..semi_implicit.kernels_particle import ( + eval_bending_forces, + eval_spring_forces, + eval_tetrahedra_forces, + eval_triangle_forces, +) +from ..solver import SolverBase +from .kernels import ( + TILE_THREADS, + add_dense_contact_compliance_to_diag, + allocate_joint_limit_slots, + allocate_world_contact_slots, + apply_augmented_joint_tau, + apply_augmented_mass_diagonal_grouped, + apply_impulses_world_par_dof, + build_augmented_joint_rows, + build_mass_update_mask, + build_mf_body_map, + build_mf_contact_rows, + cholesky_loop, + clamp_joint_tau, + compute_com_transforms, + compute_composite_inertia, + compute_contact_linear_force_from_impulses, + compute_delta_and_accumulate, + compute_mf_body_Hinv, + compute_mf_effective_mass_and_rhs, + compute_mf_world_dof_offsets, + compute_spatial_inertia, + compute_velocity_predictor, + compute_world_contact_bias, + convert_root_free_qd_local_to_world, + convert_root_free_qd_world_to_local, + copy_int_array_masked, + crba_fill_par_dof, + delassus_par_row_col, + detect_limit_count_changes, + diag_from_JY_par_art, + eval_rigid_fk, + eval_rigid_id, + eval_rigid_mass, + eval_rigid_tau, + finalize_mf_constraint_counts, + finalize_world_constraint_counts, + finalize_world_diag_cfm, + gather_JY_to_world, + gather_tau_to_groups, + hinv_jt_par_row, + integrate_generalized_joints, + pack_contact_linear_force_as_spatial, + pgs_convergence_diagnostic_velocity, + pgs_solve_loop, + pgs_solve_mf_loop, + populate_joint_limit_J_for_size, + populate_world_J_for_size, + prepare_world_impulses, + rhs_accum_world_par_art, + scatter_qdd_from_groups, + trisolve_loop, + update_articulation_origins, + update_articulation_root_com_offsets, + update_body_qd_from_featherstone, + update_qdd_from_velocity, + vector_add_inplace, +) + + +@wp.kernel +def localize_parent_indices( + counts: wp.array(dtype=int), + max_constraints: int, + parent_arr: wp.array(dtype=int), + parent_local_arr: wp.array(dtype=int), +): + art = wp.tid() + m = counts[art] + base = art * max_constraints + + for i in range(m): + idx = base + i + p = parent_arr[idx] + if p >= 0: + parent_local_arr[idx] = p - base + else: + parent_local_arr[idx] = -1 + + +class SolverFeatherPGS(SolverBase): + """A semi-implicit integrator using symplectic Euler that operates + on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics + based on Featherstone's composite rigid body algorithm (CRBA). + + See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014. + + Instead of maximal coordinates :attr:`~newton.State.body_q` (rigid body positions) and :attr:`~newton.State.body_qd` + (rigid body velocities) as is the case in :class:`~newton.solvers.SolverSemiImplicit` and :class:`~newton.solvers.SolverXPBD`, + :class:`~newton._src.solvers.feather_pgs.SolverFeatherPGS` uses :attr:`~newton.State.joint_q` and :attr:`~newton.State.joint_qd` to represent + the positions and velocities of joints without allowing any redundant degrees of freedom. + + After constructing :class:`~newton.Model` and :class:`~newton.State` objects this time-integrator + may be used to advance the simulation state forward in time. + + Note: + Unlike :class:`~newton.solvers.SolverSemiImplicit` and :class:`~newton.solvers.SolverXPBD`, :class:`~newton._src.solvers.feather_pgs.SolverFeatherPGS` + does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. + Floating-base systems require an explicit free joint with which the body is connected to the world, + see :meth:`newton.ModelBuilder.add_joint_free`. + + Semi-implicit time integration is a variational integrator that + preserves energy, however it not unconditionally stable, and requires a time-step + small enough to support the required stiffness and damping forces. + + See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method + + This solver uses the routines from :class:`~newton.solvers.SolverSemiImplicit` to simulate particles, cloth, and soft bodies. + + Example + ------- + + .. code-block:: python + + from newton._src.solvers.feather_pgs import SolverFeatherPGS + + solver = SolverFeatherPGS(model) + + # simulation loop + for i in range(100): + solver.step(state_in, state_out, control, contacts, dt) + state_in, state_out = state_out, state_in + + """ + + def __init__( + self, + model: Model, + angular_damping: float = 0.05, + update_mass_matrix_interval: int = 1, + friction_smoothing: float = 1.0, + enable_contact_friction: bool = True, + enable_joint_limits: bool = False, + pgs_iterations: int = 12, + pgs_beta: float = 0.2, + pgs_cfm: float = 1.0e-6, + dense_contact_compliance: float = 0.0, + pgs_omega: float = 1.0, + dense_max_constraints: int = 32, + pgs_warmstart: bool = False, + pgs_mode: str = "split", + mf_max_constraints: int = 512, + # Kernel selection per operation + cholesky_kernel: str = "auto", + trisolve_kernel: str = "auto", + hinv_jt_kernel: str = "auto", + delassus_kernel: str = "auto", + pgs_kernel: str = "tiled_row", + # Streaming kernel chunk sizes (None = auto-select) + delassus_chunk_size: int | None = None, + pgs_chunk_size: int | None = None, + # Auto selection threshold + small_dof_threshold: int = 12, + # Parallelism options + use_parallel_streams: bool = True, + double_buffer: bool = True, + nvtx: bool = False, + pgs_debug: bool = False, + ): + """ + Args: + model (Model): the model to be simulated. + angular_damping (float, optional): Angular damping factor. Defaults to 0.05. + update_mass_matrix_interval (int, optional): How often to update the mass matrix (every n-th time the :meth:`step` function gets called). Defaults to 1. + friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0. + enable_contact_friction (bool, optional): Enables Coulomb friction contacts inside the PGS solve. Defaults to True. + enable_joint_limits (bool, optional): Enforce joint position limits as unilateral PGS + constraints. Each violated limit adds one constraint row. Supported with + ``pgs_kernel="loop"`` and ``pgs_kernel="tiled_row"``; the ``"tiled_contact"`` + and ``"streaming"`` PGS kernels are *not* compatible. Defaults to False. + pgs_iterations (int, optional): Number of Gauss-Seidel iterations to apply per frame. Defaults to 12. + pgs_beta (float, optional): ERP style position correction factor. Defaults to 0.2. + pgs_cfm (float, optional): Compliance/regularization added to the Delassus diagonal. Defaults to 1.0e-6. + dense_contact_compliance (float, optional): Normal contact compliance [m/N] applied + only to dense articulated contact rows. Converted to an impulse-space diagonal + term using ``compliance / dt^2``. Defaults to 0.0. + pgs_omega (float, optional): Successive over-relaxation factor for the PGS sweep. Defaults to 1.0. + dense_max_constraints (int, optional): Maximum number of dense (articulation) contact constraint + rows stored per world. Free rigid body contacts are stored separately, bounded by + mf_max_constraints. Defaults to 32. + pgs_warmstart (bool, optional): Re-use impulses from the previous frame when contacts persist. Defaults to False. + pgs_mode (str, optional): PGS mode. "dense" builds the full Delassus matrix C = J*H^{-1}*J^T + and solves in impulse space (Gauss-Seidel) for all contacts. "split" uses the dense + path for articulated bodies and a cheaper matrix-free PGS path for free rigid body + contacts. "matrix_free" skips C entirely, recomputes J*v each iteration, and uses only + the diagonal for preconditioning — O(max_constraints) memory instead of + O(max_constraints^2). Defaults to "split". + mf_max_constraints (int, optional): Maximum number of matrix-free constraints per world. Defaults to 512. + cholesky_kernel (str, optional): "tiled", "loop", or "auto" for Cholesky factorization. Defaults to "auto". + trisolve_kernel (str, optional): "tiled", "loop", or "auto" for triangular solve. Defaults to "auto". + hinv_jt_kernel (str, optional): "tiled", "par_row", or "auto" for H^{-1}J^T. Defaults to "auto". + delassus_kernel (str, optional): "tiled", "par_row_col", or "auto" for Delassus accumulation + (C = J * H^{-1} * J^T). "tiled" uses a streaming CUDA kernel that chunks shared memory + and scales to any constraint count. "par_row_col" launches one thread per matrix element. + "auto" selects "tiled" when DOFs exceed the threshold. Defaults to "auto". + pgs_kernel (str, optional): "loop", "tiled_row", or "tiled_contact" for PGS solve. Defaults to "tiled_row". + delassus_chunk_size (int, optional): Chunk size (in constraint rows) for the streaming Delassus + kernel. Controls how many rows of J and Y are loaded into shared memory at once. + None selects automatically based on shared memory heuristics. Defaults to None. + pgs_chunk_size (int, optional): Chunk size (in contacts, i.e. groups of 3 constraint rows) + for the streaming PGS kernel. Controls how many block-rows of the Delassus matrix are + preloaded into shared memory at once. 1 = current streaming behavior (one block-row + at a time). None defaults to 1. Defaults to None. + small_dof_threshold (int, optional): DOF threshold for "auto" kernel selection. Defaults to 12. + use_parallel_streams (bool, optional): Dispatch size groups on separate CUDA streams. + Defaults to True. + Auto selection behavior: + - auto: size > threshold -> tiled, else loop/par_row. + - Delassus auto/tiled: streaming kernel (handles any constraint count via chunking). + + """ + super().__init__(model) + + self.angular_damping = angular_damping + self.update_mass_matrix_interval = update_mass_matrix_interval + self.friction_smoothing = friction_smoothing + self.enable_contact_friction = enable_contact_friction + self.enable_joint_limits = enable_joint_limits + self.pgs_iterations = pgs_iterations + self.pgs_beta = pgs_beta + self.pgs_cfm = pgs_cfm + self.dense_contact_compliance = dense_contact_compliance + self.pgs_omega = pgs_omega + self.dense_max_constraints = dense_max_constraints + self.pgs_warmstart = pgs_warmstart + if pgs_mode not in ("dense", "split", "matrix_free"): + raise ValueError(f"pgs_mode must be 'dense', 'split', or 'matrix_free', got {pgs_mode!r}") + self.pgs_mode = pgs_mode + self.mf_max_constraints = mf_max_constraints + self._double_buffer = double_buffer + self._nvtx = nvtx + self.pgs_debug = pgs_debug + self._pgs_convergence_log: list[np.ndarray] = [] + valid_cholesky = {"tiled", "loop", "auto"} + if cholesky_kernel not in valid_cholesky: + raise ValueError(f"cholesky_kernel must be one of {sorted(valid_cholesky)}") + + valid_trisolve = {"tiled", "loop", "auto"} + if trisolve_kernel not in valid_trisolve: + raise ValueError(f"trisolve_kernel must be one of {sorted(valid_trisolve)}") + + valid_hinv_jt = {"tiled", "par_row", "auto"} + if hinv_jt_kernel not in valid_hinv_jt: + raise ValueError(f"hinv_jt_kernel must be one of {sorted(valid_hinv_jt)}") + + valid_delassus = {"tiled", "par_row_col", "auto"} + if delassus_kernel not in valid_delassus: + raise ValueError(f"delassus_kernel must be one of {sorted(valid_delassus)}") + + valid_pgs = {"loop", "tiled_row", "tiled_contact", "streaming"} + if pgs_kernel not in valid_pgs: + raise ValueError(f"pgs_kernel must be one of {sorted(valid_pgs)}") + + self.cholesky_kernel = cholesky_kernel + self.trisolve_kernel = trisolve_kernel + self.hinv_jt_kernel = hinv_jt_kernel + self.delassus_kernel = delassus_kernel + self.pgs_kernel = pgs_kernel + self.delassus_chunk_size = delassus_chunk_size + self.pgs_chunk_size = pgs_chunk_size if pgs_chunk_size is not None else 1 + self.small_dof_threshold = small_dof_threshold + self.use_parallel_streams = use_parallel_streams + + self._step = 0 + self._force_mass_update = False + self._last_step_dt = None + + self._compute_articulation_metadata(model) + + self._allocate_common_buffers(model) + self._allocate_buffers(model) + self._allocate_world_buffers(model) + self._allocate_mf_buffers(model) + self._allocate_debug_buffers(model) + self._scatter_armature_to_groups(model) + self._init_size_group_streams(model) + self._dummy_contact_impulses = wp.zeros((1, 1), dtype=wp.float32, device=model.device) + + if model.shape_material_mu is not None: + self.shape_material_mu = model.shape_material_mu + else: + self.shape_material_mu = wp.zeros( + (1,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad + ) + + self._init_double_buffer_stream() + + def _compute_articulation_metadata(self, model): + self._compute_articulation_indices(model) + self._compute_root_free_metadata(model) + self._setup_size_grouping(model) + self._setup_world_mapping(model) + self._is_one_art_per_world = self.world_count == model.articulation_count + self._is_homogeneous = (len(self.size_groups) == 1) if self.size_groups else True + self._build_body_maps(model) + self._classify_free_rigid_bodies(model) + + def _compute_articulation_indices(self, model): + # calculate total size and offsets of Jacobian and mass matrices for entire system + if model.joint_count: + self.J_size = 0 + self.M_size = wp.int64(0) + self.H_size = 0 + + articulation_J_start = [] + articulation_M_start = [] + articulation_H_start = [] + + articulation_M_rows = [] + articulation_H_rows = [] + articulation_J_rows = [] + articulation_J_cols = [] + + articulation_dof_start = [] + articulation_coord_start = [] + + articulation_start = model.articulation_start.numpy() + joint_q_start = model.joint_q_start.numpy() + joint_qd_start = model.joint_qd_start.numpy() + + for i in range(model.articulation_count): + first_joint = articulation_start[i] + last_joint = articulation_start[i + 1] + + first_coord = joint_q_start[first_joint] + + first_dof = joint_qd_start[first_joint] + last_dof = joint_qd_start[last_joint] + + joint_count = last_joint - first_joint + dof_count = last_dof - first_dof + + articulation_J_start.append(self.J_size) + articulation_M_start.append(int(self.M_size)) + articulation_H_start.append(self.H_size) + articulation_dof_start.append(first_dof) + articulation_coord_start.append(first_coord) + + # bit of data duplication here, but will leave it as such for clarity + articulation_M_rows.append(joint_count * 6) + articulation_H_rows.append(dof_count) + articulation_J_rows.append(joint_count * 6) + articulation_J_cols.append(dof_count) + + self.J_size += 6 * joint_count * dof_count + self.M_size = wp.int64(self.M_size + wp.int64(joint_count * 36)) + self.H_size += dof_count * dof_count + + # matrix offsets for grouped gemm + self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device) + self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device) + self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device) + + self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device) + self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device) + self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device) + self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device) + + self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device) + self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device) + + self.articulation_max_dofs = int(max(articulation_H_rows)) if articulation_H_rows else 0 + self.M_size = int(self.M_size) + else: + self.M_size = 0 + self.articulation_max_dofs = 0 + + def _compute_root_free_metadata(self, model): + if not model.articulation_count or not model.joint_count: + self.articulation_root_is_free = None + self.articulation_root_dof_start = None + self._has_root_free = False + return + + articulation_start = model.articulation_start.numpy() + joint_type = model.joint_type.numpy() + joint_parent = model.joint_parent.numpy() + joint_qd_start = model.joint_qd_start.numpy() + + root_is_free = np.zeros(model.articulation_count, dtype=np.int32) + root_dof_start = np.zeros(model.articulation_count, dtype=np.int32) + + for art in range(model.articulation_count): + root_joint = articulation_start[art] + root_dof_start[art] = int(joint_qd_start[root_joint]) + jt = int(joint_type[root_joint]) + jp = int(joint_parent[root_joint]) + if jp == -1 and (jt == int(JointType.FREE) or jt == int(JointType.DISTANCE)): + root_is_free[art] = 1 + + self.articulation_root_is_free = wp.array(root_is_free, dtype=wp.int32, device=model.device) + self.articulation_root_dof_start = wp.array(root_dof_start, dtype=wp.int32, device=model.device) + self._has_root_free = bool(np.any(root_is_free != 0)) + + def _setup_size_grouping(self, model): + """Set up size-grouped storage and indirection arrays for multi-articulation support. + + This enables efficient handling of articulations with different DOF counts by grouping + them by size, allowing optimized tiled kernel launches for each size group. + """ + if not model.articulation_count or not model.joint_count: + self.size_groups = [] + self.n_arts_by_size = {} + return + + device = model.device + + # Get DOF counts per articulation + articulation_start = model.articulation_start.numpy() + joint_qd_start = model.joint_qd_start.numpy() + + articulation_dof_counts = np.zeros(model.articulation_count, dtype=np.int32) + for art_idx in range(model.articulation_count): + first_joint = articulation_start[art_idx] + last_joint = articulation_start[art_idx + 1] + first_dof = joint_qd_start[first_joint] + last_dof = joint_qd_start[last_joint] + articulation_dof_counts[art_idx] = last_dof - first_dof + + # Determine unique sizes (sorted descending for largest first) + unique_sizes = sorted(set(articulation_dof_counts), reverse=True) + self.size_groups = unique_sizes + self.n_arts_by_size = {size: int(np.sum(articulation_dof_counts == size)) for size in unique_sizes} + + # Build indirection arrays + art_size_np = articulation_dof_counts.copy() + art_group_idx_np = np.zeros(model.articulation_count, dtype=np.int32) + group_to_art_np = {size: np.zeros(self.n_arts_by_size[size], dtype=np.int32) for size in unique_sizes} + + # Track current index within each size group + size_counters = dict.fromkeys(unique_sizes, 0) + + for art_idx in range(model.articulation_count): + size = articulation_dof_counts[art_idx] + group_idx = size_counters[size] + + art_group_idx_np[art_idx] = group_idx + group_to_art_np[size][group_idx] = art_idx + + size_counters[size] += 1 + + # Copy to GPU + self.art_size = wp.array(art_size_np, dtype=wp.int32, device=device) + self.art_group_idx = wp.array(art_group_idx_np, dtype=wp.int32, device=device) + self.group_to_art = { + size: wp.array(group_to_art_np[size], dtype=wp.int32, device=device) for size in unique_sizes + } + + def _setup_world_mapping(self, model): + """Set up world-level mapping for multi-articulation support. + + Maps articulations to worlds and computes per-world articulation ranges. + """ + if not model.articulation_count: + self.world_count = 0 + self.art_to_world = None + self.world_art_start = None + self._is_multi_articulation = False + self._max_arts_per_world = 0 + return + + device = model.device + + # Get articulation-to-world mapping from model + if model.articulation_world is not None: + art_to_world_np = model.articulation_world.numpy().astype(np.int32) + # Handle -1 (global) by mapping to world 0 + art_to_world_np = np.where(art_to_world_np < 0, 0, art_to_world_np) + self.world_count = int(np.max(art_to_world_np)) + 1 + else: + # Default: one articulation per world (current behavior) + art_to_world_np = np.arange(model.articulation_count, dtype=np.int32) + self.world_count = model.articulation_count + + self.art_to_world = wp.array(art_to_world_np, dtype=wp.int32, device=device) + + # Compute per-world articulation ranges + # Count articulations per world + world_art_counts = np.zeros(self.world_count, dtype=np.int32) + for world_idx in art_to_world_np: + world_art_counts[world_idx] += 1 + + # Compute start indices (exclusive prefix sum) + world_art_start_np = np.zeros(self.world_count + 1, dtype=np.int32) + world_art_start_np[1:] = np.cumsum(world_art_counts) + + self.world_art_start = wp.array(world_art_start_np, dtype=wp.int32, device=device) + + # Detect if we have multiple articulations per world + self._max_arts_per_world = int(np.max(world_art_counts)) if len(world_art_counts) > 0 else 0 + self._is_multi_articulation = self._max_arts_per_world > 1 + + def _build_body_maps(self, model): + if not model.body_count or not model.articulation_count: + self.body_to_joint = None + self.body_to_articulation = None + return + + joint_child = model.joint_child.numpy() + articulation_start = model.articulation_start.numpy() + + body_to_joint = [-1] * model.body_count + body_to_articulation = [-1] * model.body_count + + for articulation in range(model.articulation_count): + joint_start = articulation_start[articulation] + joint_end = articulation_start[articulation + 1] + + for joint_index in range(joint_start, joint_end): + child = joint_child[joint_index] + if child < 0: + continue + + body_to_joint[child] = joint_index + body_to_articulation[child] = articulation + + device = model.device + self.body_to_joint = wp.array(body_to_joint, dtype=wp.int32, device=device) + self.body_to_articulation = wp.array(body_to_articulation, dtype=wp.int32, device=device) + + def _classify_free_rigid_bodies(self, model): + """Identify articulations that are single free rigid bodies. + + An articulation is "free rigid" if it has exactly 1 joint, that joint + is FREE type, and the joint parent is -1 (world). These can be solved + with a cheaper matrix-free PGS path. + """ + if not model.articulation_count or not model.joint_count: + self._has_free_rigid_bodies = False + self._has_mixed_contacts = False + self._n_free_rigid = 0 + self.is_free_rigid = None + return + + joint_type_np = model.joint_type.numpy() + joint_parent_np = model.joint_parent.numpy() + articulation_start_np = model.articulation_start.numpy() + + is_free_rigid_np = np.zeros(model.articulation_count, dtype=np.int32) + for art_idx in range(model.articulation_count): + first_joint = articulation_start_np[art_idx] + last_joint = articulation_start_np[art_idx + 1] + if last_joint - first_joint == 1: + if int(joint_type_np[first_joint]) == int(JointType.FREE) and int(joint_parent_np[first_joint]) == -1: + is_free_rigid_np[art_idx] = 1 + + n_free = int(np.sum(is_free_rigid_np)) + self._has_free_rigid_bodies = n_free > 0 + self._n_free_rigid = n_free + self._has_mixed_contacts = self._has_free_rigid_bodies and self._n_free_rigid < model.articulation_count + self.is_free_rigid = wp.array(is_free_rigid_np, dtype=wp.int32, device=model.device) + + def _compute_world_dof_mapping(self, model): + """Compute per-world DOF start and max DOF count for consolidated J/Y arrays.""" + art_to_world_np = self.art_to_world.numpy() + art_dof_start_np = self.articulation_dof_start.numpy() + art_H_rows_np = self.articulation_H_rows.numpy() + + world_dof_start_np = np.full(self.world_count, np.iinfo(np.int32).max, dtype=np.int32) + world_dof_end_np = np.zeros(self.world_count, dtype=np.int32) + + for art_idx in range(model.articulation_count): + w = art_to_world_np[art_idx] + ds = art_dof_start_np[art_idx] + de = ds + art_H_rows_np[art_idx] + world_dof_start_np[w] = min(world_dof_start_np[w], ds) + world_dof_end_np[w] = max(world_dof_end_np[w], de) + + # For worlds with no articulations, set start to 0 + world_dof_start_np = np.where(world_dof_start_np == np.iinfo(np.int32).max, 0, world_dof_start_np) + + world_dof_counts = world_dof_end_np - world_dof_start_np + self.max_world_dofs = int(np.max(world_dof_counts)) if len(world_dof_counts) > 0 else 0 + self.world_dof_start = wp.array(world_dof_start_np, dtype=wp.int32, device=model.device) + + def _allocate_common_buffers(self, model): + if model.joint_count: + self.M_blocks = wp.zeros( + (self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad + ) + self.mass_update_mask = wp.zeros( + (model.articulation_count,), dtype=wp.int32, device=model.device, requires_grad=model.requires_grad + ) + self.v_hat = wp.zeros_like(model.joint_qd, requires_grad=model.requires_grad) + self.v_out = wp.zeros_like(model.joint_qd, requires_grad=model.requires_grad) + self.qd_work = wp.zeros_like(model.joint_qd, requires_grad=model.requires_grad) + self.v_mf_accum = wp.zeros_like(model.joint_qd, requires_grad=model.requires_grad) + self.v_out_snap = wp.zeros_like(model.joint_qd, requires_grad=model.requires_grad) + else: + self.M_blocks = None + self.mass_update_mask = None + self.v_hat = None + self.v_out = None + self.qd_work = None + self.v_mf_accum = None + self.v_out_snap = None + + if model.body_count: + self.body_I_m = wp.empty( + (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad + ) + wp.launch( + compute_spatial_inertia, + model.body_count, + inputs=[model.body_inertia, model.body_mass], + outputs=[self.body_I_m], + device=model.device, + ) + self.body_X_com = wp.empty( + (model.body_count,), dtype=wp.transform, device=model.device, requires_grad=model.requires_grad + ) + wp.launch( + compute_com_transforms, + model.body_count, + inputs=[model.body_com], + outputs=[self.body_X_com], + device=model.device, + ) + self.body_I_c = wp.empty( + (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad + ) + else: + self.body_I_m = None + self.body_X_com = None + self.body_I_c = None + + if not model.articulation_count or not model.joint_count: + self.articulation_origin = None + self.articulation_root_com_offset = None + return + + self.articulation_origin = wp.zeros( + (model.articulation_count,), dtype=wp.vec3, device=model.device, requires_grad=model.requires_grad + ) + self.articulation_root_com_offset = wp.zeros( + (model.articulation_count,), dtype=wp.vec3, device=model.device, requires_grad=model.requires_grad + ) + + max_dofs = self.articulation_max_dofs + if max_dofs == 0: + return + + device = model.device + requires_grad = model.requires_grad + articulation_count = model.articulation_count + total_rows = articulation_count * max_dofs + + self.aug_row_counts = wp.zeros( + (articulation_count,), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.aug_limit_counts = wp.zeros( + (articulation_count,), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.aug_prev_limit_counts = wp.zeros_like(self.aug_limit_counts) + self.limit_change_mask = wp.zeros_like(self.aug_limit_counts) + self.aug_row_dof_index = wp.zeros((total_rows,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.aug_row_K = wp.zeros((total_rows,), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.aug_row_u0 = wp.zeros((total_rows,), dtype=wp.float32, device=device, requires_grad=requires_grad) + + def _allocate_buffers(self, model): + if not self.size_groups: + self.H_by_size = {} + self.L_by_size = {} + self.J_by_size = {} + self.Y_by_size = {} + self.R_by_size = {} + self.tau_by_size = {} + self.qdd_by_size = {} + self._H_bufs = None + self._J_bufs = None + return + + device = model.device + requires_grad = model.requires_grad + max_constraints = self.dense_max_constraints + + self.L_by_size = {} + self.Y_by_size = {} + self.R_by_size = {} + self.tau_by_size = {} + self.qdd_by_size = {} + + if self._double_buffer and device.is_cuda: + self._H_bufs = [{}, {}] + self._J_bufs = [{}, {}] + else: + self._H_bufs = None + self._J_bufs = None + + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + + h_dim = size + j_rows = max_constraints + + if self._H_bufs is not None: + for buf_idx in range(2): + self._H_bufs[buf_idx][size] = wp.zeros( + (n_arts, h_dim, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self._J_bufs[buf_idx][size] = wp.zeros( + (n_arts, j_rows, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + else: + pass # allocated below after the if/else + + self.L_by_size[size] = wp.zeros( + (n_arts, h_dim, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + self.Y_by_size[size] = wp.zeros( + (n_arts, j_rows, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + # Armature (regularization) [n_arts, h_dim] - needs to match H dimension for tile_diag_add + self.R_by_size[size] = wp.zeros( + (n_arts, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + # Tau and qdd grouped buffers for tiled triangular solve [n_arts, h_dim, 1] + self.tau_by_size[size] = wp.zeros((n_arts, h_dim, 1), dtype=wp.float32, device=device) + self.qdd_by_size[size] = wp.zeros((n_arts, h_dim, 1), dtype=wp.float32, device=device) + + if self._H_bufs is not None: + self.H_by_size = self._H_bufs[0] + self.J_by_size = self._J_bufs[0] + self._buf_idx = 0 + else: + self.H_by_size = {} + self.J_by_size = {} + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + h_dim = size + j_rows = max_constraints + self.H_by_size[size] = wp.zeros( + (n_arts, h_dim, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.J_by_size[size] = wp.zeros( + (n_arts, j_rows, h_dim), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + max_contacts = int(model.rigid_contact_max) + if max_contacts <= 0: + # Unified collision pipeline manages its own contact capacity and may + # leave model.rigid_contact_max unset. Use the same estimator as collide.py. + from ...sim.collide import _estimate_rigid_contact_max # noqa: PLC0415 + + max_contacts = int(_estimate_rigid_contact_max(model)) + max_contacts = max(max_contacts, 1) + self.contact_world = wp.zeros((max_contacts,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.contact_slot = wp.zeros((max_contacts,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.contact_art_a = wp.zeros((max_contacts,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.contact_art_b = wp.zeros((max_contacts,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.slot_counter = wp.zeros((self.world_count,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.contact_path = wp.zeros((max_contacts,), dtype=wp.int32, device=device, requires_grad=requires_grad) + + # Joint limit buffers (per-DOF tracking) + if self.enable_joint_limits and model.joint_dof_count > 0: + dof_count = model.joint_dof_count + self.limit_slot = wp.full((dof_count,), -1, dtype=wp.int32, device=device, requires_grad=requires_grad) + self.limit_sign = wp.zeros((dof_count,), dtype=wp.float32, device=device, requires_grad=requires_grad) + else: + self.limit_slot = None + self.limit_sign = None + + def _allocate_world_buffers(self, model): + """Allocate world-level constraint system buffers for multi-articulation support.""" + if self.world_count == 0: + return + + device = model.device + requires_grad = model.requires_grad + max_constraints = self.dense_max_constraints + + # Per-world constraint matrices and vectors + if self.pgs_mode != "matrix_free": + self.C = wp.zeros( + (self.world_count, max_constraints, max_constraints), + dtype=wp.float32, + device=device, + requires_grad=requires_grad, + ) + else: + self.C = None + + if self.pgs_mode == "matrix_free": + self._compute_world_dof_mapping(model) + self.J_world = wp.zeros( + (self.world_count, max_constraints, self.max_world_dofs), + dtype=wp.float32, + device=device, + requires_grad=requires_grad, + ) + self.Y_world = wp.zeros( + (self.world_count, max_constraints, self.max_world_dofs), + dtype=wp.float32, + device=device, + requires_grad=requires_grad, + ) + else: + self.J_world = None + self.Y_world = None + + self.rhs = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.impulses = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.diag = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + # Constraint metadata (per world x constraint) + self.row_type = wp.zeros( + (self.world_count, max_constraints), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.row_parent = wp.full( + (self.world_count, max_constraints), -1, dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.row_mu = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.row_beta = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.row_cfm = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.phi = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + self.target_velocity = wp.zeros( + (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + # Per-world constraint counts + self.constraint_count = wp.zeros( + (self.world_count,), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + + def _allocate_mf_buffers(self, model): + """Allocate buffers for matrix-free PGS path for free rigid body contacts.""" + if self.pgs_mode == "dense" or (not self._has_free_rigid_bodies and self.pgs_mode != "matrix_free"): + return + + device = model.device + requires_grad = model.requires_grad + worlds = self.world_count + mf_max_c = self.mf_max_constraints + body_count = model.body_count + + self.mf_constraint_count = wp.zeros((worlds,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_slot_counter = wp.zeros((worlds,), dtype=wp.int32, device=device, requires_grad=requires_grad) + + self.mf_body_a = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_body_b = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + + self.mf_J_a = wp.zeros((worlds, mf_max_c, 6), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.mf_J_b = wp.zeros((worlds, mf_max_c, 6), dtype=wp.float32, device=device, requires_grad=requires_grad) + + self.mf_MiJt_a = wp.zeros((worlds, mf_max_c, 6), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.mf_MiJt_b = wp.zeros((worlds, mf_max_c, 6), dtype=wp.float32, device=device, requires_grad=requires_grad) + + self.mf_rhs = wp.zeros((worlds, mf_max_c), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.mf_impulses = wp.zeros((worlds, mf_max_c), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.mf_eff_mass_inv = wp.zeros( + (worlds, mf_max_c), dtype=wp.float32, device=device, requires_grad=requires_grad + ) + + self.mf_row_type = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_row_parent = wp.full((worlds, mf_max_c), -1, dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_row_mu = wp.zeros((worlds, mf_max_c), dtype=wp.float32, device=device, requires_grad=requires_grad) + self.mf_phi = wp.zeros((worlds, mf_max_c), dtype=wp.float32, device=device, requires_grad=requires_grad) + + self.mf_body_Hinv = wp.zeros((body_count,), dtype=wp.spatial_matrix, device=device, requires_grad=requires_grad) + + # World-relative DOF offsets for two-phase GS kernel + self.mf_dof_a = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_dof_b = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + + # Packed MF metadata for two-phase GS kernel (int4 per constraint): + # .x = (dof_a << 16) | (dof_b & 0xFFFF) + # .y = __float_as_int(eff_mass_inv) + # .z = __float_as_int(rhs) + # .w = row_type | (row_parent << 16) + self.mf_meta_packed = wp.zeros((worlds, mf_max_c * 4), dtype=wp.int32, device=device) + + # Body map buffers for tiled MF PGS kernel + self.max_mf_bodies = 64 + self.mf_body_list = wp.zeros( + (worlds, self.max_mf_bodies), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.mf_body_dof_start = wp.zeros( + (worlds, self.max_mf_bodies), dtype=wp.int32, device=device, requires_grad=requires_grad + ) + self.mf_body_count = wp.zeros((worlds,), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_local_body_a = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + self.mf_local_body_b = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) + + def _allocate_debug_buffers(self, model): + """Allocate buffers for PGS convergence diagnostics.""" + if not self.pgs_debug: + return + device = model.device + worlds = self.world_count + max_c = self.dense_max_constraints + mf_max_c = self.mf_max_constraints + + self._diag_metrics = wp.zeros((worlds, 4), dtype=wp.float32, device=device) + self._diag_prev_impulses = wp.zeros((worlds, max_c), dtype=wp.float32, device=device) + if hasattr(self, "mf_impulses"): + self._diag_prev_mf_impulses = wp.zeros((worlds, mf_max_c), dtype=wp.float32, device=device) + else: + self._diag_prev_mf_impulses = None + + def _scatter_armature_to_groups(self, model): + """Copy armature from model (DOF-ordered) to size-grouped storage.""" + if not self.size_groups: + return + + armature_np = model.joint_armature.numpy() + art_dof_start_np = self.articulation_dof_start.numpy() + art_H_rows_np = self.articulation_H_rows.numpy() + + # R_by_size is sized to actual DOF count (matches H_by_size allocation) + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + R_np = np.zeros((n_arts, size), dtype=np.float32) + + group_to_art_np = self.group_to_art[size].numpy() + for group_idx in range(n_arts): + art_idx = group_to_art_np[group_idx] + dof_start = art_dof_start_np[art_idx] + dof_count = art_H_rows_np[art_idx] + R_np[group_idx, :dof_count] = armature_np[dof_start : dof_start + dof_count] + + self.R_by_size[size] = wp.array(R_np, dtype=wp.float32, device=model.device) + + def _init_size_group_streams(self, model): + """Initialize CUDA streams for parallel kernel launches across size groups. + + When multiple DOF sizes exist (heterogeneous articulations), we can launch + tiled kernels for different sizes in parallel using separate CUDA streams. + """ + self._size_streams: dict[int, wp.Stream | None] = {} + self._size_events: dict[int, wp.Event | None] = {} + + if self.use_parallel_streams and model.device.is_cuda and len(self.size_groups) > 1: + for size in self.size_groups: + self._size_streams[size] = wp.Stream(model.device) + self._size_events[size] = wp.Event(model.device) + else: + # No streams needed for CPU or single size group + for size in self.size_groups: + self._size_streams[size] = None + self._size_events[size] = None + + def _init_double_buffer_stream(self): + """Create a dedicated CUDA stream for async memset of H/J buffers.""" + if self._H_bufs is None or not self.model.device.is_cuda: + self._memset_stream = None + return + self._memset_stream = wp.Stream(self.model.device) + # Track the last memset-done event per buffer slot so the main stream + # can wait only for the specific buffer it needs. + self._memset_done_event: list[wp.Event | None] = [None, None] + + def seed_double_buffer_events(self): + """Record initial memset_done events on the main stream. + + Must be called inside CUDA graph capture, before the first ``step()`` call. + Since buffers are allocated with ``wp.zeros()``, they are already zeroed; + recording here provides trivially-satisfied wait targets for the first two + substeps. + """ + if self._memset_stream is None: + return + main_stream = wp.get_stream(self.model.device) + self._memset_done_event[0] = main_stream.record_event() + self._memset_done_event[1] = main_stream.record_event() + + @override + def step( + self, + state_in: State, + state_out: State, + control: Control, + contacts: Contacts, + dt: float, + collide_done_event=None, + ): + if self._last_step_dt is None: + self._last_step_dt = dt + elif abs(self._last_step_dt - dt) > 1.0e-8: + self._force_mass_update = True + self._last_step_dt = dt + else: + self._last_step_dt = dt + + model = self.model + + if control is None: + control = model.control(clone_variables=False) + state_aug = self._prepare_augmented_state(state_in, state_out, control) + + if collide_done_event is not None and state_in.particle_count > 0: + wp.get_stream(self.model.device).wait_event(collide_done_event) + collide_done_event = None # consumed + + self._eval_particle_forces(state_in, control, contacts) + + if not model.joint_count: + self.integrate_particles(model, state_in, state_out, dt) + self._step += 1 + return state_out + + # Double-buffer: select buffer set and wait for its memset to finish + if self._memset_stream is not None: + self.H_by_size = self._H_bufs[self._buf_idx] + self.J_by_size = self._J_bufs[self._buf_idx] + evt = self._memset_done_event[self._buf_idx] + if evt is not None: + wp.get_stream(model.device).wait_event(evt) + + # ══════════════════════════════════════════════════════════════ + # STAGE 1: FK/ID + drives + CRBA + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S1_FK_ID_CRBA", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._stage1_fk_id(state_in, state_aug, state_out) + + if model.articulation_count: + self._stage1_drives(state_in, state_aug, control, dt) + + self._stage1_crba(state_aug) + # ══════════════════════════════════════════════════════════════ + # STAGE 2: Cholesky + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S2_Cholesky", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + use_tiled = (self.cholesky_kernel == "tiled") or ( + self.cholesky_kernel == "auto" and size > self.small_dof_threshold + ) + if use_tiled: + self._stage2_cholesky_tiled(size) + else: + self._stage2_cholesky_loop(size) + # ══════════════════════════════════════════════════════════════ + # STAGE 3: Triangular solve + v_hat + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S3_Trisolve_Vhat", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._stage3_zero_qdd(state_aug) + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + use_tiled = (self.trisolve_kernel == "tiled") or ( + self.trisolve_kernel == "auto" and size > self.small_dof_threshold + ) + if use_tiled: + self._stage3_trisolve_tiled(size, state_aug) + else: + self._stage3_trisolve_loop(size, state_aug) + self._stage3_compute_v_hat(state_in, state_aug, dt) + + # Wait for pipelined collide (if running on separate stream) + if collide_done_event is not None: + wp.get_stream(model.device).wait_event(collide_done_event) + + # ══════════════════════════════════════════════════════════════ + # STAGE 4: Build contact problem + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S4_ContactBuild", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._stage4_build_rows(state_in, state_aug, contacts) + + if self.pgs_mode == "matrix_free": + # Compute Y = H^-1 * J^T only (no Delassus C) + with wp.ScopedTimer("S4_HinvJt_Diag_RHS", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + use_tiled = (self.hinv_jt_kernel == "tiled") or ( + self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold + ) + if use_tiled: + self._stage4_hinv_jt_tiled(size) + else: + self._stage4_hinv_jt_par_row(size) + + # Diagonal from J*Y (no full Delassus) + self.diag.zero_() + for size in self.size_groups: + self._stage4_diag_from_JY(size) + self._stage4_finalize_world_diag_cfm() + self._stage4_add_dense_contact_compliance(dt) + + # RHS = bias only (J*v recomputed per iteration) + self._stage4_compute_rhs_world(dt) + # NOTE: skip _stage4_accumulate_rhs_world — J*v_hat not baked into rhs + + # MF: compute mf_MiJt, mf_rhs, mf_eff_mass_inv, body maps + if self._has_free_rigid_bodies: + with wp.ScopedTimer("S4_MF_Setup", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._mf_pgs_setup(state_aug, dt) + + # MF: compute world-relative DOF offsets for two-phase GS kernel + wp.launch( + compute_mf_world_dof_offsets, + dim=self.world_count * self.mf_max_constraints, + inputs=[ + self.mf_constraint_count, + self.mf_body_a, + self.mf_body_b, + self.body_to_articulation, + self.articulation_dof_start, + self.world_dof_start, + self.mf_max_constraints, + ], + outputs=[self.mf_dof_a, self.mf_dof_b], + device=self.model.device, + ) + + else: + # Existing Delassus path (unchanged) + fused_ok = ( + self._is_one_art_per_world + and self.hinv_jt_kernel != "par_row" + and all( + (self.hinv_jt_kernel == "tiled") + or (self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold) + for size in self.size_groups + ) + ) + + if fused_ok: + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + self._stage4_hinv_jt_tiled_fused(size) + else: + self._stage4_zero_world_C() + + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + use_tiled = (self.hinv_jt_kernel == "tiled") or ( + self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold + ) + if use_tiled: + self._stage4_hinv_jt_tiled(size) + else: + self._stage4_hinv_jt_par_row(size) + + for size in self.size_groups: + use_tiled_delassus = self.delassus_kernel != "par_row_col" + if use_tiled_delassus: + self._stage4_delassus_tiled(size) + else: + self._stage4_delassus_par_row_col(size) + + self._stage4_finalize_world_diag_cfm() + + self._stage4_add_dense_contact_compliance(dt) + self._stage4_compute_rhs_world(dt) + + for size in self.size_groups: + self._stage4_accumulate_rhs_world(size) + + # ══════════════════════════════════════════════════════════════ + # STAGE 5+6: PGS solve + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S5_PGS_Prep", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._stage5_prepare_impulses_world() + + if self.pgs_mode == "matrix_free": + with wp.ScopedTimer("S5_GatherJY", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + # Gather J/Y from per-size-group arrays into world-indexed arrays + # No J_world/Y_world zeroing needed: gather writes all DOFs unconditionally + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + gather_JY_to_world, + dim=int(n_arts * self.dense_max_constraints * size), + inputs=[ + self.group_to_art[size], + self.art_to_world, + self.articulation_dof_start, + self.constraint_count, + self.world_dof_start, + self.J_by_size[size], + self.Y_by_size[size], + size, + self.dense_max_constraints, + n_arts, + ], + outputs=[self.J_world, self.Y_world], + device=self.model.device, + ) + + # Initialize v_out = v_hat before GS loop + self._stage6_prepare_world_velocity() + + # Pack MF metadata into int4 structs for coalesced 128-bit loads + pack_kernel = TiledKernelFactory.get_pack_mf_meta_kernel(self.mf_max_constraints, self.model.device) + wp.launch_tiled( + pack_kernel, + dim=[self.world_count], + inputs=[ + self.mf_constraint_count, + self.mf_dof_a, + self.mf_dof_b, + self.mf_eff_mass_inv, + self.mf_rhs, + self.mf_row_type, + self.mf_row_parent, + ], + outputs=[self.mf_meta_packed], + block_dim=32, + device=self.model.device, + ) + + # Two-phase GS kernel: split-style dense + MF in one pass + with wp.ScopedTimer("S6_PGS_Solve", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + mf_gs_kernel = TiledKernelFactory.get_pgs_solve_mf_gs_kernel( + self.dense_max_constraints, + self.mf_max_constraints, + self.max_world_dofs, + self.model.device, + ) + + if self.pgs_debug: + self._pgs_convergence_log.append([]) + for _pgs_dbg_iter in range(self.pgs_iterations): + # Snapshot impulses before this iteration + wp.copy(self._diag_prev_impulses, self.impulses) + if self._diag_prev_mf_impulses is not None: + wp.copy(self._diag_prev_mf_impulses, self.mf_impulses) + + # Run 1 iteration + wp.launch_tiled( + mf_gs_kernel, + dim=[self.world_count], + inputs=[ + self.constraint_count, + self.world_dof_start, + self.rhs, + self.diag, + self.impulses, + self.J_world, + self.Y_world, + self.row_type, + self.row_parent, + self.row_mu, + self.mf_constraint_count, + self.mf_meta_packed, + self.mf_impulses, + self.mf_J_a, + self.mf_J_b, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_row_mu, + 1, # iterations=1 + self.pgs_omega, + ], + outputs=[self.v_out], + block_dim=32, + device=self.model.device, + ) + + # Diagnostic kernel + wp.launch( + pgs_convergence_diagnostic_velocity, + dim=self.world_count, + inputs=[ + self.constraint_count, + self.world_dof_start, + self.rhs, + self.impulses, + self._diag_prev_impulses, + self.row_type, + self.row_parent, + self.row_mu, + self.J_world, + self.dense_max_constraints, + self.max_world_dofs, + self.mf_constraint_count, + self.mf_rhs, + self.mf_impulses, + self._diag_prev_mf_impulses, + self.mf_row_type, + self.mf_row_parent, + self.mf_row_mu, + self.mf_J_a, + self.mf_J_b, + self.mf_dof_a, + self.mf_dof_b, + self.mf_max_constraints, + self.v_out, + ], + outputs=[self._diag_metrics], + device=self.model.device, + ) + + # Sync and reduce across worlds + metrics_np = self._diag_metrics.numpy() + row = np.array( + [ + np.max(metrics_np[:, 0]), # max|delta_lambda| + np.sum(metrics_np[:, 1]), # complementarity gap + np.sum(metrics_np[:, 2]), # tangent residual + np.sum(metrics_np[:, 3]), # FB merit + ] + ) + self._pgs_convergence_log[-1].append(row) + + self._pgs_convergence_log[-1] = np.array(self._pgs_convergence_log[-1]) + + else: + wp.launch_tiled( + mf_gs_kernel, + dim=[self.world_count], + inputs=[ + # Dense + self.constraint_count, + self.world_dof_start, + self.rhs, + self.diag, + self.impulses, + self.J_world, + self.Y_world, + self.row_type, + self.row_parent, + self.row_mu, + # MF + self.mf_constraint_count, + self.mf_meta_packed, + self.mf_impulses, + self.mf_J_a, + self.mf_J_b, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_row_mu, + # Shared + self.pgs_iterations, + self.pgs_omega, + ], + outputs=[self.v_out], + block_dim=32, + device=self.model.device, + ) + elif self.pgs_mode == "split" and self._has_mixed_contacts: + # Split mode with mixed contacts: interleaved dense and MF, 1 iteration each + self._mf_pgs_setup(state_aug, dt) + self.v_mf_accum.zero_() + + for _pgs_iter in range(self.pgs_iterations): + # Dense PGS (1 iteration, impulse space) + self._dispatch_dense_pgs_solve(iterations=1) + + # Rebuild v_out = v_hat + Y*impulses + MF_corrections + self._stage6_prepare_world_velocity() + for size in self.size_groups: + self._stage6_apply_impulses_world(size) + wp.launch( + vector_add_inplace, + dim=self.v_out.size, + inputs=[self.v_out, self.v_mf_accum], + device=self.model.device, + ) + + # Snapshot v_out, run MF, compute delta + wp.copy(self.v_out_snap, self.v_out) + self._mf_pgs_solve(iterations=1) + + # v_mf_accum += (v_out - v_out_snap); v_out_snap = delta + wp.launch( + compute_delta_and_accumulate, + dim=self.v_out.size, + inputs=[self.v_out, self.v_out_snap, self.v_mf_accum], + device=self.model.device, + ) + + # Update dense rhs: world_rhs += J * delta_v_mf + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + rhs_accum_world_par_art, + dim=n_arts, + inputs=[ + self.constraint_count, + self.dense_max_constraints, + self.art_to_world, + self.art_size, + self.art_group_idx, + self.articulation_dof_start, + self.v_out_snap, + self.group_to_art[size], + self.J_by_size[size], + size, + ], + outputs=[self.rhs], + device=self.model.device, + ) + + # v_out is already final (includes both dense and MF corrections) + + else: + # Dense or split without mixed contacts: dense PGS, then optional MF + if self.pgs_debug: + self._pgs_convergence_log.append([]) + for _pgs_dbg_iter in range(self.pgs_iterations): + prev_np = self.impulses.numpy().copy() + self._dispatch_dense_pgs_solve(iterations=1) + cur_np = self.impulses.numpy() + max_delta = float(np.max(np.abs(cur_np - prev_np))) + self._pgs_convergence_log[-1].append(np.array([max_delta, 0.0, 0.0, 0.0])) + self._pgs_convergence_log[-1] = np.array(self._pgs_convergence_log[-1]) + else: + self._dispatch_dense_pgs_solve(iterations=self.pgs_iterations) + + self._stage6_prepare_world_velocity() + for size in self.size_groups: + self._stage6_apply_impulses_world(size) + + if self.pgs_mode == "split" and self._has_free_rigid_bodies: + self._stage6b_mf_pgs(state_aug, dt) + + # ══════════════════════════════════════════════════════════════ + # STAGE 7: Update qdd + integrate + # ══════════════════════════════════════════════════════════════ + with wp.ScopedTimer("S7_Integrate", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._stage6_update_qdd(state_in, state_aug, dt) + self._stage6_integrate(state_in, state_aug, state_out, dt) + + # Double-buffer: fork memset stream to zero current buffer for reuse. + # ScopedStream(sync_enter=True) records an event on the main stream and + # makes the memset stream wait — this is what forks it into graph capture. + if self._memset_stream is not None: + with wp.ScopedTimer("DB_Memset", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + with wp.ScopedStream(self._memset_stream): + for size in self.size_groups: + self._H_bufs[self._buf_idx][size].zero_() + self._J_bufs[self._buf_idx][size].zero_() + self._memset_done_event[self._buf_idx] = self._memset_stream.record_event() + self._buf_idx = 1 - self._buf_idx + + self._step += 1 + return state_out + + @override + def update_contacts(self, contacts: Contacts) -> None: + """Populate Newton contact-force buffers from the last FeatherPGS solve.""" + if contacts is None or contacts.rigid_contact_count is None: + return + + dt = self._last_step_dt + inv_dt = 0.0 if dt is None or dt <= 0.0 else 1.0 / dt + enable_friction_flag = 1 if self.enable_contact_friction else 0 + mf_impulses = getattr(self, "mf_impulses", None) + if mf_impulses is None: + mf_impulses = self._dummy_contact_impulses + + wp.launch( + compute_contact_linear_force_from_impulses, + dim=contacts.rigid_contact_max, + inputs=[ + contacts.rigid_contact_count, + contacts.rigid_contact_normal, + self.contact_world, + self.contact_slot, + self.contact_path, + self.impulses, + mf_impulses, + enable_friction_flag, + inv_dt, + ], + outputs=[contacts.rigid_contact_force], + device=self.model.device, + ) + + if contacts.force is not None: + wp.launch( + pack_contact_linear_force_as_spatial, + dim=contacts.rigid_contact_max, + inputs=[ + contacts.rigid_contact_count, + contacts.rigid_contact_force, + ], + outputs=[contacts.force], + device=self.model.device, + ) + + def _prepare_augmented_state( + self, + state_in: State, + state_out: State, + control: Control, + ) -> State: + requires_grad = state_in.requires_grad + state_aug = state_out if requires_grad else self + model = self.model + + if not getattr(state_aug, "_featherstone_augmented", False): + self._allocate_state_aux_vars(model, state_aug, requires_grad) + + return state_aug + + def _allocate_state_aux_vars(self, model, target, requires_grad): + # allocate auxiliary variables that vary with state + if model.body_count: + # joints + target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad) + target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad) + if requires_grad: + # used in the custom grad implementation of trisolve_loop + target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True) + else: + target.joint_solve_tmp = None + target.joint_S_s = wp.empty( + (model.joint_dof_count,), + dtype=wp.spatial_vector, + device=model.device, + requires_grad=requires_grad, + ) + + # derived rigid body data (maximal coordinates) + target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad) + target.body_I_s = wp.empty( + (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad + ) + target.body_v_s = wp.empty( + (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad + ) + target.body_a_s = wp.empty( + (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad + ) + target.body_f_s = wp.zeros( + (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad + ) + target.body_ft_s = wp.zeros( + (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad + ) + + target._featherstone_augmented = True + + def _eval_particle_forces(self, state_in: State, control: Control, contacts: Contacts): + model = self.model + + particle_f = state_in.particle_f if state_in.particle_count else None + body_f = state_in.body_f if state_in.body_count else None + + # damped springs + eval_spring_forces(model, state_in, particle_f) + + # triangle elastic and lift/drag forces + eval_triangle_forces(model, state_in, control, particle_f) + + # triangle bending + eval_bending_forces(model, state_in, particle_f) + + # tetrahedral FEM + eval_tetrahedra_forces(model, state_in, control, particle_f) + + # particle-particle interactions + eval_particle_contact_forces(model, state_in, particle_f) + + # particle shape contact + eval_particle_body_contact_forces(model, state_in, contacts, particle_f, body_f, body_f_in_world_frame=True) + + @contextmanager + def _parallel_size_region(self, enabled: bool = True): + """Context for parallel dispatch across size groups.""" + if not enabled or not self.use_parallel_streams or not self.model.device.is_cuda or len(self.size_groups) <= 1: + yield + return + + main_stream = wp.get_stream(self.model.device) + self._main_stream = main_stream + self._init_event = main_stream.record_event() + try: + yield + finally: + for size in self.size_groups: + stream = self._size_streams.get(size) + if stream is not None: + main_stream.wait_event(stream.record_event()) + self._main_stream = None + self._init_event = None + + @contextmanager + def _on_size_stream(self, size: int): + """Execute block on this size's CUDA stream.""" + stream = self._size_streams.get(size) + init_event = getattr(self, "_init_event", None) + if stream is not None and init_event is not None: + stream.wait_event(init_event) + with wp.ScopedStream(stream): + yield + else: + yield + + @contextmanager + def _size_dispatch(self, enabled: bool): + with self._parallel_size_region(enabled=enabled): + yield + + @contextmanager + def _size_ctx(self, size: int): + with self._on_size_stream(size): + yield + + def _for_sizes(self, enabled: bool): + # convenience generator; keeps step code tight + with self._size_dispatch(enabled): + for size in self.size_groups: + yield size, self._size_ctx(size) + + def _stage1_fk_id(self, state_in: State, state_aug: State, state_out: State): + model = self.model + + # evaluate body transforms + wp.launch( + eval_rigid_fk, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + model.joint_type, + model.joint_parent, + model.joint_child, + model.joint_q_start, + model.joint_qd_start, + state_in.joint_q, + model.joint_X_p, + model.joint_X_c, + self.body_X_com, + model.joint_axis, + model.joint_dof_dim, + ], + outputs=[state_in.body_q, state_aug.body_q_com], + device=model.device, + ) + + wp.launch( + update_articulation_origins, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + model.joint_child, + state_in.body_q, + model.body_com, + ], + outputs=[self.articulation_origin], + device=model.device, + ) + wp.launch( + update_articulation_root_com_offsets, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + model.joint_child, + state_in.body_q, + model.body_com, + ], + outputs=[self.articulation_root_com_offset], + device=model.device, + ) + # evaluate joint inertias, motion vectors, and forces + state_aug.body_f_s.zero_() + wp.copy(self.qd_work, state_in.joint_qd) + if self._has_root_free: + wp.launch( + convert_root_free_qd_world_to_local, + dim=model.articulation_count, + inputs=[ + self.articulation_root_is_free, + self.articulation_root_dof_start, + self.articulation_root_com_offset, + ], + outputs=[self.qd_work], + device=model.device, + ) + + wp.launch( + eval_rigid_id, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + model.joint_type, + model.joint_parent, + model.joint_child, + model.joint_articulation, + model.joint_qd_start, + self.qd_work, + model.joint_axis, + model.joint_dof_dim, + self.body_I_m, + state_in.body_q, + state_aug.body_q_com, + model.joint_X_p, + self.articulation_origin, + model.gravity, + ], + outputs=[ + state_aug.joint_S_s, + state_aug.body_I_s, + state_aug.body_v_s, + state_aug.body_f_s, + state_aug.body_a_s, + ], + device=model.device, + ) + if model.body_count: + wp.launch( + update_body_qd_from_featherstone, + dim=model.body_count, + inputs=[ + state_aug.body_v_s, + state_in.body_q, + model.body_com, + self.body_to_articulation, + self.articulation_origin, + ], + outputs=[state_out.body_qd], + device=model.device, + ) + + def _stage1_drives(self, state_in: State, state_aug: State, control: Control, dt: float): + model = self.model + + if model.articulation_count: + body_f = state_in.body_f if state_in.body_count else None + # evaluate joint torques + state_aug.body_ft_s.zero_() + wp.launch( + eval_rigid_tau, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + model.joint_type, + model.joint_parent, + model.joint_child, + model.joint_articulation, + model.joint_qd_start, + model.joint_dof_dim, + control.joint_f, + state_aug.joint_S_s, + state_aug.body_f_s, + body_f, + state_in.body_q, + model.body_com, + self.articulation_origin, + ], + outputs=[ + state_aug.body_ft_s, + state_aug.joint_tau, + ], + device=model.device, + ) + + self.build_augmented_joint_targets(state_in, control, dt) + self.apply_augmented_joint_tau(state_in, state_aug, dt) + + wp.launch( + clamp_joint_tau, + dim=model.joint_dof_count, + inputs=[state_aug.joint_tau, model.joint_effort_limit], + device=model.device, + ) + + def build_augmented_joint_targets(self, state_in: State, control: Control, dt: float): + model = self.model + if model.articulation_count == 0 or self.articulation_max_dofs == 0: + return + device = model.device + + self.aug_row_counts.zero_() + self.aug_limit_counts.zero_() + + wp.launch( + build_augmented_joint_rows, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + self.articulation_dof_start, + self.articulation_H_rows, + model.joint_type, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.joint_target_ke, + model.joint_target_kd, + state_in.joint_q, + state_in.joint_qd, + control.joint_target_pos, + control.joint_target_vel, + self.articulation_max_dofs, + dt, + ], + outputs=[ + self.aug_row_counts, + self.aug_row_dof_index, + self.aug_row_K, + self.aug_row_u0, + self.aug_limit_counts, + ], + device=device, + ) + + wp.launch( + detect_limit_count_changes, + dim=model.articulation_count, + inputs=[ + self.aug_limit_counts, + self.aug_prev_limit_counts, + ], + outputs=[ + self.limit_change_mask, + ], + device=device, + ) + + def apply_augmented_joint_tau(self, state_in: State, state_aug: State, dt: float): + model = self.model + if model.articulation_count == 0 or self.articulation_max_dofs == 0: + return + + wp.launch( + apply_augmented_joint_tau, + dim=model.articulation_count, + inputs=[ + self.articulation_max_dofs, + self.aug_row_counts, + self.aug_row_dof_index, + self.aug_row_u0, + ], + outputs=[state_aug.joint_tau], + device=model.device, + ) + + def _stage1_crba(self, state_aug: State): + model = self.model + global_flag = 1 if ((self._step % self.update_mass_matrix_interval) == 0 or self._force_mass_update) else 0 + + wp.launch( + build_mass_update_mask, + dim=model.articulation_count, + inputs=[ + global_flag, + self.limit_change_mask, + ], + outputs=[self.mass_update_mask], + device=model.device, + ) + + wp.launch( + eval_rigid_mass, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + self.articulation_M_start, + self.mass_update_mask, + state_aug.body_I_s, + ], + outputs=[self.M_blocks], + device=model.device, + ) + + wp.launch( + compute_composite_inertia, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + self.mass_update_mask, + model.joint_ancestor, + state_aug.body_I_s, + ], + outputs=[self.body_I_c], + device=model.device, + block_dim=128, + ) + + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + if self._H_bufs is None: # not double-buffered + self.H_by_size[size].zero_() + wp.launch( + crba_fill_par_dof, + dim=int(n_arts * size), + inputs=[ + model.articulation_start, + self.articulation_dof_start, + self.mass_update_mask, + model.joint_ancestor, + model.joint_qd_start, + model.joint_dof_dim, + state_aug.joint_S_s, + self.body_I_c, + self.group_to_art[size], + size, + ], + outputs=[self.H_by_size[size]], + device=model.device, + block_dim=128, + ) + + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + apply_augmented_mass_diagonal_grouped, + dim=n_arts, + inputs=[ + self.group_to_art[size], + self.articulation_dof_start, + size, + self.articulation_max_dofs, + self.mass_update_mask, + self.aug_row_counts, + self.aug_row_dof_index, + self.aug_row_K, + ], + outputs=[self.H_by_size[size]], + device=model.device, + ) + + wp.launch( + copy_int_array_masked, + dim=model.articulation_count, + inputs=[self.aug_limit_counts, self.mass_update_mask], + outputs=[self.aug_prev_limit_counts], + device=model.device, + ) + + self._force_mass_update = False + + def _stage2_cholesky_tiled(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + cholesky_kernel = TiledKernelFactory.get_cholesky_kernel(size, model.device) + wp.launch_tiled( + cholesky_kernel, + dim=[n_arts], + inputs=[ + self.H_by_size[size], + self.R_by_size[size], + self.group_to_art[size], + self.mass_update_mask, + ], + outputs=[self.L_by_size[size]], + block_dim=TILE_THREADS, + device=model.device, + ) + + def _stage2_cholesky_loop(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + cholesky_loop, + dim=n_arts, + inputs=[ + self.H_by_size[size], + self.R_by_size[size], + self.group_to_art[size], + self.mass_update_mask, + size, + ], + outputs=[self.L_by_size[size]], + device=model.device, + ) + + def _stage3_zero_qdd(self, state_aug: State): + state_aug.joint_qdd.zero_() + + def _stage3_trisolve_tiled(self, size: int, state_aug: State): + model = self.model + n_arts = self.n_arts_by_size[size] + + wp.launch( + gather_tau_to_groups, + dim=n_arts, + inputs=[ + state_aug.joint_tau, + self.group_to_art[size], + self.articulation_dof_start, + size, + ], + outputs=[self.tau_by_size[size]], + device=model.device, + ) + solve_kernel = TiledKernelFactory.get_triangular_solve_kernel(size, model.device) + wp.launch_tiled( + solve_kernel, + dim=[n_arts], + inputs=[ + self.L_by_size[size], + self.tau_by_size[size], + ], + outputs=[self.qdd_by_size[size]], + block_dim=TILE_THREADS, + device=model.device, + ) + wp.launch( + scatter_qdd_from_groups, + dim=n_arts, + inputs=[ + self.qdd_by_size[size], + self.group_to_art[size], + self.articulation_dof_start, + size, + ], + outputs=[state_aug.joint_qdd], + device=model.device, + ) + + def _stage3_trisolve_loop(self, size: int, state_aug: State): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + trisolve_loop, + dim=n_arts, + inputs=[ + self.L_by_size[size], + self.group_to_art[size], + self.articulation_dof_start, + size, + state_aug.joint_tau, + ], + outputs=[state_aug.joint_qdd], + device=model.device, + ) + + def _stage3_compute_v_hat(self, state_in: State, state_aug: State, dt: float): + model = self.model + if not model.joint_count: + return + wp.launch( + compute_velocity_predictor, + dim=model.joint_dof_count, + inputs=[ + self.qd_work, + state_aug.joint_qdd, + dt, + ], + outputs=[self.v_hat], + device=model.device, + ) + + def _stage4_build_rows(self, state_in: State, state_aug: State, contacts: Contacts): + model = self.model + max_constraints = self.dense_max_constraints + mf_active = self._has_free_rigid_bodies and self.pgs_mode != "dense" + + # Zero world-level buffers (only arrays that require it) + self.slot_counter.zero_() # atomic-add counter + + if mf_active: + self.mf_slot_counter.zero_() # atomic-add counter + self.mf_constraint_count.zero_() # finalize only runs when contacts exist + self.mf_impulses.zero_() # PGS reads before first write + # mf_J_a/b, mf_MiJt_a/b: writers cover all used slots, readers gated by body >= 0 + # mf_body_a/b, mf_row_type, mf_row_parent, mf_row_mu, mf_phi: unconditionally overwritten + # constraint_count: fully overwritten by finalize_world_constraint_counts + + has_free_rigid_flag = 1 if mf_active else 0 + # Dummy arrays when MF is not active (kernel still needs valid pointers) + is_free_rigid = ( + self.is_free_rigid + if self.is_free_rigid is not None + else wp.zeros((1,), dtype=wp.int32, device=model.device) + ) + mf_slot_counter = ( + self.mf_slot_counter if mf_active else wp.zeros((self.world_count,), dtype=wp.int32, device=model.device) + ) + + if ( + contacts is not None + and getattr(contacts, "rigid_contact_count", None) is not None + and contacts.rigid_contact_max > 0 + ): + enable_friction_flag = 1 if self.enable_contact_friction else 0 + + wp.launch( + allocate_world_contact_slots, + dim=contacts.rigid_contact_max, + inputs=[ + contacts.rigid_contact_count, + contacts.rigid_contact_shape0, + contacts.rigid_contact_shape1, + contacts.rigid_contact_point0, + contacts.rigid_contact_point1, + contacts.rigid_contact_normal, + contacts.rigid_contact_margin0, + contacts.rigid_contact_margin1, + state_in.body_q, + model.shape_transform, + model.shape_body, + self.body_to_articulation, + self.art_to_world, + is_free_rigid, + has_free_rigid_flag, + max_constraints, + self.mf_max_constraints, + enable_friction_flag, + ], + outputs=[ + self.contact_world, + self.contact_slot, + self.contact_art_a, + self.contact_art_b, + self.slot_counter, + self.contact_path, + mf_slot_counter, + ], + device=model.device, + ) + + # Allocate joint limit constraint slots (same counter as contacts) + if self.enable_joint_limits and self.limit_slot is not None: + wp.launch( + allocate_joint_limit_slots, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + self.articulation_dof_start, + self.articulation_H_rows, + model.joint_type, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.joint_limit_lower, + model.joint_limit_upper, + state_in.joint_q, + self.art_to_world, + max_constraints, + ], + outputs=[ + self.limit_slot, + self.limit_sign, + self.slot_counter, + ], + device=model.device, + ) + + if self._H_bufs is None: # not double-buffered + for size in self.size_groups: + self.J_by_size[size].zero_() + + for size in self.size_groups: + wp.launch( + populate_world_J_for_size, + dim=contacts.rigid_contact_max, + inputs=[ + contacts.rigid_contact_count, + contacts.rigid_contact_point0, + contacts.rigid_contact_point1, + contacts.rigid_contact_normal, + contacts.rigid_contact_shape0, + contacts.rigid_contact_shape1, + contacts.rigid_contact_margin0, + contacts.rigid_contact_margin1, + self.contact_world, + self.contact_slot, + self.contact_art_a, + self.contact_art_b, + self.contact_path, + size, # target_size + self.art_size, + self.art_group_idx, + self.articulation_dof_start, + self.articulation_origin, + self.body_to_joint, + model.joint_ancestor, + model.joint_qd_start, + state_aug.joint_S_s, + model.shape_body, + state_in.body_q, + model.shape_transform, + self.shape_material_mu, + enable_friction_flag, + self.pgs_beta, + self.pgs_cfm, + ], + outputs=[ + self.J_by_size[size], + self.row_type, + self.row_parent, + self.row_mu, + self.row_beta, + self.row_cfm, + self.phi, + self.target_velocity, + ], + device=model.device, + ) + + # Populate joint limit Jacobian rows (per size group) + if self.enable_joint_limits and self.limit_slot is not None: + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + populate_joint_limit_J_for_size, + dim=n_arts, + inputs=[ + model.articulation_start, + self.articulation_dof_start, + model.joint_type, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.joint_limit_lower, + model.joint_limit_upper, + state_in.joint_q, + self.art_to_world, + self.limit_slot, + self.limit_sign, + self.group_to_art[size], + self.pgs_beta, + self.pgs_cfm, + ], + outputs=[ + self.J_by_size[size], + self.row_type, + self.row_parent, + self.row_mu, + self.row_beta, + self.row_cfm, + self.phi, + self.target_velocity, + ], + device=model.device, + ) + + # Build MF contact rows + if mf_active: + wp.launch( + build_mf_contact_rows, + dim=contacts.rigid_contact_max, + inputs=[ + contacts.rigid_contact_count, + contacts.rigid_contact_point0, + contacts.rigid_contact_point1, + contacts.rigid_contact_normal, + contacts.rigid_contact_shape0, + contacts.rigid_contact_shape1, + contacts.rigid_contact_margin0, + contacts.rigid_contact_margin1, + self.contact_world, + self.contact_slot, + self.contact_path, + self.contact_art_a, + self.contact_art_b, + self.articulation_origin, + model.shape_body, + state_in.body_q, + self.shape_material_mu, + enable_friction_flag, + self.pgs_beta, + ], + outputs=[ + self.mf_body_a, + self.mf_body_b, + self.mf_J_a, + self.mf_J_b, + self.mf_row_type, + self.mf_row_parent, + self.mf_row_mu, + self.mf_phi, + ], + device=model.device, + ) + + slots_per_contact = 3 if self.enable_contact_friction else 1 + wp.launch( + finalize_mf_constraint_counts, + dim=self.world_count, + inputs=[self.mf_slot_counter, self.mf_max_constraints, slots_per_contact], + outputs=[self.mf_constraint_count], + device=model.device, + ) + + # Joint limit constraints (outside contact block — limits work with or without contacts) + if self.enable_joint_limits and self.limit_slot is not None: + has_contacts = ( + contacts is not None + and getattr(contacts, "rigid_contact_count", None) is not None + and contacts.rigid_contact_max > 0 + ) + if not has_contacts: + # Contacts block was skipped — allocate limits and J from scratch + wp.launch( + allocate_joint_limit_slots, + dim=model.articulation_count, + inputs=[ + model.articulation_start, + self.articulation_dof_start, + self.articulation_H_rows, + model.joint_type, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.joint_limit_lower, + model.joint_limit_upper, + state_in.joint_q, + self.art_to_world, + max_constraints, + ], + outputs=[ + self.limit_slot, + self.limit_sign, + self.slot_counter, + ], + device=model.device, + ) + if self._H_bufs is None: + for size in self.size_groups: + self.J_by_size[size].zero_() + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + populate_joint_limit_J_for_size, + dim=n_arts, + inputs=[ + model.articulation_start, + self.articulation_dof_start, + model.joint_type, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.joint_limit_lower, + model.joint_limit_upper, + state_in.joint_q, + self.art_to_world, + self.limit_slot, + self.limit_sign, + self.group_to_art[size], + self.pgs_beta, + self.pgs_cfm, + ], + outputs=[ + self.J_by_size[size], + self.row_type, + self.row_parent, + self.row_mu, + self.row_beta, + self.row_cfm, + self.phi, + self.target_velocity, + ], + device=model.device, + ) + + slots_per_contact_dense = 3 if self.enable_contact_friction else 1 + wp.launch( + finalize_world_constraint_counts, + dim=self.world_count, + inputs=[self.slot_counter, max_constraints, slots_per_contact_dense], + outputs=[self.constraint_count], + device=model.device, + ) + + def _stage4_zero_world_C(self): + self.C.zero_() + self.diag.zero_() + + def _stage4_hinv_jt_tiled(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_kernel(size, self.dense_max_constraints, model.device) + wp.launch_tiled( + hinv_jt_kernel, + dim=[n_arts], + inputs=[ + self.L_by_size[size], + self.J_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + ], + outputs=[self.Y_by_size[size]], + block_dim=TILE_THREADS, + device=model.device, + ) + + def _stage4_hinv_jt_tiled_fused(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_fused_kernel(size, self.dense_max_constraints, model.device) + wp.launch_tiled( + hinv_jt_kernel, + dim=[n_arts], + inputs=[ + self.L_by_size[size], + self.J_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + self.row_cfm, + ], + outputs=[self.C, self.diag, self.Y_by_size[size]], + block_dim=TILE_THREADS, + device=model.device, + ) + + def _stage4_hinv_jt_par_row(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + hinv_jt_par_row, + dim=n_arts * self.dense_max_constraints, + inputs=[ + self.L_by_size[size], + self.J_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + size, + self.dense_max_constraints, + n_arts, + ], + outputs=[self.Y_by_size[size]], + device=model.device, + ) + + def _stage4_delassus_par_row_col(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + delassus_par_row_col, + dim=n_arts * self.dense_max_constraints * self.dense_max_constraints, + inputs=[ + self.J_by_size[size], + self.Y_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + size, + self.dense_max_constraints, + n_arts, + ], + outputs=[self.C, self.diag], + device=model.device, + ) + + def _stage4_delassus_tiled(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + delassus_kernel = TiledKernelFactory.get_delassus_kernel( + size, self.dense_max_constraints, model.device, chunk_size=self.delassus_chunk_size + ) + wp.launch_tiled( + delassus_kernel, + dim=[n_arts], + inputs=[ + self.J_by_size[size], + self.Y_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + n_arts, + ], + outputs=[self.C, self.diag], + block_dim=128, + device=model.device, + ) + + def _stage4_finalize_world_diag_cfm(self): + model = self.model + wp.launch( + finalize_world_diag_cfm, + dim=self.world_count, + inputs=[self.constraint_count, self.row_cfm], + outputs=[self.diag], + device=model.device, + ) + + def _stage4_add_dense_contact_compliance(self, dt: float): + if self.dense_contact_compliance <= 0.0: + return + + contact_alpha = float(self.dense_contact_compliance / (dt * dt)) + wp.launch( + add_dense_contact_compliance_to_diag, + dim=self.world_count, + inputs=[self.constraint_count, self.row_type, contact_alpha], + outputs=[self.diag], + device=self.model.device, + ) + + def _stage4_diag_from_JY(self, size: int): + n_arts = self.n_arts_by_size[size] + wp.launch( + diag_from_JY_par_art, + dim=n_arts * self.dense_max_constraints, + inputs=[ + self.J_by_size[size], + self.Y_by_size[size], + self.group_to_art[size], + self.art_to_world, + self.constraint_count, + size, + self.dense_max_constraints, + n_arts, + ], + outputs=[self.diag], + device=self.model.device, + ) + + def _stage4_compute_rhs_world(self, dt: float): + model = self.model + wp.launch( + compute_world_contact_bias, + dim=self.world_count, + inputs=[ + self.constraint_count, + self.dense_max_constraints, + self.phi, + self.row_beta, + self.row_type, + self.target_velocity, + dt, + ], + outputs=[self.rhs], + device=model.device, + ) + + def _stage4_accumulate_rhs_world(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + rhs_accum_world_par_art, + dim=n_arts, + inputs=[ + self.constraint_count, + self.dense_max_constraints, + self.art_to_world, + self.art_size, + self.art_group_idx, + self.articulation_dof_start, + self.v_hat, + self.group_to_art[size], + self.J_by_size[size], + size, + ], + outputs=[self.rhs], + device=model.device, + ) + + def _stage5_prepare_impulses_world(self): + warmstart_flag = 1 if self.pgs_warmstart else 0 + wp.launch( + prepare_world_impulses, + dim=self.world_count, + inputs=[self.constraint_count, self.dense_max_constraints, warmstart_flag], + outputs=[self.impulses], + device=self.model.device, + ) + + def _dispatch_dense_pgs_solve(self, iterations: int): + """Dispatch the dense PGS kernel with a given iteration count.""" + saved = self.pgs_iterations + self.pgs_iterations = iterations + if self.pgs_kernel == "tiled_row": + self._stage5_pgs_solve_world_tiled_row() + elif self.pgs_kernel == "tiled_contact": + self._stage5_pgs_solve_world_tiled_contact() + elif self.pgs_kernel == "streaming": + self._stage5_pgs_solve_world_streaming() + else: + self._stage5_pgs_solve_world_loop() + self.pgs_iterations = saved + + def _stage5_pgs_solve_world_tiled_row(self): + pgs_kernel = TiledKernelFactory.get_pgs_solve_tiled_row_kernel(self.dense_max_constraints, self.model.device) + wp.launch_tiled( + pgs_kernel, + dim=[self.world_count], + inputs=[ + self.constraint_count, + self.diag, + self.C, + self.rhs, + self.impulses, + self.pgs_iterations, + self.pgs_omega, + self.row_type, + self.row_parent, + self.row_mu, + ], + block_dim=32, + device=self.model.device, + ) + + def _stage5_pgs_solve_world_loop(self): + wp.launch( + pgs_solve_loop, + dim=self.world_count, + inputs=[ + self.constraint_count, + self.dense_max_constraints, + self.diag, + self.C, + self.rhs, + self.impulses, + self.pgs_iterations, + self.pgs_omega, + self.row_type, + self.row_parent, + self.row_mu, + ], + device=self.model.device, + ) + + def _stage5_pgs_solve_world_tiled_contact(self): + pgs_kernel = TiledKernelFactory.get_pgs_solve_tiled_contact_kernel( + self.dense_max_constraints, self.model.device + ) + wp.launch_tiled( + pgs_kernel, + dim=[self.world_count], + inputs=[ + self.constraint_count, + self.C, + self.rhs, + self.impulses, + self.pgs_iterations, + self.pgs_omega, + self.row_mu, + ], + block_dim=32, + device=self.model.device, + ) + + def _stage5_pgs_solve_world_streaming(self): + pgs_kernel = TiledKernelFactory.get_pgs_solve_streaming_kernel( + self.dense_max_constraints, self.model.device, pgs_chunk_size=self.pgs_chunk_size + ) + wp.launch_tiled( + pgs_kernel, + dim=[self.world_count], + inputs=[ + self.constraint_count, + self.C, + self.rhs, + self.impulses, + self.pgs_iterations, + self.pgs_omega, + self.row_mu, + ], + block_dim=32, + device=self.model.device, + ) + + def _stage6_prepare_world_velocity(self): + wp.copy(self.v_out, self.v_hat) + + def _stage6_apply_impulses_world(self, size: int): + model = self.model + n_arts = self.n_arts_by_size[size] + wp.launch( + apply_impulses_world_par_dof, + dim=int(n_arts * size), + inputs=[ + self.group_to_art[size], + self.art_to_world, + self.articulation_dof_start, + size, + n_arts, + self.constraint_count, + self.dense_max_constraints, + self.Y_by_size[size], + self.impulses, + self.v_hat, + ], + outputs=[self.v_out], + device=model.device, + ) + + def _stage6b_mf_pgs(self, state_aug: State, dt: float): + """Run matrix-free PGS for free rigid body contacts.""" + self._mf_pgs_setup(state_aug, dt) + self._mf_pgs_solve(self.pgs_iterations) + + def _mf_pgs_setup(self, state_aug: State, dt: float): + """MF PGS setup: compute Hinv, compute effective mass and RHS.""" + model = self.model + + # Compute H^-1 = inverse(body_I_s) for free rigid bodies + wp.launch( + compute_mf_body_Hinv, + dim=model.body_count, + inputs=[ + state_aug.body_I_s, + self.is_free_rigid, + self.body_to_articulation, + ], + outputs=[self.mf_body_Hinv], + device=model.device, + ) + + # Compute effective mass and RHS + self.mf_rhs.zero_() + self.mf_eff_mass_inv.zero_() + wp.launch( + compute_mf_effective_mass_and_rhs, + dim=self.world_count * self.mf_max_constraints, + inputs=[ + self.mf_constraint_count, + self.mf_body_a, + self.mf_body_b, + self.mf_J_a, + self.mf_J_b, + self.mf_body_Hinv, + self.mf_phi, + self.mf_row_type, + self.pgs_cfm, + self.pgs_beta, + dt, + self.mf_max_constraints, + ], + outputs=[ + self.mf_eff_mass_inv, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_rhs, + ], + device=model.device, + ) + + def _mf_pgs_solve(self, iterations: int): + """MF PGS solve with given iteration count.""" + model = self.model + + # Build compact body map for standalone MF kernel + wp.launch( + build_mf_body_map, + dim=self.world_count, + inputs=[ + self.mf_constraint_count, + self.mf_body_a, + self.mf_body_b, + self.body_to_articulation, + self.articulation_dof_start, + self.max_mf_bodies, + ], + outputs=[ + self.mf_body_list, + self.mf_body_dof_start, + self.mf_body_count, + self.mf_local_body_a, + self.mf_local_body_b, + ], + device=model.device, + ) + + if model.device.is_cuda: + mf_pgs_kernel = TiledKernelFactory.get_pgs_solve_mf_kernel( + self.mf_max_constraints, self.max_mf_bodies, model.device + ) + wp.launch_tiled( + mf_pgs_kernel, + dim=[self.world_count], + inputs=[ + self.mf_constraint_count, + self.mf_body_count, + self.mf_body_dof_start, + self.mf_local_body_a, + self.mf_local_body_b, + self.mf_J_a, + self.mf_J_b, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_eff_mass_inv, + self.mf_rhs, + self.mf_row_type, + self.mf_row_parent, + self.mf_row_mu, + self.mf_impulses, + self.v_out, + iterations, + self.pgs_omega, + ], + block_dim=32, + device=model.device, + ) + else: + # CPU fallback: use the loop kernel + wp.launch( + pgs_solve_mf_loop, + dim=self.world_count, + inputs=[ + self.mf_constraint_count, + self.mf_body_a, + self.mf_body_b, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_J_a, + self.mf_J_b, + self.mf_eff_mass_inv, + self.mf_rhs, + self.mf_row_type, + self.mf_row_parent, + self.mf_row_mu, + self.body_to_articulation, + self.articulation_dof_start, + iterations, + self.pgs_omega, + ], + outputs=[ + self.mf_impulses, + self.v_out, + ], + device=model.device, + ) + + def _stage6_update_qdd(self, state_in: State, state_aug: State, dt: float): + model = self.model + if self._has_root_free: + wp.launch( + convert_root_free_qd_local_to_world, + dim=model.articulation_count, + inputs=[ + self.articulation_root_is_free, + self.articulation_root_dof_start, + self.articulation_root_com_offset, + ], + outputs=[self.v_out], + device=model.device, + ) + wp.launch( + update_qdd_from_velocity, + dim=model.joint_dof_count, + inputs=[state_in.joint_qd, self.v_out, 1.0 / dt], + outputs=[state_aug.joint_qdd], + device=model.device, + ) + + def _stage6_integrate(self, state_in: State, state_aug: State, state_out: State, dt: float): + model = self.model + + if model.joint_count: + wp.launch( + kernel=integrate_generalized_joints, + dim=model.joint_count, + inputs=[ + model.joint_type, + model.joint_child, + model.joint_q_start, + model.joint_qd_start, + model.joint_dof_dim, + model.body_com, + state_in.joint_q, + state_in.joint_qd, + state_aug.joint_qdd, + dt, + ], + outputs=[state_out.joint_q, state_out.joint_qd], + device=model.device, + ) + + # Match Featherstone FK writeback so FREE/DISTANCE body_qd stores COM velocity. + eval_fk_with_velocity_conversion(model, state_out.joint_q, state_out.joint_qd, state_out) + + self.integrate_particles(model, state_in, state_out, dt) + + +class TiledKernelFactory: + """Factory for generating size-specialized tiled kernels for heterogeneous multi-articulation. + + This factory generates and caches tiled kernels specialized for specific DOF counts, + enabling optimal tiled operations (Cholesky, triangular solves) for articulations + with different numbers of degrees of freedom. + + The pattern follows ik_lbfgs_optimizer.py: kernels are generated on-demand with + wp.constant() captured via closure, then cached by (dimensions, device.arch). + """ + + # Class-level caches: key -> compiled kernel + _hinv_jt_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} + _hinv_jt_fused_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} + _cholesky_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _pgs_solve_tiled_row_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _pgs_solve_tiled_contact_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _pgs_solve_streaming_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _pgs_solve_mf_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} + _pgs_solve_mf_gs_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} + _pack_mf_meta_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _triangular_solve_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} + _delassus_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} + + @classmethod + def get_hinv_jt_kernel(cls, n_dofs: int, max_constraints: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled H^-1*J^T kernel for the given dimensions.""" + key = (n_dofs, max_constraints, device.arch) + if key not in cls._hinv_jt_cache: + cls._hinv_jt_cache[key] = cls._build_hinv_jt_kernel(n_dofs, max_constraints) + return cls._hinv_jt_cache[key] + + @classmethod + def get_hinv_jt_fused_kernel(cls, n_dofs: int, max_constraints: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled fused H^-1*J^T + Delassus kernel for the given dimensions.""" + key = (n_dofs, max_constraints, device.arch) + if key not in cls._hinv_jt_fused_cache: + cls._hinv_jt_fused_cache[key] = cls._build_hinv_jt_fused_kernel(n_dofs, max_constraints) + return cls._hinv_jt_fused_cache[key] + + @classmethod + def _build_hinv_jt_kernel(cls, n_dofs: int, max_constraints: int) -> "wp.Kernel": + """Build specialized H^-1*J^T kernel for given dimensions. + + Solves Y = H^-1 * J^T using tiled Cholesky solve: + L * L^T * Y = J^T + => L * Z = J^T (forward solve) + => L^T * Y = Z (backward solve) + """ + # Create compile-time constants via closure + # Convert to Python int to ensure wp.constant() accepts them + TILE_DOF_LOCAL = wp.constant(int(n_dofs)) + TILE_CONSTRAINTS_LOCAL = wp.constant(int(max_constraints)) + + def hinv_jt_tiled_template( + L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + J_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + # output + Y_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + ): + idx = wp.tid() + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + + if n_constraints == 0: + return + + # Load L (Cholesky factor) and J (Jacobian rows) + L_tile = wp.tile_load(L_group[idx], shape=(TILE_DOF_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + J_tile = wp.tile_load(J_group[idx], shape=(TILE_CONSTRAINTS_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + + # Solve L * Z = J^T (forward substitution) + # J_tile is (max_c x n_dofs), J^T is (n_dofs x max_c) + Jt_tile = wp.tile_transpose(J_tile) + Z_tile = wp.tile_lower_solve(L_tile, Jt_tile) + + # Solve L^T * Y = Z (backward substitution) + Lt_tile = wp.tile_transpose(L_tile) + X_tile = wp.tile_upper_solve(Lt_tile, Z_tile) + + # Store Y = H^-1 * J^T (transpose back to row layout) + Y_out_tile = wp.tile_transpose(X_tile) + wp.tile_store(Y_group[idx], Y_out_tile) + + hinv_jt_tiled_template.__name__ = f"hinv_jt_tiled_{n_dofs}_{max_constraints}" + hinv_jt_tiled_template.__qualname__ = f"hinv_jt_tiled_{n_dofs}_{max_constraints}" + return wp.kernel(enable_backward=False, module="unique")(hinv_jt_tiled_template) + + @classmethod + def _build_hinv_jt_fused_kernel(cls, n_dofs: int, max_constraints: int) -> "wp.Kernel": + """Build specialized fused H^-1*J^T + Delassus kernel for given dimensions.""" + TILE_DOF_LOCAL = wp.constant(int(n_dofs)) + TILE_CONSTRAINTS_LOCAL = wp.constant(int(max_constraints)) + + def hinv_jt_tiled_fused_template( + L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + J_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + row_cfm: wp.array2d(dtype=float), + # outputs + world_C: wp.array3d(dtype=float), # [world_count, max_c, max_c] + world_diag: wp.array2d(dtype=float), # [world_count, max_c] + Y_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + ): + idx, thread = wp.tid() + art = group_to_art[idx] + world = art_to_world[art] + n_constraints = world_constraint_count[world] + + if n_constraints == 0: + return + + # Load L (Cholesky factor) and J (Jacobian rows) + L_tile = wp.tile_load(L_group[idx], shape=(TILE_DOF_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + J_tile = wp.tile_load(J_group[idx], shape=(TILE_CONSTRAINTS_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + + # Solve L * Z = J^T (forward substitution) + Jt_tile = wp.tile_transpose(J_tile) + Z_tile = wp.tile_lower_solve(L_tile, Jt_tile) + + # Solve L^T * Y = Z (backward substitution) + Lt_tile = wp.tile_transpose(L_tile) + X_tile = wp.tile_upper_solve(Lt_tile, Z_tile) + + # Store Y = H^-1 * J^T (transpose back to row layout) + Y_out_tile = wp.tile_transpose(X_tile) + wp.tile_store(Y_group[idx], Y_out_tile) + + # Form C = J * H^-1 * J^T + C_tile = wp.tile_zeros(shape=(TILE_CONSTRAINTS_LOCAL, TILE_CONSTRAINTS_LOCAL), dtype=wp.float32) + wp.tile_matmul(J_tile, X_tile, C_tile) + wp.tile_store(world_C[world], C_tile) + + if thread == 0: + for i in range(n_constraints): + world_diag[world, i] = C_tile[i, i] + row_cfm[world, i] + + hinv_jt_tiled_fused_template.__name__ = f"hinv_jt_tiled_fused_{n_dofs}_{max_constraints}" + hinv_jt_tiled_fused_template.__qualname__ = f"hinv_jt_tiled_fused_{n_dofs}_{max_constraints}" + return wp.kernel(enable_backward=False, module="unique")(hinv_jt_tiled_fused_template) + + @classmethod + def get_delassus_kernel( + cls, n_dofs: int, max_constraints: int, device: "wp.Device", chunk_size: int | None = None + ) -> "wp.Kernel": + """Get or create a streaming Delassus kernel for the given dimensions.""" + key = (n_dofs, max_constraints, device.arch, chunk_size) + if key not in cls._delassus_cache: + cls._delassus_cache[key] = cls._build_delassus_kernel(n_dofs, max_constraints, chunk_size) + return cls._delassus_cache[key] + + @classmethod + def _build_delassus_kernel(cls, n_dofs: int, max_constraints: int, chunk_size: int | None = None) -> "wp.Kernel": + """Streaming Delassus: C += J * Y^T with shared memory.""" + TILE_D = n_dofs + TILE_M = max_constraints + if chunk_size is not None: + CHUNK = chunk_size + else: + CHUNK = 64 if (2 * TILE_M * TILE_D * 4 > 45000) else TILE_M + + snippet = f""" +#if defined(__CUDA_ARCH__) + const int TILE_D = {TILE_D}; + const int TILE_M = {TILE_M}; + const int CHUNK = {CHUNK}; + + int lane = threadIdx.x; + int art = group_to_art.data[idx]; + int world = art_to_world.data[art]; + int m = world_constraint_count.data[world]; + if (m == 0) return; + + __shared__ float s_J[CHUNK * TILE_D]; + __shared__ float s_Y[CHUNK * TILE_D]; + + int num_chunks = (m + CHUNK - 1) / CHUNK; + + for (int ci = 0; ci < num_chunks; ci++) {{ + int i0 = ci * CHUNK, i1 = min(i0 + CHUNK, m); + + for (int t = lane; t < (i1 - i0) * TILE_D; t += blockDim.x) + s_J[t] = J_group.data[idx * TILE_M * TILE_D + i0 * TILE_D + t]; + __syncthreads(); + + for (int cj = 0; cj < num_chunks; cj++) {{ + int j0 = cj * CHUNK, j1 = min(j0 + CHUNK, m); + + for (int t = lane; t < (j1 - j0) * TILE_D; t += blockDim.x) + s_Y[t] = Y_group.data[idx * TILE_M * TILE_D + j0 * TILE_D + t]; + __syncthreads(); + + // Each thread computes multiple C elements + for (int e = lane; e < (i1 - i0) * (j1 - j0); e += blockDim.x) {{ + int il = e / (j1 - j0), jl = e % (j1 - j0); + float sum = 0.0f; + for (int k = 0; k < TILE_D; k++) + sum += s_J[il * TILE_D + k] * s_Y[jl * TILE_D + k]; + if (sum != 0.0f) {{ + int ig = i0 + il, jg = j0 + jl; + atomicAdd(&world_C.data[world * TILE_M * TILE_M + ig * TILE_M + jg], sum); + if (ig == jg) atomicAdd(&world_diag.data[world * TILE_M + ig], sum); + }} + }} + __syncthreads(); + }} + }} +#endif +""" + + @wp.func_native(snippet) + def delassus_native( + idx: int, + J_group: wp.array3d(dtype=float), + Y_group: wp.array3d(dtype=float), + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + world_C: wp.array3d(dtype=float), + world_diag: wp.array2d(dtype=float), + ): ... + + def delassus_template( + J_group: wp.array3d(dtype=float), + Y_group: wp.array3d(dtype=float), + group_to_art: wp.array(dtype=int), + art_to_world: wp.array(dtype=int), + world_constraint_count: wp.array(dtype=int), + n_arts: int, + world_C: wp.array3d(dtype=float), + world_diag: wp.array2d(dtype=float), + ): + idx, _lane = wp.tid() + if idx < n_arts: + delassus_native( + idx, J_group, Y_group, group_to_art, art_to_world, world_constraint_count, world_C, world_diag + ) + + delassus_template.__name__ = f"delassus_streaming_{n_dofs}_{max_constraints}_chunk{CHUNK}" + delassus_template.__qualname__ = f"delassus_streaming_{n_dofs}_{max_constraints}_chunk{CHUNK}" + return wp.kernel(enable_backward=False, module="unique")(delassus_template) + + @classmethod + def get_cholesky_kernel(cls, n_dofs: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled Cholesky kernel for the given DOF count.""" + key = (n_dofs, device.arch) + if key not in cls._cholesky_cache: + cls._cholesky_cache[key] = cls._build_cholesky_kernel(n_dofs) + return cls._cholesky_cache[key] + + @classmethod + def _build_cholesky_kernel(cls, n_dofs: int) -> "wp.Kernel": + """Build specialized Cholesky kernel for given DOF count. + + Computes L such that H + diag(armature) = L * L^T. + """ + # Convert to Python int to ensure wp.constant() accepts them + TILE_DOF_LOCAL = wp.constant(int(n_dofs)) + + def cholesky_tiled_template( + H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + R_group: wp.array2d(dtype=float), # [n_arts, n_dofs] armature + group_to_art: wp.array(dtype=int), + mass_update_mask: wp.array(dtype=int), + # output + L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + ): + idx = wp.tid() + art = group_to_art[idx] + + if mass_update_mask[art] == 0: + return + + # Load H and armature + H_tile = wp.tile_load(H_group[idx], shape=(TILE_DOF_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + armature = wp.tile_load(R_group[idx], shape=(TILE_DOF_LOCAL,), bounds_check=False) + + # Add armature to diagonal + H_tile = wp.tile_diag_add(H_tile, armature) + + # Compute Cholesky factorization + L_tile = wp.tile_cholesky(H_tile) + + # Store result + wp.tile_store(L_group[idx], L_tile) + + cholesky_tiled_template.__name__ = f"cholesky_tiled_{n_dofs}" + cholesky_tiled_template.__qualname__ = f"cholesky_tiled_{n_dofs}" + return wp.kernel(enable_backward=False, module="unique")(cholesky_tiled_template) + + @classmethod + def get_triangular_solve_kernel(cls, n_dofs: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled triangular solve kernel for the given DOF count.""" + key = (n_dofs, device.arch) + if key not in cls._triangular_solve_cache: + cls._triangular_solve_cache[key] = cls._build_triangular_solve_kernel(n_dofs) + return cls._triangular_solve_cache[key] + + @classmethod + def _build_triangular_solve_kernel(cls, n_dofs: int) -> "wp.Kernel": + """Build specialized triangular solve kernel for given DOF count. + + Solves L * L^T * x = b for x using tiled forward and backward substitution. + """ + TILE_DOF_LOCAL = wp.constant(int(n_dofs)) + + def trisolve_tiled_template( + L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + tau_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] + qdd_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] + ): + idx = wp.tid() + L_tile = wp.tile_load(L_group[idx], shape=(TILE_DOF_LOCAL, TILE_DOF_LOCAL), bounds_check=False) + tau_tile = wp.tile_load(tau_group[idx], shape=(TILE_DOF_LOCAL, 1), bounds_check=False) + + # Forward substitution: L * z = tau + z_tile = wp.tile_lower_solve(L_tile, tau_tile) + + # Backward substitution: L^T * qdd = z + Lt_tile = wp.tile_transpose(L_tile) + qdd_tile = wp.tile_upper_solve(Lt_tile, z_tile) + + wp.tile_store(qdd_group[idx], qdd_tile) + + trisolve_tiled_template.__name__ = f"trisolve_tiled_{n_dofs}" + trisolve_tiled_template.__qualname__ = f"trisolve_tiled_{n_dofs}" + return wp.kernel(enable_backward=False, module="unique")(trisolve_tiled_template) + + @classmethod + def get_pgs_solve_tiled_row_kernel(cls, max_constraints: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled row-wise PGS world solve kernel for the given constraint count.""" + key = (max_constraints, device.arch) + if key not in cls._pgs_solve_tiled_row_cache: + cls._pgs_solve_tiled_row_cache[key] = cls._build_pgs_solve_tiled_row_kernel(max_constraints) + return cls._pgs_solve_tiled_row_cache[key] + + @classmethod + def _build_pgs_solve_tiled_row_kernel(cls, max_constraints: int) -> "wp.Kernel": + """PGS world solve kernel that stages only the LOWER triangle of Delassus. + + Shared memory footprint drops from M*M to M*(M+1)/2 floats. + Uses symmetry in dot: C(i,j) = L(i,j) if j<=i else L(j,i). + """ + TILE_M = max_constraints + TILE_M_SQ = TILE_M * TILE_M + TILE_TRI = TILE_M * (TILE_M + 1) // 2 + + ELEMS_PER_THREAD_1D = (TILE_M + 31) // 32 + + def gen_load_1d(dst, src): + return "\n".join( + [ + f" {dst}[lane + {k * 32}] = {src}.data[off1 + lane + {k * 32}];" + for k in range(ELEMS_PER_THREAD_1D) + if (k * 32) < TILE_M + ] + ) + + # Build a deterministic packed-lower-tri index order: row-major over (i, j<=i) + # idx = i*(i+1)/2 + j + tri_pairs = [] + for i in range(TILE_M): + base = i * (i + 1) // 2 + for j in range(i + 1): + tri_pairs.append((base + j, i, j)) + assert len(tri_pairs) == TILE_TRI + + load_code = "\n".join( + [ + gen_load_1d("s_lam", "world_impulses"), + gen_load_1d("s_rhs", "world_rhs"), + gen_load_1d("s_diag", "world_diag"), + gen_load_1d("s_rtype", "world_row_type"), + gen_load_1d("s_parent", "world_row_parent"), + gen_load_1d("s_mu", "world_row_mu"), + ] + ) + + # Precompute lane's column indices (j_k) and their triangular bases (j_k*(j_k+1)/2) + # so inside the dot we avoid multiply. + precompute_j = [] + for k in range(ELEMS_PER_THREAD_1D): + j = k * 32 + if j < TILE_M: + precompute_j.append( + f" const int j{k} = lane + {j};\n const int jb{k} = (j{k} * (j{k} + 1)) >> 1;" + ) + precompute_j_code = "\n".join(precompute_j) + + # Dot code: guarded on j_k < m + dot_terms = [] + for k in range(ELEMS_PER_THREAD_1D): + joff = k * 32 + if joff < TILE_M: + dot_terms.append( + f""" if (j{k} < m) {{ + // Use symmetry to fetch C(i, j{k}) from packed-lower shared. + // base_i = i*(i+1)/2 + float cij = (j{k} <= i) ? s_Ctri[base_i + j{k}] : s_Ctri[jb{k} + i]; + my_sum += cij * s_lam[j{k}]; + }}""" + ) + dot_code = "\n".join(["float my_sum = 0.0f;", "int base_i = (i * (i + 1)) >> 1;", *dot_terms]) + + store_code = "\n".join( + [ + f" world_impulses.data[off1 + lane + {k * 32}] = s_lam[lane + {k * 32}];" + for k in range(ELEMS_PER_THREAD_1D) + if (k * 32) < TILE_M + ] + ) + + snippet = f""" + #if defined(__CUDA_ARCH__) + const int TILE_M = {TILE_M}; + const int TILE_M_SQ = {TILE_M_SQ}; + const int TILE_TRI = {TILE_TRI}; + const unsigned MASK = 0xFFFFFFFF; + + int lane = threadIdx.x; + + int m = world_constraint_count.data[world]; + if (m == 0) return; + + // Packed LOWER triangle of C in row-major (i*(i+1)/2 + j), j<=i + __shared__ float s_Ctri[TILE_TRI]; + + __shared__ float s_lam[TILE_M]; + __shared__ float s_rhs[TILE_M]; + __shared__ float s_diag[TILE_M]; + __shared__ int s_rtype[TILE_M]; + __shared__ int s_parent[TILE_M]; + __shared__ float s_mu[TILE_M]; + + int off1 = world * TILE_M; + int off2 = world * TILE_M_SQ; + + {load_code} + + // Load only lower triangle from global full matrix into packed shared. + // Work distribution: each lane walks rows; for each row i, lane loads j = lane, lane+32, lane+64... + for (int i = 0; i < TILE_M; ++i) {{ + int base = (i * (i + 1)) >> 1; // packed base for row i + for (int j = lane; j <= i; j += 32) {{ + s_Ctri[base + j] = world_C.data[off2 + i * TILE_M + j]; + }} + }} + __syncwarp(); + + {precompute_j_code} + + for (int iter = 0; iter < iterations; iter++) {{ + for (int i = 0; i < m; i++) {{ + // NOTE: single-warp kernel; __syncwarp here is typically unnecessary unless divergence occurs + // before the dot. If you want max perf, try removing it after verifying correctness. + // __syncwarp(); + + {dot_code} + + // Warp reduce my_sum + my_sum += __shfl_down_sync(MASK, my_sum, 16); + my_sum += __shfl_down_sync(MASK, my_sum, 8); + my_sum += __shfl_down_sync(MASK, my_sum, 4); + my_sum += __shfl_down_sync(MASK, my_sum, 2); + my_sum += __shfl_down_sync(MASK, my_sum, 1); + float dot_sum = __shfl_sync(MASK, my_sum, 0); + + float denom = s_diag[i]; + if (denom <= 0.0f) continue; + + float w_val = s_rhs[i] + dot_sum; + float delta = -w_val / denom; + float new_impulse = s_lam[i] + omega * delta; + int row_type = s_rtype[i]; + + if (row_type == 0 || row_type == 3) {{ + if (new_impulse < 0.0f) new_impulse = 0.0f; + s_lam[i] = new_impulse; + }} else if (row_type == 2) {{ + int parent_idx = s_parent[i]; + float lambda_n = s_lam[parent_idx]; + float mu = s_mu[i]; + float radius = fmaxf(mu * lambda_n, 0.0f); + + if (radius <= 0.0f) {{ + s_lam[i] = 0.0f; + }} else {{ + s_lam[i] = new_impulse; + int sib = (i == parent_idx + 1) ? (parent_idx + 2) : (parent_idx + 1); + float a = s_lam[i]; + float b = s_lam[sib]; + float mag = sqrtf(a * a + b * b); + if (mag > radius) {{ + float scale = radius / mag; + s_lam[i] = a * scale; + s_lam[sib] = b * scale; + }} + }} + }} else {{ + s_lam[i] = new_impulse; + }} + }} + }} + + {store_code} + #endif + """ + + @wp.func_native(snippet) + def pgs_solve_native( + world: int, + world_constraint_count: wp.array(dtype=int), + world_diag: wp.array2d(dtype=float), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + ): ... + + def pgs_solve_tiled_template( + world_constraint_count: wp.array(dtype=int), + world_diag: wp.array2d(dtype=float), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + ): + world, _lane = wp.tid() + pgs_solve_native( + world, + world_constraint_count, + world_diag, + world_C, + world_rhs, + world_impulses, + iterations, + omega, + world_row_type, + world_row_parent, + world_row_mu, + ) + + pgs_solve_tiled_template.__name__ = f"pgs_solve_tiled_row_{max_constraints}" + pgs_solve_tiled_template.__qualname__ = f"pgs_solve_tiled_row_{max_constraints}" + return wp.kernel(enable_backward=False, module="unique")(pgs_solve_tiled_template) + + @classmethod + def get_pgs_solve_tiled_contact_kernel(cls, max_constraints: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a tiled contact-wise PGS world solve kernel using 3x3 block formulation.""" + key = (max_constraints, device.arch) + if key not in cls._pgs_solve_tiled_contact_cache: + cls._pgs_solve_tiled_contact_cache[key] = cls._build_pgs_solve_tiled_contact_kernel(max_constraints) + return cls._pgs_solve_tiled_contact_cache[key] + + @classmethod + def _build_pgs_solve_tiled_contact_kernel(cls, max_constraints: int) -> "wp.Kernel": + """PGS world solve kernel using 3x3 block formulation. + + Stores only the LOWER triangle of block Delassus matrix. + Each contact is a 3-vector (normal, tangent1, tangent2). + Reduces serial depth from M to M/3. + + TILE_M can be any value (power of 2 recommended for other kernels). + Runtime m must be divisible by 3. + """ + TILE_M = max_constraints + # Max contacts we can handle (rounded down) + NUM_CONTACTS_MAX = TILE_M // 3 + # Actual max constraints we'll process (may be < TILE_M) + TILE_M_USABLE = NUM_CONTACTS_MAX * 3 + + # Lower triangle of block matrix (sized for max) + NUM_BLOCKS_TRI = NUM_CONTACTS_MAX * (NUM_CONTACTS_MAX + 1) // 2 + BLOCK_TRI_FLOATS = NUM_BLOCKS_TRI * 9 + + snippet = f""" + #if defined(__CUDA_ARCH__) + const int TILE_M = {TILE_M}; + const int TILE_M_USABLE = {TILE_M_USABLE}; + const int NUM_CONTACTS_MAX = {NUM_CONTACTS_MAX}; + const int BLOCK_TRI_FLOATS = {BLOCK_TRI_FLOATS}; + const unsigned MASK = 0xFFFFFFFF; + + int lane = threadIdx.x; + + int m = world_constraint_count.data[world]; + if (m == 0) return; + + // Clamp m to usable range and ensure divisible by 3 + if (m > TILE_M_USABLE) m = TILE_M_USABLE; + int num_contacts = m / 3; + + // Shared memory (sized for max) + __shared__ float s_Dtri[BLOCK_TRI_FLOATS]; + __shared__ float s_Dinv[NUM_CONTACTS_MAX * 9]; + __shared__ float s_lam[TILE_M_USABLE]; + __shared__ float s_rhs[TILE_M_USABLE]; + __shared__ float s_mu[NUM_CONTACTS_MAX]; + + int off1 = world * TILE_M; + int off2 = world * TILE_M * TILE_M; + + // ============ LOAD PHASE ============ + + // Load lambda and rhs + for (int i = lane; i < TILE_M_USABLE; i += 32) {{ + if (i < m) {{ + s_lam[i] = world_impulses.data[off1 + i]; + s_rhs[i] = world_rhs.data[off1 + i]; + }} else {{ + s_lam[i] = 0.0f; + s_rhs[i] = 0.0f; + }} + }} + + // Load mu (one per contact, stored on tangent1 row) + for (int c = lane; c < NUM_CONTACTS_MAX; c += 32) {{ + if (c < num_contacts) {{ + s_mu[c] = world_row_mu.data[off1 + c * 3 + 1]; + }} + }} + + // Load lower triangle of block Delassus + for (int c = 0; c < num_contacts; c++) {{ + int base_block = (c * (c + 1)) >> 1; + int floats_in_row = (c + 1) * 9; + + for (int f = lane; f < floats_in_row; f += 32) {{ + int j = f / 9; + int k = f % 9; + int lr = k / 3; + int lc = k % 3; + int gr = c * 3 + lr; + int gc = j * 3 + lc; + s_Dtri[(base_block + j) * 9 + k] = world_C.data[off2 + gr * TILE_M + gc]; + }} + }} + __syncwarp(); + + // Compute diagonal block inverses + for (int c = lane; c < num_contacts; c += 32) {{ + int diag_block_idx = ((c * (c + 1)) >> 1) + c; + const float* D = &s_Dtri[diag_block_idx * 9]; + float* Dinv = &s_Dinv[c * 9]; + + float det = D[0] * (D[4] * D[8] - D[5] * D[7]) + - D[1] * (D[3] * D[8] - D[5] * D[6]) + + D[2] * (D[3] * D[7] - D[4] * D[6]); + + float inv_det = 1.0f / det; + + Dinv[0] = (D[4] * D[8] - D[5] * D[7]) * inv_det; + Dinv[1] = (D[2] * D[7] - D[1] * D[8]) * inv_det; + Dinv[2] = (D[1] * D[5] - D[2] * D[4]) * inv_det; + Dinv[3] = (D[5] * D[6] - D[3] * D[8]) * inv_det; + Dinv[4] = (D[0] * D[8] - D[2] * D[6]) * inv_det; + Dinv[5] = (D[2] * D[3] - D[0] * D[5]) * inv_det; + Dinv[6] = (D[3] * D[7] - D[4] * D[6]) * inv_det; + Dinv[7] = (D[1] * D[6] - D[0] * D[7]) * inv_det; + Dinv[8] = (D[0] * D[4] - D[1] * D[3]) * inv_det; + }} + __syncwarp(); + + // ============ ITERATION PHASE ============ + + for (int iter = 0; iter < iterations; iter++) {{ + for (int c = 0; c < num_contacts; c++) {{ + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f; + + for (int j = lane; j < num_contacts; j += 32) {{ + float l0 = s_lam[j * 3 + 0]; + float l1 = s_lam[j * 3 + 1]; + float l2 = s_lam[j * 3 + 2]; + + int block_off; + bool transpose; + if (j <= c) {{ + block_off = (((c * (c + 1)) >> 1) + j) * 9; + transpose = false; + }} else {{ + block_off = (((j * (j + 1)) >> 1) + c) * 9; + transpose = true; + }} + + const float* B = &s_Dtri[block_off]; + + if (!transpose) {{ + sum0 += B[0] * l0 + B[1] * l1 + B[2] * l2; + sum1 += B[3] * l0 + B[4] * l1 + B[5] * l2; + sum2 += B[6] * l0 + B[7] * l1 + B[8] * l2; + }} else {{ + sum0 += B[0] * l0 + B[3] * l1 + B[6] * l2; + sum1 += B[1] * l0 + B[4] * l1 + B[7] * l2; + sum2 += B[2] * l0 + B[5] * l1 + B[8] * l2; + }} + }} + + // Warp reduce + sum0 += __shfl_down_sync(MASK, sum0, 16); + sum1 += __shfl_down_sync(MASK, sum1, 16); + sum2 += __shfl_down_sync(MASK, sum2, 16); + sum0 += __shfl_down_sync(MASK, sum0, 8); + sum1 += __shfl_down_sync(MASK, sum1, 8); + sum2 += __shfl_down_sync(MASK, sum2, 8); + sum0 += __shfl_down_sync(MASK, sum0, 4); + sum1 += __shfl_down_sync(MASK, sum1, 4); + sum2 += __shfl_down_sync(MASK, sum2, 4); + sum0 += __shfl_down_sync(MASK, sum0, 2); + sum1 += __shfl_down_sync(MASK, sum1, 2); + sum2 += __shfl_down_sync(MASK, sum2, 2); + sum0 += __shfl_down_sync(MASK, sum0, 1); + sum1 += __shfl_down_sync(MASK, sum1, 1); + sum2 += __shfl_down_sync(MASK, sum2, 1); + + if (lane == 0) {{ + // Corrected sign: -(rhs + D*lambda) + float res0 = -(s_rhs[c * 3 + 0] + sum0); + float res1 = -(s_rhs[c * 3 + 1] + sum1); + float res2 = -(s_rhs[c * 3 + 2] + sum2); + + const float* Dinv = &s_Dinv[c * 9]; + float d0 = Dinv[0] * res0 + Dinv[1] * res1 + Dinv[2] * res2; + float d1 = Dinv[3] * res0 + Dinv[4] * res1 + Dinv[5] * res2; + float d2 = Dinv[6] * res0 + Dinv[7] * res1 + Dinv[8] * res2; + + float new_n = s_lam[c * 3 + 0] + omega * d0; + float new_t1 = s_lam[c * 3 + 1] + omega * d1; + float new_t2 = s_lam[c * 3 + 2] + omega * d2; + + // Friction cone projection + new_n = fmaxf(new_n, 0.0f); + + float mu = s_mu[c]; + float radius = mu * new_n; + + if (radius <= 0.0f) {{ + new_t1 = 0.0f; + new_t2 = 0.0f; + }} else {{ + float t_mag_sq = new_t1 * new_t1 + new_t2 * new_t2; + if (t_mag_sq > radius * radius) {{ + float scale = radius * rsqrtf(t_mag_sq); + new_t1 *= scale; + new_t2 *= scale; + }} + }} + + s_lam[c * 3 + 0] = new_n; + s_lam[c * 3 + 1] = new_t1; + s_lam[c * 3 + 2] = new_t2; + }} + __syncwarp(); + }} + }} + + // ============ STORE PHASE ============ + + for (int i = lane; i < TILE_M_USABLE; i += 32) {{ + if (i < m) {{ + world_impulses.data[off1 + i] = s_lam[i]; + }} + }} + #endif + """ + + @wp.func_native(snippet) + def pgs_solve_contact_native( + world: int, + world_constraint_count: wp.array(dtype=int), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_mu: wp.array2d(dtype=float), + ): ... + + def pgs_solve_tiled_contact_template( + world_constraint_count: wp.array(dtype=int), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_mu: wp.array2d(dtype=float), + ): + world, _lane = wp.tid() + pgs_solve_contact_native( + world, + world_constraint_count, + world_C, + world_rhs, + world_impulses, + iterations, + omega, + world_row_mu, + ) + + pgs_solve_tiled_contact_template.__name__ = f"pgs_solve_tiled_contact_{max_constraints}" + pgs_solve_tiled_contact_template.__qualname__ = f"pgs_solve_tiled_contact_{max_constraints}" + return wp.kernel(enable_backward=False, module="unique")(pgs_solve_tiled_contact_template) + + @classmethod + def get_pgs_solve_streaming_kernel( + cls, max_constraints: int, device: "wp.Device", pgs_chunk_size: int = 1 + ) -> "wp.Kernel": + """Get or create a streaming contact-wise PGS world solve kernel.""" + key = (max_constraints, device.arch, pgs_chunk_size) + if key not in cls._pgs_solve_streaming_cache: + cls._pgs_solve_streaming_cache[key] = cls._build_pgs_solve_streaming_kernel(max_constraints, pgs_chunk_size) + return cls._pgs_solve_streaming_cache[key] + + @classmethod + def _build_pgs_solve_streaming_kernel(cls, max_constraints: int, pgs_chunk_size: int = 1) -> "wp.Kernel": + """Streaming contact-wise PGS kernel that streams block-rows from global memory. + + Unlike tiled_contact which loads the entire Delassus matrix into shared memory, + this kernel keeps only lambda and auxiliaries in shared memory and streams + block-rows of C on demand. This enables handling much larger constraint counts + (hundreds of contacts) at the cost of increased global memory bandwidth. + + When pgs_chunk_size > 1, multiple block-rows are preloaded into shared memory + at once, reducing the number of global memory round-trips per PGS iteration. + + Algorithm: + - Load lambda, rhs, mu, and compute diagonal block inverses once + - For each PGS iteration: + - For each chunk of pgs_chunk_size contacts: + - Preload pgs_chunk_size block-rows of C into shared memory + - For each contact c in the chunk: + - Compute block-row dot product with lambda (warp-parallel) + - Update lambda[c] with friction cone projection (lane 0) + - Store final lambda back to global memory + """ + TILE_M = max_constraints + NUM_CONTACTS_MAX = TILE_M // 3 + TILE_M_USABLE = NUM_CONTACTS_MAX * 3 + PGS_CHUNK = pgs_chunk_size + + snippet = f""" + #if defined(__CUDA_ARCH__) + const int TILE_M = {TILE_M}; + const int TILE_M_USABLE = {TILE_M_USABLE}; + const int NUM_CONTACTS_MAX = {NUM_CONTACTS_MAX}; + const int PGS_CHUNK = {PGS_CHUNK}; + const unsigned MASK = 0xFFFFFFFF; + + int lane = threadIdx.x; + + int m = world_constraint_count.data[world]; + if (m == 0) return; + + // Clamp m to usable range and ensure divisible by 3 + if (m > TILE_M_USABLE) m = TILE_M_USABLE; + int num_contacts = m / 3; + + // ═══════════════════════════════════════════════════════════════ + // SHARED MEMORY: lambda, rhs, mu, diagonal inverses, and + // block-row buffer for PGS_CHUNK contacts at a time + // ═══════════════════════════════════════════════════════════════ + __shared__ float s_lam[{TILE_M_USABLE}]; + __shared__ float s_rhs[{TILE_M_USABLE}]; + __shared__ float s_mu[{NUM_CONTACTS_MAX}]; + __shared__ float s_Dinv[{NUM_CONTACTS_MAX} * 9]; + __shared__ float s_block_rows[{PGS_CHUNK} * {NUM_CONTACTS_MAX} * 9]; + + int off1 = world * TILE_M; + int off2 = world * TILE_M * TILE_M; + + // ═══════════════════════════════════════════════════════════════ + // LOAD PHASE: Load persistent data into shared memory + // ═══════════════════════════════════════════════════════════════ + + // Load lambda and rhs (coalesced) + for (int i = lane; i < TILE_M_USABLE; i += 32) {{ + if (i < m) {{ + s_lam[i] = world_impulses.data[off1 + i]; + s_rhs[i] = world_rhs.data[off1 + i]; + }} else {{ + s_lam[i] = 0.0f; + s_rhs[i] = 0.0f; + }} + }} + + // Load mu (one per contact, stored on tangent1 row) + for (int c = lane; c < NUM_CONTACTS_MAX; c += 32) {{ + if (c < num_contacts) {{ + s_mu[c] = world_row_mu.data[off1 + c * 3 + 1]; + }} + }} + __syncwarp(); + + // Compute diagonal block inverses (each thread handles one contact) + for (int c = lane; c < num_contacts; c += 32) {{ + // Load diagonal block D[c,c] from global memory + int diag_row = c * 3; + float D[9]; + for (int k = 0; k < 9; k++) {{ + int lr = k / 3; + int lc = k % 3; + D[k] = world_C.data[off2 + (diag_row + lr) * TILE_M + (diag_row + lc)]; + }} + + // Compute 3x3 inverse + float det = D[0] * (D[4] * D[8] - D[5] * D[7]) + - D[1] * (D[3] * D[8] - D[5] * D[6]) + + D[2] * (D[3] * D[7] - D[4] * D[6]); + + float inv_det = 1.0f / det; + float* Dinv = &s_Dinv[c * 9]; + + Dinv[0] = (D[4] * D[8] - D[5] * D[7]) * inv_det; + Dinv[1] = (D[2] * D[7] - D[1] * D[8]) * inv_det; + Dinv[2] = (D[1] * D[5] - D[2] * D[4]) * inv_det; + Dinv[3] = (D[5] * D[6] - D[3] * D[8]) * inv_det; + Dinv[4] = (D[0] * D[8] - D[2] * D[6]) * inv_det; + Dinv[5] = (D[2] * D[3] - D[0] * D[5]) * inv_det; + Dinv[6] = (D[3] * D[7] - D[4] * D[6]) * inv_det; + Dinv[7] = (D[1] * D[6] - D[0] * D[7]) * inv_det; + Dinv[8] = (D[0] * D[4] - D[1] * D[3]) * inv_det; + }} + __syncwarp(); + + // ═══════════════════════════════════════════════════════════════ + // ITERATION PHASE: Stream block-rows in chunks and solve + // ═══════════════════════════════════════════════════════════════ + + for (int iter = 0; iter < iterations; iter++) {{ + for (int chunk_start = 0; chunk_start < num_contacts; chunk_start += PGS_CHUNK) {{ + int chunk_end = min(chunk_start + PGS_CHUNK, num_contacts); + int chunk_len = chunk_end - chunk_start; + + // ───────────────────────────────────────────────────────── + // STREAM: Preload chunk_len block-rows of Delassus matrix + // ───────────────────────────────────────────────────────── + for (int ci = 0; ci < chunk_len; ci++) {{ + int c = chunk_start + ci; + int c_row = c * 3; + float* row_base = &s_block_rows[ci * NUM_CONTACTS_MAX * 9]; + for (int j = lane; j < num_contacts; j += 32) {{ + int j_col = j * 3; + float* dst = &row_base[j * 9]; + for (int k = 0; k < 9; k++) {{ + int lr = k / 3; + int lc = k % 3; + dst[k] = world_C.data[off2 + (c_row + lr) * TILE_M + (j_col + lc)]; + }} + }} + }} + __syncwarp(); + + // ───────────────────────────────────────────────────────── + // SOLVE: Process each contact in the chunk sequentially + // ───────────────────────────────────────────────────────── + for (int ci = 0; ci < chunk_len; ci++) {{ + int c = chunk_start + ci; + const float* row_base = &s_block_rows[ci * NUM_CONTACTS_MAX * 9]; + + // Block-row dot product sum_j C[c,j] * lambda[j] + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f; + + for (int j = lane; j < num_contacts; j += 32) {{ + float l0 = s_lam[j * 3 + 0]; + float l1 = s_lam[j * 3 + 1]; + float l2 = s_lam[j * 3 + 2]; + + const float* B = &row_base[j * 9]; + + sum0 += B[0] * l0 + B[1] * l1 + B[2] * l2; + sum1 += B[3] * l0 + B[4] * l1 + B[5] * l2; + sum2 += B[6] * l0 + B[7] * l1 + B[8] * l2; + }} + + // Warp reduce + sum0 += __shfl_down_sync(MASK, sum0, 16); + sum1 += __shfl_down_sync(MASK, sum1, 16); + sum2 += __shfl_down_sync(MASK, sum2, 16); + sum0 += __shfl_down_sync(MASK, sum0, 8); + sum1 += __shfl_down_sync(MASK, sum1, 8); + sum2 += __shfl_down_sync(MASK, sum2, 8); + sum0 += __shfl_down_sync(MASK, sum0, 4); + sum1 += __shfl_down_sync(MASK, sum1, 4); + sum2 += __shfl_down_sync(MASK, sum2, 4); + sum0 += __shfl_down_sync(MASK, sum0, 2); + sum1 += __shfl_down_sync(MASK, sum1, 2); + sum2 += __shfl_down_sync(MASK, sum2, 2); + sum0 += __shfl_down_sync(MASK, sum0, 1); + sum1 += __shfl_down_sync(MASK, sum1, 1); + sum2 += __shfl_down_sync(MASK, sum2, 1); + + // Update: Solve and project (lane 0 only) + if (lane == 0) {{ + float res0 = -(s_rhs[c * 3 + 0] + sum0); + float res1 = -(s_rhs[c * 3 + 1] + sum1); + float res2 = -(s_rhs[c * 3 + 2] + sum2); + + const float* Dinv = &s_Dinv[c * 9]; + float d0 = Dinv[0] * res0 + Dinv[1] * res1 + Dinv[2] * res2; + float d1 = Dinv[3] * res0 + Dinv[4] * res1 + Dinv[5] * res2; + float d2 = Dinv[6] * res0 + Dinv[7] * res1 + Dinv[8] * res2; + + float new_n = s_lam[c * 3 + 0] + omega * d0; + float new_t1 = s_lam[c * 3 + 1] + omega * d1; + float new_t2 = s_lam[c * 3 + 2] + omega * d2; + + // Friction cone projection + new_n = fmaxf(new_n, 0.0f); + + float mu = s_mu[c]; + float radius = mu * new_n; + + if (radius <= 0.0f) {{ + new_t1 = 0.0f; + new_t2 = 0.0f; + }} else {{ + float t_mag_sq = new_t1 * new_t1 + new_t2 * new_t2; + if (t_mag_sq > radius * radius) {{ + float scale = radius * rsqrtf(t_mag_sq); + new_t1 *= scale; + new_t2 *= scale; + }} + }} + + s_lam[c * 3 + 0] = new_n; + s_lam[c * 3 + 1] = new_t1; + s_lam[c * 3 + 2] = new_t2; + }} + __syncwarp(); + }} + }} + }} + + // ═══════════════════════════════════════════════════════════════ + // STORE PHASE: Write final lambda back to global memory + // ═══════════════════════════════════════════════════════════════ + for (int i = lane; i < TILE_M_USABLE; i += 32) {{ + if (i < m) {{ + world_impulses.data[off1 + i] = s_lam[i]; + }} + }} + #endif + """ + + @wp.func_native(snippet) + def pgs_solve_streaming_native( + world: int, + world_constraint_count: wp.array(dtype=int), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_mu: wp.array2d(dtype=float), + ): ... + + def pgs_solve_streaming_template( + world_constraint_count: wp.array(dtype=int), + world_C: wp.array3d(dtype=float), + world_rhs: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + iterations: int, + omega: float, + world_row_mu: wp.array2d(dtype=float), + ): + world, _lane = wp.tid() + pgs_solve_streaming_native( + world, + world_constraint_count, + world_C, + world_rhs, + world_impulses, + iterations, + omega, + world_row_mu, + ) + + pgs_solve_streaming_template.__name__ = f"pgs_solve_streaming_{max_constraints}_chunk{pgs_chunk_size}" + pgs_solve_streaming_template.__qualname__ = f"pgs_solve_streaming_{max_constraints}_chunk{pgs_chunk_size}" + return wp.kernel(enable_backward=False, module="unique")(pgs_solve_streaming_template) + + @classmethod + def get_pgs_solve_mf_kernel(cls, mf_max_constraints: int, max_mf_bodies: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a streaming MF PGS kernel for free rigid body contacts.""" + key = (mf_max_constraints, max_mf_bodies, device.arch) + if key not in cls._pgs_solve_mf_cache: + cls._pgs_solve_mf_cache[key] = cls._build_pgs_solve_mf_kernel(mf_max_constraints, max_mf_bodies) + return cls._pgs_solve_mf_cache[key] + + @classmethod + def _build_pgs_solve_mf_kernel(cls, mf_max_constraints: int, max_mf_bodies: int) -> "wp.Kernel": + """Matrix-free PGS with body velocities and impulses in shared memory. + + Uses one warp (32 threads) per world. Body velocities and impulses live + in shared memory for the duration of all PGS iterations, eliminating + global memory round-trips. J, MiJt, eff_mass_inv, and rhs are read + from global memory per constraint (read-only, cache-friendly sequential access). + """ + MF_MAX_C = mf_max_constraints + MAX_BODIES = max_mf_bodies + + snippet = f""" + #if defined(__CUDA_ARCH__) + const int MF_MAX_C = {MF_MAX_C}; + const int MAX_BODIES = {MAX_BODIES}; + + int lane = threadIdx.x; + + int m = mf_constraint_count.data[world]; + if (m == 0) return; + if (m > MF_MAX_C) m = MF_MAX_C; + + int n_bodies = mf_body_count.data[world]; + if (n_bodies > MAX_BODIES) n_bodies = MAX_BODIES; + + // ═══════════════════════════════════════════════════════════════ + // SHARED MEMORY + // ═══════════════════════════════════════════════════════════════ + __shared__ float s_vel[{MAX_BODIES * 6}]; + __shared__ float s_impulse[{MF_MAX_C}]; + __shared__ int s_dof_start[{MAX_BODIES}]; + + int body_off = world * MAX_BODIES; + int c_off = world * MF_MAX_C; + + // ═══════════════════════════════════════════════════════════════ + // LOAD PHASE + // ═══════════════════════════════════════════════════════════════ + + // Load body DOF starts and velocities + for (int b = lane; b < n_bodies; b += 32) {{ + int dof = mf_body_dof_start.data[body_off + b]; + s_dof_start[b] = dof; + for (int k = 0; k < 6; k++) {{ + s_vel[b * 6 + k] = v_out.data[dof + k]; + }} + }} + + // Load impulses + for (int i = lane; i < m; i += 32) {{ + s_impulse[i] = mf_impulses.data[c_off + i]; + }} + __syncwarp(); + + // ═══════════════════════════════════════════════════════════════ + // SOLVE PHASE (lane 0) + // ═══════════════════════════════════════════════════════════════ + + if (lane == 0) {{ + for (int iter = 0; iter < iterations; iter++) {{ + for (int i = 0; i < m; i++) {{ + float eff_inv = mf_eff_mass_inv.data[c_off + i]; + if (eff_inv <= 0.0f) continue; + + int lba = mf_local_body_a.data[c_off + i]; + int lbb = mf_local_body_b.data[c_off + i]; + + // Load J from global memory + int j_base = (c_off + i) * 6; + float ja0 = mf_J_a.data[j_base + 0]; + float ja1 = mf_J_a.data[j_base + 1]; + float ja2 = mf_J_a.data[j_base + 2]; + float ja3 = mf_J_a.data[j_base + 3]; + float ja4 = mf_J_a.data[j_base + 4]; + float ja5 = mf_J_a.data[j_base + 5]; + + float jb0 = mf_J_b.data[j_base + 0]; + float jb1 = mf_J_b.data[j_base + 1]; + float jb2 = mf_J_b.data[j_base + 2]; + float jb3 = mf_J_b.data[j_base + 3]; + float jb4 = mf_J_b.data[j_base + 4]; + float jb5 = mf_J_b.data[j_base + 5]; + + // Compute J * v from shared memory + float jv = 0.0f; + if (lba >= 0) {{ + int va = lba * 6; + jv += ja0 * s_vel[va] + ja1 * s_vel[va+1] + ja2 * s_vel[va+2] + + ja3 * s_vel[va+3] + ja4 * s_vel[va+4] + ja5 * s_vel[va+5]; + }} + if (lbb >= 0) {{ + int vb = lbb * 6; + jv += jb0 * s_vel[vb] + jb1 * s_vel[vb+1] + jb2 * s_vel[vb+2] + + jb3 * s_vel[vb+3] + jb4 * s_vel[vb+4] + jb5 * s_vel[vb+5]; + }} + + // PGS update + float rhs_i = mf_rhs.data[c_off + i]; + float delta = -(jv + rhs_i) * eff_inv; + float old_impulse = s_impulse[i]; + float new_impulse = old_impulse + omega * delta; + + int row_type = mf_row_type.data[c_off + i]; + + // Project: contact or joint limit + if (row_type == 0 || row_type == 3) {{ + if (new_impulse < 0.0f) new_impulse = 0.0f; + }} + // Project: friction + else if (row_type == 2) {{ + int parent_idx = mf_row_parent.data[c_off + i]; + float lambda_n = s_impulse[parent_idx]; + float mu = mf_row_mu.data[c_off + i]; + float radius = fmaxf(mu * lambda_n, 0.0f); + + if (radius <= 0.0f) {{ + new_impulse = 0.0f; + }} else {{ + int sib = (i == parent_idx + 1) ? parent_idx + 2 : parent_idx + 1; + + s_impulse[i] = new_impulse; + float a_val = new_impulse; + float b_val = s_impulse[sib]; + float mag = sqrtf(a_val * a_val + b_val * b_val); + if (mag > radius) {{ + float scale = radius / mag; + new_impulse = a_val * scale; + float sib_new = b_val * scale; + float sib_delta = sib_new - b_val; + s_impulse[sib] = sib_new; + + // Apply sibling correction to body velocities + int sib_lba = mf_local_body_a.data[c_off + sib]; + int sib_lbb = mf_local_body_b.data[c_off + sib]; + int sib_j_base = (c_off + sib) * 6; + if (sib_lba >= 0) {{ + int sva = sib_lba * 6; + s_vel[sva+0] += mf_MiJt_a.data[sib_j_base+0] * sib_delta; + s_vel[sva+1] += mf_MiJt_a.data[sib_j_base+1] * sib_delta; + s_vel[sva+2] += mf_MiJt_a.data[sib_j_base+2] * sib_delta; + s_vel[sva+3] += mf_MiJt_a.data[sib_j_base+3] * sib_delta; + s_vel[sva+4] += mf_MiJt_a.data[sib_j_base+4] * sib_delta; + s_vel[sva+5] += mf_MiJt_a.data[sib_j_base+5] * sib_delta; + }} + if (sib_lbb >= 0) {{ + int svb = sib_lbb * 6; + s_vel[svb+0] += mf_MiJt_b.data[sib_j_base+0] * sib_delta; + s_vel[svb+1] += mf_MiJt_b.data[sib_j_base+1] * sib_delta; + s_vel[svb+2] += mf_MiJt_b.data[sib_j_base+2] * sib_delta; + s_vel[svb+3] += mf_MiJt_b.data[sib_j_base+3] * sib_delta; + s_vel[svb+4] += mf_MiJt_b.data[sib_j_base+4] * sib_delta; + s_vel[svb+5] += mf_MiJt_b.data[sib_j_base+5] * sib_delta; + }} + }} + }} + }} + + float delta_impulse = new_impulse - old_impulse; + s_impulse[i] = new_impulse; + + // Apply velocity correction: v += MiJt * delta_impulse + int mijt_base = (c_off + i) * 6; + if (lba >= 0) {{ + int va = lba * 6; + s_vel[va+0] += mf_MiJt_a.data[mijt_base+0] * delta_impulse; + s_vel[va+1] += mf_MiJt_a.data[mijt_base+1] * delta_impulse; + s_vel[va+2] += mf_MiJt_a.data[mijt_base+2] * delta_impulse; + s_vel[va+3] += mf_MiJt_a.data[mijt_base+3] * delta_impulse; + s_vel[va+4] += mf_MiJt_a.data[mijt_base+4] * delta_impulse; + s_vel[va+5] += mf_MiJt_a.data[mijt_base+5] * delta_impulse; + }} + if (lbb >= 0) {{ + int vb = lbb * 6; + s_vel[vb+0] += mf_MiJt_b.data[mijt_base+0] * delta_impulse; + s_vel[vb+1] += mf_MiJt_b.data[mijt_base+1] * delta_impulse; + s_vel[vb+2] += mf_MiJt_b.data[mijt_base+2] * delta_impulse; + s_vel[vb+3] += mf_MiJt_b.data[mijt_base+3] * delta_impulse; + s_vel[vb+4] += mf_MiJt_b.data[mijt_base+4] * delta_impulse; + s_vel[vb+5] += mf_MiJt_b.data[mijt_base+5] * delta_impulse; + }} + }} + }} + }} + __syncwarp(); + + // ═══════════════════════════════════════════════════════════════ + // STORE PHASE + // ═══════════════════════════════════════════════════════════════ + + // Write body velocities back to v_out + for (int b = lane; b < n_bodies; b += 32) {{ + int dof = s_dof_start[b]; + for (int k = 0; k < 6; k++) {{ + v_out.data[dof + k] = s_vel[b * 6 + k]; + }} + }} + + // Write impulses back + for (int i = lane; i < m; i += 32) {{ + mf_impulses.data[c_off + i] = s_impulse[i]; + }} + #endif + """ + + @wp.func_native(snippet) + def pgs_solve_mf_native( + world: int, + mf_constraint_count: wp.array(dtype=int), + mf_body_count: wp.array(dtype=int), + mf_body_dof_start: wp.array2d(dtype=int), + mf_local_body_a: wp.array2d(dtype=int), + mf_local_body_b: wp.array2d(dtype=int), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_rhs: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_row_mu: wp.array2d(dtype=float), + mf_impulses: wp.array2d(dtype=float), + v_out: wp.array(dtype=float), + iterations: int, + omega: float, + ): ... + + def pgs_solve_mf_template( + mf_constraint_count: wp.array(dtype=int), + mf_body_count: wp.array(dtype=int), + mf_body_dof_start: wp.array2d(dtype=int), + mf_local_body_a: wp.array2d(dtype=int), + mf_local_body_b: wp.array2d(dtype=int), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_rhs: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_row_mu: wp.array2d(dtype=float), + mf_impulses: wp.array2d(dtype=float), + v_out: wp.array(dtype=float), + iterations: int, + omega: float, + ): + world, _lane = wp.tid() + pgs_solve_mf_native( + world, + mf_constraint_count, + mf_body_count, + mf_body_dof_start, + mf_local_body_a, + mf_local_body_b, + mf_J_a, + mf_J_b, + mf_MiJt_a, + mf_MiJt_b, + mf_eff_mass_inv, + mf_rhs, + mf_row_type, + mf_row_parent, + mf_row_mu, + mf_impulses, + v_out, + iterations, + omega, + ) + + name = f"pgs_solve_mf_{mf_max_constraints}_{max_mf_bodies}" + pgs_solve_mf_template.__name__ = name + pgs_solve_mf_template.__qualname__ = name + return wp.kernel(enable_backward=False, module="unique")(pgs_solve_mf_template) + + @classmethod + def get_pack_mf_meta_kernel(cls, mf_max_constraints: int, device: "wp.Device") -> "wp.Kernel": + """Get or create a kernel to pack MF metadata into int4 format.""" + key = (mf_max_constraints, device.arch) + if key not in cls._pack_mf_meta_cache: + cls._pack_mf_meta_cache[key] = cls._build_pack_mf_meta_kernel(mf_max_constraints) + return cls._pack_mf_meta_cache[key] + + @classmethod + def _build_pack_mf_meta_kernel(cls, mf_max_constraints: int) -> "wp.Kernel": + """Build a kernel to pack MF constraint metadata into int4 structs. + + Packs dof_a, dof_b, eff_mass_inv, rhs, row_type, row_parent into + 4 contiguous int32s per constraint for 128-bit coalesced loads. + """ + M_MF = mf_max_constraints + + snippet = f""" + #if defined(__CUDA_ARCH__) + int lane = threadIdx.x; + int m_mf = mf_constraint_count.data[world]; + int off_mf = world * {M_MF}; + int off_meta = off_mf * 4; + + for (int i = lane; i < m_mf; i += 32) {{ + int da = mf_dof_a.data[off_mf + i]; + int db = mf_dof_b.data[off_mf + i]; + float diag = mf_eff_mass_inv.data[off_mf + i]; + float rhs_val = mf_rhs.data[off_mf + i]; + int rt = mf_row_type.data[off_mf + i]; + int par = mf_row_parent.data[off_mf + i]; + + int4 packed; + packed.x = (da << 16) | (db & 0xFFFF); + packed.y = __float_as_int(diag); + packed.z = __float_as_int(rhs_val); + packed.w = rt | (par << 16); + *reinterpret_cast(&mf_meta.data[off_meta + i * 4]) = packed; + }} + #endif + """ + + @wp.func_native(snippet) + def pack_mf_meta_native( + world: int, + mf_constraint_count: wp.array(dtype=int), + mf_dof_a: wp.array2d(dtype=int), + mf_dof_b: wp.array2d(dtype=int), + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_rhs: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_meta: wp.array2d(dtype=int), + ): ... + + def pack_mf_meta_template( + mf_constraint_count: wp.array(dtype=int), + mf_dof_a: wp.array2d(dtype=int), + mf_dof_b: wp.array2d(dtype=int), + mf_eff_mass_inv: wp.array2d(dtype=float), + mf_rhs: wp.array2d(dtype=float), + mf_row_type: wp.array2d(dtype=int), + mf_row_parent: wp.array2d(dtype=int), + mf_meta: wp.array2d(dtype=int), + ): + world, _lane = wp.tid() + pack_mf_meta_native( + world, + mf_constraint_count, + mf_dof_a, + mf_dof_b, + mf_eff_mass_inv, + mf_rhs, + mf_row_type, + mf_row_parent, + mf_meta, + ) + + name = f"pack_mf_meta_{mf_max_constraints}" + pack_mf_meta_template.__name__ = name + pack_mf_meta_template.__qualname__ = name + return wp.kernel(enable_backward=False, module="unique")(pack_mf_meta_template) + + @classmethod + def get_pgs_solve_mf_gs_kernel( + cls, max_constraints: int, mf_max_constraints: int, max_world_dofs: int, device: "wp.Device" + ) -> "wp.Kernel": + """Get or create a two-phase GS kernel for matrix-free articulated PGS. + + Phase 1 processes dense constraints (via J_world/Y_world at ``max_constraints``). + Phase 2 processes MF constraints (via mf_J/mf_MiJt at ``mf_max_constraints``). + Both phases share a single velocity vector in shared memory. + """ + key = (max_constraints, mf_max_constraints, max_world_dofs, device.arch) + if key not in cls._pgs_solve_mf_gs_cache: + cls._pgs_solve_mf_gs_cache[key] = cls._build_pgs_solve_mf_gs_kernel( + max_constraints, mf_max_constraints, max_world_dofs + ) + return cls._pgs_solve_mf_gs_cache[key] + + @classmethod + def _build_pgs_solve_mf_gs_kernel( + cls, max_constraints: int, mf_max_constraints: int, max_world_dofs: int + ) -> "wp.Kernel": + """Two-phase GS PGS kernel: dense + matrix-free in one pass. + + Uses one warp (32 threads) per world. + + Phase 1 (dense): warp-parallel dot/update over D DOFs using J_world/Y_world. + Phase 2 (MF): lanes 0-5 handle body_a, lanes 6-11 handle body_b (6 DOFs each). + + Shared memory layout: + s_v[D] — world velocity + s_lam_dense[M_D] + metadata — dense impulses and constraint info + s_lam_mf[M_MF] — MF impulses (metadata read from global per constraint) + """ + M_D = max_constraints + M_MF = mf_max_constraints + D = max_world_dofs + + # How many DOF elements each lane handles (ceil(D/32)) + ELEMS_PER_LANE = (D + 31) // 32 + + # --- Code generation for dense phase (D-wide dot/update, software-pipelined) --- + + # Pipeline register declarations + dense_pipe_decl = "\n".join( + [f" float pre_dJ_{k} = 0.0f, pre_dY_{k} = 0.0f;" for k in range(ELEMS_PER_LANE)] + ) + + # Initial prefetch (constraint 0) + dense_prefetch_init_parts = [] + for k in range(ELEMS_PER_LANE): + d_expr = f"lane + {k * 32}" if k > 0 else "lane" + dense_prefetch_init_parts.append(f""" + if ({d_expr} < {D}) {{ + pre_dJ_{k} = J_world.data[jy_world_base + {d_expr}]; + pre_dY_{k} = Y_world.data[jy_world_base + {d_expr}]; + }}""") + dense_prefetch_init_code = "\n".join(dense_prefetch_init_parts) + + # Consume prefetched values into cur_ variables + dense_consume_code = "\n".join( + [f" float cur_dJ_{k} = pre_dJ_{k}, cur_dY_{k} = pre_dY_{k};" for k in range(ELEMS_PER_LANE)] + ) + + # Prefetch next constraint (i+1) + dense_prefetch_next_parts = [] + for k in range(ELEMS_PER_LANE): + d_expr = f"lane + {k * 32}" if k > 0 else "lane" + dense_prefetch_next_parts.append(f""" + if ({d_expr} < {D}) {{ + pre_dJ_{k} = J_world.data[next_jy_base + {d_expr}]; + pre_dY_{k} = Y_world.data[next_jy_base + {d_expr}]; + }}""") + dense_prefetch_next_code = "\n".join(dense_prefetch_next_parts) + + # Dense dot product using prefetched J: cur_dJ_k * s_v[d] + dense_dot_parts = [] + for k in range(ELEMS_PER_LANE): + d_expr = f"lane + {k * 32}" if k > 0 else "lane" + dense_dot_parts.append(f""" + if ({d_expr} < {D}) {{ + my_sum += cur_dJ_{k} * s_v[{d_expr}]; + }}""") + dense_dot_code = "\n".join(["float my_sum = 0.0f;", *dense_dot_parts]) + + # Dense v update using prefetched Y: s_v[d] += cur_dY_k * delta + dense_v_update_parts = [] + for k in range(ELEMS_PER_LANE): + d_expr = f"lane + {k * 32}" if k > 0 else "lane" + dense_v_update_parts.append(f""" + if ({d_expr} < {D}) {{ + s_v[{d_expr}] += cur_dY_{k} * delta_impulse; + }}""") + dense_v_update_code = "\n".join(dense_v_update_parts) + + # Dense sibling v update — NOT pipelined (random sib index) + dense_sib_v_parts = [] + for k in range(ELEMS_PER_LANE): + d_expr = f"lane + {k * 32}" if k > 0 else "lane" + dense_sib_v_parts.append(f""" + if ({d_expr} < {D}) {{ + s_v[{d_expr}] += Y_world.data[sib_row_base + {d_expr}] * sib_delta; + }}""") + dense_sib_v_code = "\n".join(dense_sib_v_parts) + + snippet = f""" + #if defined(__CUDA_ARCH__) + const unsigned MASK = 0xFFFFFFFF; + int lane = threadIdx.x; + + int m_dense = world_constraint_count.data[world]; + int m_mf = mf_constraint_count.data[world]; + if (m_dense == 0 && m_mf == 0) return; + if (m_dense > {M_D}) m_dense = {M_D}; + if (m_mf > {M_MF}) m_mf = {M_MF}; + + int w_dof_start = world_dof_start.data[world]; + int off_dense = world * {M_D}; + int off_mf = world * {M_MF}; + int off_meta = off_mf * 4; + int jy_world_base = world * {M_D} * {D}; + int mf6_base = world * {M_MF} * 6; + + // ═══════════════════════════════════════════════════════ + // SHARED MEMORY + // ═══════════════════════════════════════════════════════ + __shared__ float s_v[{D}]; + __shared__ float s_lam_dense[{M_D}]; + __shared__ float s_rhs_dense[{M_D}]; + __shared__ float s_diag_dense[{M_D}]; + __shared__ int s_rtype_dense[{M_D}]; + __shared__ int s_parent_dense[{M_D}]; + __shared__ float s_mu_dense[{M_D}]; + __shared__ float s_lam_mf[{M_MF}]; + + // ═══════════════════════════════════════════════════════ + // LOAD PHASE + // ═══════════════════════════════════════════════════════ + for (int i = lane; i < m_dense; i += 32) {{ + s_lam_dense[i] = world_impulses.data[off_dense + i]; + s_rhs_dense[i] = rhs_bias.data[off_dense + i]; + s_diag_dense[i] = world_diag.data[off_dense + i]; + s_rtype_dense[i] = world_row_type.data[off_dense + i]; + s_parent_dense[i] = world_row_parent.data[off_dense + i]; + s_mu_dense[i] = world_row_mu.data[off_dense + i]; + }} + for (int i = lane; i < m_mf; i += 32) {{ + s_lam_mf[i] = mf_impulses.data[off_mf + i]; + }} + for (int d = lane; d < {D}; d += 32) {{ + s_v[d] = v_out.data[w_dof_start + d]; + }} + __syncwarp(); + + // ═══════════════════════════════════════════════════════ + // SOLVE PHASE + // ═══════════════════════════════════════════════════════ + // Dense pipeline registers +{dense_pipe_decl} + + for (int iter = 0; iter < iterations; iter++) {{ + + // ── Phase 1: Dense constraints (D-DOF warp-parallel, software-pipelined) ── + + // Prefetch constraint 0 + if (m_dense > 0) {{ + {dense_prefetch_init_code} + }} + + for (int i = 0; i < m_dense; i++) {{ + // Consume prefetched J/Y for constraint i + {dense_consume_code} + + // Prefetch constraint i+1 + if (i + 1 < m_dense) {{ + int next_jy_base = jy_world_base + (i + 1) * {D}; + {dense_prefetch_next_code} + }} + + float denom = s_diag_dense[i]; + if (denom <= 0.0f) continue; + + // J_i · v (using prefetched J) + {dense_dot_code} + + // Warp reduce + my_sum += __shfl_down_sync(MASK, my_sum, 16); + my_sum += __shfl_down_sync(MASK, my_sum, 8); + my_sum += __shfl_down_sync(MASK, my_sum, 4); + my_sum += __shfl_down_sync(MASK, my_sum, 2); + my_sum += __shfl_down_sync(MASK, my_sum, 1); + float jv = __shfl_sync(MASK, my_sum, 0); + + float residual = jv + s_rhs_dense[i]; + float delta = -residual / denom; + float old_impulse = s_lam_dense[i]; + float new_impulse = old_impulse + omega * delta; + int row_type = s_rtype_dense[i]; + + if (row_type == 0 || row_type == 3) {{ + if (new_impulse < 0.0f) new_impulse = 0.0f; + }} else if (row_type == 2) {{ + int parent_idx = s_parent_dense[i]; + float lambda_n = s_lam_dense[parent_idx]; + float mu = s_mu_dense[i]; + float radius = fmaxf(mu * lambda_n, 0.0f); + + if (radius <= 0.0f) {{ + new_impulse = 0.0f; + }} else {{ + int sib = (i == parent_idx + 1) ? parent_idx + 2 : parent_idx + 1; + s_lam_dense[i] = new_impulse; + float a_val = new_impulse; + float b_val = s_lam_dense[sib]; + float mag = sqrtf(a_val * a_val + b_val * b_val); + if (mag > radius) {{ + float scale = radius / mag; + new_impulse = a_val * scale; + float sib_new = b_val * scale; + float sib_delta = sib_new - b_val; + s_lam_dense[sib] = sib_new; + + int sib_row_base = jy_world_base + sib * {D}; + {dense_sib_v_code} + }} + }} + }} + + float delta_impulse = new_impulse - old_impulse; + s_lam_dense[i] = new_impulse; + + // V update using prefetched Y + if (delta_impulse != 0.0f) {{ + {dense_v_update_code} + }} + __syncwarp(); + }} + + // ── Phase 2: MF constraints (6-DOF per body, software-pipelined) ── + + // Pipeline registers: prefetch next constraint's global data + int4 pre_meta; + float pre_Ja = 0.0f, pre_Jb = 0.0f; + float pre_MiJta = 0.0f, pre_MiJtb = 0.0f; + + // Prefetch constraint 0 + if (m_mf > 0) {{ + pre_meta = *reinterpret_cast(&mf_meta.data[off_meta]); + if (lane < 6) {{ + pre_Ja = mf_J_a.data[mf6_base + lane]; + pre_MiJta = mf_MiJt_a.data[mf6_base + lane]; + }} + if (lane >= 6 && lane < 12) {{ + pre_Jb = mf_J_b.data[mf6_base + lane - 6]; + pre_MiJtb = mf_MiJt_b.data[mf6_base + lane - 6]; + }} + }} + + for (int i = 0; i < m_mf; i++) {{ + // Consume prefetched data for constraint i + int4 meta = pre_meta; + float cur_Ja = pre_Ja; + float cur_Jb = pre_Jb; + float cur_MiJta = pre_MiJta; + float cur_MiJtb = pre_MiJtb; + + // Prefetch constraint i+1 (loads issued now, complete during compute) + if (i + 1 < m_mf) {{ + int next_mf6 = mf6_base + (i + 1) * 6; + pre_meta = *reinterpret_cast(&mf_meta.data[off_meta + (i + 1) * 4]); + if (lane < 6) {{ + pre_Ja = mf_J_a.data[next_mf6 + lane]; + pre_MiJta = mf_MiJt_a.data[next_mf6 + lane]; + }} + if (lane >= 6 && lane < 12) {{ + pre_Jb = mf_J_b.data[next_mf6 + lane - 6]; + pre_MiJtb = mf_MiJt_b.data[next_mf6 + lane - 6]; + }} + }} + + // Process constraint i + int packed_dofs = meta.x; + int dof_a = packed_dofs >> 16; + int dof_b = (packed_dofs << 16) >> 16; + float mf_diag = __int_as_float(meta.y); + + if (mf_diag <= 0.0f) continue; + + // J · v using prefetched J values + float my_sum = 0.0f; + if (lane < 6 && dof_a >= 0) {{ + my_sum = cur_Ja * s_v[dof_a + lane]; + }} + if (lane >= 6 && lane < 12 && dof_b >= 0) {{ + my_sum = cur_Jb * s_v[dof_b + lane - 6]; + }} + my_sum += __shfl_down_sync(MASK, my_sum, 16); + my_sum += __shfl_down_sync(MASK, my_sum, 8); + my_sum += __shfl_down_sync(MASK, my_sum, 4); + my_sum += __shfl_down_sync(MASK, my_sum, 2); + my_sum += __shfl_down_sync(MASK, my_sum, 1); + float jv = __shfl_sync(MASK, my_sum, 0); + + float residual = jv + __int_as_float(meta.z); + float delta = -residual * mf_diag; + float old_impulse = s_lam_mf[i]; + float new_impulse = old_impulse + omega * delta; + int packed_tp = meta.w; + int mf_rt = packed_tp & 0xFFFF; + + if (mf_rt == 0) {{ + if (new_impulse < 0.0f) new_impulse = 0.0f; + }} else if (mf_rt == 2) {{ + int mf_par = packed_tp >> 16; + float lambda_n = s_lam_mf[mf_par]; + float mu = mf_row_mu.data[off_mf + i]; + float radius = fmaxf(mu * lambda_n, 0.0f); + + if (radius <= 0.0f) {{ + new_impulse = 0.0f; + }} else {{ + int sib = (i == mf_par + 1) ? mf_par + 2 : mf_par + 1; + s_lam_mf[i] = new_impulse; + float a_val = new_impulse; + float b_val = s_lam_mf[sib]; + float mag = sqrtf(a_val * a_val + b_val * b_val); + if (mag > radius) {{ + float scale = radius / mag; + new_impulse = a_val * scale; + float sib_new = b_val * scale; + float sib_delta = sib_new - b_val; + s_lam_mf[sib] = sib_new; + + // Sibling v update (can't prefetch — random sib index) + int sib_packed_dofs = mf_meta.data[off_meta + sib * 4]; + int sib_dof_a = sib_packed_dofs >> 16; + int sib_dof_b = (sib_packed_dofs << 16) >> 16; + int sib_mf6 = mf6_base + sib * 6; + if (lane < 6 && sib_dof_a >= 0) {{ + s_v[sib_dof_a + lane] += mf_MiJt_a.data[sib_mf6 + lane] * sib_delta; + }} + if (lane >= 6 && lane < 12 && sib_dof_b >= 0) {{ + s_v[sib_dof_b + lane - 6] += mf_MiJt_b.data[sib_mf6 + lane - 6] * sib_delta; + }} + }} + }} + }} + + float delta_impulse = new_impulse - old_impulse; + s_lam_mf[i] = new_impulse; + + // V update using prefetched MiJt values + if (delta_impulse != 0.0f) {{ + if (lane < 6 && dof_a >= 0) {{ + s_v[dof_a + lane] += cur_MiJta * delta_impulse; + }} + if (lane >= 6 && lane < 12 && dof_b >= 0) {{ + s_v[dof_b + lane - 6] += cur_MiJtb * delta_impulse; + }} + }} + __syncwarp(); + }} + }} + + // ═══════════════════════════════════════════════════════ + // STORE PHASE + // ═══════════════════════════════════════════════════════ + for (int d = lane; d < {D}; d += 32) {{ + v_out.data[w_dof_start + d] = s_v[d]; + }} + for (int i = lane; i < m_dense; i += 32) {{ + world_impulses.data[off_dense + i] = s_lam_dense[i]; + }} + for (int i = lane; i < m_mf; i += 32) {{ + mf_impulses.data[off_mf + i] = s_lam_mf[i]; + }} + #endif + """ + + @wp.func_native(snippet) + def pgs_solve_mf_gs_native( + world: int, + # Dense + world_constraint_count: wp.array(dtype=int), + world_dof_start: wp.array(dtype=int), + rhs_bias: wp.array2d(dtype=float), + world_diag: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + J_world: wp.array3d(dtype=float), + Y_world: wp.array3d(dtype=float), + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + # MF + mf_constraint_count: wp.array(dtype=int), + mf_meta: wp.array2d(dtype=int), + mf_impulses: wp.array2d(dtype=float), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_row_mu: wp.array2d(dtype=float), + # Shared + iterations: int, + omega: float, + # Output + v_out: wp.array(dtype=float), + ): ... + + def pgs_solve_mf_gs_template( + # Dense + world_constraint_count: wp.array(dtype=int), + world_dof_start: wp.array(dtype=int), + rhs_bias: wp.array2d(dtype=float), + world_diag: wp.array2d(dtype=float), + world_impulses: wp.array2d(dtype=float), + J_world: wp.array3d(dtype=float), + Y_world: wp.array3d(dtype=float), + world_row_type: wp.array2d(dtype=int), + world_row_parent: wp.array2d(dtype=int), + world_row_mu: wp.array2d(dtype=float), + # MF + mf_constraint_count: wp.array(dtype=int), + mf_meta: wp.array2d(dtype=int), + mf_impulses: wp.array2d(dtype=float), + mf_J_a: wp.array3d(dtype=float), + mf_J_b: wp.array3d(dtype=float), + mf_MiJt_a: wp.array3d(dtype=float), + mf_MiJt_b: wp.array3d(dtype=float), + mf_row_mu: wp.array2d(dtype=float), + # Shared + iterations: int, + omega: float, + # Output + v_out: wp.array(dtype=float), + ): + world, _lane = wp.tid() + pgs_solve_mf_gs_native( + world, + world_constraint_count, + world_dof_start, + rhs_bias, + world_diag, + world_impulses, + J_world, + Y_world, + world_row_type, + world_row_parent, + world_row_mu, + mf_constraint_count, + mf_meta, + mf_impulses, + mf_J_a, + mf_J_b, + mf_MiJt_a, + mf_MiJt_b, + mf_row_mu, + iterations, + omega, + v_out, + ) + + name = f"pgs_solve_mf_gs_{max_constraints}_{mf_max_constraints}_{max_world_dofs}" + pgs_solve_mf_gs_template.__name__ = name + pgs_solve_mf_gs_template.__qualname__ = name + return wp.kernel(enable_backward=False, module="unique")(pgs_solve_mf_gs_template) diff --git a/newton/_src/solvers/featherstone/kernels.py b/newton/_src/solvers/featherstone/kernels.py index dd782d6338..6132a00c45 100644 --- a/newton/_src/solvers/featherstone/kernels.py +++ b/newton/_src/solvers/featherstone/kernels.py @@ -5,7 +5,7 @@ import warp as wp -from ...math import transform_twist, velocity_at_point +from ...math import transform_twist from ...sim import BodyFlags, JointType, Model, State from ...sim.articulation import ( com_twist_to_point_velocity, @@ -1859,38 +1859,48 @@ def eval_single_articulation_fk_with_velocity_conversion( X_j = wp.transform(pos, rot) v_j = wp.spatial_vector(vel_v, vel_w) - # transform from world to parent joint anchor frame X_wpj = X_pj + v_wpj = wp.spatial_vector() if parent >= 0: X_wp = body_q[parent] X_wpj = X_wp * X_wpj + r_p = wp.transform_get_translation(X_wpj) - wp.transform_point(X_wp, body_com[parent]) + + v_wp = body_qd[parent] + w_p = wp.spatial_bottom(v_wp) + v_p = wp.spatial_top(v_wp) + wp.cross(w_p, r_p) + v_wpj = wp.spatial_vector(v_p, w_p) # transform from world to joint anchor frame at child body X_wcj = X_wpj * X_j # transform from world to child body frame X_wc = X_wcj * wp.transform_inverse(X_cj) - x_child_origin = wp.transform_get_translation(X_wc) - v_parent_origin = wp.vec3() - w_parent = wp.vec3() - if parent >= 0: - v_wp = body_qd[parent] - w_parent = wp.spatial_bottom(v_wp) - v_parent_origin = com_twist_to_point_velocity(v_wp, X_wp, body_com[parent], x_child_origin) - linear_joint_world = wp.transform_vector(X_wpj, wp.spatial_top(v_j)) - angular_joint_world = wp.transform_vector(X_wpj, wp.spatial_bottom(v_j)) + linear_vel = wp.transform_vector(X_wpj, wp.spatial_top(v_j)) + angular_vel = wp.transform_vector(X_wpj, wp.spatial_bottom(v_j)) + if type == JointType.FREE or type == JointType.DISTANCE: - v_j_world = transform_twist(X_wpj, v_j) - linear_joint_origin = velocity_at_point(v_j_world, x_child_origin) - angular_joint_world = wp.spatial_bottom(v_j_world) - else: - child_origin_offset_world = x_child_origin - wp.transform_get_translation(X_wcj) - linear_joint_origin = linear_joint_world + wp.cross(angular_joint_world, child_origin_offset_world) + # Public FREE/DISTANCE joint_qd stores the linear term at the child COM. + # Convert back to the child-body origin convention before composing the + # world-space twist, then convert back to COM on body_qd writeback below. + r_com = wp.quat_rotate(wp.transform_get_rotation(X_wc), body_com[child]) + linear_vel = linear_vel - wp.cross(angular_vel, r_com) - v_wc_origin = wp.spatial_vector(v_parent_origin + linear_joint_origin, w_parent + angular_joint_world) + v_wc = v_wpj + wp.spatial_vector(linear_vel, angular_vel) body_q[child] = X_wc - body_qd[child] = origin_twist_to_com_twist(v_wc_origin, X_wc, body_com[child]) + + # Velocity conversion for FREE and DISTANCE joints: + # v_wc is a spatial twist at the origin, but body_qd should store COM velocity + # Transform: v_com = v_origin + ω x r_com + if type == JointType.FREE or type == JointType.DISTANCE: + v_origin = wp.spatial_top(v_wc) + omega = wp.spatial_bottom(v_wc) + r_com = wp.quat_rotate(wp.transform_get_rotation(X_wc), body_com[child]) + v_com = v_origin + wp.cross(omega, r_com) + body_qd[child] = wp.spatial_vector(v_com, omega) + else: + body_qd[child] = v_wc @wp.kernel @@ -2016,11 +2026,11 @@ def eval_fk_with_velocity_conversion( indices: wp.array[int] | None = None, ): """ - Evaluates Featherstone FK from internal free-joint speeds and writes public body twists. + Evaluates the model's forward kinematics with velocity conversion for Featherstone-based state writeback. - This helper mirrors :func:`newton.eval_fk`, but it expects Featherstone's - internal FREE/DISTANCE ``joint_qd`` convention as input and still writes - the public COM-referenced :attr:`State.body_qd` output. + FREE/DISTANCE joint_qd is interpreted in the public CoM convention and converted to + origin-referenced twists for FK propagation. The resulting body_qd writeback is then + converted back to CoM velocity so state.body_qd matches IsaacLab's expected convention. Args: model (Model): The model to evaluate. @@ -2068,8 +2078,6 @@ def eval_fk_with_velocity_conversion( ], device=model.device, ) - - def eval_fk_with_velocity_conversion_from_joint_starts( model: Model, articulation_indices: wp.array[int], From 2273167a6b1f0e4c294ce3bdff6fa1d8b489020c Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 12:42:50 -0400 Subject: [PATCH 2/7] Simplify FeatherPGS matrix-free API Remove the private solver mode and per-stage kernel selection knobs\nfrom SolverFeatherPGS and run the live step path as matrix-free\nonly. This keeps the private API aligned with the current winner\npath instead of preserving branch-local ablation surface area.\n\nAdd focused unit coverage for the stripped constructor signature and\na minimal smoke step, and keep a human-review PR description draft in\n.agent/review using the published nightly gh-pages artifacts. --- .../execplans/fpgs-private-api-matrix-free.md | 194 +++ .../fpgs-private-api-pr-description-draft.md | 63 + newton/_src/solvers/feather_pgs/kernels.py | 1110 ++++++++--------- .../solvers/feather_pgs/solver_feather_pgs.py | 903 +++++--------- newton/tests/test_feather_pgs.py | 46 + 5 files changed, 1198 insertions(+), 1118 deletions(-) create mode 100644 .agent/execplans/fpgs-private-api-matrix-free.md create mode 100644 .agent/review/fpgs-private-api-pr-description-draft.md create mode 100644 newton/tests/test_feather_pgs.py diff --git a/.agent/execplans/fpgs-private-api-matrix-free.md b/.agent/execplans/fpgs-private-api-matrix-free.md new file mode 100644 index 0000000000..6d4065519b --- /dev/null +++ b/.agent/execplans/fpgs-private-api-matrix-free.md @@ -0,0 +1,194 @@ +# FeatherPGS Private API Matrix-Free ExecPlan + +## Goal + +Create a stripped-down private FeatherPGS API branch that keeps only the current +matrix-free winner path, removes obsolete top-level and kernel-level branching, +and leaves a reviewable PR-description draft in the workspace instead of editing +the GitHub PR description directly. + +## Branch and Workspace + +- Working branch: `dturpin/fpgs-private-api-matrix-free` +- Tracking branch: `origin/fpgs-private-api` +- Workspace root: repository root of this worktree + +## Constraints + +- Rebase onto `upstream/main` before the implementation settles. +- Do not open, merge, or edit the GitHub PR directly. +- Do not update `gh-pages`. +- Do not move the benchmark/design rationale into FeatherPGS docs for this task. +- Keep the matrix-free justification tight and write it to + `.agent/review/fpgs-private-api-pr-description-draft.md` for human review. +- Run `uvx pre-commit run -a` before milestone pushes when practical, and at + minimum before the final push. + +## Design Intent + +The private API branch should present one coherent implementation path rather +than a research branch with many retained ablations. The target is not “all +paths still exist but only one is recommended”; the target is “the private API +is the current matrix-free implementation, with the winning defaults baked in +and the obsolete branches removed.” + +For the long-lived `feather_pgs` research branch, prefer cleaner separation +between: + +1. shared articulation stages +2. contact-solve backend realization +3. kernel-policy / ablation plumbing + +That separation should make future winner-cherry-picks into smaller API-facing +branches cheaper and less error-prone. + +## Milestones + +### Milestone 0: Replan / baseline capture + +Deliverable: +- Tighten this ExecPlan with any details discovered while comparing + `origin/fpgs-private-api`, `origin/feather_pgs`, and `upstream/main`. + +Required work: +- Inspect the current private API branch contents and review comments. +- Identify which parts should be kept, dropped, or selectively borrowed from + `feather_pgs`. +- Record any conflict hotspots expected during rebase. + +Validation: +- No code changes required beyond plan updates. + +Checkpoint: +- Commit and push if the plan changes materially enough to warrant review; + otherwise continue into Milestone 1 within the same pass only if that is + clearly implementable. + +### Milestone 1: Rebase private API line onto upstream + +Deliverable: +- `dturpin/fpgs-private-api-matrix-free` rebased onto `upstream/main`. + +Required work: +- Rebase the branch onto `upstream/main`. +- Resolve conflicts carefully without regressing the private API line. +- Keep the branch reviewable after the rebase. + +Validation: +- Run at least a focused smoke check sufficient to verify the branch imports and + the affected solver modules still load. + +Checkpoint: +- Commit any post-rebase conflict resolutions if needed and push the branch. +- Stop after the push; do not start Milestone 2 in the same implementation pass. + +### Milestone 2: Collapse the private API to matrix-free only + +Deliverable: +- Private FeatherPGS implementation supports only the matrix-free path and only + the winning kernel/default choices. + +Progress update (2026-04-13, pass 1): +- Shipped the first reviewable slice of this milestone: + `SolverFeatherPGS` no longer exposes `pgs_mode` or per-stage kernel-selection + constructor knobs, and the live `step()` path now executes the matrix-free + solve only. +- Added focused unit coverage for the stripped-down constructor surface and a + minimal matrix-free smoke step in `newton/tests/test_feather_pgs.py`. +- Remaining cleanup for later passes: remove dead dense/split-only helper code + and supporting branch-local references that are no longer reachable. + +Required work: +- Remove `dense` and `split` support from the private API implementation. +- Remove retained kernel-selection and intra-mode multi-path knobs whose only + purpose was ablation on the research branch. +- Simplify constructor/API surface accordingly. +- Simplify code structure, comments, and docstrings to describe one path. +- Preserve correctness-oriented pieces needed by the winner path. + +Validation: +- Add or update focused tests for the stripped-down behavior. +- Run the affected tests. + +Checkpoint: +- Commit the simplification and push the branch. +- Stop after the push. + +### Milestone 3: Clean supporting surfaces + +Deliverable: +- Bench/test/supporting code on the private API branch reflects the single-path + implementation instead of the old ablation-heavy surface. + +Required work: +- Update any tests, helper code, or branch-local references that still assume + mode or kernel multiplicity. +- Remove dead code left behind by Milestone 2. +- Keep changes scoped to the private API branch rather than the `feather_pgs` + docs or `gh-pages`. + +Validation: +- Run the relevant focused tests. +- Run `uvx pre-commit run -a`. + +Checkpoint: +- Commit and push the cleanup milestone. +- Stop after the push. + +### Milestone 4: Draft PR description for human review + +Deliverable: +- `.agent/review/fpgs-private-api-pr-description-draft.md` + +Required work: +- Draft a tight PR description for the private API branch. +- Explain the matrix-free-only decision briefly. +- Include small benchmark tables or bullets derived from the published nightly + data on `gh-pages` JSONL artifacts. +- Keep the scope to the private API decision and resulting simplification. +- Do not edit the actual GitHub PR description. + +Validation: +- Verify the draft reads as a plausible PR description and cites the right run + context and branch intent. + +Checkpoint: +- Commit and push the draft file if it is useful to keep with the branch. +- Stop after the push. + +### Milestone 5: Final validation pass + +Deliverable: +- Reviewable branch state with validation evidence and no pending plan items. + +Required work: +- Re-run the final focused test set. +- Re-run `uvx pre-commit run -a`. +- Update this ExecPlan to reflect what shipped and any intentional omissions. + +Validation: +- Record exact commands run and results. + +Checkpoint: +- Commit and push the final polishing pass. +- Stop and wait for human review. + +## Notes for the Implementor + +- Replan update (2026-04-13, pass 1): + `git rev-list --left-right --count upstream/main...HEAD` reports `0 97`, so the + branch already contains `upstream/main` and no Milestone 1 rebase work is + pending in this workspace. The next reviewable slice is Milestone 2 focused on + the private solver surface itself: remove the `pgs_mode` and kernel-selection + public knobs, hard-wire the solver to the matrix-free winner path in + `solver_feather_pgs.py`, and add focused unit coverage for the stripped-down + constructor behavior. Dense/split-only helper code that becomes unreachable in + that slice should be removed when practical; broader supporting cleanup stays + in Milestone 3. +- If a milestone is too large to finish cleanly, begin the pass by tightening + this ExecPlan with a short replan note and then complete one reviewable slice + of that milestone. +- Never update the real PR description in GitHub for this task. +- Never touch `gh-pages` for this task. +- Prefer borrowing ideas or code from `origin/feather_pgs` surgically instead of + trying to merge the whole research branch into the private API line. diff --git a/.agent/review/fpgs-private-api-pr-description-draft.md b/.agent/review/fpgs-private-api-pr-description-draft.md new file mode 100644 index 0000000000..58a35e57cb --- /dev/null +++ b/.agent/review/fpgs-private-api-pr-description-draft.md @@ -0,0 +1,63 @@ +# Draft PR Description + +## Summary + +Simplify the private FeatherPGS API line to the current matrix-free path. + +This branch stops presenting the private solver as a bundle of retained +ablations. Instead, it bakes in the current matrix-free contact solve path and +removes obsolete top-level mode selection and kernel-selection API knobs from +`SolverFeatherPGS`. + +## What Changed + +- Remove the private solver constructor knobs for `pgs_mode` and per-stage + kernel selection. +- Run the private FeatherPGS step path as matrix-free only. +- Add focused unit coverage for the stripped-down constructor surface and a + minimal step smoke test. + +## Why Matrix-Free Only + +The published nightly ablations already show that the matrix-free path is the + winner for the private line, while the dense/split modes mostly preserve + research history rather than current product intent. + +Using the published nightly run `2026-04-01T20-49-30Z` (`summary.json`, +commit `53b3188`) from the `gh-pages` artifacts: + +| Scenario | Hardware | Baseline | Split | Matrix-free | Matrix-free vs split | +| --- | --- | ---: | ---: | ---: | ---: | +| `h1_tabletop_ablation` | RTX 5090 | 4,262 env_fps | 41,496 env_fps | 118,547 env_fps | 2.86x | +| `h1_tabletop_ablation` | RTX PRO 6000 Server | 4,942 env_fps | 47,528 env_fps | 114,531 env_fps | 2.41x | +| `h1_tabletop_ablation` | B200 | 3,765 env_fps | 60,684 env_fps | 107,262 env_fps | 1.77x | + +The same nightly run also supports keeping the current winner kernel choices on +the private line instead of exposing them as API: + +| Scenario | Hardware | FeatherPGS baseline | tiled `hinv_jt` | tiled PGS | parallel streams | +| --- | --- | ---: | ---: | ---: | ---: | +| `g1_flat_ablation` | RTX 5090 | 590,152 env_fps | 1,358,725 env_fps | 1,461,194 env_fps | 1,461,373 env_fps | +| `g1_flat_ablation` | RTX PRO 6000 Server | 504,100 env_fps | 1,182,672 env_fps | 1,276,901 env_fps | 1,277,128 env_fps | +| `g1_flat_ablation` | B200 | 760,212 env_fps | 1,389,444 env_fps | 1,551,859 env_fps | 1,550,952 env_fps | + +Interpretation: + +- `matrix-free` materially outperforms `split` on the published tabletop + ablation across all listed GPUs. +- `tiled hinv_jt`, tiled PGS, and parallel streams are already the winning + direction in the published flat-scene ablation, so the private API does not + need to keep exposing these as branch-local tuning knobs. + +## Validation + +- `uv run --extra dev -m newton.tests -k test_feather_pgs` +- `uvx pre-commit run -a` + +## Notes For Review + +- This PR intentionally updates only the private API line. It does not touch + `gh-pages` content and does not move benchmark rationale into public docs. +- The benchmark figures above are copied from published nightly artifacts + already present on `origin/gh-pages`, not from new benchmark runs in this + branch. diff --git a/newton/_src/solvers/feather_pgs/kernels.py b/newton/_src/solvers/feather_pgs/kernels.py index 5389545566..aab14167e9 100644 --- a/newton/_src/solvers/feather_pgs/kernels.py +++ b/newton/_src/solvers/feather_pgs/kernels.py @@ -30,10 +30,10 @@ @wp.kernel def copy_int_array_masked( - src: wp.array(dtype=int), - mask: wp.array(dtype=int), + src: wp.array[int], + mask: wp.array[int], # outputs - dst: wp.array(dtype=int), + dst: wp.array[int], ): tid = wp.tid() if mask[tid] != 0: @@ -42,10 +42,10 @@ def copy_int_array_masked( @wp.kernel def compute_spatial_inertia( - body_inertia: wp.array(dtype=wp.mat33), - body_mass: wp.array(dtype=float), + body_inertia: wp.array[wp.mat33], + body_mass: wp.array[float], # outputs - body_I_m: wp.array(dtype=wp.spatial_matrix), + body_I_m: wp.array[wp.spatial_matrix], ): tid = wp.tid() I = body_inertia[tid] @@ -64,9 +64,9 @@ def compute_spatial_inertia( @wp.kernel def compute_com_transforms( - body_com: wp.array(dtype=wp.vec3), + body_com: wp.array[wp.vec3], # outputs - body_X_com: wp.array(dtype=wp.transform), + body_X_com: wp.array[wp.transform], ): tid = wp.tid() com = body_com[tid] @@ -75,12 +75,12 @@ def compute_com_transforms( @wp.kernel def update_articulation_origins( - articulation_start: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - body_q: wp.array(dtype=wp.transform), - body_com: wp.array(dtype=wp.vec3), + articulation_start: wp.array[int], + joint_child: wp.array[int], + body_q: wp.array[wp.transform], + body_com: wp.array[wp.vec3], # outputs - articulation_origin: wp.array(dtype=wp.vec3), + articulation_origin: wp.array[wp.vec3], ): art = wp.tid() @@ -101,12 +101,12 @@ def update_articulation_origins( @wp.kernel def update_articulation_root_com_offsets( - articulation_start: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - body_q: wp.array(dtype=wp.transform), - body_com: wp.array(dtype=wp.vec3), + articulation_start: wp.array[int], + joint_child: wp.array[int], + body_q: wp.array[wp.transform], + body_com: wp.array[wp.vec3], # outputs - articulation_root_com_offset: wp.array(dtype=wp.vec3), + articulation_root_com_offset: wp.array[wp.vec3], ): # NOTE: This helper keeps the rotated root COM offset in world orientation. # FeatherPGS currently uses update_articulation_origins() instead, which @@ -130,11 +130,11 @@ def update_articulation_root_com_offsets( @wp.kernel def convert_root_free_qd_world_to_local( - articulation_root_is_free: wp.array(dtype=int), - articulation_root_dof_start: wp.array(dtype=int), - articulation_root_com_offset: wp.array(dtype=wp.vec3), + articulation_root_is_free: wp.array[int], + articulation_root_dof_start: wp.array[int], + articulation_root_com_offset: wp.array[wp.vec3], # in/out - qd: wp.array(dtype=float), + qd: wp.array[float], ): art = wp.tid() if articulation_root_is_free[art] == 0: @@ -156,11 +156,11 @@ def convert_root_free_qd_world_to_local( @wp.kernel def convert_root_free_qd_local_to_world( - articulation_root_is_free: wp.array(dtype=int), - articulation_root_dof_start: wp.array(dtype=int), - articulation_root_com_offset: wp.array(dtype=wp.vec3), + articulation_root_is_free: wp.array[int], + articulation_root_dof_start: wp.array[int], + articulation_root_com_offset: wp.array[wp.vec3], # in/out - qd: wp.array(dtype=float), + qd: wp.array[float], ): art = wp.tid() if articulation_root_is_free[art] == 0: @@ -258,11 +258,11 @@ def transform_spatial_inertia(t: wp.transform, I: wp.spatial_matrix): @wp.func def jcalc_transform( type: int, - joint_axis: wp.array(dtype=wp.vec3), + joint_axis: wp.array[wp.vec3], axis_start: int, lin_axis_count: int, ang_axis_count: int, - joint_q: wp.array(dtype=float), + joint_q: wp.array[float], q_start: int, ): if type == JointType.PRISMATIC: @@ -358,14 +358,14 @@ def jcalc_transform( @wp.func def jcalc_motion( type: int, - joint_axis: wp.array(dtype=wp.vec3), + joint_axis: wp.array[wp.vec3], lin_axis_count: int, ang_axis_count: int, X_sc: wp.transform, - joint_qd: wp.array(dtype=float), + joint_qd: wp.array[float], qd_start: int, # outputs - joint_S_s: wp.array(dtype=wp.spatial_vector), + joint_S_s: wp.array[wp.spatial_vector], ): if type == JointType.PRISMATIC: axis = joint_axis[qd_start] @@ -462,14 +462,14 @@ def jcalc_motion( @wp.func def jcalc_tau( type: int, - joint_S_s: wp.array(dtype=wp.spatial_vector), - joint_f: wp.array(dtype=float), + joint_S_s: wp.array[wp.spatial_vector], + joint_f: wp.array[float], dof_start: int, lin_axis_count: int, ang_axis_count: int, body_f_s: wp.spatial_vector, # outputs - tau: wp.array(dtype=float), + tau: wp.array[float], ): if type == JointType.BALL: # target_ke = joint_target_ke[dof_start] @@ -509,18 +509,18 @@ def jcalc_tau( def jcalc_integrate( type: int, child: int, - body_com: wp.array(dtype=wp.vec3), - joint_q: wp.array(dtype=float), - joint_qd: wp.array(dtype=float), - joint_qdd: wp.array(dtype=float), + body_com: wp.array[wp.vec3], + joint_q: wp.array[float], + joint_qd: wp.array[float], + joint_qdd: wp.array[float], coord_start: int, dof_start: int, lin_axis_count: int, ang_axis_count: int, dt: float, # outputs - joint_q_new: wp.array(dtype=float), - joint_qd_new: wp.array(dtype=float), + joint_q_new: wp.array[float], + joint_qd_new: wp.array[float], ): if type == JointType.FIXED: return @@ -634,20 +634,20 @@ def jcalc_integrate( @wp.func def compute_link_transform( i: int, - joint_type: wp.array(dtype=int), - joint_parent: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_q: wp.array(dtype=float), - joint_X_p: wp.array(dtype=wp.transform), - joint_X_c: wp.array(dtype=wp.transform), - body_X_com: wp.array(dtype=wp.transform), - joint_axis: wp.array(dtype=wp.vec3), - joint_dof_dim: wp.array(dtype=int, ndim=2), + joint_type: wp.array[int], + joint_parent: wp.array[int], + joint_child: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_q: wp.array[float], + joint_X_p: wp.array[wp.transform], + joint_X_c: wp.array[wp.transform], + body_X_com: wp.array[wp.transform], + joint_axis: wp.array[wp.vec3], + joint_dof_dim: wp.array2d[int], # outputs - body_q: wp.array(dtype=wp.transform), - body_q_com: wp.array(dtype=wp.transform), + body_q: wp.array[wp.transform], + body_q_com: wp.array[wp.transform], ): # parent transform parent = joint_parent[i] @@ -687,21 +687,21 @@ def compute_link_transform( @wp.kernel def eval_rigid_fk( - articulation_start: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_parent: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_q: wp.array(dtype=float), - joint_X_p: wp.array(dtype=wp.transform), - joint_X_c: wp.array(dtype=wp.transform), - body_X_com: wp.array(dtype=wp.transform), - joint_axis: wp.array(dtype=wp.vec3), - joint_dof_dim: wp.array(dtype=int, ndim=2), + articulation_start: wp.array[int], + joint_type: wp.array[int], + joint_parent: wp.array[int], + joint_child: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_q: wp.array[float], + joint_X_p: wp.array[wp.transform], + joint_X_c: wp.array[wp.transform], + body_X_com: wp.array[wp.transform], + joint_axis: wp.array[wp.vec3], + joint_dof_dim: wp.array2d[int], # outputs - body_q: wp.array(dtype=wp.transform), - body_q_com: wp.array(dtype=wp.transform), + body_q: wp.array[wp.transform], + body_q_com: wp.array[wp.transform], ): # one thread per-articulation index = wp.tid() @@ -764,26 +764,26 @@ def dense_index(stride: int, i: int, j: int): @wp.func def compute_link_velocity( i: int, - joint_type: wp.array(dtype=int), - joint_parent: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_articulation: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_qd: wp.array(dtype=float), - joint_axis: wp.array(dtype=wp.vec3), - joint_dof_dim: wp.array(dtype=int, ndim=2), - body_I_m: wp.array(dtype=wp.spatial_matrix), - body_q: wp.array(dtype=wp.transform), - body_q_com: wp.array(dtype=wp.transform), - joint_X_p: wp.array(dtype=wp.transform), - articulation_origin: wp.array(dtype=wp.vec3), - gravity: wp.array(dtype=wp.vec3), + joint_type: wp.array[int], + joint_parent: wp.array[int], + joint_child: wp.array[int], + joint_articulation: wp.array[int], + joint_qd_start: wp.array[int], + joint_qd: wp.array[float], + joint_axis: wp.array[wp.vec3], + joint_dof_dim: wp.array2d[int], + body_I_m: wp.array[wp.spatial_matrix], + body_q: wp.array[wp.transform], + body_q_com: wp.array[wp.transform], + joint_X_p: wp.array[wp.transform], + articulation_origin: wp.array[wp.vec3], + gravity: wp.array[wp.vec3], # outputs - joint_S_s: wp.array(dtype=wp.spatial_vector), - body_I_s: wp.array(dtype=wp.spatial_matrix), - body_v_s: wp.array(dtype=wp.spatial_vector), - body_f_s: wp.array(dtype=wp.spatial_vector), - body_a_s: wp.array(dtype=wp.spatial_vector), + joint_S_s: wp.array[wp.spatial_vector], + body_I_s: wp.array[wp.spatial_matrix], + body_v_s: wp.array[wp.spatial_vector], + body_f_s: wp.array[wp.spatial_vector], + body_a_s: wp.array[wp.spatial_vector], ): type = joint_type[i] child = joint_child[i] @@ -862,27 +862,27 @@ def compute_link_velocity( # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1) @wp.kernel def eval_rigid_id( - articulation_start: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_parent: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_articulation: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_qd: wp.array(dtype=float), - joint_axis: wp.array(dtype=wp.vec3), - joint_dof_dim: wp.array(dtype=int, ndim=2), - body_I_m: wp.array(dtype=wp.spatial_matrix), - body_q: wp.array(dtype=wp.transform), - body_q_com: wp.array(dtype=wp.transform), - joint_X_p: wp.array(dtype=wp.transform), - articulation_origin: wp.array(dtype=wp.vec3), - gravity: wp.array(dtype=wp.vec3), + articulation_start: wp.array[int], + joint_type: wp.array[int], + joint_parent: wp.array[int], + joint_child: wp.array[int], + joint_articulation: wp.array[int], + joint_qd_start: wp.array[int], + joint_qd: wp.array[float], + joint_axis: wp.array[wp.vec3], + joint_dof_dim: wp.array2d[int], + body_I_m: wp.array[wp.spatial_matrix], + body_q: wp.array[wp.transform], + body_q_com: wp.array[wp.transform], + joint_X_p: wp.array[wp.transform], + articulation_origin: wp.array[wp.vec3], + gravity: wp.array[wp.vec3], # outputs - joint_S_s: wp.array(dtype=wp.spatial_vector), - body_I_s: wp.array(dtype=wp.spatial_matrix), - body_v_s: wp.array(dtype=wp.spatial_vector), - body_f_s: wp.array(dtype=wp.spatial_vector), - body_a_s: wp.array(dtype=wp.spatial_vector), + joint_S_s: wp.array[wp.spatial_vector], + body_I_s: wp.array[wp.spatial_matrix], + body_v_s: wp.array[wp.spatial_vector], + body_f_s: wp.array[wp.spatial_vector], + body_a_s: wp.array[wp.spatial_vector], ): # one thread per-articulation index = wp.tid() @@ -918,23 +918,23 @@ def eval_rigid_id( @wp.kernel def eval_rigid_tau( - articulation_start: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_parent: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_articulation: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - joint_f: wp.array(dtype=float), - joint_S_s: wp.array(dtype=wp.spatial_vector), - body_fb_s: wp.array(dtype=wp.spatial_vector), - body_f_ext: wp.array(dtype=wp.spatial_vector), - body_q: wp.array(dtype=wp.transform), - body_com: wp.array(dtype=wp.vec3), - articulation_origin: wp.array(dtype=wp.vec3), + articulation_start: wp.array[int], + joint_type: wp.array[int], + joint_parent: wp.array[int], + joint_child: wp.array[int], + joint_articulation: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + joint_f: wp.array[float], + joint_S_s: wp.array[wp.spatial_vector], + body_fb_s: wp.array[wp.spatial_vector], + body_f_ext: wp.array[wp.spatial_vector], + body_q: wp.array[wp.transform], + body_com: wp.array[wp.vec3], + articulation_origin: wp.array[wp.vec3], # outputs - body_ft_s: wp.array(dtype=wp.spatial_vector), - tau: wp.array(dtype=float), + body_ft_s: wp.array[wp.spatial_vector], + tau: wp.array[float], ): # one thread per-articulation index = wp.tid() @@ -997,12 +997,12 @@ def eval_rigid_tau( @wp.kernel def eval_rigid_mass( - articulation_start: wp.array(dtype=int), - articulation_M_start: wp.array(dtype=int), - mass_update_mask: wp.array(dtype=int), - body_I_s: wp.array(dtype=wp.spatial_matrix), + articulation_start: wp.array[int], + articulation_M_start: wp.array[int], + mass_update_mask: wp.array[int], + body_I_s: wp.array[wp.spatial_matrix], # outputs - M_blocks: wp.array(dtype=float), + M_blocks: wp.array[float], ): # one thread per-articulation index = wp.tid() @@ -1026,12 +1026,12 @@ def eval_rigid_mass( @wp.kernel def compute_composite_inertia( - articulation_start: wp.array(dtype=int), - mass_update_mask: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - body_I_s: wp.array(dtype=wp.spatial_matrix), + articulation_start: wp.array[int], + mass_update_mask: wp.array[int], + joint_ancestor: wp.array[int], + body_I_s: wp.array[wp.spatial_matrix], # outputs - body_I_c: wp.array(dtype=wp.spatial_matrix), + body_I_c: wp.array[wp.spatial_matrix], ): art_idx = wp.tid() @@ -1057,12 +1057,12 @@ def compute_composite_inertia( @wp.func def dense_cholesky( n: int, - A: wp.array(dtype=float), - R: wp.array(dtype=float), + A: wp.array[float], + R: wp.array[float], A_start: int, R_start: int, # outputs - L: wp.array(dtype=float), + L: wp.array[float], ): # compute the Cholesky factorization of A = L L^T with diagonal regularization R for j in range(n): @@ -1088,13 +1088,13 @@ def dense_cholesky( @wp.kernel def cholesky_loop( - H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] - R_group: wp.array2d(dtype=float), # [n_arts, n_dofs] - group_to_art: wp.array(dtype=int), - mass_update_mask: wp.array(dtype=int), + H_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] + R_group: wp.array2d[float], # [n_arts, n_dofs] + group_to_art: wp.array[int], + mass_update_mask: wp.array[int], n_dofs: int, # output - L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + L_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] ): """Non-tiled Cholesky for grouped articulation storage. @@ -1135,10 +1135,10 @@ def dense_subs( n: int, L_start: int, b_start: int, - L: wp.array(dtype=float), - b: wp.array(dtype=float), + L: wp.array[float], + b: wp.array[float], # outputs - x: wp.array(dtype=float), + x: wp.array[float], ): # Solves (L L^T) x = b for x given the Cholesky factor L # forward substitution solves the lower triangular system L y = b for y @@ -1165,12 +1165,12 @@ def dense_solve( n: int, L_start: int, b_start: int, - A: wp.array(dtype=float), - L: wp.array(dtype=float), - b: wp.array(dtype=float), + A: wp.array[float], + L: wp.array[float], + b: wp.array[float], # outputs - x: wp.array(dtype=float), - tmp: wp.array(dtype=float), + x: wp.array[float], + tmp: wp.array[float], ): # helper function to include tmp argument for backward pass dense_subs(n, L_start, b_start, L, b, x) @@ -1178,19 +1178,19 @@ def dense_solve( @wp.kernel def integrate_generalized_joints( - joint_type: wp.array(dtype=int), - joint_child: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - body_com: wp.array(dtype=wp.vec3), - joint_q: wp.array(dtype=float), - joint_qd: wp.array(dtype=float), - joint_qdd: wp.array(dtype=float), + joint_type: wp.array[int], + joint_child: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + body_com: wp.array[wp.vec3], + joint_q: wp.array[float], + joint_qd: wp.array[float], + joint_qdd: wp.array[float], dt: float, # outputs - joint_q_new: wp.array(dtype=float), - joint_qd_new: wp.array(dtype=float), + joint_q_new: wp.array[float], + joint_qd_new: wp.array[float], ): # one thread per-articulation index = wp.tid() @@ -1221,11 +1221,11 @@ def integrate_generalized_joints( @wp.kernel def compute_velocity_predictor( - joint_qd: wp.array(dtype=float), - joint_qdd: wp.array(dtype=float), + joint_qd: wp.array[float], + joint_qdd: wp.array[float], dt: float, # outputs - v_hat: wp.array(dtype=float), + v_hat: wp.array[float], ): tid = wp.tid() v_hat[tid] = joint_qd[tid] + joint_qdd[tid] * dt @@ -1233,11 +1233,11 @@ def compute_velocity_predictor( @wp.kernel def update_qdd_from_velocity( - joint_qd: wp.array(dtype=float), - v_new: wp.array(dtype=float), + joint_qd: wp.array[float], + v_new: wp.array[float], inv_dt: float, # outputs - joint_qdd: wp.array(dtype=float), + joint_qdd: wp.array[float], ): tid = wp.tid() joint_qdd[tid] = (v_new[tid] - joint_qd[tid]) * inv_dt @@ -1256,17 +1256,17 @@ def contact_tangent_basis(n: wp.vec3): @wp.kernel def compute_contact_linear_force_from_impulses( - contact_count: wp.array(dtype=wp.int32), - contact_normal: wp.array(dtype=wp.vec3), - contact_world: wp.array(dtype=wp.int32), - contact_slot: wp.array(dtype=wp.int32), - contact_path: wp.array(dtype=wp.int32), - world_impulses: wp.array2d(dtype=wp.float32), - mf_impulses: wp.array2d(dtype=wp.float32), + contact_count: wp.array[wp.int32], + contact_normal: wp.array[wp.vec3], + contact_world: wp.array[wp.int32], + contact_slot: wp.array[wp.int32], + contact_path: wp.array[wp.int32], + world_impulses: wp.array2d[wp.float32], + mf_impulses: wp.array2d[wp.float32], enable_friction: int, inv_dt: float, # outputs - rigid_contact_force: wp.array(dtype=wp.vec3), + rigid_contact_force: wp.array[wp.vec3], ): """Convert solved FeatherPGS contact impulses into world-frame forces.""" c = wp.tid() @@ -1310,10 +1310,10 @@ def compute_contact_linear_force_from_impulses( @wp.kernel def pack_contact_linear_force_as_spatial( - contact_count: wp.array(dtype=wp.int32), - rigid_contact_force: wp.array(dtype=wp.vec3), + contact_count: wp.array[wp.int32], + rigid_contact_force: wp.array[wp.vec3], # outputs - contact_force: wp.array(dtype=wp.spatial_vector), + contact_force: wp.array[wp.spatial_vector], ): """Pack linear contact forces into Newton's spatial-force contact buffer.""" c = wp.tid() @@ -1333,16 +1333,16 @@ def accumulate_contact_jacobian_matrix_free( weight: float, point_world: wp.vec3, n_vec: wp.vec3, - body_to_joint: wp.array(dtype=int), - body_to_articulation: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_S_s: wp.array(dtype=wp.spatial_vector), - articulation_origin: wp.array(dtype=wp.vec3), + body_to_joint: wp.array[int], + body_to_articulation: wp.array[int], + joint_ancestor: wp.array[int], + joint_qd_start: wp.array[int], + joint_S_s: wp.array[wp.spatial_vector], + articulation_origin: wp.array[wp.vec3], articulation_dof_start: int, # Outputs row_base_index: int, - Jc_out: wp.array(dtype=float), + Jc_out: wp.array[float], ): if body_index < 0: return @@ -1377,42 +1377,42 @@ def accumulate_contact_jacobian_matrix_free( @wp.kernel def build_contact_rows_normal( - contact_count: wp.array(dtype=int), - contact_point0: wp.array(dtype=wp.vec3), - contact_point1: wp.array(dtype=wp.vec3), - contact_normal: wp.array(dtype=wp.vec3), - contact_shape0: wp.array(dtype=int), - contact_shape1: wp.array(dtype=int), - contact_thickness0: wp.array(dtype=float), - contact_thickness1: wp.array(dtype=float), - shape_body: wp.array(dtype=int), - body_q: wp.array(dtype=wp.transform), - shape_transform: wp.array(dtype=wp.transform), - shape_material_mu: wp.array(dtype=float), - articulation_start: wp.array(dtype=int), - articulation_H_rows: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), - body_to_joint: wp.array(dtype=int), - body_to_articulation: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_S_s: wp.array(dtype=wp.spatial_vector), - articulation_origin: wp.array(dtype=wp.vec3), + contact_count: wp.array[int], + contact_point0: wp.array[wp.vec3], + contact_point1: wp.array[wp.vec3], + contact_normal: wp.array[wp.vec3], + contact_shape0: wp.array[int], + contact_shape1: wp.array[int], + contact_thickness0: wp.array[float], + contact_thickness1: wp.array[float], + shape_body: wp.array[int], + body_q: wp.array[wp.transform], + shape_transform: wp.array[wp.transform], + shape_material_mu: wp.array[float], + articulation_start: wp.array[int], + articulation_H_rows: wp.array[int], + articulation_dof_start: wp.array[int], + body_to_joint: wp.array[int], + body_to_articulation: wp.array[int], + joint_ancestor: wp.array[int], + joint_qd_start: wp.array[int], + joint_S_s: wp.array[wp.spatial_vector], + articulation_origin: wp.array[wp.vec3], max_constraints: int, max_dofs: int, contact_beta: float, contact_cfm: float, enable_friction: int, # Outputs - constraint_counts: wp.array(dtype=int), - Jc_out: wp.array(dtype=float), - phi_out: wp.array(dtype=float), - row_beta: wp.array(dtype=float), - row_cfm: wp.array(dtype=float), - row_types: wp.array(dtype=int), - target_velocity: wp.array(dtype=float), - row_parent: wp.array(dtype=int), - row_mu: wp.array(dtype=float), + constraint_counts: wp.array[int], + Jc_out: wp.array[float], + phi_out: wp.array[float], + row_beta: wp.array[float], + row_cfm: wp.array[float], + row_types: wp.array[int], + target_velocity: wp.array[float], + row_parent: wp.array[int], + row_mu: wp.array[float], ): tid = wp.tid() total_contacts = contact_count[0] @@ -1651,27 +1651,27 @@ def build_contact_rows_normal( @wp.kernel def build_augmented_joint_rows( - articulation_start: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), - articulation_H_rows: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - joint_target_ke: wp.array(dtype=float), - joint_target_kd: wp.array(dtype=float), - joint_q: wp.array(dtype=float), - joint_qd: wp.array(dtype=float), - joint_target_pos: wp.array(dtype=float), - joint_target_vel: wp.array(dtype=float), + articulation_start: wp.array[int], + articulation_dof_start: wp.array[int], + articulation_H_rows: wp.array[int], + joint_type: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + joint_target_ke: wp.array[float], + joint_target_kd: wp.array[float], + joint_q: wp.array[float], + joint_qd: wp.array[float], + joint_target_pos: wp.array[float], + joint_target_vel: wp.array[float], max_dofs: int, dt: float, # outputs - row_counts: wp.array(dtype=int), - row_dof_index: wp.array(dtype=int), - row_K: wp.array(dtype=float), - row_u0: wp.array(dtype=float), - limit_counts: wp.array(dtype=int), + row_counts: wp.array[int], + row_dof_index: wp.array[int], + row_K: wp.array[float], + row_u0: wp.array[float], + limit_counts: wp.array[int], ): articulation = wp.tid() if max_dofs == 0: @@ -1738,10 +1738,10 @@ def build_augmented_joint_rows( @wp.kernel def detect_limit_count_changes( - limit_counts: wp.array(dtype=int), - prev_limit_counts: wp.array(dtype=int), + limit_counts: wp.array[int], + prev_limit_counts: wp.array[int], # outputs - limit_change_mask: wp.array(dtype=int), + limit_change_mask: wp.array[int], ): tid = wp.tid() change = 1 if limit_counts[tid] != prev_limit_counts[tid] else 0 @@ -1751,9 +1751,9 @@ def detect_limit_count_changes( @wp.kernel def build_mass_update_mask( global_flag: int, - limit_change_mask: wp.array(dtype=int), + limit_change_mask: wp.array[int], # outputs - mass_update_mask: wp.array(dtype=int), + mass_update_mask: wp.array[int], ): tid = wp.tid() flag = 1 if global_flag != 0 else 0 @@ -1769,22 +1769,22 @@ def build_mass_update_mask( @wp.kernel def allocate_joint_limit_slots( - articulation_start: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), - articulation_H_rows: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - joint_limit_lower: wp.array(dtype=float), - joint_limit_upper: wp.array(dtype=float), - joint_q: wp.array(dtype=float), - art_to_world: wp.array(dtype=int), + articulation_start: wp.array[int], + articulation_dof_start: wp.array[int], + articulation_H_rows: wp.array[int], + joint_type: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + joint_limit_lower: wp.array[float], + joint_limit_upper: wp.array[float], + joint_q: wp.array[float], + art_to_world: wp.array[int], max_constraints: int, # outputs - limit_slot: wp.array(dtype=int), - limit_sign: wp.array(dtype=float), - world_slot_counter: wp.array(dtype=int), + limit_slot: wp.array[int], + limit_sign: wp.array[float], + world_slot_counter: wp.array[int], ): """Allocate constraint slots for violated joint limits. @@ -1842,30 +1842,30 @@ def allocate_joint_limit_slots( @wp.kernel def populate_joint_limit_J_for_size( - articulation_start: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), - joint_type: wp.array(dtype=int), - joint_q_start: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - joint_limit_lower: wp.array(dtype=float), - joint_limit_upper: wp.array(dtype=float), - joint_q: wp.array(dtype=float), - art_to_world: wp.array(dtype=int), - limit_slot: wp.array(dtype=int), - limit_sign: wp.array(dtype=float), - group_to_art: wp.array(dtype=int), + articulation_start: wp.array[int], + articulation_dof_start: wp.array[int], + joint_type: wp.array[int], + joint_q_start: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + joint_limit_lower: wp.array[float], + joint_limit_upper: wp.array[float], + joint_q: wp.array[float], + art_to_world: wp.array[int], + limit_slot: wp.array[int], + limit_sign: wp.array[float], + group_to_art: wp.array[int], pgs_beta: float, pgs_cfm: float, # outputs - J_group: wp.array3d(dtype=float), - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), - world_row_beta: wp.array2d(dtype=float), - world_row_cfm: wp.array2d(dtype=float), - world_phi: wp.array2d(dtype=float), - world_target_velocity: wp.array2d(dtype=float), + J_group: wp.array3d[float], + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], + world_row_beta: wp.array2d[float], + world_row_cfm: wp.array2d[float], + world_phi: wp.array2d[float], + world_target_velocity: wp.array2d[float], ): """Populate Jacobian and metadata for joint limit constraints. @@ -1937,32 +1937,32 @@ def populate_joint_limit_J_for_size( @wp.kernel def allocate_world_contact_slots( - contact_count: wp.array(dtype=int), - contact_shape0: wp.array(dtype=int), - contact_shape1: wp.array(dtype=int), - contact_point0: wp.array(dtype=wp.vec3), - contact_point1: wp.array(dtype=wp.vec3), - contact_normal: wp.array(dtype=wp.vec3), - contact_thickness0: wp.array(dtype=float), - contact_thickness1: wp.array(dtype=float), - body_q: wp.array(dtype=wp.transform), - shape_transform: wp.array(dtype=wp.transform), - shape_body: wp.array(dtype=int), - body_to_articulation: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - is_free_rigid: wp.array(dtype=int), + contact_count: wp.array[int], + contact_shape0: wp.array[int], + contact_shape1: wp.array[int], + contact_point0: wp.array[wp.vec3], + contact_point1: wp.array[wp.vec3], + contact_normal: wp.array[wp.vec3], + contact_thickness0: wp.array[float], + contact_thickness1: wp.array[float], + body_q: wp.array[wp.transform], + shape_transform: wp.array[wp.transform], + shape_body: wp.array[int], + body_to_articulation: wp.array[int], + art_to_world: wp.array[int], + is_free_rigid: wp.array[int], has_free_rigid: int, max_constraints: int, mf_max_constraints: int, enable_friction: int, # outputs - contact_world: wp.array(dtype=int), - contact_slot: wp.array(dtype=int), - contact_art_a: wp.array(dtype=int), - contact_art_b: wp.array(dtype=int), - world_slot_counter: wp.array(dtype=int), - contact_path: wp.array(dtype=int), - mf_slot_counter: wp.array(dtype=int), + contact_world: wp.array[int], + contact_slot: wp.array[int], + contact_art_a: wp.array[int], + contact_art_b: wp.array[int], + world_slot_counter: wp.array[int], + contact_path: wp.array[int], + mf_slot_counter: wp.array[int], ): """ Phase 1 of multi-articulation contact building. @@ -2099,15 +2099,15 @@ def accumulate_jacobian_row_world( point_world: wp.vec3, origin: wp.vec3, direction: wp.vec3, - body_to_joint: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_S_s: wp.array(dtype=wp.spatial_vector), + body_to_joint: wp.array[int], + joint_ancestor: wp.array[int], + joint_qd_start: wp.array[int], + joint_S_s: wp.array[wp.spatial_vector], art_dof_start: int, n_dofs: int, group_idx: int, row: int, - J_group: wp.array3d(dtype=float), + J_group: wp.array3d[float], ): """Accumulate Jacobian contributions by walking up the kinematic tree.""" if body_index < 0: @@ -2138,44 +2138,44 @@ def accumulate_jacobian_row_world( @wp.kernel def populate_world_J_for_size( - contact_count: wp.array(dtype=int), - contact_point0: wp.array(dtype=wp.vec3), - contact_point1: wp.array(dtype=wp.vec3), - contact_normal: wp.array(dtype=wp.vec3), - contact_shape0: wp.array(dtype=int), - contact_shape1: wp.array(dtype=int), - contact_thickness0: wp.array(dtype=float), - contact_thickness1: wp.array(dtype=float), - contact_world: wp.array(dtype=int), - contact_slot: wp.array(dtype=int), - contact_art_a: wp.array(dtype=int), - contact_art_b: wp.array(dtype=int), - contact_path: wp.array(dtype=int), + contact_count: wp.array[int], + contact_point0: wp.array[wp.vec3], + contact_point1: wp.array[wp.vec3], + contact_normal: wp.array[wp.vec3], + contact_shape0: wp.array[int], + contact_shape1: wp.array[int], + contact_thickness0: wp.array[float], + contact_thickness1: wp.array[float], + contact_world: wp.array[int], + contact_slot: wp.array[int], + contact_art_a: wp.array[int], + contact_art_b: wp.array[int], + contact_path: wp.array[int], target_size: int, - art_size: wp.array(dtype=int), - art_group_idx: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), - articulation_origin: wp.array(dtype=wp.vec3), - body_to_joint: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_S_s: wp.array(dtype=wp.spatial_vector), - shape_body: wp.array(dtype=int), - body_q: wp.array(dtype=wp.transform), - shape_transform: wp.array(dtype=wp.transform), - shape_material_mu: wp.array(dtype=float), + art_size: wp.array[int], + art_group_idx: wp.array[int], + art_dof_start: wp.array[int], + articulation_origin: wp.array[wp.vec3], + body_to_joint: wp.array[int], + joint_ancestor: wp.array[int], + joint_qd_start: wp.array[int], + joint_S_s: wp.array[wp.spatial_vector], + shape_body: wp.array[int], + body_q: wp.array[wp.transform], + shape_transform: wp.array[wp.transform], + shape_material_mu: wp.array[float], enable_friction: int, pgs_beta: float, pgs_cfm: float, # outputs - J_group: wp.array3d(dtype=float), - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), - world_row_beta: wp.array2d(dtype=float), - world_row_cfm: wp.array2d(dtype=float), - world_phi: wp.array2d(dtype=float), - world_target_velocity: wp.array2d(dtype=float), + J_group: wp.array3d[float], + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], + world_row_beta: wp.array2d[float], + world_row_cfm: wp.array2d[float], + world_phi: wp.array2d[float], + world_target_velocity: wp.array2d[float], ): """ Phase 2 of multi-articulation contact building (per size group). @@ -2433,11 +2433,11 @@ def populate_world_J_for_size( @wp.kernel def finalize_world_constraint_counts( - world_slot_counter: wp.array(dtype=int), + world_slot_counter: wp.array[int], max_constraints: int, slots_per_contact: int, # outputs - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], ): """Copy and clamp the slot counter to constraint counts. @@ -2460,7 +2460,7 @@ def finalize_world_constraint_counts( @wp.kernel def clamp_contact_counts( - constraint_counts: wp.array(dtype=int), + constraint_counts: wp.array[int], max_constraints: int, ): articulation = wp.tid() @@ -2471,16 +2471,16 @@ def clamp_contact_counts( @wp.kernel def apply_augmented_mass_diagonal( - articulation_H_start: wp.array(dtype=int), - articulation_H_rows: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), + articulation_H_start: wp.array[int], + articulation_H_rows: wp.array[int], + articulation_dof_start: wp.array[int], max_dofs: int, - mass_update_mask: wp.array(dtype=int), - row_counts: wp.array(dtype=int), - row_dof_index: wp.array(dtype=int), - row_K: wp.array(dtype=float), + mass_update_mask: wp.array[int], + row_counts: wp.array[int], + row_dof_index: wp.array[int], + row_K: wp.array[float], # outputs - H: wp.array(dtype=float), + H: wp.array[float], ): articulation = wp.tid() if mass_update_mask[articulation] == 0: @@ -2514,16 +2514,16 @@ def apply_augmented_mass_diagonal( @wp.kernel def apply_augmented_mass_diagonal_grouped( - group_to_art: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), + group_to_art: wp.array[int], + articulation_dof_start: wp.array[int], n_dofs: int, max_dofs: int, - mass_update_mask: wp.array(dtype=int), - row_counts: wp.array(dtype=int), - row_dof_index: wp.array(dtype=int), - row_K: wp.array(dtype=float), + mass_update_mask: wp.array[int], + row_counts: wp.array[int], + row_dof_index: wp.array[int], + row_K: wp.array[float], # outputs - H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + H_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] ): """Apply augmented mass diagonal for grouped H storage.""" idx = wp.tid() @@ -2555,11 +2555,11 @@ def apply_augmented_mass_diagonal_grouped( @wp.kernel def apply_augmented_joint_tau( max_dofs: int, - row_counts: wp.array(dtype=int), - row_dof_index: wp.array(dtype=int), - row_u0: wp.array(dtype=float), + row_counts: wp.array[int], + row_dof_index: wp.array[int], + row_u0: wp.array[float], # outputs - joint_tau: wp.array(dtype=float), + joint_tau: wp.array[float], ): articulation = wp.tid() if max_dofs == 0: @@ -2581,11 +2581,11 @@ def apply_augmented_joint_tau( @wp.kernel def prepare_impulses( - constraint_counts: wp.array(dtype=int), + constraint_counts: wp.array[int], max_constraints: int, warmstart: int, # outputs - impulses: wp.array(dtype=float), + impulses: wp.array[float], ): articulation = wp.tid() m = constraint_counts[articulation] @@ -2598,8 +2598,8 @@ def prepare_impulses( @wp.kernel def clamp_joint_tau( - joint_tau: wp.array(dtype=float), - joint_effort_limit: wp.array(dtype=float), + joint_tau: wp.array[float], + joint_effort_limit: wp.array[float], ): tid = wp.tid() @@ -2638,12 +2638,12 @@ def clamp_joint_tau( @wp.kernel def update_body_qd_from_featherstone( - body_v_s: wp.array(dtype=wp.spatial_vector), - body_q: wp.array(dtype=wp.transform), - body_com: wp.array(dtype=wp.vec3), - body_to_articulation: wp.array(dtype=int), - articulation_origin: wp.array(dtype=wp.vec3), - body_qd_out: wp.array(dtype=wp.spatial_vector), + body_v_s: wp.array[wp.spatial_vector], + body_q: wp.array[wp.transform], + body_com: wp.array[wp.vec3], + body_to_articulation: wp.array[int], + articulation_origin: wp.array[wp.vec3], + body_qd_out: wp.array[wp.spatial_vector], ): tid = wp.tid() @@ -2672,15 +2672,15 @@ def update_body_qd_from_featherstone( @wp.kernel def compute_world_contact_bias( - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], max_constraints: int, - world_phi: wp.array2d(dtype=float), - world_row_beta: wp.array2d(dtype=float), - world_row_type: wp.array2d(dtype=int), - world_target_velocity: wp.array2d(dtype=float), + world_phi: wp.array2d[float], + world_row_beta: wp.array2d[float], + world_row_type: wp.array2d[int], + world_target_velocity: wp.array2d[float], dt: float, # outputs - world_rhs: wp.array2d(dtype=float), + world_rhs: wp.array2d[float], ): """Compute the RHS bias term for world-level PGS solve. @@ -2714,18 +2714,18 @@ def compute_world_contact_bias( @wp.kernel def rhs_accum_world_par_art( - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], max_constraints: int, - art_to_world: wp.array(dtype=int), - art_size: wp.array(dtype=int), - art_group_idx: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), - v_hat: wp.array(dtype=float), - group_to_art: wp.array(dtype=int), - J_group: wp.array3d(dtype=float), + art_to_world: wp.array[int], + art_size: wp.array[int], + art_group_idx: wp.array[int], + art_dof_start: wp.array[int], + v_hat: wp.array[float], + group_to_art: wp.array[int], + J_group: wp.array3d[float], n_dofs: int, # outputs - world_rhs: wp.array2d(dtype=float), + world_rhs: wp.array2d[float], ): """ Accumulate J*v_hat into world RHS for a single size group. @@ -2752,11 +2752,11 @@ def rhs_accum_world_par_art( @wp.kernel def prepare_world_impulses( - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], max_constraints: int, warmstart: int, # in/out - world_impulses: wp.array2d(dtype=float), + world_impulses: wp.array2d[float], ): """Initialize world impulses (zero or warmstart).""" world = wp.tid() @@ -2774,16 +2774,16 @@ def prepare_world_impulses( @wp.kernel def diag_from_JY_par_art( - J_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] - Y_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), + J_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] + Y_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], n_dofs: int, max_constraints: int, n_arts: int, # output - world_diag: wp.array2d(dtype=float), + world_diag: wp.array2d[float], ): """Compute diagonal of Delassus from J and Y without assembling the full matrix. @@ -2807,19 +2807,19 @@ def diag_from_JY_par_art( @wp.kernel def gather_JY_to_world( - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), - world_dof_start: wp.array(dtype=int), - J_group: wp.array3d(dtype=float), - Y_group: wp.array3d(dtype=float), + group_to_art: wp.array[int], + art_to_world: wp.array[int], + art_dof_start: wp.array[int], + world_constraint_count: wp.array[int], + world_dof_start: wp.array[int], + J_group: wp.array3d[float], + Y_group: wp.array3d[float], n_dofs: int, max_constraints: int, n_arts: int, # outputs - J_world: wp.array3d(dtype=float), - Y_world: wp.array3d(dtype=float), + J_world: wp.array3d[float], + Y_world: wp.array3d[float], ): """Gather per-size-group J/Y into world-indexed arrays. @@ -2851,34 +2851,34 @@ def gather_JY_to_world( @wp.kernel def build_mf_contact_rows( - contact_count: wp.array(dtype=int), - contact_point0: wp.array(dtype=wp.vec3), - contact_point1: wp.array(dtype=wp.vec3), - contact_normal: wp.array(dtype=wp.vec3), - contact_shape0: wp.array(dtype=int), - contact_shape1: wp.array(dtype=int), - contact_thickness0: wp.array(dtype=float), - contact_thickness1: wp.array(dtype=float), - contact_world: wp.array(dtype=int), - contact_slot: wp.array(dtype=int), - contact_path: wp.array(dtype=int), - contact_art_a: wp.array(dtype=int), - contact_art_b: wp.array(dtype=int), - articulation_origin: wp.array(dtype=wp.vec3), - shape_body: wp.array(dtype=int), - body_q: wp.array(dtype=wp.transform), - shape_material_mu: wp.array(dtype=float), + contact_count: wp.array[int], + contact_point0: wp.array[wp.vec3], + contact_point1: wp.array[wp.vec3], + contact_normal: wp.array[wp.vec3], + contact_shape0: wp.array[int], + contact_shape1: wp.array[int], + contact_thickness0: wp.array[float], + contact_thickness1: wp.array[float], + contact_world: wp.array[int], + contact_slot: wp.array[int], + contact_path: wp.array[int], + contact_art_a: wp.array[int], + contact_art_b: wp.array[int], + articulation_origin: wp.array[wp.vec3], + shape_body: wp.array[int], + body_q: wp.array[wp.transform], + shape_material_mu: wp.array[float], enable_friction: int, pgs_beta: float, # outputs - mf_body_a: wp.array2d(dtype=int), - mf_body_b: wp.array2d(dtype=int), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_row_mu: wp.array2d(dtype=float), - mf_phi: wp.array2d(dtype=float), + mf_body_a: wp.array2d[int], + mf_body_b: wp.array2d[int], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_row_mu: wp.array2d[float], + mf_phi: wp.array2d[float], ): """Build MF constraint rows for contacts between free rigid bodies / ground. @@ -3117,11 +3117,11 @@ def spatial_matrix_block_inverse(M: wp.spatial_matrix): @wp.kernel def compute_mf_body_Hinv( - body_I_s: wp.array(dtype=wp.spatial_matrix), - is_free_rigid: wp.array(dtype=int), - body_to_articulation: wp.array(dtype=int), + body_I_s: wp.array[wp.spatial_matrix], + is_free_rigid: wp.array[int], + body_to_articulation: wp.array[int], # outputs - mf_body_Hinv: wp.array(dtype=wp.spatial_matrix), + mf_body_Hinv: wp.array[wp.spatial_matrix], ): """Compute H^-1 = inverse(body_I_s) for free rigid bodies. @@ -3140,23 +3140,23 @@ def compute_mf_body_Hinv( @wp.kernel def compute_mf_effective_mass_and_rhs( - mf_constraint_count: wp.array(dtype=int), - mf_body_a: wp.array2d(dtype=int), - mf_body_b: wp.array2d(dtype=int), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_body_Hinv: wp.array(dtype=wp.spatial_matrix), - mf_phi: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), + mf_constraint_count: wp.array[int], + mf_body_a: wp.array2d[int], + mf_body_b: wp.array2d[int], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_body_Hinv: wp.array[wp.spatial_matrix], + mf_phi: wp.array2d[float], + mf_row_type: wp.array2d[int], pgs_cfm: float, pgs_beta: float, dt: float, mf_max_constraints: int, # outputs - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_rhs: wp.array2d(dtype=float), + mf_eff_mass_inv: wp.array2d[float], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_rhs: wp.array2d[float], ): """Compute effective mass diagonal, H^-1*J^T, and RHS bias for MF constraints. @@ -3239,25 +3239,25 @@ def compute_mf_effective_mass_and_rhs( @wp.kernel def pgs_solve_mf_loop( - mf_constraint_count: wp.array(dtype=int), - mf_body_a: wp.array2d(dtype=int), - mf_body_b: wp.array2d(dtype=int), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_rhs: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_row_mu: wp.array2d(dtype=float), - body_to_articulation: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), + mf_constraint_count: wp.array[int], + mf_body_a: wp.array2d[int], + mf_body_b: wp.array2d[int], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_eff_mass_inv: wp.array2d[float], + mf_rhs: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_row_mu: wp.array2d[float], + body_to_articulation: wp.array[int], + art_dof_start: wp.array[int], iterations: int, omega: float, # in/out - mf_impulses: wp.array2d(dtype=float), - v_out: wp.array(dtype=float), + mf_impulses: wp.array2d[float], + v_out: wp.array[float], ): """Matrix-free PGS solver for free rigid body contacts. @@ -3359,11 +3359,11 @@ def pgs_solve_mf_loop( @wp.kernel def finalize_mf_constraint_counts( - mf_slot_counter: wp.array(dtype=int), + mf_slot_counter: wp.array[int], mf_max_constraints: int, slots_per_contact: int, # outputs - mf_constraint_count: wp.array(dtype=int), + mf_constraint_count: wp.array[int], ): """Clamp MF slot counter to max and store as constraint count. @@ -3380,18 +3380,18 @@ def finalize_mf_constraint_counts( @wp.kernel def build_mf_body_map( - mf_constraint_count: wp.array(dtype=int), - mf_body_a: wp.array2d(dtype=int), - mf_body_b: wp.array2d(dtype=int), - body_to_articulation: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), + mf_constraint_count: wp.array[int], + mf_body_a: wp.array2d[int], + mf_body_b: wp.array2d[int], + body_to_articulation: wp.array[int], + art_dof_start: wp.array[int], max_mf_bodies: int, # outputs - mf_body_list: wp.array2d(dtype=int), - mf_body_dof_start: wp.array2d(dtype=int), - mf_body_count: wp.array(dtype=int), - mf_local_body_a: wp.array2d(dtype=int), - mf_local_body_b: wp.array2d(dtype=int), + mf_body_list: wp.array2d[int], + mf_body_dof_start: wp.array2d[int], + mf_body_count: wp.array[int], + mf_local_body_a: wp.array2d[int], + mf_local_body_b: wp.array2d[int], ): """Build per-world compact body table and local body index mapping. @@ -3448,16 +3448,16 @@ def build_mf_body_map( @wp.kernel def compute_mf_world_dof_offsets( - mf_constraint_count: wp.array(dtype=int), - mf_body_a: wp.array2d(dtype=int), - mf_body_b: wp.array2d(dtype=int), - body_to_articulation: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), - world_dof_start: wp.array(dtype=int), + mf_constraint_count: wp.array[int], + mf_body_a: wp.array2d[int], + mf_body_b: wp.array2d[int], + body_to_articulation: wp.array[int], + art_dof_start: wp.array[int], + world_dof_start: wp.array[int], mf_max_constraints: int, # outputs - mf_dof_a: wp.array2d(dtype=int), - mf_dof_b: wp.array2d(dtype=int), + mf_dof_a: wp.array2d[int], + mf_dof_b: wp.array2d[int], ): """Compute world-relative DOF offsets for each MF contact body. @@ -3485,17 +3485,17 @@ def compute_mf_world_dof_offsets( @wp.kernel def pgs_solve_loop( - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], max_constraints: int, - world_diag: wp.array2d(dtype=float), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_diag: wp.array2d[float], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], ): """ World-level Projected Gauss-Seidel solver. @@ -3565,18 +3565,18 @@ def pgs_solve_loop( @wp.kernel def apply_impulses_world_par_dof( - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - art_dof_start: wp.array(dtype=int), + group_to_art: wp.array[int], + art_to_world: wp.array[int], + art_dof_start: wp.array[int], n_dofs: int, n_arts: int, - world_constraint_count: wp.array(dtype=int), + world_constraint_count: wp.array[int], max_constraints: int, - Y_group: wp.array3d(dtype=float), - world_impulses: wp.array2d(dtype=float), - v_hat: wp.array(dtype=float), + Y_group: wp.array3d[float], + world_impulses: wp.array2d[float], + v_hat: wp.array[float], # outputs - v_out: wp.array(dtype=float), + v_out: wp.array[float], ): """ Accumulate velocity changes from world impulses for a single size group. @@ -3609,10 +3609,10 @@ def apply_impulses_world_par_dof( @wp.kernel def finalize_world_diag_cfm( - world_constraint_count: wp.array(dtype=int), - world_row_cfm: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_row_cfm: wp.array2d[float], # in/out - world_diag: wp.array2d(dtype=float), + world_diag: wp.array2d[float], ): """Add CFM to world diagonal after Delassus accumulation.""" world = wp.tid() @@ -3624,11 +3624,11 @@ def finalize_world_diag_cfm( @wp.kernel def add_dense_contact_compliance_to_diag( - world_constraint_count: wp.array(dtype=int), - world_row_type: wp.array2d(dtype=int), + world_constraint_count: wp.array[int], + world_row_type: wp.array2d[int], contact_alpha: float, # in/out - world_diag: wp.array2d(dtype=float), + world_diag: wp.array2d[float], ): """Add normal-contact compliance to the dense PGS diagonal. @@ -3655,19 +3655,19 @@ def add_dense_contact_compliance_to_diag( @wp.kernel def hinv_jt_par_row( # Grouped Cholesky factor storage [n_arts, n_dofs, n_dofs] - L_group: wp.array3d(dtype=float), + L_group: wp.array3d[float], # Size-grouped Jacobian [n_arts_of_size, max_constraints, n_dofs] - J_group: wp.array3d(dtype=float), + J_group: wp.array3d[float], # Indirection arrays - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], # Size parameters n_dofs: int, max_constraints: int, n_arts: int, # Output: Y = H^-1 * J^T [n_arts_of_size, max_constraints, n_dofs] - Y_group: wp.array3d(dtype=float), + Y_group: wp.array3d[float], ): """ Compute Y = H^-1 * J^T for one size group using forward/backward substitution. @@ -3746,19 +3746,19 @@ def hinv_jt_par_row( @wp.kernel def delassus_par_row_col( # Size-grouped arrays - J_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] - Y_group: wp.array3d(dtype=float), # [n_arts_of_size, max_constraints, n_dofs] + J_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] + Y_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] # Indirection arrays - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], # Size parameters n_dofs: int, max_constraints: int, n_arts: int, # Output: Delassus matrix C and diagonal (accumulated via atomics) - world_C: wp.array3d(dtype=float), # [world_count, max_constraints, max_constraints] - world_diag: wp.array2d(dtype=float), # [world_count, max_constraints] + world_C: wp.array3d[float], # [world_count, max_constraints, max_constraints] + world_diag: wp.array2d[float], # [world_count, max_constraints] ): """ Accumulate Delassus matrix contribution C += J * Y^T from one size group. @@ -3811,19 +3811,19 @@ def delassus_par_row_col( @wp.kernel def crba_fill_par_dof( - articulation_start: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), - mass_update_mask: wp.array(dtype=int), - joint_ancestor: wp.array(dtype=int), - joint_qd_start: wp.array(dtype=int), - joint_dof_dim: wp.array(dtype=int, ndim=2), - joint_S_s: wp.array(dtype=wp.spatial_vector), - body_I_c: wp.array(dtype=wp.spatial_matrix), + articulation_start: wp.array[int], + articulation_dof_start: wp.array[int], + mass_update_mask: wp.array[int], + joint_ancestor: wp.array[int], + joint_qd_start: wp.array[int], + joint_dof_dim: wp.array2d[int], + joint_S_s: wp.array[wp.spatial_vector], + body_I_c: wp.array[wp.spatial_matrix], # Size-group parameters - group_to_art: wp.array(dtype=int), + group_to_art: wp.array[int], n_dofs: int, # = TILE_DOF for tiled path # outputs - H_group: wp.array3d(dtype=float), # [n_arts_of_size, n_dofs, n_dofs] + H_group: wp.array3d[float], # [n_arts_of_size, n_dofs, n_dofs] ): """ CRBA fill kernel that writes directly to size-grouped H storage. @@ -3899,13 +3899,13 @@ def crba_fill_par_dof( @wp.kernel def trisolve_loop( - L_group: wp.array3d(dtype=float), # [n_arts_of_size, n_dofs, n_dofs] - group_to_art: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), + L_group: wp.array3d[float], # [n_arts_of_size, n_dofs, n_dofs] + group_to_art: wp.array[int], + articulation_dof_start: wp.array[int], n_dofs: int, - joint_tau: wp.array(dtype=float), # [total_dofs] + joint_tau: wp.array[float], # [total_dofs] # output - joint_qdd: wp.array(dtype=float), # [total_dofs] + joint_qdd: wp.array[float], # [total_dofs] ): """ Solve L * L^T * qdd = tau for grouped articulations using forward/backward substitution. @@ -3948,11 +3948,11 @@ def trisolve_loop( @wp.kernel def gather_tau_to_groups( - joint_tau: wp.array(dtype=float), # [total_dofs] - group_to_art: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), + joint_tau: wp.array[float], # [total_dofs] + group_to_art: wp.array[int], + articulation_dof_start: wp.array[int], n_dofs: int, - tau_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] + tau_group: wp.array3d[float], # [n_arts, n_dofs, 1] ): """Gather joint_tau from 1D array into grouped 3D buffer for tiled solve. @@ -3967,11 +3967,11 @@ def gather_tau_to_groups( @wp.kernel def scatter_qdd_from_groups( - qdd_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] - group_to_art: wp.array(dtype=int), - articulation_dof_start: wp.array(dtype=int), + qdd_group: wp.array3d[float], # [n_arts, n_dofs, 1] + group_to_art: wp.array[int], + articulation_dof_start: wp.array[int], n_dofs: int, - joint_qdd: wp.array(dtype=float), # [total_dofs] + joint_qdd: wp.array[float], # [total_dofs] ): """Scatter qdd from grouped 3D buffer back to 1D array after tiled solve. @@ -3985,7 +3985,7 @@ def scatter_qdd_from_groups( @wp.kernel -def vector_add_inplace(a: wp.array(dtype=float), b: wp.array(dtype=float)): +def vector_add_inplace(a: wp.array[float], b: wp.array[float]): """a[i] += b[i]""" i = wp.tid() a[i] = a[i] + b[i] @@ -3993,9 +3993,9 @@ def vector_add_inplace(a: wp.array(dtype=float), b: wp.array(dtype=float)): @wp.kernel def compute_delta_and_accumulate( - v_out: wp.array(dtype=float), - v_snap: wp.array(dtype=float), - v_accum: wp.array(dtype=float), + v_out: wp.array[float], + v_snap: wp.array[float], + v_accum: wp.array[float], ): """delta = v_out - v_snap; v_accum += delta; v_snap = delta (reuse buffer for rhs_accum input)""" i = wp.tid() @@ -4012,34 +4012,34 @@ def compute_delta_and_accumulate( @wp.kernel def pgs_convergence_diagnostic_velocity( # Dense constraints - constraint_count: wp.array(dtype=int), - world_dof_start: wp.array(dtype=int), - rhs: wp.array2d(dtype=float), - impulses: wp.array2d(dtype=float), - prev_impulses: wp.array2d(dtype=float), - row_type: wp.array2d(dtype=int), - row_parent: wp.array2d(dtype=int), - row_mu: wp.array2d(dtype=float), - J_world: wp.array3d(dtype=float), + constraint_count: wp.array[int], + world_dof_start: wp.array[int], + rhs: wp.array2d[float], + impulses: wp.array2d[float], + prev_impulses: wp.array2d[float], + row_type: wp.array2d[int], + row_parent: wp.array2d[int], + row_mu: wp.array2d[float], + J_world: wp.array3d[float], max_constraints: int, max_world_dofs: int, # MF constraints - mf_constraint_count: wp.array(dtype=int), - mf_rhs: wp.array2d(dtype=float), - mf_impulses: wp.array2d(dtype=float), - prev_mf_impulses: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_row_mu: wp.array2d(dtype=float), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_dof_a: wp.array2d(dtype=int), - mf_dof_b: wp.array2d(dtype=int), + mf_constraint_count: wp.array[int], + mf_rhs: wp.array2d[float], + mf_impulses: wp.array2d[float], + prev_mf_impulses: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_row_mu: wp.array2d[float], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_dof_a: wp.array2d[int], + mf_dof_b: wp.array2d[int], mf_max_constraints: int, # Velocity - v_out: wp.array(dtype=float), + v_out: wp.array[float], # Output: [worlds, 4] - metrics: wp.array2d(dtype=float), + metrics: wp.array2d[float], ): """Compute per-world PGS convergence metrics for velocity-space mode. diff --git a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py index ed58597ac4..2485a97bd1 100644 --- a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py +++ b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py @@ -50,7 +50,6 @@ compute_com_transforms, compute_composite_inertia, compute_contact_linear_force_from_impulses, - compute_delta_and_accumulate, compute_mf_body_Hinv, compute_mf_effective_mass_and_rhs, compute_mf_world_dof_offsets, @@ -89,16 +88,15 @@ update_articulation_root_com_offsets, update_body_qd_from_featherstone, update_qdd_from_velocity, - vector_add_inplace, ) @wp.kernel def localize_parent_indices( - counts: wp.array(dtype=int), + counts: wp.array[int], max_constraints: int, - parent_arr: wp.array(dtype=int), - parent_local_arr: wp.array(dtype=int), + parent_arr: wp.array[int], + parent_local_arr: wp.array[int], ): art = wp.tid() m = counts[art] @@ -118,6 +116,9 @@ class SolverFeatherPGS(SolverBase): on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics based on Featherstone's composite rigid body algorithm (CRBA). + This private solver branch keeps only the matrix-free contact solve path and + the current winner kernel strategy. + See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014. Instead of maximal coordinates :attr:`~newton.State.body_q` (rigid body positions) and :attr:`~newton.State.body_qd` @@ -173,19 +174,7 @@ def __init__( pgs_omega: float = 1.0, dense_max_constraints: int = 32, pgs_warmstart: bool = False, - pgs_mode: str = "split", mf_max_constraints: int = 512, - # Kernel selection per operation - cholesky_kernel: str = "auto", - trisolve_kernel: str = "auto", - hinv_jt_kernel: str = "auto", - delassus_kernel: str = "auto", - pgs_kernel: str = "tiled_row", - # Streaming kernel chunk sizes (None = auto-select) - delassus_chunk_size: int | None = None, - pgs_chunk_size: int | None = None, - # Auto selection threshold - small_dof_threshold: int = 12, # Parallelism options use_parallel_streams: bool = True, double_buffer: bool = True, @@ -200,9 +189,7 @@ def __init__( friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0. enable_contact_friction (bool, optional): Enables Coulomb friction contacts inside the PGS solve. Defaults to True. enable_joint_limits (bool, optional): Enforce joint position limits as unilateral PGS - constraints. Each violated limit adds one constraint row. Supported with - ``pgs_kernel="loop"`` and ``pgs_kernel="tiled_row"``; the ``"tiled_contact"`` - and ``"streaming"`` PGS kernels are *not* compatible. Defaults to False. + constraints. Each violated limit adds one constraint row. Defaults to False. pgs_iterations (int, optional): Number of Gauss-Seidel iterations to apply per frame. Defaults to 12. pgs_beta (float, optional): ERP style position correction factor. Defaults to 0.2. pgs_cfm (float, optional): Compliance/regularization added to the Delassus diagonal. Defaults to 1.0e-6. @@ -210,39 +197,13 @@ def __init__( only to dense articulated contact rows. Converted to an impulse-space diagonal term using ``compliance / dt^2``. Defaults to 0.0. pgs_omega (float, optional): Successive over-relaxation factor for the PGS sweep. Defaults to 1.0. - dense_max_constraints (int, optional): Maximum number of dense (articulation) contact constraint + dense_max_constraints (int, optional): Maximum number of articulated contact constraint rows stored per world. Free rigid body contacts are stored separately, bounded by mf_max_constraints. Defaults to 32. pgs_warmstart (bool, optional): Re-use impulses from the previous frame when contacts persist. Defaults to False. - pgs_mode (str, optional): PGS mode. "dense" builds the full Delassus matrix C = J*H^{-1}*J^T - and solves in impulse space (Gauss-Seidel) for all contacts. "split" uses the dense - path for articulated bodies and a cheaper matrix-free PGS path for free rigid body - contacts. "matrix_free" skips C entirely, recomputes J*v each iteration, and uses only - the diagonal for preconditioning — O(max_constraints) memory instead of - O(max_constraints^2). Defaults to "split". mf_max_constraints (int, optional): Maximum number of matrix-free constraints per world. Defaults to 512. - cholesky_kernel (str, optional): "tiled", "loop", or "auto" for Cholesky factorization. Defaults to "auto". - trisolve_kernel (str, optional): "tiled", "loop", or "auto" for triangular solve. Defaults to "auto". - hinv_jt_kernel (str, optional): "tiled", "par_row", or "auto" for H^{-1}J^T. Defaults to "auto". - delassus_kernel (str, optional): "tiled", "par_row_col", or "auto" for Delassus accumulation - (C = J * H^{-1} * J^T). "tiled" uses a streaming CUDA kernel that chunks shared memory - and scales to any constraint count. "par_row_col" launches one thread per matrix element. - "auto" selects "tiled" when DOFs exceed the threshold. Defaults to "auto". - pgs_kernel (str, optional): "loop", "tiled_row", or "tiled_contact" for PGS solve. Defaults to "tiled_row". - delassus_chunk_size (int, optional): Chunk size (in constraint rows) for the streaming Delassus - kernel. Controls how many rows of J and Y are loaded into shared memory at once. - None selects automatically based on shared memory heuristics. Defaults to None. - pgs_chunk_size (int, optional): Chunk size (in contacts, i.e. groups of 3 constraint rows) - for the streaming PGS kernel. Controls how many block-rows of the Delassus matrix are - preloaded into shared memory at once. 1 = current streaming behavior (one block-row - at a time). None defaults to 1. Defaults to None. - small_dof_threshold (int, optional): DOF threshold for "auto" kernel selection. Defaults to 12. use_parallel_streams (bool, optional): Dispatch size groups on separate CUDA streams. Defaults to True. - Auto selection behavior: - - auto: size > threshold -> tiled, else loop/par_row. - - Delassus auto/tiled: streaming kernel (handles any constraint count via chunking). - """ super().__init__(model) @@ -258,42 +219,15 @@ def __init__( self.pgs_omega = pgs_omega self.dense_max_constraints = dense_max_constraints self.pgs_warmstart = pgs_warmstart - if pgs_mode not in ("dense", "split", "matrix_free"): - raise ValueError(f"pgs_mode must be 'dense', 'split', or 'matrix_free', got {pgs_mode!r}") - self.pgs_mode = pgs_mode self.mf_max_constraints = mf_max_constraints self._double_buffer = double_buffer self._nvtx = nvtx self.pgs_debug = pgs_debug self._pgs_convergence_log: list[np.ndarray] = [] - valid_cholesky = {"tiled", "loop", "auto"} - if cholesky_kernel not in valid_cholesky: - raise ValueError(f"cholesky_kernel must be one of {sorted(valid_cholesky)}") - - valid_trisolve = {"tiled", "loop", "auto"} - if trisolve_kernel not in valid_trisolve: - raise ValueError(f"trisolve_kernel must be one of {sorted(valid_trisolve)}") - - valid_hinv_jt = {"tiled", "par_row", "auto"} - if hinv_jt_kernel not in valid_hinv_jt: - raise ValueError(f"hinv_jt_kernel must be one of {sorted(valid_hinv_jt)}") - - valid_delassus = {"tiled", "par_row_col", "auto"} - if delassus_kernel not in valid_delassus: - raise ValueError(f"delassus_kernel must be one of {sorted(valid_delassus)}") - - valid_pgs = {"loop", "tiled_row", "tiled_contact", "streaming"} - if pgs_kernel not in valid_pgs: - raise ValueError(f"pgs_kernel must be one of {sorted(valid_pgs)}") - - self.cholesky_kernel = cholesky_kernel - self.trisolve_kernel = trisolve_kernel - self.hinv_jt_kernel = hinv_jt_kernel - self.delassus_kernel = delassus_kernel - self.pgs_kernel = pgs_kernel - self.delassus_chunk_size = delassus_chunk_size - self.pgs_chunk_size = pgs_chunk_size if pgs_chunk_size is not None else 1 - self.small_dof_threshold = small_dof_threshold + self.small_dof_threshold = 12 + self.delassus_chunk_size = None + self.pgs_chunk_size = 1 + self.pgs_kernel = "tiled_row" self.use_parallel_streams = use_parallel_streams self._step = 0 @@ -807,34 +741,20 @@ def _allocate_world_buffers(self, model): requires_grad = model.requires_grad max_constraints = self.dense_max_constraints - # Per-world constraint matrices and vectors - if self.pgs_mode != "matrix_free": - self.C = wp.zeros( - (self.world_count, max_constraints, max_constraints), - dtype=wp.float32, - device=device, - requires_grad=requires_grad, - ) - else: - self.C = None - - if self.pgs_mode == "matrix_free": - self._compute_world_dof_mapping(model) - self.J_world = wp.zeros( - (self.world_count, max_constraints, self.max_world_dofs), - dtype=wp.float32, - device=device, - requires_grad=requires_grad, - ) - self.Y_world = wp.zeros( - (self.world_count, max_constraints, self.max_world_dofs), - dtype=wp.float32, - device=device, - requires_grad=requires_grad, - ) - else: - self.J_world = None - self.Y_world = None + self.C = None + self._compute_world_dof_mapping(model) + self.J_world = wp.zeros( + (self.world_count, max_constraints, self.max_world_dofs), + dtype=wp.float32, + device=device, + requires_grad=requires_grad, + ) + self.Y_world = wp.zeros( + (self.world_count, max_constraints, self.max_world_dofs), + dtype=wp.float32, + device=device, + requires_grad=requires_grad, + ) self.rhs = wp.zeros( (self.world_count, max_constraints), dtype=wp.float32, device=device, requires_grad=requires_grad @@ -876,7 +796,7 @@ def _allocate_world_buffers(self, model): def _allocate_mf_buffers(self, model): """Allocate buffers for matrix-free PGS path for free rigid body contacts.""" - if self.pgs_mode == "dense" or (not self._has_free_rigid_bodies and self.pgs_mode != "matrix_free"): + if not self._has_free_rigid_bodies: return device = model.device @@ -1074,9 +994,7 @@ def step( with wp.ScopedTimer("S2_Cholesky", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): with ctx: - use_tiled = (self.cholesky_kernel == "tiled") or ( - self.cholesky_kernel == "auto" and size > self.small_dof_threshold - ) + use_tiled = size > self.small_dof_threshold if use_tiled: self._stage2_cholesky_tiled(size) else: @@ -1088,9 +1006,7 @@ def step( self._stage3_zero_qdd(state_aug) for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): with ctx: - use_tiled = (self.trisolve_kernel == "tiled") or ( - self.trisolve_kernel == "auto" and size > self.small_dof_threshold - ) + use_tiled = size > self.small_dof_threshold if use_tiled: self._stage3_trisolve_tiled(size, state_aug) else: @@ -1107,95 +1023,45 @@ def step( with wp.ScopedTimer("S4_ContactBuild", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): self._stage4_build_rows(state_in, state_aug, contacts) - if self.pgs_mode == "matrix_free": - # Compute Y = H^-1 * J^T only (no Delassus C) - with wp.ScopedTimer("S4_HinvJt_Diag_RHS", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): - for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): - with ctx: - use_tiled = (self.hinv_jt_kernel == "tiled") or ( - self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold - ) - if use_tiled: - self._stage4_hinv_jt_tiled(size) - else: - self._stage4_hinv_jt_par_row(size) - - # Diagonal from J*Y (no full Delassus) - self.diag.zero_() - for size in self.size_groups: - self._stage4_diag_from_JY(size) - self._stage4_finalize_world_diag_cfm() - self._stage4_add_dense_contact_compliance(dt) - - # RHS = bias only (J*v recomputed per iteration) - self._stage4_compute_rhs_world(dt) - # NOTE: skip _stage4_accumulate_rhs_world — J*v_hat not baked into rhs - - # MF: compute mf_MiJt, mf_rhs, mf_eff_mass_inv, body maps - if self._has_free_rigid_bodies: - with wp.ScopedTimer("S4_MF_Setup", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): - self._mf_pgs_setup(state_aug, dt) - - # MF: compute world-relative DOF offsets for two-phase GS kernel - wp.launch( - compute_mf_world_dof_offsets, - dim=self.world_count * self.mf_max_constraints, - inputs=[ - self.mf_constraint_count, - self.mf_body_a, - self.mf_body_b, - self.body_to_articulation, - self.articulation_dof_start, - self.world_dof_start, - self.mf_max_constraints, - ], - outputs=[self.mf_dof_a, self.mf_dof_b], - device=self.model.device, - ) - - else: - # Existing Delassus path (unchanged) - fused_ok = ( - self._is_one_art_per_world - and self.hinv_jt_kernel != "par_row" - and all( - (self.hinv_jt_kernel == "tiled") - or (self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold) - for size in self.size_groups - ) - ) - - if fused_ok: - for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): - with ctx: - self._stage4_hinv_jt_tiled_fused(size) - else: - self._stage4_zero_world_C() - - for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): - with ctx: - use_tiled = (self.hinv_jt_kernel == "tiled") or ( - self.hinv_jt_kernel == "auto" and size > self.small_dof_threshold - ) - if use_tiled: - self._stage4_hinv_jt_tiled(size) - else: - self._stage4_hinv_jt_par_row(size) - - for size in self.size_groups: - use_tiled_delassus = self.delassus_kernel != "par_row_col" - if use_tiled_delassus: - self._stage4_delassus_tiled(size) + with wp.ScopedTimer("S4_HinvJt_Diag_RHS", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + for size, ctx in self._for_sizes(enabled=self.use_parallel_streams): + with ctx: + use_tiled = size > self.small_dof_threshold + if use_tiled: + self._stage4_hinv_jt_tiled(size) else: - self._stage4_delassus_par_row_col(size) - - self._stage4_finalize_world_diag_cfm() + self._stage4_hinv_jt_par_row(size) + # Diagonal from J*Y (no full Delassus) + self.diag.zero_() + for size in self.size_groups: + self._stage4_diag_from_JY(size) + self._stage4_finalize_world_diag_cfm() self._stage4_add_dense_contact_compliance(dt) + + # RHS = bias only (J*v recomputed per iteration) self._stage4_compute_rhs_world(dt) + # NOTE: skip _stage4_accumulate_rhs_world — J*v_hat not baked into rhs - for size in self.size_groups: - self._stage4_accumulate_rhs_world(size) + if self._has_free_rigid_bodies: + with wp.ScopedTimer("S4_MF_Setup", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + self._mf_pgs_setup(state_aug, dt) + + wp.launch( + compute_mf_world_dof_offsets, + dim=self.world_count * self.mf_max_constraints, + inputs=[ + self.mf_constraint_count, + self.mf_body_a, + self.mf_body_b, + self.body_to_articulation, + self.articulation_dof_start, + self.world_dof_start, + self.mf_max_constraints, + ], + outputs=[self.mf_dof_a, self.mf_dof_b], + device=self.model.device, + ) # ══════════════════════════════════════════════════════════════ # STAGE 5+6: PGS solve @@ -1203,155 +1069,67 @@ def step( with wp.ScopedTimer("S5_PGS_Prep", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): self._stage5_prepare_impulses_world() - if self.pgs_mode == "matrix_free": - with wp.ScopedTimer("S5_GatherJY", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): - # Gather J/Y from per-size-group arrays into world-indexed arrays - # No J_world/Y_world zeroing needed: gather writes all DOFs unconditionally - for size in self.size_groups: - n_arts = self.n_arts_by_size[size] - wp.launch( - gather_JY_to_world, - dim=int(n_arts * self.dense_max_constraints * size), - inputs=[ - self.group_to_art[size], - self.art_to_world, - self.articulation_dof_start, - self.constraint_count, - self.world_dof_start, - self.J_by_size[size], - self.Y_by_size[size], - size, - self.dense_max_constraints, - n_arts, - ], - outputs=[self.J_world, self.Y_world], - device=self.model.device, - ) - - # Initialize v_out = v_hat before GS loop - self._stage6_prepare_world_velocity() - - # Pack MF metadata into int4 structs for coalesced 128-bit loads - pack_kernel = TiledKernelFactory.get_pack_mf_meta_kernel(self.mf_max_constraints, self.model.device) - wp.launch_tiled( - pack_kernel, - dim=[self.world_count], + with wp.ScopedTimer("S5_GatherJY", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + for size in self.size_groups: + n_arts = self.n_arts_by_size[size] + wp.launch( + gather_JY_to_world, + dim=int(n_arts * self.dense_max_constraints * size), inputs=[ - self.mf_constraint_count, - self.mf_dof_a, - self.mf_dof_b, - self.mf_eff_mass_inv, - self.mf_rhs, - self.mf_row_type, - self.mf_row_parent, + self.group_to_art[size], + self.art_to_world, + self.articulation_dof_start, + self.constraint_count, + self.world_dof_start, + self.J_by_size[size], + self.Y_by_size[size], + size, + self.dense_max_constraints, + n_arts, ], - outputs=[self.mf_meta_packed], - block_dim=32, + outputs=[self.J_world, self.Y_world], device=self.model.device, ) - # Two-phase GS kernel: split-style dense + MF in one pass - with wp.ScopedTimer("S6_PGS_Solve", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): - mf_gs_kernel = TiledKernelFactory.get_pgs_solve_mf_gs_kernel( - self.dense_max_constraints, - self.mf_max_constraints, - self.max_world_dofs, - self.model.device, - ) + self._stage6_prepare_world_velocity() + + pack_kernel = TiledKernelFactory.get_pack_mf_meta_kernel(self.mf_max_constraints, self.model.device) + wp.launch_tiled( + pack_kernel, + dim=[self.world_count], + inputs=[ + self.mf_constraint_count, + self.mf_dof_a, + self.mf_dof_b, + self.mf_eff_mass_inv, + self.mf_rhs, + self.mf_row_type, + self.mf_row_parent, + ], + outputs=[self.mf_meta_packed], + block_dim=32, + device=self.model.device, + ) + + with wp.ScopedTimer("S6_PGS_Solve", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): + mf_gs_kernel = TiledKernelFactory.get_pgs_solve_mf_gs_kernel( + self.dense_max_constraints, + self.mf_max_constraints, + self.max_world_dofs, + self.model.device, + ) + + if self.pgs_debug: + self._pgs_convergence_log.append([]) + for _pgs_dbg_iter in range(self.pgs_iterations): + wp.copy(self._diag_prev_impulses, self.impulses) + if self._diag_prev_mf_impulses is not None: + wp.copy(self._diag_prev_mf_impulses, self.mf_impulses) - if self.pgs_debug: - self._pgs_convergence_log.append([]) - for _pgs_dbg_iter in range(self.pgs_iterations): - # Snapshot impulses before this iteration - wp.copy(self._diag_prev_impulses, self.impulses) - if self._diag_prev_mf_impulses is not None: - wp.copy(self._diag_prev_mf_impulses, self.mf_impulses) - - # Run 1 iteration - wp.launch_tiled( - mf_gs_kernel, - dim=[self.world_count], - inputs=[ - self.constraint_count, - self.world_dof_start, - self.rhs, - self.diag, - self.impulses, - self.J_world, - self.Y_world, - self.row_type, - self.row_parent, - self.row_mu, - self.mf_constraint_count, - self.mf_meta_packed, - self.mf_impulses, - self.mf_J_a, - self.mf_J_b, - self.mf_MiJt_a, - self.mf_MiJt_b, - self.mf_row_mu, - 1, # iterations=1 - self.pgs_omega, - ], - outputs=[self.v_out], - block_dim=32, - device=self.model.device, - ) - - # Diagnostic kernel - wp.launch( - pgs_convergence_diagnostic_velocity, - dim=self.world_count, - inputs=[ - self.constraint_count, - self.world_dof_start, - self.rhs, - self.impulses, - self._diag_prev_impulses, - self.row_type, - self.row_parent, - self.row_mu, - self.J_world, - self.dense_max_constraints, - self.max_world_dofs, - self.mf_constraint_count, - self.mf_rhs, - self.mf_impulses, - self._diag_prev_mf_impulses, - self.mf_row_type, - self.mf_row_parent, - self.mf_row_mu, - self.mf_J_a, - self.mf_J_b, - self.mf_dof_a, - self.mf_dof_b, - self.mf_max_constraints, - self.v_out, - ], - outputs=[self._diag_metrics], - device=self.model.device, - ) - - # Sync and reduce across worlds - metrics_np = self._diag_metrics.numpy() - row = np.array( - [ - np.max(metrics_np[:, 0]), # max|delta_lambda| - np.sum(metrics_np[:, 1]), # complementarity gap - np.sum(metrics_np[:, 2]), # tangent residual - np.sum(metrics_np[:, 3]), # FB merit - ] - ) - self._pgs_convergence_log[-1].append(row) - - self._pgs_convergence_log[-1] = np.array(self._pgs_convergence_log[-1]) - - else: wp.launch_tiled( mf_gs_kernel, dim=[self.world_count], inputs=[ - # Dense self.constraint_count, self.world_dof_start, self.rhs, @@ -1362,7 +1140,6 @@ def step( self.row_type, self.row_parent, self.row_mu, - # MF self.mf_constraint_count, self.mf_meta_packed, self.mf_impulses, @@ -1371,90 +1148,90 @@ def step( self.mf_MiJt_a, self.mf_MiJt_b, self.mf_row_mu, - # Shared - self.pgs_iterations, + 1, self.pgs_omega, ], outputs=[self.v_out], block_dim=32, device=self.model.device, ) - elif self.pgs_mode == "split" and self._has_mixed_contacts: - # Split mode with mixed contacts: interleaved dense and MF, 1 iteration each - self._mf_pgs_setup(state_aug, dt) - self.v_mf_accum.zero_() - - for _pgs_iter in range(self.pgs_iterations): - # Dense PGS (1 iteration, impulse space) - self._dispatch_dense_pgs_solve(iterations=1) - - # Rebuild v_out = v_hat + Y*impulses + MF_corrections - self._stage6_prepare_world_velocity() - for size in self.size_groups: - self._stage6_apply_impulses_world(size) - wp.launch( - vector_add_inplace, - dim=self.v_out.size, - inputs=[self.v_out, self.v_mf_accum], - device=self.model.device, - ) - - # Snapshot v_out, run MF, compute delta - wp.copy(self.v_out_snap, self.v_out) - self._mf_pgs_solve(iterations=1) - - # v_mf_accum += (v_out - v_out_snap); v_out_snap = delta - wp.launch( - compute_delta_and_accumulate, - dim=self.v_out.size, - inputs=[self.v_out, self.v_out_snap, self.v_mf_accum], - device=self.model.device, - ) - # Update dense rhs: world_rhs += J * delta_v_mf - for size in self.size_groups: - n_arts = self.n_arts_by_size[size] wp.launch( - rhs_accum_world_par_art, - dim=n_arts, + pgs_convergence_diagnostic_velocity, + dim=self.world_count, inputs=[ self.constraint_count, + self.world_dof_start, + self.rhs, + self.impulses, + self._diag_prev_impulses, + self.row_type, + self.row_parent, + self.row_mu, + self.J_world, self.dense_max_constraints, - self.art_to_world, - self.art_size, - self.art_group_idx, - self.articulation_dof_start, - self.v_out_snap, - self.group_to_art[size], - self.J_by_size[size], - size, + self.max_world_dofs, + self.mf_constraint_count, + self.mf_rhs, + self.mf_impulses, + self._diag_prev_mf_impulses, + self.mf_row_type, + self.mf_row_parent, + self.mf_row_mu, + self.mf_J_a, + self.mf_J_b, + self.mf_dof_a, + self.mf_dof_b, + self.mf_max_constraints, + self.v_out, ], - outputs=[self.rhs], + outputs=[self._diag_metrics], device=self.model.device, ) - # v_out is already final (includes both dense and MF corrections) + metrics_np = self._diag_metrics.numpy() + row = np.array( + [ + np.max(metrics_np[:, 0]), + np.sum(metrics_np[:, 1]), + np.sum(metrics_np[:, 2]), + np.sum(metrics_np[:, 3]), + ] + ) + self._pgs_convergence_log[-1].append(row) - else: - # Dense or split without mixed contacts: dense PGS, then optional MF - if self.pgs_debug: - self._pgs_convergence_log.append([]) - for _pgs_dbg_iter in range(self.pgs_iterations): - prev_np = self.impulses.numpy().copy() - self._dispatch_dense_pgs_solve(iterations=1) - cur_np = self.impulses.numpy() - max_delta = float(np.max(np.abs(cur_np - prev_np))) - self._pgs_convergence_log[-1].append(np.array([max_delta, 0.0, 0.0, 0.0])) self._pgs_convergence_log[-1] = np.array(self._pgs_convergence_log[-1]) - else: - self._dispatch_dense_pgs_solve(iterations=self.pgs_iterations) - - self._stage6_prepare_world_velocity() - for size in self.size_groups: - self._stage6_apply_impulses_world(size) - if self.pgs_mode == "split" and self._has_free_rigid_bodies: - self._stage6b_mf_pgs(state_aug, dt) + else: + wp.launch_tiled( + mf_gs_kernel, + dim=[self.world_count], + inputs=[ + self.constraint_count, + self.world_dof_start, + self.rhs, + self.diag, + self.impulses, + self.J_world, + self.Y_world, + self.row_type, + self.row_parent, + self.row_mu, + self.mf_constraint_count, + self.mf_meta_packed, + self.mf_impulses, + self.mf_J_a, + self.mf_J_b, + self.mf_MiJt_a, + self.mf_MiJt_b, + self.mf_row_mu, + self.pgs_iterations, + self.pgs_omega, + ], + outputs=[self.v_out], + block_dim=32, + device=self.model.device, + ) # ══════════════════════════════════════════════════════════════ # STAGE 7: Update qdd + integrate @@ -2078,7 +1855,7 @@ def _stage3_compute_v_hat(self, state_in: State, state_aug: State, dt: float): def _stage4_build_rows(self, state_in: State, state_aug: State, contacts: Contacts): model = self.model max_constraints = self.dense_max_constraints - mf_active = self._has_free_rigid_bodies and self.pgs_mode != "dense" + mf_active = self._has_free_rigid_bodies # Zero world-level buffers (only arrays that require it) self.slot_counter.zero_() # atomic-add counter @@ -2948,13 +2725,13 @@ def _build_hinv_jt_kernel(cls, n_dofs: int, max_constraints: int) -> "wp.Kernel" TILE_CONSTRAINTS_LOCAL = wp.constant(int(max_constraints)) def hinv_jt_tiled_template( - L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] - J_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), + L_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] + J_group: wp.array3d[float], # [n_arts, max_c, n_dofs] + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], # output - Y_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + Y_group: wp.array3d[float], # [n_arts, max_c, n_dofs] ): idx = wp.tid() art = group_to_art[idx] @@ -2992,16 +2769,16 @@ def _build_hinv_jt_fused_kernel(cls, n_dofs: int, max_constraints: int) -> "wp.K TILE_CONSTRAINTS_LOCAL = wp.constant(int(max_constraints)) def hinv_jt_tiled_fused_template( - L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] - J_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), - row_cfm: wp.array2d(dtype=float), + L_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] + J_group: wp.array3d[float], # [n_arts, max_c, n_dofs] + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], + row_cfm: wp.array2d[float], # outputs - world_C: wp.array3d(dtype=float), # [world_count, max_c, max_c] - world_diag: wp.array2d(dtype=float), # [world_count, max_c] - Y_group: wp.array3d(dtype=float), # [n_arts, max_c, n_dofs] + world_C: wp.array3d[float], # [world_count, max_c, max_c] + world_diag: wp.array2d[float], # [world_count, max_c] + Y_group: wp.array3d[float], # [n_arts, max_c, n_dofs] ): idx, thread = wp.tid() art = group_to_art[idx] @@ -3112,24 +2889,24 @@ def _build_delassus_kernel(cls, n_dofs: int, max_constraints: int, chunk_size: i @wp.func_native(snippet) def delassus_native( idx: int, - J_group: wp.array3d(dtype=float), - Y_group: wp.array3d(dtype=float), - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), - world_C: wp.array3d(dtype=float), - world_diag: wp.array2d(dtype=float), + J_group: wp.array3d[float], + Y_group: wp.array3d[float], + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], + world_C: wp.array3d[float], + world_diag: wp.array2d[float], ): ... def delassus_template( - J_group: wp.array3d(dtype=float), - Y_group: wp.array3d(dtype=float), - group_to_art: wp.array(dtype=int), - art_to_world: wp.array(dtype=int), - world_constraint_count: wp.array(dtype=int), + J_group: wp.array3d[float], + Y_group: wp.array3d[float], + group_to_art: wp.array[int], + art_to_world: wp.array[int], + world_constraint_count: wp.array[int], n_arts: int, - world_C: wp.array3d(dtype=float), - world_diag: wp.array2d(dtype=float), + world_C: wp.array3d[float], + world_diag: wp.array2d[float], ): idx, _lane = wp.tid() if idx < n_arts: @@ -3159,12 +2936,12 @@ def _build_cholesky_kernel(cls, n_dofs: int) -> "wp.Kernel": TILE_DOF_LOCAL = wp.constant(int(n_dofs)) def cholesky_tiled_template( - H_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] - R_group: wp.array2d(dtype=float), # [n_arts, n_dofs] armature - group_to_art: wp.array(dtype=int), - mass_update_mask: wp.array(dtype=int), + H_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] + R_group: wp.array2d[float], # [n_arts, n_dofs] armature + group_to_art: wp.array[int], + mass_update_mask: wp.array[int], # output - L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] + L_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] ): idx = wp.tid() art = group_to_art[idx] @@ -3206,9 +2983,9 @@ def _build_triangular_solve_kernel(cls, n_dofs: int) -> "wp.Kernel": TILE_DOF_LOCAL = wp.constant(int(n_dofs)) def trisolve_tiled_template( - L_group: wp.array3d(dtype=float), # [n_arts, n_dofs, n_dofs] - tau_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] - qdd_group: wp.array3d(dtype=float), # [n_arts, n_dofs, 1] + L_group: wp.array3d[float], # [n_arts, n_dofs, n_dofs] + tau_group: wp.array3d[float], # [n_arts, n_dofs, 1] + qdd_group: wp.array3d[float], # [n_arts, n_dofs, 1] ): idx = wp.tid() L_tile = wp.tile_load(L_group[idx], shape=(TILE_DOF_LOCAL, TILE_DOF_LOCAL), bounds_check=False) @@ -3410,29 +3187,29 @@ def gen_load_1d(dst, src): @wp.func_native(snippet) def pgs_solve_native( world: int, - world_constraint_count: wp.array(dtype=int), - world_diag: wp.array2d(dtype=float), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_diag: wp.array2d[float], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], ): ... def pgs_solve_tiled_template( - world_constraint_count: wp.array(dtype=int), - world_diag: wp.array2d(dtype=float), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_diag: wp.array2d[float], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], ): world, _lane = wp.tid() pgs_solve_native( @@ -3675,23 +3452,23 @@ def _build_pgs_solve_tiled_contact_kernel(cls, max_constraints: int) -> "wp.Kern @wp.func_native(snippet) def pgs_solve_contact_native( world: int, - world_constraint_count: wp.array(dtype=int), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_mu: wp.array2d(dtype=float), + world_row_mu: wp.array2d[float], ): ... def pgs_solve_tiled_contact_template( - world_constraint_count: wp.array(dtype=int), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_mu: wp.array2d(dtype=float), + world_row_mu: wp.array2d[float], ): world, _lane = wp.tid() pgs_solve_contact_native( @@ -3953,23 +3730,23 @@ def _build_pgs_solve_streaming_kernel(cls, max_constraints: int, pgs_chunk_size: @wp.func_native(snippet) def pgs_solve_streaming_native( world: int, - world_constraint_count: wp.array(dtype=int), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_mu: wp.array2d(dtype=float), + world_row_mu: wp.array2d[float], ): ... def pgs_solve_streaming_template( - world_constraint_count: wp.array(dtype=int), - world_C: wp.array3d(dtype=float), - world_rhs: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_C: wp.array3d[float], + world_rhs: wp.array2d[float], + world_impulses: wp.array2d[float], iterations: int, omega: float, - world_row_mu: wp.array2d(dtype=float), + world_row_mu: wp.array2d[float], ): world, _lane = wp.tid() pgs_solve_streaming_native( @@ -4203,43 +3980,43 @@ def _build_pgs_solve_mf_kernel(cls, mf_max_constraints: int, max_mf_bodies: int) @wp.func_native(snippet) def pgs_solve_mf_native( world: int, - mf_constraint_count: wp.array(dtype=int), - mf_body_count: wp.array(dtype=int), - mf_body_dof_start: wp.array2d(dtype=int), - mf_local_body_a: wp.array2d(dtype=int), - mf_local_body_b: wp.array2d(dtype=int), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_rhs: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_row_mu: wp.array2d(dtype=float), - mf_impulses: wp.array2d(dtype=float), - v_out: wp.array(dtype=float), + mf_constraint_count: wp.array[int], + mf_body_count: wp.array[int], + mf_body_dof_start: wp.array2d[int], + mf_local_body_a: wp.array2d[int], + mf_local_body_b: wp.array2d[int], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_eff_mass_inv: wp.array2d[float], + mf_rhs: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_row_mu: wp.array2d[float], + mf_impulses: wp.array2d[float], + v_out: wp.array[float], iterations: int, omega: float, ): ... def pgs_solve_mf_template( - mf_constraint_count: wp.array(dtype=int), - mf_body_count: wp.array(dtype=int), - mf_body_dof_start: wp.array2d(dtype=int), - mf_local_body_a: wp.array2d(dtype=int), - mf_local_body_b: wp.array2d(dtype=int), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_rhs: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_row_mu: wp.array2d(dtype=float), - mf_impulses: wp.array2d(dtype=float), - v_out: wp.array(dtype=float), + mf_constraint_count: wp.array[int], + mf_body_count: wp.array[int], + mf_body_dof_start: wp.array2d[int], + mf_local_body_a: wp.array2d[int], + mf_local_body_b: wp.array2d[int], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_eff_mass_inv: wp.array2d[float], + mf_rhs: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_row_mu: wp.array2d[float], + mf_impulses: wp.array2d[float], + v_out: wp.array[float], iterations: int, omega: float, ): @@ -4316,25 +4093,25 @@ def _build_pack_mf_meta_kernel(cls, mf_max_constraints: int) -> "wp.Kernel": @wp.func_native(snippet) def pack_mf_meta_native( world: int, - mf_constraint_count: wp.array(dtype=int), - mf_dof_a: wp.array2d(dtype=int), - mf_dof_b: wp.array2d(dtype=int), - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_rhs: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_meta: wp.array2d(dtype=int), + mf_constraint_count: wp.array[int], + mf_dof_a: wp.array2d[int], + mf_dof_b: wp.array2d[int], + mf_eff_mass_inv: wp.array2d[float], + mf_rhs: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_meta: wp.array2d[int], ): ... def pack_mf_meta_template( - mf_constraint_count: wp.array(dtype=int), - mf_dof_a: wp.array2d(dtype=int), - mf_dof_b: wp.array2d(dtype=int), - mf_eff_mass_inv: wp.array2d(dtype=float), - mf_rhs: wp.array2d(dtype=float), - mf_row_type: wp.array2d(dtype=int), - mf_row_parent: wp.array2d(dtype=int), - mf_meta: wp.array2d(dtype=int), + mf_constraint_count: wp.array[int], + mf_dof_a: wp.array2d[int], + mf_dof_b: wp.array2d[int], + mf_eff_mass_inv: wp.array2d[float], + mf_rhs: wp.array2d[float], + mf_row_type: wp.array2d[int], + mf_row_parent: wp.array2d[int], + mf_meta: wp.array2d[int], ): world, _lane = wp.tid() pack_mf_meta_native( @@ -4736,58 +4513,58 @@ def _build_pgs_solve_mf_gs_kernel( def pgs_solve_mf_gs_native( world: int, # Dense - world_constraint_count: wp.array(dtype=int), - world_dof_start: wp.array(dtype=int), - rhs_bias: wp.array2d(dtype=float), - world_diag: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), - J_world: wp.array3d(dtype=float), - Y_world: wp.array3d(dtype=float), - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_dof_start: wp.array[int], + rhs_bias: wp.array2d[float], + world_diag: wp.array2d[float], + world_impulses: wp.array2d[float], + J_world: wp.array3d[float], + Y_world: wp.array3d[float], + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], # MF - mf_constraint_count: wp.array(dtype=int), - mf_meta: wp.array2d(dtype=int), - mf_impulses: wp.array2d(dtype=float), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_row_mu: wp.array2d(dtype=float), + mf_constraint_count: wp.array[int], + mf_meta: wp.array2d[int], + mf_impulses: wp.array2d[float], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_row_mu: wp.array2d[float], # Shared iterations: int, omega: float, # Output - v_out: wp.array(dtype=float), + v_out: wp.array[float], ): ... def pgs_solve_mf_gs_template( # Dense - world_constraint_count: wp.array(dtype=int), - world_dof_start: wp.array(dtype=int), - rhs_bias: wp.array2d(dtype=float), - world_diag: wp.array2d(dtype=float), - world_impulses: wp.array2d(dtype=float), - J_world: wp.array3d(dtype=float), - Y_world: wp.array3d(dtype=float), - world_row_type: wp.array2d(dtype=int), - world_row_parent: wp.array2d(dtype=int), - world_row_mu: wp.array2d(dtype=float), + world_constraint_count: wp.array[int], + world_dof_start: wp.array[int], + rhs_bias: wp.array2d[float], + world_diag: wp.array2d[float], + world_impulses: wp.array2d[float], + J_world: wp.array3d[float], + Y_world: wp.array3d[float], + world_row_type: wp.array2d[int], + world_row_parent: wp.array2d[int], + world_row_mu: wp.array2d[float], # MF - mf_constraint_count: wp.array(dtype=int), - mf_meta: wp.array2d(dtype=int), - mf_impulses: wp.array2d(dtype=float), - mf_J_a: wp.array3d(dtype=float), - mf_J_b: wp.array3d(dtype=float), - mf_MiJt_a: wp.array3d(dtype=float), - mf_MiJt_b: wp.array3d(dtype=float), - mf_row_mu: wp.array2d(dtype=float), + mf_constraint_count: wp.array[int], + mf_meta: wp.array2d[int], + mf_impulses: wp.array2d[float], + mf_J_a: wp.array3d[float], + mf_J_b: wp.array3d[float], + mf_MiJt_a: wp.array3d[float], + mf_MiJt_b: wp.array3d[float], + mf_row_mu: wp.array2d[float], # Shared iterations: int, omega: float, # Output - v_out: wp.array(dtype=float), + v_out: wp.array[float], ): world, _lane = wp.tid() pgs_solve_mf_gs_native( diff --git a/newton/tests/test_feather_pgs.py b/newton/tests/test_feather_pgs.py new file mode 100644 index 0000000000..fcb6df11a8 --- /dev/null +++ b/newton/tests/test_feather_pgs.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import unittest + +import numpy as np + +import newton +from newton._src.solvers.feather_pgs import SolverFeatherPGS + + +class TestFeatherPGS(unittest.TestCase): + def _make_model(self): + builder = newton.ModelBuilder(gravity=0.0) + builder.add_body(mass=1.0, inertia=np.eye(3, dtype=np.float32), label="body") + return builder.finalize(device="cpu") + + def test_constructor_signature_is_matrix_free_only(self): + params = inspect.signature(SolverFeatherPGS).parameters + + self.assertNotIn("pgs_mode", params) + self.assertNotIn("cholesky_kernel", params) + self.assertNotIn("trisolve_kernel", params) + self.assertNotIn("hinv_jt_kernel", params) + self.assertNotIn("delassus_kernel", params) + self.assertNotIn("pgs_kernel", params) + + def test_step_smoke_runs_matrix_free_solver(self): + model = self._make_model() + solver = SolverFeatherPGS(model, pgs_iterations=1) + + state_in = model.state() + state_out = model.state() + control = model.control() + contacts = model.contacts() + + solver.step(state_in, state_out, control, contacts, 1.0e-3) + + self.assertIsNone(solver.C) + self.assertIsNotNone(solver.J_world) + self.assertIsNotNone(solver.Y_world) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 147ebd5366c12c73c5c7211e2b748135530abdac Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 12:53:03 -0400 Subject: [PATCH 3/7] Remove dead FeatherPGS solver branches Drop the unused dense-only and standalone matrix-free helper paths\nfrom the private FeatherPGS line so the branch reflects the\nshipped winner path only.\n\nAlso update the living ExecPlan and PR draft polish to match this\nMilestone 2 pass and keep the review artifacts accurate. --- .../execplans/fpgs-private-api-matrix-free.md | 18 +- .../fpgs-private-api-pr-description-draft.md | 4 +- .../solvers/feather_pgs/solver_feather_pgs.py | 1268 +---------------- 3 files changed, 33 insertions(+), 1257 deletions(-) diff --git a/.agent/execplans/fpgs-private-api-matrix-free.md b/.agent/execplans/fpgs-private-api-matrix-free.md index 6d4065519b..6a4abe1d06 100644 --- a/.agent/execplans/fpgs-private-api-matrix-free.md +++ b/.agent/execplans/fpgs-private-api-matrix-free.md @@ -95,8 +95,16 @@ Progress update (2026-04-13, pass 1): solve only. - Added focused unit coverage for the stripped-down constructor surface and a minimal matrix-free smoke step in `newton/tests/test_feather_pgs.py`. -- Remaining cleanup for later passes: remove dead dense/split-only helper code - and supporting branch-local references that are no longer reachable. + +Progress update (2026-04-13, pass 2): +- Removed the remaining dead dense-only and standalone matrix-free solve + helpers from `solver_feather_pgs.py`, including obsolete dense kernel + factory entries and branch-local solver state that the private line no longer + reaches. +- Reworded the surviving GS kernel comments/docstrings to describe the unified + winner path rather than an ablation-era "dense + matrix-free" split. +- Milestone 2 is now functionally complete on the private solver surface; any + leftover cleanup is supporting-surface polish tracked in Milestone 3. Required work: - Remove `dense` and `split` support from the private API implementation. @@ -185,6 +193,12 @@ Checkpoint: constructor behavior. Dense/split-only helper code that becomes unreachable in that slice should be removed when practical; broader supporting cleanup stays in Milestone 3. +- Replan update (2026-04-13, pass 2): + The next slice tightens Milestone 2 itself instead of jumping ahead: remove + the now-dead dense-only / standalone-MF helper paths and rewrite the live GS + kernel commentary so the private branch reflects one coherent articulated + + free-rigid winner path. After that push, stop before Milestone 3 and leave + supporting-surface cleanup plus final validation recording for later passes. - If a milestone is too large to finish cleanly, begin the pass by tightening this ExecPlan with a short replan note and then complete one reviewable slice of that milestone. diff --git a/.agent/review/fpgs-private-api-pr-description-draft.md b/.agent/review/fpgs-private-api-pr-description-draft.md index 58a35e57cb..9e0fca87ca 100644 --- a/.agent/review/fpgs-private-api-pr-description-draft.md +++ b/.agent/review/fpgs-private-api-pr-description-draft.md @@ -20,8 +20,8 @@ removes obsolete top-level mode selection and kernel-selection API knobs from ## Why Matrix-Free Only The published nightly ablations already show that the matrix-free path is the - winner for the private line, while the dense/split modes mostly preserve - research history rather than current product intent. +winner for the private line, while the dense/split modes mostly preserve +research history rather than current product intent. Using the published nightly run `2026-04-01T20-49-30Z` (`summary.json`, commit `53b3188`) from the `gh-pages` artifacts: diff --git a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py index 2485a97bd1..306f2e9349 100644 --- a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py +++ b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py @@ -43,7 +43,6 @@ apply_impulses_world_par_dof, build_augmented_joint_rows, build_mass_update_mask, - build_mf_body_map, build_mf_contact_rows, cholesky_loop, clamp_joint_tau, @@ -76,8 +75,6 @@ integrate_generalized_joints, pack_contact_linear_force_as_spatial, pgs_convergence_diagnostic_velocity, - pgs_solve_loop, - pgs_solve_mf_loop, populate_joint_limit_J_for_size, populate_world_J_for_size, prepare_world_impulses, @@ -226,8 +223,6 @@ def __init__( self._pgs_convergence_log: list[np.ndarray] = [] self.small_dof_threshold = 12 self.delassus_chunk_size = None - self.pgs_chunk_size = 1 - self.pgs_kernel = "tiled_row" self.use_parallel_streams = use_parallel_streams self._step = 0 @@ -830,11 +825,11 @@ def _allocate_mf_buffers(self, model): self.mf_body_Hinv = wp.zeros((body_count,), dtype=wp.spatial_matrix, device=device, requires_grad=requires_grad) - # World-relative DOF offsets for two-phase GS kernel + # World-relative DOF offsets for the unified articulated + free-rigid GS kernel. self.mf_dof_a = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) self.mf_dof_b = wp.zeros((worlds, mf_max_c), dtype=wp.int32, device=device, requires_grad=requires_grad) - # Packed MF metadata for two-phase GS kernel (int4 per constraint): + # Packed MF metadata for the unified GS kernel (int4 per constraint): # .x = (dof_a << 16) | (dof_b & 0xFFFF) # .y = __float_as_int(eff_mass_inv) # .z = __float_as_int(rhs) @@ -2362,101 +2357,6 @@ def _stage5_prepare_impulses_world(self): device=self.model.device, ) - def _dispatch_dense_pgs_solve(self, iterations: int): - """Dispatch the dense PGS kernel with a given iteration count.""" - saved = self.pgs_iterations - self.pgs_iterations = iterations - if self.pgs_kernel == "tiled_row": - self._stage5_pgs_solve_world_tiled_row() - elif self.pgs_kernel == "tiled_contact": - self._stage5_pgs_solve_world_tiled_contact() - elif self.pgs_kernel == "streaming": - self._stage5_pgs_solve_world_streaming() - else: - self._stage5_pgs_solve_world_loop() - self.pgs_iterations = saved - - def _stage5_pgs_solve_world_tiled_row(self): - pgs_kernel = TiledKernelFactory.get_pgs_solve_tiled_row_kernel(self.dense_max_constraints, self.model.device) - wp.launch_tiled( - pgs_kernel, - dim=[self.world_count], - inputs=[ - self.constraint_count, - self.diag, - self.C, - self.rhs, - self.impulses, - self.pgs_iterations, - self.pgs_omega, - self.row_type, - self.row_parent, - self.row_mu, - ], - block_dim=32, - device=self.model.device, - ) - - def _stage5_pgs_solve_world_loop(self): - wp.launch( - pgs_solve_loop, - dim=self.world_count, - inputs=[ - self.constraint_count, - self.dense_max_constraints, - self.diag, - self.C, - self.rhs, - self.impulses, - self.pgs_iterations, - self.pgs_omega, - self.row_type, - self.row_parent, - self.row_mu, - ], - device=self.model.device, - ) - - def _stage5_pgs_solve_world_tiled_contact(self): - pgs_kernel = TiledKernelFactory.get_pgs_solve_tiled_contact_kernel( - self.dense_max_constraints, self.model.device - ) - wp.launch_tiled( - pgs_kernel, - dim=[self.world_count], - inputs=[ - self.constraint_count, - self.C, - self.rhs, - self.impulses, - self.pgs_iterations, - self.pgs_omega, - self.row_mu, - ], - block_dim=32, - device=self.model.device, - ) - - def _stage5_pgs_solve_world_streaming(self): - pgs_kernel = TiledKernelFactory.get_pgs_solve_streaming_kernel( - self.dense_max_constraints, self.model.device, pgs_chunk_size=self.pgs_chunk_size - ) - wp.launch_tiled( - pgs_kernel, - dim=[self.world_count], - inputs=[ - self.constraint_count, - self.C, - self.rhs, - self.impulses, - self.pgs_iterations, - self.pgs_omega, - self.row_mu, - ], - block_dim=32, - device=self.model.device, - ) - def _stage6_prepare_world_velocity(self): wp.copy(self.v_out, self.v_hat) @@ -2482,11 +2382,6 @@ def _stage6_apply_impulses_world(self, size: int): device=model.device, ) - def _stage6b_mf_pgs(self, state_aug: State, dt: float): - """Run matrix-free PGS for free rigid body contacts.""" - self._mf_pgs_setup(state_aug, dt) - self._mf_pgs_solve(self.pgs_iterations) - def _mf_pgs_setup(self, state_aug: State, dt: float): """MF PGS setup: compute Hinv, compute effective mass and RHS.""" model = self.model @@ -2533,92 +2428,6 @@ def _mf_pgs_setup(self, state_aug: State, dt: float): device=model.device, ) - def _mf_pgs_solve(self, iterations: int): - """MF PGS solve with given iteration count.""" - model = self.model - - # Build compact body map for standalone MF kernel - wp.launch( - build_mf_body_map, - dim=self.world_count, - inputs=[ - self.mf_constraint_count, - self.mf_body_a, - self.mf_body_b, - self.body_to_articulation, - self.articulation_dof_start, - self.max_mf_bodies, - ], - outputs=[ - self.mf_body_list, - self.mf_body_dof_start, - self.mf_body_count, - self.mf_local_body_a, - self.mf_local_body_b, - ], - device=model.device, - ) - - if model.device.is_cuda: - mf_pgs_kernel = TiledKernelFactory.get_pgs_solve_mf_kernel( - self.mf_max_constraints, self.max_mf_bodies, model.device - ) - wp.launch_tiled( - mf_pgs_kernel, - dim=[self.world_count], - inputs=[ - self.mf_constraint_count, - self.mf_body_count, - self.mf_body_dof_start, - self.mf_local_body_a, - self.mf_local_body_b, - self.mf_J_a, - self.mf_J_b, - self.mf_MiJt_a, - self.mf_MiJt_b, - self.mf_eff_mass_inv, - self.mf_rhs, - self.mf_row_type, - self.mf_row_parent, - self.mf_row_mu, - self.mf_impulses, - self.v_out, - iterations, - self.pgs_omega, - ], - block_dim=32, - device=model.device, - ) - else: - # CPU fallback: use the loop kernel - wp.launch( - pgs_solve_mf_loop, - dim=self.world_count, - inputs=[ - self.mf_constraint_count, - self.mf_body_a, - self.mf_body_b, - self.mf_MiJt_a, - self.mf_MiJt_b, - self.mf_J_a, - self.mf_J_b, - self.mf_eff_mass_inv, - self.mf_rhs, - self.mf_row_type, - self.mf_row_parent, - self.mf_row_mu, - self.body_to_articulation, - self.articulation_dof_start, - iterations, - self.pgs_omega, - ], - outputs=[ - self.mf_impulses, - self.v_out, - ], - device=model.device, - ) - def _stage6_update_qdd(self, state_in: State, state_aug: State, dt: float): model = self.model if self._has_root_free: @@ -2685,10 +2494,6 @@ class TiledKernelFactory: _hinv_jt_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} _hinv_jt_fused_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} _cholesky_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} - _pgs_solve_tiled_row_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} - _pgs_solve_tiled_contact_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} - _pgs_solve_streaming_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} - _pgs_solve_mf_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} _pgs_solve_mf_gs_cache: ClassVar[dict[tuple[int, int, str], "wp.Kernel"]] = {} _pack_mf_meta_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} _triangular_solve_cache: ClassVar[dict[tuple[int, str], "wp.Kernel"]] = {} @@ -3004,1050 +2809,6 @@ def trisolve_tiled_template( trisolve_tiled_template.__qualname__ = f"trisolve_tiled_{n_dofs}" return wp.kernel(enable_backward=False, module="unique")(trisolve_tiled_template) - @classmethod - def get_pgs_solve_tiled_row_kernel(cls, max_constraints: int, device: "wp.Device") -> "wp.Kernel": - """Get or create a tiled row-wise PGS world solve kernel for the given constraint count.""" - key = (max_constraints, device.arch) - if key not in cls._pgs_solve_tiled_row_cache: - cls._pgs_solve_tiled_row_cache[key] = cls._build_pgs_solve_tiled_row_kernel(max_constraints) - return cls._pgs_solve_tiled_row_cache[key] - - @classmethod - def _build_pgs_solve_tiled_row_kernel(cls, max_constraints: int) -> "wp.Kernel": - """PGS world solve kernel that stages only the LOWER triangle of Delassus. - - Shared memory footprint drops from M*M to M*(M+1)/2 floats. - Uses symmetry in dot: C(i,j) = L(i,j) if j<=i else L(j,i). - """ - TILE_M = max_constraints - TILE_M_SQ = TILE_M * TILE_M - TILE_TRI = TILE_M * (TILE_M + 1) // 2 - - ELEMS_PER_THREAD_1D = (TILE_M + 31) // 32 - - def gen_load_1d(dst, src): - return "\n".join( - [ - f" {dst}[lane + {k * 32}] = {src}.data[off1 + lane + {k * 32}];" - for k in range(ELEMS_PER_THREAD_1D) - if (k * 32) < TILE_M - ] - ) - - # Build a deterministic packed-lower-tri index order: row-major over (i, j<=i) - # idx = i*(i+1)/2 + j - tri_pairs = [] - for i in range(TILE_M): - base = i * (i + 1) // 2 - for j in range(i + 1): - tri_pairs.append((base + j, i, j)) - assert len(tri_pairs) == TILE_TRI - - load_code = "\n".join( - [ - gen_load_1d("s_lam", "world_impulses"), - gen_load_1d("s_rhs", "world_rhs"), - gen_load_1d("s_diag", "world_diag"), - gen_load_1d("s_rtype", "world_row_type"), - gen_load_1d("s_parent", "world_row_parent"), - gen_load_1d("s_mu", "world_row_mu"), - ] - ) - - # Precompute lane's column indices (j_k) and their triangular bases (j_k*(j_k+1)/2) - # so inside the dot we avoid multiply. - precompute_j = [] - for k in range(ELEMS_PER_THREAD_1D): - j = k * 32 - if j < TILE_M: - precompute_j.append( - f" const int j{k} = lane + {j};\n const int jb{k} = (j{k} * (j{k} + 1)) >> 1;" - ) - precompute_j_code = "\n".join(precompute_j) - - # Dot code: guarded on j_k < m - dot_terms = [] - for k in range(ELEMS_PER_THREAD_1D): - joff = k * 32 - if joff < TILE_M: - dot_terms.append( - f""" if (j{k} < m) {{ - // Use symmetry to fetch C(i, j{k}) from packed-lower shared. - // base_i = i*(i+1)/2 - float cij = (j{k} <= i) ? s_Ctri[base_i + j{k}] : s_Ctri[jb{k} + i]; - my_sum += cij * s_lam[j{k}]; - }}""" - ) - dot_code = "\n".join(["float my_sum = 0.0f;", "int base_i = (i * (i + 1)) >> 1;", *dot_terms]) - - store_code = "\n".join( - [ - f" world_impulses.data[off1 + lane + {k * 32}] = s_lam[lane + {k * 32}];" - for k in range(ELEMS_PER_THREAD_1D) - if (k * 32) < TILE_M - ] - ) - - snippet = f""" - #if defined(__CUDA_ARCH__) - const int TILE_M = {TILE_M}; - const int TILE_M_SQ = {TILE_M_SQ}; - const int TILE_TRI = {TILE_TRI}; - const unsigned MASK = 0xFFFFFFFF; - - int lane = threadIdx.x; - - int m = world_constraint_count.data[world]; - if (m == 0) return; - - // Packed LOWER triangle of C in row-major (i*(i+1)/2 + j), j<=i - __shared__ float s_Ctri[TILE_TRI]; - - __shared__ float s_lam[TILE_M]; - __shared__ float s_rhs[TILE_M]; - __shared__ float s_diag[TILE_M]; - __shared__ int s_rtype[TILE_M]; - __shared__ int s_parent[TILE_M]; - __shared__ float s_mu[TILE_M]; - - int off1 = world * TILE_M; - int off2 = world * TILE_M_SQ; - - {load_code} - - // Load only lower triangle from global full matrix into packed shared. - // Work distribution: each lane walks rows; for each row i, lane loads j = lane, lane+32, lane+64... - for (int i = 0; i < TILE_M; ++i) {{ - int base = (i * (i + 1)) >> 1; // packed base for row i - for (int j = lane; j <= i; j += 32) {{ - s_Ctri[base + j] = world_C.data[off2 + i * TILE_M + j]; - }} - }} - __syncwarp(); - - {precompute_j_code} - - for (int iter = 0; iter < iterations; iter++) {{ - for (int i = 0; i < m; i++) {{ - // NOTE: single-warp kernel; __syncwarp here is typically unnecessary unless divergence occurs - // before the dot. If you want max perf, try removing it after verifying correctness. - // __syncwarp(); - - {dot_code} - - // Warp reduce my_sum - my_sum += __shfl_down_sync(MASK, my_sum, 16); - my_sum += __shfl_down_sync(MASK, my_sum, 8); - my_sum += __shfl_down_sync(MASK, my_sum, 4); - my_sum += __shfl_down_sync(MASK, my_sum, 2); - my_sum += __shfl_down_sync(MASK, my_sum, 1); - float dot_sum = __shfl_sync(MASK, my_sum, 0); - - float denom = s_diag[i]; - if (denom <= 0.0f) continue; - - float w_val = s_rhs[i] + dot_sum; - float delta = -w_val / denom; - float new_impulse = s_lam[i] + omega * delta; - int row_type = s_rtype[i]; - - if (row_type == 0 || row_type == 3) {{ - if (new_impulse < 0.0f) new_impulse = 0.0f; - s_lam[i] = new_impulse; - }} else if (row_type == 2) {{ - int parent_idx = s_parent[i]; - float lambda_n = s_lam[parent_idx]; - float mu = s_mu[i]; - float radius = fmaxf(mu * lambda_n, 0.0f); - - if (radius <= 0.0f) {{ - s_lam[i] = 0.0f; - }} else {{ - s_lam[i] = new_impulse; - int sib = (i == parent_idx + 1) ? (parent_idx + 2) : (parent_idx + 1); - float a = s_lam[i]; - float b = s_lam[sib]; - float mag = sqrtf(a * a + b * b); - if (mag > radius) {{ - float scale = radius / mag; - s_lam[i] = a * scale; - s_lam[sib] = b * scale; - }} - }} - }} else {{ - s_lam[i] = new_impulse; - }} - }} - }} - - {store_code} - #endif - """ - - @wp.func_native(snippet) - def pgs_solve_native( - world: int, - world_constraint_count: wp.array[int], - world_diag: wp.array2d[float], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_type: wp.array2d[int], - world_row_parent: wp.array2d[int], - world_row_mu: wp.array2d[float], - ): ... - - def pgs_solve_tiled_template( - world_constraint_count: wp.array[int], - world_diag: wp.array2d[float], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_type: wp.array2d[int], - world_row_parent: wp.array2d[int], - world_row_mu: wp.array2d[float], - ): - world, _lane = wp.tid() - pgs_solve_native( - world, - world_constraint_count, - world_diag, - world_C, - world_rhs, - world_impulses, - iterations, - omega, - world_row_type, - world_row_parent, - world_row_mu, - ) - - pgs_solve_tiled_template.__name__ = f"pgs_solve_tiled_row_{max_constraints}" - pgs_solve_tiled_template.__qualname__ = f"pgs_solve_tiled_row_{max_constraints}" - return wp.kernel(enable_backward=False, module="unique")(pgs_solve_tiled_template) - - @classmethod - def get_pgs_solve_tiled_contact_kernel(cls, max_constraints: int, device: "wp.Device") -> "wp.Kernel": - """Get or create a tiled contact-wise PGS world solve kernel using 3x3 block formulation.""" - key = (max_constraints, device.arch) - if key not in cls._pgs_solve_tiled_contact_cache: - cls._pgs_solve_tiled_contact_cache[key] = cls._build_pgs_solve_tiled_contact_kernel(max_constraints) - return cls._pgs_solve_tiled_contact_cache[key] - - @classmethod - def _build_pgs_solve_tiled_contact_kernel(cls, max_constraints: int) -> "wp.Kernel": - """PGS world solve kernel using 3x3 block formulation. - - Stores only the LOWER triangle of block Delassus matrix. - Each contact is a 3-vector (normal, tangent1, tangent2). - Reduces serial depth from M to M/3. - - TILE_M can be any value (power of 2 recommended for other kernels). - Runtime m must be divisible by 3. - """ - TILE_M = max_constraints - # Max contacts we can handle (rounded down) - NUM_CONTACTS_MAX = TILE_M // 3 - # Actual max constraints we'll process (may be < TILE_M) - TILE_M_USABLE = NUM_CONTACTS_MAX * 3 - - # Lower triangle of block matrix (sized for max) - NUM_BLOCKS_TRI = NUM_CONTACTS_MAX * (NUM_CONTACTS_MAX + 1) // 2 - BLOCK_TRI_FLOATS = NUM_BLOCKS_TRI * 9 - - snippet = f""" - #if defined(__CUDA_ARCH__) - const int TILE_M = {TILE_M}; - const int TILE_M_USABLE = {TILE_M_USABLE}; - const int NUM_CONTACTS_MAX = {NUM_CONTACTS_MAX}; - const int BLOCK_TRI_FLOATS = {BLOCK_TRI_FLOATS}; - const unsigned MASK = 0xFFFFFFFF; - - int lane = threadIdx.x; - - int m = world_constraint_count.data[world]; - if (m == 0) return; - - // Clamp m to usable range and ensure divisible by 3 - if (m > TILE_M_USABLE) m = TILE_M_USABLE; - int num_contacts = m / 3; - - // Shared memory (sized for max) - __shared__ float s_Dtri[BLOCK_TRI_FLOATS]; - __shared__ float s_Dinv[NUM_CONTACTS_MAX * 9]; - __shared__ float s_lam[TILE_M_USABLE]; - __shared__ float s_rhs[TILE_M_USABLE]; - __shared__ float s_mu[NUM_CONTACTS_MAX]; - - int off1 = world * TILE_M; - int off2 = world * TILE_M * TILE_M; - - // ============ LOAD PHASE ============ - - // Load lambda and rhs - for (int i = lane; i < TILE_M_USABLE; i += 32) {{ - if (i < m) {{ - s_lam[i] = world_impulses.data[off1 + i]; - s_rhs[i] = world_rhs.data[off1 + i]; - }} else {{ - s_lam[i] = 0.0f; - s_rhs[i] = 0.0f; - }} - }} - - // Load mu (one per contact, stored on tangent1 row) - for (int c = lane; c < NUM_CONTACTS_MAX; c += 32) {{ - if (c < num_contacts) {{ - s_mu[c] = world_row_mu.data[off1 + c * 3 + 1]; - }} - }} - - // Load lower triangle of block Delassus - for (int c = 0; c < num_contacts; c++) {{ - int base_block = (c * (c + 1)) >> 1; - int floats_in_row = (c + 1) * 9; - - for (int f = lane; f < floats_in_row; f += 32) {{ - int j = f / 9; - int k = f % 9; - int lr = k / 3; - int lc = k % 3; - int gr = c * 3 + lr; - int gc = j * 3 + lc; - s_Dtri[(base_block + j) * 9 + k] = world_C.data[off2 + gr * TILE_M + gc]; - }} - }} - __syncwarp(); - - // Compute diagonal block inverses - for (int c = lane; c < num_contacts; c += 32) {{ - int diag_block_idx = ((c * (c + 1)) >> 1) + c; - const float* D = &s_Dtri[diag_block_idx * 9]; - float* Dinv = &s_Dinv[c * 9]; - - float det = D[0] * (D[4] * D[8] - D[5] * D[7]) - - D[1] * (D[3] * D[8] - D[5] * D[6]) - + D[2] * (D[3] * D[7] - D[4] * D[6]); - - float inv_det = 1.0f / det; - - Dinv[0] = (D[4] * D[8] - D[5] * D[7]) * inv_det; - Dinv[1] = (D[2] * D[7] - D[1] * D[8]) * inv_det; - Dinv[2] = (D[1] * D[5] - D[2] * D[4]) * inv_det; - Dinv[3] = (D[5] * D[6] - D[3] * D[8]) * inv_det; - Dinv[4] = (D[0] * D[8] - D[2] * D[6]) * inv_det; - Dinv[5] = (D[2] * D[3] - D[0] * D[5]) * inv_det; - Dinv[6] = (D[3] * D[7] - D[4] * D[6]) * inv_det; - Dinv[7] = (D[1] * D[6] - D[0] * D[7]) * inv_det; - Dinv[8] = (D[0] * D[4] - D[1] * D[3]) * inv_det; - }} - __syncwarp(); - - // ============ ITERATION PHASE ============ - - for (int iter = 0; iter < iterations; iter++) {{ - for (int c = 0; c < num_contacts; c++) {{ - float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f; - - for (int j = lane; j < num_contacts; j += 32) {{ - float l0 = s_lam[j * 3 + 0]; - float l1 = s_lam[j * 3 + 1]; - float l2 = s_lam[j * 3 + 2]; - - int block_off; - bool transpose; - if (j <= c) {{ - block_off = (((c * (c + 1)) >> 1) + j) * 9; - transpose = false; - }} else {{ - block_off = (((j * (j + 1)) >> 1) + c) * 9; - transpose = true; - }} - - const float* B = &s_Dtri[block_off]; - - if (!transpose) {{ - sum0 += B[0] * l0 + B[1] * l1 + B[2] * l2; - sum1 += B[3] * l0 + B[4] * l1 + B[5] * l2; - sum2 += B[6] * l0 + B[7] * l1 + B[8] * l2; - }} else {{ - sum0 += B[0] * l0 + B[3] * l1 + B[6] * l2; - sum1 += B[1] * l0 + B[4] * l1 + B[7] * l2; - sum2 += B[2] * l0 + B[5] * l1 + B[8] * l2; - }} - }} - - // Warp reduce - sum0 += __shfl_down_sync(MASK, sum0, 16); - sum1 += __shfl_down_sync(MASK, sum1, 16); - sum2 += __shfl_down_sync(MASK, sum2, 16); - sum0 += __shfl_down_sync(MASK, sum0, 8); - sum1 += __shfl_down_sync(MASK, sum1, 8); - sum2 += __shfl_down_sync(MASK, sum2, 8); - sum0 += __shfl_down_sync(MASK, sum0, 4); - sum1 += __shfl_down_sync(MASK, sum1, 4); - sum2 += __shfl_down_sync(MASK, sum2, 4); - sum0 += __shfl_down_sync(MASK, sum0, 2); - sum1 += __shfl_down_sync(MASK, sum1, 2); - sum2 += __shfl_down_sync(MASK, sum2, 2); - sum0 += __shfl_down_sync(MASK, sum0, 1); - sum1 += __shfl_down_sync(MASK, sum1, 1); - sum2 += __shfl_down_sync(MASK, sum2, 1); - - if (lane == 0) {{ - // Corrected sign: -(rhs + D*lambda) - float res0 = -(s_rhs[c * 3 + 0] + sum0); - float res1 = -(s_rhs[c * 3 + 1] + sum1); - float res2 = -(s_rhs[c * 3 + 2] + sum2); - - const float* Dinv = &s_Dinv[c * 9]; - float d0 = Dinv[0] * res0 + Dinv[1] * res1 + Dinv[2] * res2; - float d1 = Dinv[3] * res0 + Dinv[4] * res1 + Dinv[5] * res2; - float d2 = Dinv[6] * res0 + Dinv[7] * res1 + Dinv[8] * res2; - - float new_n = s_lam[c * 3 + 0] + omega * d0; - float new_t1 = s_lam[c * 3 + 1] + omega * d1; - float new_t2 = s_lam[c * 3 + 2] + omega * d2; - - // Friction cone projection - new_n = fmaxf(new_n, 0.0f); - - float mu = s_mu[c]; - float radius = mu * new_n; - - if (radius <= 0.0f) {{ - new_t1 = 0.0f; - new_t2 = 0.0f; - }} else {{ - float t_mag_sq = new_t1 * new_t1 + new_t2 * new_t2; - if (t_mag_sq > radius * radius) {{ - float scale = radius * rsqrtf(t_mag_sq); - new_t1 *= scale; - new_t2 *= scale; - }} - }} - - s_lam[c * 3 + 0] = new_n; - s_lam[c * 3 + 1] = new_t1; - s_lam[c * 3 + 2] = new_t2; - }} - __syncwarp(); - }} - }} - - // ============ STORE PHASE ============ - - for (int i = lane; i < TILE_M_USABLE; i += 32) {{ - if (i < m) {{ - world_impulses.data[off1 + i] = s_lam[i]; - }} - }} - #endif - """ - - @wp.func_native(snippet) - def pgs_solve_contact_native( - world: int, - world_constraint_count: wp.array[int], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_mu: wp.array2d[float], - ): ... - - def pgs_solve_tiled_contact_template( - world_constraint_count: wp.array[int], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_mu: wp.array2d[float], - ): - world, _lane = wp.tid() - pgs_solve_contact_native( - world, - world_constraint_count, - world_C, - world_rhs, - world_impulses, - iterations, - omega, - world_row_mu, - ) - - pgs_solve_tiled_contact_template.__name__ = f"pgs_solve_tiled_contact_{max_constraints}" - pgs_solve_tiled_contact_template.__qualname__ = f"pgs_solve_tiled_contact_{max_constraints}" - return wp.kernel(enable_backward=False, module="unique")(pgs_solve_tiled_contact_template) - - @classmethod - def get_pgs_solve_streaming_kernel( - cls, max_constraints: int, device: "wp.Device", pgs_chunk_size: int = 1 - ) -> "wp.Kernel": - """Get or create a streaming contact-wise PGS world solve kernel.""" - key = (max_constraints, device.arch, pgs_chunk_size) - if key not in cls._pgs_solve_streaming_cache: - cls._pgs_solve_streaming_cache[key] = cls._build_pgs_solve_streaming_kernel(max_constraints, pgs_chunk_size) - return cls._pgs_solve_streaming_cache[key] - - @classmethod - def _build_pgs_solve_streaming_kernel(cls, max_constraints: int, pgs_chunk_size: int = 1) -> "wp.Kernel": - """Streaming contact-wise PGS kernel that streams block-rows from global memory. - - Unlike tiled_contact which loads the entire Delassus matrix into shared memory, - this kernel keeps only lambda and auxiliaries in shared memory and streams - block-rows of C on demand. This enables handling much larger constraint counts - (hundreds of contacts) at the cost of increased global memory bandwidth. - - When pgs_chunk_size > 1, multiple block-rows are preloaded into shared memory - at once, reducing the number of global memory round-trips per PGS iteration. - - Algorithm: - - Load lambda, rhs, mu, and compute diagonal block inverses once - - For each PGS iteration: - - For each chunk of pgs_chunk_size contacts: - - Preload pgs_chunk_size block-rows of C into shared memory - - For each contact c in the chunk: - - Compute block-row dot product with lambda (warp-parallel) - - Update lambda[c] with friction cone projection (lane 0) - - Store final lambda back to global memory - """ - TILE_M = max_constraints - NUM_CONTACTS_MAX = TILE_M // 3 - TILE_M_USABLE = NUM_CONTACTS_MAX * 3 - PGS_CHUNK = pgs_chunk_size - - snippet = f""" - #if defined(__CUDA_ARCH__) - const int TILE_M = {TILE_M}; - const int TILE_M_USABLE = {TILE_M_USABLE}; - const int NUM_CONTACTS_MAX = {NUM_CONTACTS_MAX}; - const int PGS_CHUNK = {PGS_CHUNK}; - const unsigned MASK = 0xFFFFFFFF; - - int lane = threadIdx.x; - - int m = world_constraint_count.data[world]; - if (m == 0) return; - - // Clamp m to usable range and ensure divisible by 3 - if (m > TILE_M_USABLE) m = TILE_M_USABLE; - int num_contacts = m / 3; - - // ═══════════════════════════════════════════════════════════════ - // SHARED MEMORY: lambda, rhs, mu, diagonal inverses, and - // block-row buffer for PGS_CHUNK contacts at a time - // ═══════════════════════════════════════════════════════════════ - __shared__ float s_lam[{TILE_M_USABLE}]; - __shared__ float s_rhs[{TILE_M_USABLE}]; - __shared__ float s_mu[{NUM_CONTACTS_MAX}]; - __shared__ float s_Dinv[{NUM_CONTACTS_MAX} * 9]; - __shared__ float s_block_rows[{PGS_CHUNK} * {NUM_CONTACTS_MAX} * 9]; - - int off1 = world * TILE_M; - int off2 = world * TILE_M * TILE_M; - - // ═══════════════════════════════════════════════════════════════ - // LOAD PHASE: Load persistent data into shared memory - // ═══════════════════════════════════════════════════════════════ - - // Load lambda and rhs (coalesced) - for (int i = lane; i < TILE_M_USABLE; i += 32) {{ - if (i < m) {{ - s_lam[i] = world_impulses.data[off1 + i]; - s_rhs[i] = world_rhs.data[off1 + i]; - }} else {{ - s_lam[i] = 0.0f; - s_rhs[i] = 0.0f; - }} - }} - - // Load mu (one per contact, stored on tangent1 row) - for (int c = lane; c < NUM_CONTACTS_MAX; c += 32) {{ - if (c < num_contacts) {{ - s_mu[c] = world_row_mu.data[off1 + c * 3 + 1]; - }} - }} - __syncwarp(); - - // Compute diagonal block inverses (each thread handles one contact) - for (int c = lane; c < num_contacts; c += 32) {{ - // Load diagonal block D[c,c] from global memory - int diag_row = c * 3; - float D[9]; - for (int k = 0; k < 9; k++) {{ - int lr = k / 3; - int lc = k % 3; - D[k] = world_C.data[off2 + (diag_row + lr) * TILE_M + (diag_row + lc)]; - }} - - // Compute 3x3 inverse - float det = D[0] * (D[4] * D[8] - D[5] * D[7]) - - D[1] * (D[3] * D[8] - D[5] * D[6]) - + D[2] * (D[3] * D[7] - D[4] * D[6]); - - float inv_det = 1.0f / det; - float* Dinv = &s_Dinv[c * 9]; - - Dinv[0] = (D[4] * D[8] - D[5] * D[7]) * inv_det; - Dinv[1] = (D[2] * D[7] - D[1] * D[8]) * inv_det; - Dinv[2] = (D[1] * D[5] - D[2] * D[4]) * inv_det; - Dinv[3] = (D[5] * D[6] - D[3] * D[8]) * inv_det; - Dinv[4] = (D[0] * D[8] - D[2] * D[6]) * inv_det; - Dinv[5] = (D[2] * D[3] - D[0] * D[5]) * inv_det; - Dinv[6] = (D[3] * D[7] - D[4] * D[6]) * inv_det; - Dinv[7] = (D[1] * D[6] - D[0] * D[7]) * inv_det; - Dinv[8] = (D[0] * D[4] - D[1] * D[3]) * inv_det; - }} - __syncwarp(); - - // ═══════════════════════════════════════════════════════════════ - // ITERATION PHASE: Stream block-rows in chunks and solve - // ═══════════════════════════════════════════════════════════════ - - for (int iter = 0; iter < iterations; iter++) {{ - for (int chunk_start = 0; chunk_start < num_contacts; chunk_start += PGS_CHUNK) {{ - int chunk_end = min(chunk_start + PGS_CHUNK, num_contacts); - int chunk_len = chunk_end - chunk_start; - - // ───────────────────────────────────────────────────────── - // STREAM: Preload chunk_len block-rows of Delassus matrix - // ───────────────────────────────────────────────────────── - for (int ci = 0; ci < chunk_len; ci++) {{ - int c = chunk_start + ci; - int c_row = c * 3; - float* row_base = &s_block_rows[ci * NUM_CONTACTS_MAX * 9]; - for (int j = lane; j < num_contacts; j += 32) {{ - int j_col = j * 3; - float* dst = &row_base[j * 9]; - for (int k = 0; k < 9; k++) {{ - int lr = k / 3; - int lc = k % 3; - dst[k] = world_C.data[off2 + (c_row + lr) * TILE_M + (j_col + lc)]; - }} - }} - }} - __syncwarp(); - - // ───────────────────────────────────────────────────────── - // SOLVE: Process each contact in the chunk sequentially - // ───────────────────────────────────────────────────────── - for (int ci = 0; ci < chunk_len; ci++) {{ - int c = chunk_start + ci; - const float* row_base = &s_block_rows[ci * NUM_CONTACTS_MAX * 9]; - - // Block-row dot product sum_j C[c,j] * lambda[j] - float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f; - - for (int j = lane; j < num_contacts; j += 32) {{ - float l0 = s_lam[j * 3 + 0]; - float l1 = s_lam[j * 3 + 1]; - float l2 = s_lam[j * 3 + 2]; - - const float* B = &row_base[j * 9]; - - sum0 += B[0] * l0 + B[1] * l1 + B[2] * l2; - sum1 += B[3] * l0 + B[4] * l1 + B[5] * l2; - sum2 += B[6] * l0 + B[7] * l1 + B[8] * l2; - }} - - // Warp reduce - sum0 += __shfl_down_sync(MASK, sum0, 16); - sum1 += __shfl_down_sync(MASK, sum1, 16); - sum2 += __shfl_down_sync(MASK, sum2, 16); - sum0 += __shfl_down_sync(MASK, sum0, 8); - sum1 += __shfl_down_sync(MASK, sum1, 8); - sum2 += __shfl_down_sync(MASK, sum2, 8); - sum0 += __shfl_down_sync(MASK, sum0, 4); - sum1 += __shfl_down_sync(MASK, sum1, 4); - sum2 += __shfl_down_sync(MASK, sum2, 4); - sum0 += __shfl_down_sync(MASK, sum0, 2); - sum1 += __shfl_down_sync(MASK, sum1, 2); - sum2 += __shfl_down_sync(MASK, sum2, 2); - sum0 += __shfl_down_sync(MASK, sum0, 1); - sum1 += __shfl_down_sync(MASK, sum1, 1); - sum2 += __shfl_down_sync(MASK, sum2, 1); - - // Update: Solve and project (lane 0 only) - if (lane == 0) {{ - float res0 = -(s_rhs[c * 3 + 0] + sum0); - float res1 = -(s_rhs[c * 3 + 1] + sum1); - float res2 = -(s_rhs[c * 3 + 2] + sum2); - - const float* Dinv = &s_Dinv[c * 9]; - float d0 = Dinv[0] * res0 + Dinv[1] * res1 + Dinv[2] * res2; - float d1 = Dinv[3] * res0 + Dinv[4] * res1 + Dinv[5] * res2; - float d2 = Dinv[6] * res0 + Dinv[7] * res1 + Dinv[8] * res2; - - float new_n = s_lam[c * 3 + 0] + omega * d0; - float new_t1 = s_lam[c * 3 + 1] + omega * d1; - float new_t2 = s_lam[c * 3 + 2] + omega * d2; - - // Friction cone projection - new_n = fmaxf(new_n, 0.0f); - - float mu = s_mu[c]; - float radius = mu * new_n; - - if (radius <= 0.0f) {{ - new_t1 = 0.0f; - new_t2 = 0.0f; - }} else {{ - float t_mag_sq = new_t1 * new_t1 + new_t2 * new_t2; - if (t_mag_sq > radius * radius) {{ - float scale = radius * rsqrtf(t_mag_sq); - new_t1 *= scale; - new_t2 *= scale; - }} - }} - - s_lam[c * 3 + 0] = new_n; - s_lam[c * 3 + 1] = new_t1; - s_lam[c * 3 + 2] = new_t2; - }} - __syncwarp(); - }} - }} - }} - - // ═══════════════════════════════════════════════════════════════ - // STORE PHASE: Write final lambda back to global memory - // ═══════════════════════════════════════════════════════════════ - for (int i = lane; i < TILE_M_USABLE; i += 32) {{ - if (i < m) {{ - world_impulses.data[off1 + i] = s_lam[i]; - }} - }} - #endif - """ - - @wp.func_native(snippet) - def pgs_solve_streaming_native( - world: int, - world_constraint_count: wp.array[int], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_mu: wp.array2d[float], - ): ... - - def pgs_solve_streaming_template( - world_constraint_count: wp.array[int], - world_C: wp.array3d[float], - world_rhs: wp.array2d[float], - world_impulses: wp.array2d[float], - iterations: int, - omega: float, - world_row_mu: wp.array2d[float], - ): - world, _lane = wp.tid() - pgs_solve_streaming_native( - world, - world_constraint_count, - world_C, - world_rhs, - world_impulses, - iterations, - omega, - world_row_mu, - ) - - pgs_solve_streaming_template.__name__ = f"pgs_solve_streaming_{max_constraints}_chunk{pgs_chunk_size}" - pgs_solve_streaming_template.__qualname__ = f"pgs_solve_streaming_{max_constraints}_chunk{pgs_chunk_size}" - return wp.kernel(enable_backward=False, module="unique")(pgs_solve_streaming_template) - - @classmethod - def get_pgs_solve_mf_kernel(cls, mf_max_constraints: int, max_mf_bodies: int, device: "wp.Device") -> "wp.Kernel": - """Get or create a streaming MF PGS kernel for free rigid body contacts.""" - key = (mf_max_constraints, max_mf_bodies, device.arch) - if key not in cls._pgs_solve_mf_cache: - cls._pgs_solve_mf_cache[key] = cls._build_pgs_solve_mf_kernel(mf_max_constraints, max_mf_bodies) - return cls._pgs_solve_mf_cache[key] - - @classmethod - def _build_pgs_solve_mf_kernel(cls, mf_max_constraints: int, max_mf_bodies: int) -> "wp.Kernel": - """Matrix-free PGS with body velocities and impulses in shared memory. - - Uses one warp (32 threads) per world. Body velocities and impulses live - in shared memory for the duration of all PGS iterations, eliminating - global memory round-trips. J, MiJt, eff_mass_inv, and rhs are read - from global memory per constraint (read-only, cache-friendly sequential access). - """ - MF_MAX_C = mf_max_constraints - MAX_BODIES = max_mf_bodies - - snippet = f""" - #if defined(__CUDA_ARCH__) - const int MF_MAX_C = {MF_MAX_C}; - const int MAX_BODIES = {MAX_BODIES}; - - int lane = threadIdx.x; - - int m = mf_constraint_count.data[world]; - if (m == 0) return; - if (m > MF_MAX_C) m = MF_MAX_C; - - int n_bodies = mf_body_count.data[world]; - if (n_bodies > MAX_BODIES) n_bodies = MAX_BODIES; - - // ═══════════════════════════════════════════════════════════════ - // SHARED MEMORY - // ═══════════════════════════════════════════════════════════════ - __shared__ float s_vel[{MAX_BODIES * 6}]; - __shared__ float s_impulse[{MF_MAX_C}]; - __shared__ int s_dof_start[{MAX_BODIES}]; - - int body_off = world * MAX_BODIES; - int c_off = world * MF_MAX_C; - - // ═══════════════════════════════════════════════════════════════ - // LOAD PHASE - // ═══════════════════════════════════════════════════════════════ - - // Load body DOF starts and velocities - for (int b = lane; b < n_bodies; b += 32) {{ - int dof = mf_body_dof_start.data[body_off + b]; - s_dof_start[b] = dof; - for (int k = 0; k < 6; k++) {{ - s_vel[b * 6 + k] = v_out.data[dof + k]; - }} - }} - - // Load impulses - for (int i = lane; i < m; i += 32) {{ - s_impulse[i] = mf_impulses.data[c_off + i]; - }} - __syncwarp(); - - // ═══════════════════════════════════════════════════════════════ - // SOLVE PHASE (lane 0) - // ═══════════════════════════════════════════════════════════════ - - if (lane == 0) {{ - for (int iter = 0; iter < iterations; iter++) {{ - for (int i = 0; i < m; i++) {{ - float eff_inv = mf_eff_mass_inv.data[c_off + i]; - if (eff_inv <= 0.0f) continue; - - int lba = mf_local_body_a.data[c_off + i]; - int lbb = mf_local_body_b.data[c_off + i]; - - // Load J from global memory - int j_base = (c_off + i) * 6; - float ja0 = mf_J_a.data[j_base + 0]; - float ja1 = mf_J_a.data[j_base + 1]; - float ja2 = mf_J_a.data[j_base + 2]; - float ja3 = mf_J_a.data[j_base + 3]; - float ja4 = mf_J_a.data[j_base + 4]; - float ja5 = mf_J_a.data[j_base + 5]; - - float jb0 = mf_J_b.data[j_base + 0]; - float jb1 = mf_J_b.data[j_base + 1]; - float jb2 = mf_J_b.data[j_base + 2]; - float jb3 = mf_J_b.data[j_base + 3]; - float jb4 = mf_J_b.data[j_base + 4]; - float jb5 = mf_J_b.data[j_base + 5]; - - // Compute J * v from shared memory - float jv = 0.0f; - if (lba >= 0) {{ - int va = lba * 6; - jv += ja0 * s_vel[va] + ja1 * s_vel[va+1] + ja2 * s_vel[va+2] - + ja3 * s_vel[va+3] + ja4 * s_vel[va+4] + ja5 * s_vel[va+5]; - }} - if (lbb >= 0) {{ - int vb = lbb * 6; - jv += jb0 * s_vel[vb] + jb1 * s_vel[vb+1] + jb2 * s_vel[vb+2] - + jb3 * s_vel[vb+3] + jb4 * s_vel[vb+4] + jb5 * s_vel[vb+5]; - }} - - // PGS update - float rhs_i = mf_rhs.data[c_off + i]; - float delta = -(jv + rhs_i) * eff_inv; - float old_impulse = s_impulse[i]; - float new_impulse = old_impulse + omega * delta; - - int row_type = mf_row_type.data[c_off + i]; - - // Project: contact or joint limit - if (row_type == 0 || row_type == 3) {{ - if (new_impulse < 0.0f) new_impulse = 0.0f; - }} - // Project: friction - else if (row_type == 2) {{ - int parent_idx = mf_row_parent.data[c_off + i]; - float lambda_n = s_impulse[parent_idx]; - float mu = mf_row_mu.data[c_off + i]; - float radius = fmaxf(mu * lambda_n, 0.0f); - - if (radius <= 0.0f) {{ - new_impulse = 0.0f; - }} else {{ - int sib = (i == parent_idx + 1) ? parent_idx + 2 : parent_idx + 1; - - s_impulse[i] = new_impulse; - float a_val = new_impulse; - float b_val = s_impulse[sib]; - float mag = sqrtf(a_val * a_val + b_val * b_val); - if (mag > radius) {{ - float scale = radius / mag; - new_impulse = a_val * scale; - float sib_new = b_val * scale; - float sib_delta = sib_new - b_val; - s_impulse[sib] = sib_new; - - // Apply sibling correction to body velocities - int sib_lba = mf_local_body_a.data[c_off + sib]; - int sib_lbb = mf_local_body_b.data[c_off + sib]; - int sib_j_base = (c_off + sib) * 6; - if (sib_lba >= 0) {{ - int sva = sib_lba * 6; - s_vel[sva+0] += mf_MiJt_a.data[sib_j_base+0] * sib_delta; - s_vel[sva+1] += mf_MiJt_a.data[sib_j_base+1] * sib_delta; - s_vel[sva+2] += mf_MiJt_a.data[sib_j_base+2] * sib_delta; - s_vel[sva+3] += mf_MiJt_a.data[sib_j_base+3] * sib_delta; - s_vel[sva+4] += mf_MiJt_a.data[sib_j_base+4] * sib_delta; - s_vel[sva+5] += mf_MiJt_a.data[sib_j_base+5] * sib_delta; - }} - if (sib_lbb >= 0) {{ - int svb = sib_lbb * 6; - s_vel[svb+0] += mf_MiJt_b.data[sib_j_base+0] * sib_delta; - s_vel[svb+1] += mf_MiJt_b.data[sib_j_base+1] * sib_delta; - s_vel[svb+2] += mf_MiJt_b.data[sib_j_base+2] * sib_delta; - s_vel[svb+3] += mf_MiJt_b.data[sib_j_base+3] * sib_delta; - s_vel[svb+4] += mf_MiJt_b.data[sib_j_base+4] * sib_delta; - s_vel[svb+5] += mf_MiJt_b.data[sib_j_base+5] * sib_delta; - }} - }} - }} - }} - - float delta_impulse = new_impulse - old_impulse; - s_impulse[i] = new_impulse; - - // Apply velocity correction: v += MiJt * delta_impulse - int mijt_base = (c_off + i) * 6; - if (lba >= 0) {{ - int va = lba * 6; - s_vel[va+0] += mf_MiJt_a.data[mijt_base+0] * delta_impulse; - s_vel[va+1] += mf_MiJt_a.data[mijt_base+1] * delta_impulse; - s_vel[va+2] += mf_MiJt_a.data[mijt_base+2] * delta_impulse; - s_vel[va+3] += mf_MiJt_a.data[mijt_base+3] * delta_impulse; - s_vel[va+4] += mf_MiJt_a.data[mijt_base+4] * delta_impulse; - s_vel[va+5] += mf_MiJt_a.data[mijt_base+5] * delta_impulse; - }} - if (lbb >= 0) {{ - int vb = lbb * 6; - s_vel[vb+0] += mf_MiJt_b.data[mijt_base+0] * delta_impulse; - s_vel[vb+1] += mf_MiJt_b.data[mijt_base+1] * delta_impulse; - s_vel[vb+2] += mf_MiJt_b.data[mijt_base+2] * delta_impulse; - s_vel[vb+3] += mf_MiJt_b.data[mijt_base+3] * delta_impulse; - s_vel[vb+4] += mf_MiJt_b.data[mijt_base+4] * delta_impulse; - s_vel[vb+5] += mf_MiJt_b.data[mijt_base+5] * delta_impulse; - }} - }} - }} - }} - __syncwarp(); - - // ═══════════════════════════════════════════════════════════════ - // STORE PHASE - // ═══════════════════════════════════════════════════════════════ - - // Write body velocities back to v_out - for (int b = lane; b < n_bodies; b += 32) {{ - int dof = s_dof_start[b]; - for (int k = 0; k < 6; k++) {{ - v_out.data[dof + k] = s_vel[b * 6 + k]; - }} - }} - - // Write impulses back - for (int i = lane; i < m; i += 32) {{ - mf_impulses.data[c_off + i] = s_impulse[i]; - }} - #endif - """ - - @wp.func_native(snippet) - def pgs_solve_mf_native( - world: int, - mf_constraint_count: wp.array[int], - mf_body_count: wp.array[int], - mf_body_dof_start: wp.array2d[int], - mf_local_body_a: wp.array2d[int], - mf_local_body_b: wp.array2d[int], - mf_J_a: wp.array3d[float], - mf_J_b: wp.array3d[float], - mf_MiJt_a: wp.array3d[float], - mf_MiJt_b: wp.array3d[float], - mf_eff_mass_inv: wp.array2d[float], - mf_rhs: wp.array2d[float], - mf_row_type: wp.array2d[int], - mf_row_parent: wp.array2d[int], - mf_row_mu: wp.array2d[float], - mf_impulses: wp.array2d[float], - v_out: wp.array[float], - iterations: int, - omega: float, - ): ... - - def pgs_solve_mf_template( - mf_constraint_count: wp.array[int], - mf_body_count: wp.array[int], - mf_body_dof_start: wp.array2d[int], - mf_local_body_a: wp.array2d[int], - mf_local_body_b: wp.array2d[int], - mf_J_a: wp.array3d[float], - mf_J_b: wp.array3d[float], - mf_MiJt_a: wp.array3d[float], - mf_MiJt_b: wp.array3d[float], - mf_eff_mass_inv: wp.array2d[float], - mf_rhs: wp.array2d[float], - mf_row_type: wp.array2d[int], - mf_row_parent: wp.array2d[int], - mf_row_mu: wp.array2d[float], - mf_impulses: wp.array2d[float], - v_out: wp.array[float], - iterations: int, - omega: float, - ): - world, _lane = wp.tid() - pgs_solve_mf_native( - world, - mf_constraint_count, - mf_body_count, - mf_body_dof_start, - mf_local_body_a, - mf_local_body_b, - mf_J_a, - mf_J_b, - mf_MiJt_a, - mf_MiJt_b, - mf_eff_mass_inv, - mf_rhs, - mf_row_type, - mf_row_parent, - mf_row_mu, - mf_impulses, - v_out, - iterations, - omega, - ) - - name = f"pgs_solve_mf_{mf_max_constraints}_{max_mf_bodies}" - pgs_solve_mf_template.__name__ = name - pgs_solve_mf_template.__qualname__ = name - return wp.kernel(enable_backward=False, module="unique")(pgs_solve_mf_template) - @classmethod def get_pack_mf_meta_kernel(cls, mf_max_constraints: int, device: "wp.Device") -> "wp.Kernel": """Get or create a kernel to pack MF metadata into int4 format.""" @@ -4135,11 +2896,11 @@ def pack_mf_meta_template( def get_pgs_solve_mf_gs_kernel( cls, max_constraints: int, mf_max_constraints: int, max_world_dofs: int, device: "wp.Device" ) -> "wp.Kernel": - """Get or create a two-phase GS kernel for matrix-free articulated PGS. + """Get or create the unified GS kernel for the matrix-free winner path. - Phase 1 processes dense constraints (via J_world/Y_world at ``max_constraints``). - Phase 2 processes MF constraints (via mf_J/mf_MiJt at ``mf_max_constraints``). - Both phases share a single velocity vector in shared memory. + Articulated rows use the gathered world-space `J_world` and `Y_world` + inputs, while free-rigid rows use the packed matrix-free metadata. Both + row sets share one world-velocity tile in shared memory. """ key = (max_constraints, mf_max_constraints, max_world_dofs, device.arch) if key not in cls._pgs_solve_mf_gs_cache: @@ -4152,17 +2913,18 @@ def get_pgs_solve_mf_gs_kernel( def _build_pgs_solve_mf_gs_kernel( cls, max_constraints: int, mf_max_constraints: int, max_world_dofs: int ) -> "wp.Kernel": - """Two-phase GS PGS kernel: dense + matrix-free in one pass. + """Build the unified GS kernel for articulated and free-rigid rows. Uses one warp (32 threads) per world. - Phase 1 (dense): warp-parallel dot/update over D DOFs using J_world/Y_world. - Phase 2 (MF): lanes 0-5 handle body_a, lanes 6-11 handle body_b (6 DOFs each). + The articulated-contact rows use warp-parallel dot/update work over the + world DOF tile. The free-rigid rows use 6-DOF body slices for body A and + body B while reusing the same shared velocity state. Shared memory layout: s_v[D] — world velocity - s_lam_dense[M_D] + metadata — dense impulses and constraint info - s_lam_mf[M_MF] — MF impulses (metadata read from global per constraint) + s_lam_dense[M_D] + metadata — articulated-row impulses and row info + s_lam_mf[M_MF] — free-rigid impulses (metadata read from global) """ M_D = max_constraints M_MF = mf_max_constraints @@ -4171,7 +2933,7 @@ def _build_pgs_solve_mf_gs_kernel( # How many DOF elements each lane handles (ceil(D/32)) ELEMS_PER_LANE = (D + 31) // 32 - # --- Code generation for dense phase (D-wide dot/update, software-pipelined) --- + # --- Code generation for articulated rows (D-wide dot/update, software-pipelined) --- # Pipeline register declarations dense_pipe_decl = "\n".join( @@ -4292,7 +3054,7 @@ def _build_pgs_solve_mf_gs_kernel( for (int iter = 0; iter < iterations; iter++) {{ - // ── Phase 1: Dense constraints (D-DOF warp-parallel, software-pipelined) ── + // ── Articulated rows (D-DOF warp-parallel, software-pipelined) ── // Prefetch constraint 0 if (m_dense > 0) {{ @@ -4368,7 +3130,7 @@ def _build_pgs_solve_mf_gs_kernel( __syncwarp(); }} - // ── Phase 2: MF constraints (6-DOF per body, software-pipelined) ── + // ── Free-rigid rows (6-DOF per body, software-pipelined) ── // Pipeline registers: prefetch next constraint's global data int4 pre_meta; From cf82b4386e7f0c40c742863afda4d1a817d592e1 Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 12:56:39 -0400 Subject: [PATCH 4/7] Close FeatherPGS support cleanup --- .../execplans/fpgs-private-api-matrix-free.md | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.agent/execplans/fpgs-private-api-matrix-free.md b/.agent/execplans/fpgs-private-api-matrix-free.md index 6a4abe1d06..f8e6aa3cb5 100644 --- a/.agent/execplans/fpgs-private-api-matrix-free.md +++ b/.agent/execplans/fpgs-private-api-matrix-free.md @@ -128,6 +128,22 @@ Deliverable: - Bench/test/supporting code on the private API branch reflects the single-path implementation instead of the old ablation-heavy surface. +Progress update (2026-04-13, pass 3): +- Searched the branch for leftover private FeatherPGS callers and stale + constructor knobs (`pgs_mode`, per-stage kernel selectors). No branch-local + tests, helpers, or docs outside the focused private test and PR draft still + referenced the removed surface. +- Milestone 3 therefore closes as a support-surface verification pass rather + than a broad code-edit pass: the earlier solver cleanup already left the + private branch with one surviving test entry point in + `newton/tests/test_feather_pgs.py`, and no extra helper or benchmark cleanup + was required on this branch. +- Validation: + - `uv run --extra dev -m newton.tests -k test_feather_pgs` -> passed + (`Ran 2 tests in 14.754s`, `OK`). + - `uvx pre-commit run -a` -> passed (`ruff`, `ruff format`, `uv-lock`, + `typos`, and `check warp array syntax`). + Required work: - Update any tests, helper code, or branch-local references that still assume mode or kernel multiplicity. @@ -199,6 +215,12 @@ Checkpoint: kernel commentary so the private branch reflects one coherent articulated + free-rigid winner path. After that push, stop before Milestone 3 and leave supporting-surface cleanup plus final validation recording for later passes. +- Replan update (2026-04-13, pass 3): + A branch-wide search found no remaining private FeatherPGS callers or + support-code references to the removed constructor knobs outside the focused + private test and PR draft already added in prior passes. This pass therefore + closes Milestone 3 by recording that no further branch-local support-surface + edits are needed, then stops before the final validation milestone. - If a milestone is too large to finish cleanly, begin the pass by tightening this ExecPlan with a short replan note and then complete one reviewable slice of that milestone. From 002900723b84a890c43768fa1242ce455c69d5b4 Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 13:00:00 -0400 Subject: [PATCH 5/7] Close ExecPlan final validation --- .../execplans/fpgs-private-api-matrix-free.md | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/.agent/execplans/fpgs-private-api-matrix-free.md b/.agent/execplans/fpgs-private-api-matrix-free.md index f8e6aa3cb5..9017c9058d 100644 --- a/.agent/execplans/fpgs-private-api-matrix-free.md +++ b/.agent/execplans/fpgs-private-api-matrix-free.md @@ -10,7 +10,7 @@ the GitHub PR description directly. ## Branch and Workspace - Working branch: `dturpin/fpgs-private-api-matrix-free` -- Tracking branch: `origin/fpgs-private-api` +- Tracking branch: `origin/dturpin/fpgs-private-api-matrix-free` - Workspace root: repository root of this worktree ## Constraints @@ -185,17 +185,34 @@ Checkpoint: Deliverable: - Reviewable branch state with validation evidence and no pending plan items. -Required work: -- Re-run the final focused test set. -- Re-run `uvx pre-commit run -a`. -- Update this ExecPlan to reflect what shipped and any intentional omissions. +Progress update (2026-04-13, pass 4): +- Repointed the local branch to track + `origin/dturpin/fpgs-private-api-matrix-free`, removing the stale + `origin/fpgs-private-api` tracking ambiguity from `git status`. +- Re-ran the final focused solver validation and repository-wide pre-commit + checks, then updated this ExecPlan to reflect the shipped end state. +- Milestone 5 is closed. No milestone items remain on this ExecPlan. Validation: -- Record exact commands run and results. +- `git rev-list --left-right --count upstream/main...HEAD` -> `28 4`. + This confirms the branch is rebased onto `upstream/main` and now sits 4 + commits ahead with no missing upstream-main commits. +- `uv run --extra dev -m newton.tests -k test_feather_pgs` -> passed + (`Ran 2 tests in 14.586s`, `OK`). +- `uvx pre-commit run -a` -> passed (`ruff`, `ruff format`, `uv-lock`, + `typos`, and `check warp array syntax`). +- `git status --short --branch` after upstream fix -> + `## dturpin/fpgs-private-api-matrix-free...origin/dturpin/fpgs-private-api-matrix-free` + +Intentional omissions / review notes: +- The actual GitHub PR description was not edited in this workflow; only the + local human-review draft at + `.agent/review/fpgs-private-api-pr-description-draft.md` was maintained. + That non-edit is a process guarantee from this workspace flow rather than a + property that can be independently verified from local files alone. Checkpoint: -- Commit and push the final polishing pass. -- Stop and wait for human review. +- Commit and push the final closeout state, then stop for human review. ## Notes for the Implementor @@ -221,6 +238,11 @@ Checkpoint: private test and PR draft already added in prior passes. This pass therefore closes Milestone 3 by recording that no further branch-local support-surface edits are needed, then stops before the final validation milestone. +- Replan update (2026-04-13, pass 4): + The remaining work is strictly closeout: rerun the final validation, fix the + stale local tracking branch so the pushed review state is obvious from + `git status`, and rewrite Milestone 5 as shipped state with exact results and + the one remaining non-verifiable GitHub PR-description note. - If a milestone is too large to finish cleanly, begin the pass by tightening this ExecPlan with a short replan note and then complete one reviewable slice of that milestone. From 3fe23ff62573342409a6d017fc83339e98885f22 Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 13:08:02 -0400 Subject: [PATCH 6/7] Finish private API MF cleanup Rebase the private FeatherPGS branch onto upstream/main and remove the\nremaining dense-mode presentation from the solver surface. Rename the\nlast dense_* constructor and helper names to mode-neutral matrix-free\narticulated-row terms, refresh the focused signature coverage, and\nupdate the ExecPlan and PR draft so their validation evidence matches\nthe rebased branch state. --- .../execplans/fpgs-private-api-matrix-free.md | 41 ++++++++-- .../fpgs-private-api-pr-description-draft.md | 11 ++- newton/_src/solvers/feather_pgs/kernels.py | 16 ++-- .../solvers/feather_pgs/solver_feather_pgs.py | 80 +++++++++---------- newton/_src/solvers/featherstone/kernels.py | 4 +- newton/tests/test_feather_pgs.py | 4 + 6 files changed, 98 insertions(+), 58 deletions(-) diff --git a/.agent/execplans/fpgs-private-api-matrix-free.md b/.agent/execplans/fpgs-private-api-matrix-free.md index 9017c9058d..9ab5827fbe 100644 --- a/.agent/execplans/fpgs-private-api-matrix-free.md +++ b/.agent/execplans/fpgs-private-api-matrix-free.md @@ -106,6 +106,19 @@ Progress update (2026-04-13, pass 2): - Milestone 2 is now functionally complete on the private solver surface; any leftover cleanup is supporting-surface polish tracked in Milestone 3. +Progress update (2026-04-13, pass 5): +- Rebased the branch onto `upstream/main` and resolved the shared + `featherstone/kernels.py` conflict by keeping upstream's newer FK/twist + writeback behavior while preserving the helper this branch needs. +- Removed the last stale dense-path presentation from the private FeatherPGS + surface: `SolverFeatherPGS` now exposes `contact_compliance` and + `max_constraints` instead of `dense_contact_compliance` and + `dense_max_constraints`, and the live stage-4 helpers/kernel entry points now + use matrix-free/articulated-row names rather than `dense_*`. +- Expanded the focused signature test to lock in the absence of the old + `dense_*` constructor knobs alongside the already-removed mode and + kernel-selection knobs. + Required work: - Remove `dense` and `split` support from the private API implementation. - Remove retained kernel-selection and intra-mode multi-path knobs whose only @@ -193,14 +206,20 @@ Progress update (2026-04-13, pass 4): checks, then updated this ExecPlan to reflect the shipped end state. - Milestone 5 is closed. No milestone items remain on this ExecPlan. +Progress update (2026-04-13, pass 5): +- Reopened this milestone after judge review found stale dense-path naming and + validation evidence. Rebased onto the current `upstream/main`, applied the + remaining matrix-free-only naming cleanup, and refreshed the review artifacts + to match the rebased branch state. + Validation: -- `git rev-list --left-right --count upstream/main...HEAD` -> `28 4`. - This confirms the branch is rebased onto `upstream/main` and now sits 4 - commits ahead with no missing upstream-main commits. +- `git rev-list --left-right --count upstream/main...HEAD` -> `0 6`. + This confirms the branch is rebased onto the current `upstream/main` and now + sits 6 commits ahead with no missing upstream-main commits. - `uv run --extra dev -m newton.tests -k test_feather_pgs` -> passed - (`Ran 2 tests in 14.586s`, `OK`). -- `uvx pre-commit run -a` -> passed (`ruff`, `ruff format`, `uv-lock`, - `typos`, and `check warp array syntax`). + (`Ran 2 tests in 15.157s`, `OK`). +- `uvx pre-commit run -a` -> passed + (`ruff`, `ruff format`, `uv-lock`, `typos`, and `check warp array syntax`). - `git status --short --branch` after upstream fix -> `## dturpin/fpgs-private-api-matrix-free...origin/dturpin/fpgs-private-api-matrix-free` @@ -243,6 +262,16 @@ Checkpoint: stale local tracking branch so the pushed review state is obvious from `git status`, and rewrite Milestone 5 as shipped state with exact results and the one remaining non-verifiable GitHub PR-description note. +- Replan update (2026-04-13, pass 5): + Judge review found the private line still presented dense-path leftovers via + `dense_contact_compliance`, `dense_max_constraints`, and helper names such as + `_stage4_add_dense_contact_compliance()`, even though the live solver no + longer materializes a full Delassus matrix. This pass tightens Milestone 2's + final reviewable slice on top of a fresh rebase onto `upstream/main`: remove + the stale `dense_*` public surface, rename the surviving articulated-row + diagonal/compliance helpers to match the matrix-free implementation, refresh + focused tests to guard that constructor surface, and then rerun validation, + update the PR draft, and rewrite the stale final-validation evidence. - If a milestone is too large to finish cleanly, begin the pass by tightening this ExecPlan with a short replan note and then complete one reviewable slice of that milestone. diff --git a/.agent/review/fpgs-private-api-pr-description-draft.md b/.agent/review/fpgs-private-api-pr-description-draft.md index 9e0fca87ca..17e87b4e37 100644 --- a/.agent/review/fpgs-private-api-pr-description-draft.md +++ b/.agent/review/fpgs-private-api-pr-description-draft.md @@ -5,14 +5,17 @@ Simplify the private FeatherPGS API line to the current matrix-free path. This branch stops presenting the private solver as a bundle of retained -ablations. Instead, it bakes in the current matrix-free contact solve path and +ablations. Instead, it bakes in the current matrix-free contact solve path, removes obsolete top-level mode selection and kernel-selection API knobs from -`SolverFeatherPGS`. +`SolverFeatherPGS`, and drops the stale `dense_*` constructor naming that made +the private line still look like a dense-mode fork. ## What Changed - Remove the private solver constructor knobs for `pgs_mode` and per-stage kernel selection. +- Rename the remaining articulated-row tuning surface from `dense_*` to + mode-neutral `contact_compliance` and `max_constraints`. - Run the private FeatherPGS step path as matrix-free only. - Add focused unit coverage for the stripped-down constructor surface and a minimal step smoke test. @@ -54,6 +57,10 @@ Interpretation: - `uv run --extra dev -m newton.tests -k test_feather_pgs` - `uvx pre-commit run -a` +Current branch state after rebasing onto `upstream/main`: + +- `git rev-list --left-right --count upstream/main...HEAD` -> `0 6` + ## Notes For Review - This PR intentionally updates only the private API line. It does not touch diff --git a/newton/_src/solvers/feather_pgs/kernels.py b/newton/_src/solvers/feather_pgs/kernels.py index aab14167e9..6232af33ad 100644 --- a/newton/_src/solvers/feather_pgs/kernels.py +++ b/newton/_src/solvers/feather_pgs/kernels.py @@ -2629,7 +2629,7 @@ def clamp_joint_tau( TILE_DOF = wp.constant(49) # Max constraints per articulation we support in the tiled path. -# dense_max_constraints must be <= TILE_CONSTRAINTS or we use fall back +# max_constraints must be <= TILE_CONSTRAINTS or we use fall back TILE_CONSTRAINTS = wp.constant(128) # Threads per tile/block for tile kernels @@ -2773,7 +2773,7 @@ def prepare_world_impulses( @wp.kernel -def diag_from_JY_par_art( +def extract_diag_from_JY_par_art( J_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] Y_group: wp.array3d[float], # [n_arts_of_size, max_constraints, n_dofs] group_to_art: wp.array[int], @@ -3623,19 +3623,19 @@ def finalize_world_diag_cfm( @wp.kernel -def add_dense_contact_compliance_to_diag( +def add_contact_compliance_to_diag( world_constraint_count: wp.array[int], world_row_type: wp.array2d[int], contact_alpha: float, # in/out world_diag: wp.array2d[float], ): - """Add normal-contact compliance to the dense PGS diagonal. + """Add normal-contact compliance to the articulated-row PGS diagonal. - The dense articulated contact path uses a Delassus diagonal in impulse - space. A compliance ``alpha = compliance / dt^2`` contributes an additional - diagonal term for normal contact rows only, yielding a softer normal - response without changing friction or joint-limit rows. + The matrix-free articulated contact path still extracts a Delassus diagonal + in impulse space. A compliance ``alpha = compliance / dt^2`` contributes an + additional diagonal term for normal contact rows only, yielding a softer + normal response without changing friction or joint-limit rows. """ world = wp.tid() m = world_constraint_count[world] diff --git a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py index 306f2e9349..b6f510f3ce 100644 --- a/newton/_src/solvers/feather_pgs/solver_feather_pgs.py +++ b/newton/_src/solvers/feather_pgs/solver_feather_pgs.py @@ -35,7 +35,7 @@ from ..solver import SolverBase from .kernels import ( TILE_THREADS, - add_dense_contact_compliance_to_diag, + add_contact_compliance_to_diag, allocate_joint_limit_slots, allocate_world_contact_slots, apply_augmented_joint_tau, @@ -61,11 +61,11 @@ crba_fill_par_dof, delassus_par_row_col, detect_limit_count_changes, - diag_from_JY_par_art, eval_rigid_fk, eval_rigid_id, eval_rigid_mass, eval_rigid_tau, + extract_diag_from_JY_par_art, finalize_mf_constraint_counts, finalize_world_constraint_counts, finalize_world_diag_cfm, @@ -167,9 +167,9 @@ def __init__( pgs_iterations: int = 12, pgs_beta: float = 0.2, pgs_cfm: float = 1.0e-6, - dense_contact_compliance: float = 0.0, + contact_compliance: float = 0.0, pgs_omega: float = 1.0, - dense_max_constraints: int = 32, + max_constraints: int = 32, pgs_warmstart: bool = False, mf_max_constraints: int = 512, # Parallelism options @@ -190,11 +190,11 @@ def __init__( pgs_iterations (int, optional): Number of Gauss-Seidel iterations to apply per frame. Defaults to 12. pgs_beta (float, optional): ERP style position correction factor. Defaults to 0.2. pgs_cfm (float, optional): Compliance/regularization added to the Delassus diagonal. Defaults to 1.0e-6. - dense_contact_compliance (float, optional): Normal contact compliance [m/N] applied - only to dense articulated contact rows. Converted to an impulse-space diagonal - term using ``compliance / dt^2``. Defaults to 0.0. + contact_compliance (float, optional): Normal contact compliance [m/N] applied + to articulated contact rows. Converted to an impulse-space diagonal term using + ``compliance / dt^2``. Defaults to 0.0. pgs_omega (float, optional): Successive over-relaxation factor for the PGS sweep. Defaults to 1.0. - dense_max_constraints (int, optional): Maximum number of articulated contact constraint + max_constraints (int, optional): Maximum number of articulated contact constraint rows stored per world. Free rigid body contacts are stored separately, bounded by mf_max_constraints. Defaults to 32. pgs_warmstart (bool, optional): Re-use impulses from the previous frame when contacts persist. Defaults to False. @@ -212,9 +212,9 @@ def __init__( self.pgs_iterations = pgs_iterations self.pgs_beta = pgs_beta self.pgs_cfm = pgs_cfm - self.dense_contact_compliance = dense_contact_compliance + self.contact_compliance = contact_compliance self.pgs_omega = pgs_omega - self.dense_max_constraints = dense_max_constraints + self.max_constraints = max_constraints self.pgs_warmstart = pgs_warmstart self.mf_max_constraints = mf_max_constraints self._double_buffer = double_buffer @@ -636,7 +636,7 @@ def _allocate_buffers(self, model): device = model.device requires_grad = model.requires_grad - max_constraints = self.dense_max_constraints + max_constraints = self.max_constraints self.L_by_size = {} self.Y_by_size = {} @@ -734,7 +734,7 @@ def _allocate_world_buffers(self, model): device = model.device requires_grad = model.requires_grad - max_constraints = self.dense_max_constraints + max_constraints = self.max_constraints self.C = None self._compute_world_dof_mapping(model) @@ -854,7 +854,7 @@ def _allocate_debug_buffers(self, model): return device = model.device worlds = self.world_count - max_c = self.dense_max_constraints + max_c = self.max_constraints mf_max_c = self.mf_max_constraints self._diag_metrics = wp.zeros((worlds, 4), dtype=wp.float32, device=device) @@ -1027,12 +1027,12 @@ def step( else: self._stage4_hinv_jt_par_row(size) - # Diagonal from J*Y (no full Delassus) + # Extract only the world diagonal from J*Y; do not assemble the full Delassus matrix. self.diag.zero_() for size in self.size_groups: - self._stage4_diag_from_JY(size) + self._stage4_extract_diag_from_JY(size) self._stage4_finalize_world_diag_cfm() - self._stage4_add_dense_contact_compliance(dt) + self._stage4_add_contact_compliance(dt) # RHS = bias only (J*v recomputed per iteration) self._stage4_compute_rhs_world(dt) @@ -1069,7 +1069,7 @@ def step( n_arts = self.n_arts_by_size[size] wp.launch( gather_JY_to_world, - dim=int(n_arts * self.dense_max_constraints * size), + dim=int(n_arts * self.max_constraints * size), inputs=[ self.group_to_art[size], self.art_to_world, @@ -1079,7 +1079,7 @@ def step( self.J_by_size[size], self.Y_by_size[size], size, - self.dense_max_constraints, + self.max_constraints, n_arts, ], outputs=[self.J_world, self.Y_world], @@ -1108,7 +1108,7 @@ def step( with wp.ScopedTimer("S6_PGS_Solve", print=False, use_nvtx=self._nvtx, synchronize=self._nvtx): mf_gs_kernel = TiledKernelFactory.get_pgs_solve_mf_gs_kernel( - self.dense_max_constraints, + self.max_constraints, self.mf_max_constraints, self.max_world_dofs, self.model.device, @@ -1164,7 +1164,7 @@ def step( self.row_parent, self.row_mu, self.J_world, - self.dense_max_constraints, + self.max_constraints, self.max_world_dofs, self.mf_constraint_count, self.mf_rhs, @@ -1849,7 +1849,7 @@ def _stage3_compute_v_hat(self, state_in: State, state_aug: State, dt: float): def _stage4_build_rows(self, state_in: State, state_aug: State, contacts: Contacts): model = self.model - max_constraints = self.dense_max_constraints + max_constraints = self.max_constraints mf_active = self._has_free_rigid_bodies # Zero world-level buffers (only arrays that require it) @@ -2167,7 +2167,7 @@ def _stage4_zero_world_C(self): def _stage4_hinv_jt_tiled(self, size: int): model = self.model n_arts = self.n_arts_by_size[size] - hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_kernel(size, self.dense_max_constraints, model.device) + hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_kernel(size, self.max_constraints, model.device) wp.launch_tiled( hinv_jt_kernel, dim=[n_arts], @@ -2186,7 +2186,7 @@ def _stage4_hinv_jt_tiled(self, size: int): def _stage4_hinv_jt_tiled_fused(self, size: int): model = self.model n_arts = self.n_arts_by_size[size] - hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_fused_kernel(size, self.dense_max_constraints, model.device) + hinv_jt_kernel = TiledKernelFactory.get_hinv_jt_fused_kernel(size, self.max_constraints, model.device) wp.launch_tiled( hinv_jt_kernel, dim=[n_arts], @@ -2208,7 +2208,7 @@ def _stage4_hinv_jt_par_row(self, size: int): n_arts = self.n_arts_by_size[size] wp.launch( hinv_jt_par_row, - dim=n_arts * self.dense_max_constraints, + dim=n_arts * self.max_constraints, inputs=[ self.L_by_size[size], self.J_by_size[size], @@ -2216,7 +2216,7 @@ def _stage4_hinv_jt_par_row(self, size: int): self.art_to_world, self.constraint_count, size, - self.dense_max_constraints, + self.max_constraints, n_arts, ], outputs=[self.Y_by_size[size]], @@ -2228,7 +2228,7 @@ def _stage4_delassus_par_row_col(self, size: int): n_arts = self.n_arts_by_size[size] wp.launch( delassus_par_row_col, - dim=n_arts * self.dense_max_constraints * self.dense_max_constraints, + dim=n_arts * self.max_constraints * self.max_constraints, inputs=[ self.J_by_size[size], self.Y_by_size[size], @@ -2236,7 +2236,7 @@ def _stage4_delassus_par_row_col(self, size: int): self.art_to_world, self.constraint_count, size, - self.dense_max_constraints, + self.max_constraints, n_arts, ], outputs=[self.C, self.diag], @@ -2247,7 +2247,7 @@ def _stage4_delassus_tiled(self, size: int): model = self.model n_arts = self.n_arts_by_size[size] delassus_kernel = TiledKernelFactory.get_delassus_kernel( - size, self.dense_max_constraints, model.device, chunk_size=self.delassus_chunk_size + size, self.max_constraints, model.device, chunk_size=self.delassus_chunk_size ) wp.launch_tiled( delassus_kernel, @@ -2275,24 +2275,24 @@ def _stage4_finalize_world_diag_cfm(self): device=model.device, ) - def _stage4_add_dense_contact_compliance(self, dt: float): - if self.dense_contact_compliance <= 0.0: + def _stage4_add_contact_compliance(self, dt: float): + if self.contact_compliance <= 0.0: return - contact_alpha = float(self.dense_contact_compliance / (dt * dt)) + contact_alpha = float(self.contact_compliance / (dt * dt)) wp.launch( - add_dense_contact_compliance_to_diag, + add_contact_compliance_to_diag, dim=self.world_count, inputs=[self.constraint_count, self.row_type, contact_alpha], outputs=[self.diag], device=self.model.device, ) - def _stage4_diag_from_JY(self, size: int): + def _stage4_extract_diag_from_JY(self, size: int): n_arts = self.n_arts_by_size[size] wp.launch( - diag_from_JY_par_art, - dim=n_arts * self.dense_max_constraints, + extract_diag_from_JY_par_art, + dim=n_arts * self.max_constraints, inputs=[ self.J_by_size[size], self.Y_by_size[size], @@ -2300,7 +2300,7 @@ def _stage4_diag_from_JY(self, size: int): self.art_to_world, self.constraint_count, size, - self.dense_max_constraints, + self.max_constraints, n_arts, ], outputs=[self.diag], @@ -2314,7 +2314,7 @@ def _stage4_compute_rhs_world(self, dt: float): dim=self.world_count, inputs=[ self.constraint_count, - self.dense_max_constraints, + self.max_constraints, self.phi, self.row_beta, self.row_type, @@ -2333,7 +2333,7 @@ def _stage4_accumulate_rhs_world(self, size: int): dim=n_arts, inputs=[ self.constraint_count, - self.dense_max_constraints, + self.max_constraints, self.art_to_world, self.art_size, self.art_group_idx, @@ -2352,7 +2352,7 @@ def _stage5_prepare_impulses_world(self): wp.launch( prepare_world_impulses, dim=self.world_count, - inputs=[self.constraint_count, self.dense_max_constraints, warmstart_flag], + inputs=[self.constraint_count, self.max_constraints, warmstart_flag], outputs=[self.impulses], device=self.model.device, ) @@ -2373,7 +2373,7 @@ def _stage6_apply_impulses_world(self, size: int): size, n_arts, self.constraint_count, - self.dense_max_constraints, + self.max_constraints, self.Y_by_size[size], self.impulses, self.v_hat, diff --git a/newton/_src/solvers/featherstone/kernels.py b/newton/_src/solvers/featherstone/kernels.py index 6132a00c45..8d4abff6fe 100644 --- a/newton/_src/solvers/featherstone/kernels.py +++ b/newton/_src/solvers/featherstone/kernels.py @@ -8,10 +8,8 @@ from ...math import transform_twist from ...sim import BodyFlags, JointType, Model, State from ...sim.articulation import ( - com_twist_to_point_velocity, compute_2d_rotational_dofs, compute_3d_rotational_dofs, - origin_twist_to_com_twist, ) from ..semi_implicit.kernels_body import joint_force @@ -2078,6 +2076,8 @@ def eval_fk_with_velocity_conversion( ], device=model.device, ) + + def eval_fk_with_velocity_conversion_from_joint_starts( model: Model, articulation_indices: wp.array[int], diff --git a/newton/tests/test_feather_pgs.py b/newton/tests/test_feather_pgs.py index fcb6df11a8..8cca83464f 100644 --- a/newton/tests/test_feather_pgs.py +++ b/newton/tests/test_feather_pgs.py @@ -20,11 +20,15 @@ def test_constructor_signature_is_matrix_free_only(self): params = inspect.signature(SolverFeatherPGS).parameters self.assertNotIn("pgs_mode", params) + self.assertNotIn("dense_contact_compliance", params) + self.assertNotIn("dense_max_constraints", params) self.assertNotIn("cholesky_kernel", params) self.assertNotIn("trisolve_kernel", params) self.assertNotIn("hinv_jt_kernel", params) self.assertNotIn("delassus_kernel", params) self.assertNotIn("pgs_kernel", params) + self.assertIn("contact_compliance", params) + self.assertIn("max_constraints", params) def test_step_smoke_runs_matrix_free_solver(self): model = self._make_model() From 59c97f2959338703cfe4a6b2922c5e501d9101de Mon Sep 17 00:00:00 2001 From: Dylan Turpin Date: Mon, 13 Apr 2026 15:28:16 -0400 Subject: [PATCH 7/7] Drop checked-in PR draft artifact --- .../fpgs-private-api-pr-description-draft.md | 70 ------------------- 1 file changed, 70 deletions(-) delete mode 100644 .agent/review/fpgs-private-api-pr-description-draft.md diff --git a/.agent/review/fpgs-private-api-pr-description-draft.md b/.agent/review/fpgs-private-api-pr-description-draft.md deleted file mode 100644 index 17e87b4e37..0000000000 --- a/.agent/review/fpgs-private-api-pr-description-draft.md +++ /dev/null @@ -1,70 +0,0 @@ -# Draft PR Description - -## Summary - -Simplify the private FeatherPGS API line to the current matrix-free path. - -This branch stops presenting the private solver as a bundle of retained -ablations. Instead, it bakes in the current matrix-free contact solve path, -removes obsolete top-level mode selection and kernel-selection API knobs from -`SolverFeatherPGS`, and drops the stale `dense_*` constructor naming that made -the private line still look like a dense-mode fork. - -## What Changed - -- Remove the private solver constructor knobs for `pgs_mode` and per-stage - kernel selection. -- Rename the remaining articulated-row tuning surface from `dense_*` to - mode-neutral `contact_compliance` and `max_constraints`. -- Run the private FeatherPGS step path as matrix-free only. -- Add focused unit coverage for the stripped-down constructor surface and a - minimal step smoke test. - -## Why Matrix-Free Only - -The published nightly ablations already show that the matrix-free path is the -winner for the private line, while the dense/split modes mostly preserve -research history rather than current product intent. - -Using the published nightly run `2026-04-01T20-49-30Z` (`summary.json`, -commit `53b3188`) from the `gh-pages` artifacts: - -| Scenario | Hardware | Baseline | Split | Matrix-free | Matrix-free vs split | -| --- | --- | ---: | ---: | ---: | ---: | -| `h1_tabletop_ablation` | RTX 5090 | 4,262 env_fps | 41,496 env_fps | 118,547 env_fps | 2.86x | -| `h1_tabletop_ablation` | RTX PRO 6000 Server | 4,942 env_fps | 47,528 env_fps | 114,531 env_fps | 2.41x | -| `h1_tabletop_ablation` | B200 | 3,765 env_fps | 60,684 env_fps | 107,262 env_fps | 1.77x | - -The same nightly run also supports keeping the current winner kernel choices on -the private line instead of exposing them as API: - -| Scenario | Hardware | FeatherPGS baseline | tiled `hinv_jt` | tiled PGS | parallel streams | -| --- | --- | ---: | ---: | ---: | ---: | -| `g1_flat_ablation` | RTX 5090 | 590,152 env_fps | 1,358,725 env_fps | 1,461,194 env_fps | 1,461,373 env_fps | -| `g1_flat_ablation` | RTX PRO 6000 Server | 504,100 env_fps | 1,182,672 env_fps | 1,276,901 env_fps | 1,277,128 env_fps | -| `g1_flat_ablation` | B200 | 760,212 env_fps | 1,389,444 env_fps | 1,551,859 env_fps | 1,550,952 env_fps | - -Interpretation: - -- `matrix-free` materially outperforms `split` on the published tabletop - ablation across all listed GPUs. -- `tiled hinv_jt`, tiled PGS, and parallel streams are already the winning - direction in the published flat-scene ablation, so the private API does not - need to keep exposing these as branch-local tuning knobs. - -## Validation - -- `uv run --extra dev -m newton.tests -k test_feather_pgs` -- `uvx pre-commit run -a` - -Current branch state after rebasing onto `upstream/main`: - -- `git rev-list --left-right --count upstream/main...HEAD` -> `0 6` - -## Notes For Review - -- This PR intentionally updates only the private API line. It does not touch - `gh-pages` content and does not move benchmark rationale into public docs. -- The benchmark figures above are copied from published nightly artifacts - already present on `origin/gh-pages`, not from new benchmark runs in this - branch.