diff --git a/docs/api/newton.rst b/docs/api/newton.rst index 88d2c194e..2df167a4a 100644 --- a/docs/api/newton.rst +++ b/docs/api/newton.rst @@ -69,6 +69,7 @@ newton eval_ik eval_jacobian eval_mass_matrix + reset_state .. rubric:: Constants diff --git a/newton/__init__.py b/newton/__init__.py index 6b1517e1b..56d21828c 100644 --- a/newton/__init__.py +++ b/newton/__init__.py @@ -61,6 +61,7 @@ eval_ik, eval_jacobian, eval_mass_matrix, + reset_state, ) __all__ += [ @@ -78,6 +79,7 @@ "eval_ik", "eval_jacobian", "eval_mass_matrix", + "reset_state", ] # ================================================================================== diff --git a/newton/_src/sim/__init__.py b/newton/_src/sim/__init__.py index bd68e0a42..60539cbe2 100644 --- a/newton/_src/sim/__init__.py +++ b/newton/_src/sim/__init__.py @@ -30,4 +30,19 @@ "eval_ik", "eval_jacobian", "eval_mass_matrix", + "reset_state", ] + + +def reset_state(model: Model, state: State, eval_fk: bool = True) -> None: + """Reset a state to the model's initial configuration. + + Convenience wrapper for :meth:`Model.reset_state`. See that method for + full documentation. + + Args: + model: The model whose initial configuration to restore. + state: The state object to reset. + eval_fk: Whether to re-evaluate forward kinematics. + """ + model.reset_state(state, eval_fk=eval_fk) diff --git a/newton/_src/sim/model.py b/newton/_src/sim/model.py index a7539c66f..488259828 100644 --- a/newton/_src/sim/model.py +++ b/newton/_src/sim/model.py @@ -849,6 +849,48 @@ def state(self, requires_grad: bool | None = None) -> State: return s + def reset_state(self, state: State, eval_fk: bool = True) -> None: + """ + Reset a :class:`State` to this model's initial configuration in-place. + + Copies the model's initial position and velocity arrays into ``state`` + and zeroes all force arrays. Unlike :meth:`state`, this reuses the + existing GPU allocations -- no new arrays are created. + + Args: + state: The state object to reset (must have been created by this model). + eval_fk: If True and the model has joints, re-evaluate forward + kinematics so that :attr:`State.body_q` and :attr:`State.body_qd` + are consistent with the restored joint coordinates. + """ + if self.particle_count: + wp.copy(state.particle_q, self.particle_q) + wp.copy(state.particle_qd, self.particle_qd) + state.particle_f.zero_() + + if self.body_count: + wp.copy(state.body_q, self.body_q) + wp.copy(state.body_qd, self.body_qd) + state.body_f.zero_() + if getattr(state, "body_q_prev", None) is not None: + wp.copy(state.body_q_prev, self.body_q) + if getattr(state, "body_qdd", None) is not None: + state.body_qdd.zero_() + if getattr(state, "body_parent_f", None) is not None: + state.body_parent_f.zero_() + + if self.joint_count: + wp.copy(state.joint_q, self.joint_q) + wp.copy(state.joint_qd, self.joint_qd) + mujoco_ns = getattr(state, "mujoco", None) + if mujoco_ns is not None and getattr(mujoco_ns, "qfrc_actuator", None) is not None: + mujoco_ns.qfrc_actuator.zero_() + + if eval_fk and self.joint_count: + from .articulation import eval_fk as _eval_fk # noqa: PLC0415 + + _eval_fk(self, self.joint_q, self.joint_qd, state) + def control(self, requires_grad: bool | None = None, clone_variables: bool = True) -> Control: """ Create and return a new :class:`Control` object for this model. diff --git a/newton/_src/viewer/viewer.py b/newton/_src/viewer/viewer.py index d92ed044f..a01f45876 100644 --- a/newton/_src/viewer/viewer.py +++ b/newton/_src/viewer/viewer.py @@ -59,6 +59,32 @@ def is_paused(self) -> bool: """ return False + def is_reset_requested(self) -> bool: + """Report whether a simulation reset has been requested. + + The flag is set by viewer UI controls (e.g. the *R* key or + *Reset* button in :class:`ViewerGL`). Callers should check this + once per frame before stepping the simulation and call + :meth:`clear_reset_request` after handling the reset. + + Returns: + bool: True when a reset has been requested. + """ + return self._reset_requested + + def clear_reset_request(self) -> None: + """Clear the reset-requested flag after the reset has been handled.""" + self._reset_requested = False + + def request_reset(self) -> None: + """Request a simulation reset. + + Sets the internal flag that :meth:`is_reset_requested` queries. + The next frame's run-loop iteration (or user code) should detect + the flag, perform the reset, and call :meth:`clear_reset_request`. + """ + self._reset_requested = True + def is_key_down(self, key: str | int) -> bool: """Default key query API. Concrete viewers can override. @@ -111,6 +137,9 @@ def clear_model(self) -> None: # Picking self.picking_enabled = True + # Reset signal + self._reset_requested = False + # Display options self.show_joints = False self.show_com = False diff --git a/newton/_src/viewer/viewer_gl.py b/newton/_src/viewer/viewer_gl.py index b1b5339e1..b3f704cc0 100644 --- a/newton/_src/viewer/viewer_gl.py +++ b/newton/_src/viewer/viewer_gl.py @@ -1637,6 +1637,9 @@ def on_key_press(self, symbol: int, modifiers: int): elif symbol == pyglet.window.key.F: # Frame camera around model bounds self._frame_camera_on_model() + elif symbol == pyglet.window.key.R: + # Request simulation reset + self.request_reset() elif symbol == pyglet.window.key.ESCAPE: # Exit with Escape key self.renderer.close() @@ -1967,6 +1970,8 @@ def _render_left_panel(self): # Pause simulation checkbox changed, self._paused = imgui.checkbox("Pause", self._paused) + if imgui.button("Reset"): + self.request_reset() # Visualization Controls section imgui.set_next_item_open(True, imgui.Cond_.appearing) diff --git a/newton/examples/__init__.py b/newton/examples/__init__.py index a188a16b0..9e472fb35 100644 --- a/newton/examples/__init__.py +++ b/newton/examples/__init__.py @@ -187,7 +187,6 @@ class _ExampleBrowser: def __init__(self, viewer): self.viewer = viewer self.switch_target: str | None = None - self._reset_requested = False self.callback = None self._tree: dict[str, list[tuple[str, str]]] = {} @@ -214,7 +213,7 @@ def _browser_ui(imgui): imgui.tree_pop() imgui.separator() if imgui.button("Reset"): - self._reset_requested = True + self.viewer.request_reset() self.callback = _browser_ui viewer.register_ui_callback(_browser_ui, position="panel") @@ -240,7 +239,6 @@ def switch(self, example_class): def reset(self, example_class): """Reset the current example by re-creating it. Returns the new example or None.""" - self._reset_requested = False self.viewer.clear_model() try: parser = getattr(example_class, "create_parser", create_parser)() @@ -279,8 +277,17 @@ def run(example, args): example, example_class = browser.switch(example_class) continue - if browser is not None and browser._reset_requested: - example = browser.reset(example_class) + if viewer.is_reset_requested(): + viewer.clear_reset_request() + if hasattr(example, "reset"): + example.reset() + elif hasattr(example, "model"): + for attr in ("state_0", "state_1"): + s = getattr(example, attr, None) + if s is not None: + example.model.reset_state(s) + if hasattr(example, "sim_time"): + example.sim_time = 0.0 continue if example is None: diff --git a/newton/tests/test_reset_state.py b/newton/tests/test_reset_state.py new file mode 100644 index 000000000..0b84bbea0 --- /dev/null +++ b/newton/tests/test_reset_state.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Model.reset_state().""" + +import unittest + +import numpy as np +import warp as wp + +import newton + + +class TestResetState(unittest.TestCase): + """Tests that Model.reset_state() restores state arrays in-place.""" + + def _build_body_model(self): + """Build a model with one free body and a sphere shape.""" + builder = newton.ModelBuilder() + body = builder.add_body(mass=1.0) + builder.add_shape_sphere(body, radius=0.1) + return builder.finalize() + + def _build_particle_model(self): + """Build a model with 2 particles.""" + builder = newton.ModelBuilder() + builder.add_particle(pos=(1.0, 2.0, 3.0), vel=(0.1, 0.2, 0.3), mass=1.0) + builder.add_particle(pos=(4.0, 5.0, 6.0), vel=(0.4, 0.5, 0.6), mass=1.0) + return builder.finalize() + + def _build_articulation_model(self): + """Build a model with a revolute joint articulation. + + Sets non-zero joint offsets so that FK-computed body transforms + differ from the raw model.body_q defaults. + """ + builder = newton.ModelBuilder() + link0 = builder.add_link(mass=1.0) + builder.add_shape_sphere(link0, radius=0.1) + link1 = builder.add_link(mass=1.0) + builder.add_shape_sphere(link1, radius=0.1) + j0 = builder.add_joint_revolute( + parent=-1, + child=link0, + parent_xform=wp.transform((0.0, 1.0, 0.0), wp.quat_identity()), + ) + j1 = builder.add_joint_revolute( + parent=link0, + child=link1, + parent_xform=wp.transform((0.0, 1.0, 0.0), wp.quat_identity()), + ) + builder.add_articulation([j0, j1]) + model = builder.finalize() + + # Set raw model.body_q to zeros so it differs from FK-computed values + model.body_q.zero_() + model.body_qd.zero_() + + return model + + def test_reset_restores_body_state(self): + model = self._build_body_model() + state = model.state() + + # Save initial values + initial_body_q = state.body_q.numpy().copy() + initial_body_qd = state.body_qd.numpy().copy() + + # Mutate body arrays with 999.0 + junk_q = wp.array(np.full_like(initial_body_q, 999.0), dtype=state.body_q.dtype) + junk_qd = wp.array(np.full_like(initial_body_qd, 999.0), dtype=state.body_qd.dtype) + wp.copy(state.body_q, junk_q) + wp.copy(state.body_qd, junk_qd) + + # Seed body_f with non-zero to verify zeroing is meaningful + junk_f = wp.array(np.full_like(state.body_f.numpy(), 42.0), dtype=state.body_f.dtype) + wp.copy(state.body_f, junk_f) + self.assertTrue(np.any(state.body_f.numpy() != 0.0)) + + # Verify mutation took effect + np.testing.assert_array_equal(state.body_q.numpy(), junk_q.numpy()) + + # Reset + model.reset_state(state) + + # Verify body_q and body_qd restored + np.testing.assert_array_equal(state.body_q.numpy(), initial_body_q) + np.testing.assert_array_equal(state.body_qd.numpy(), initial_body_qd) + + # Verify body_f is zeroed + np.testing.assert_array_equal(state.body_f.numpy(), np.zeros_like(state.body_f.numpy())) + + def test_reset_restores_particle_state(self): + model = self._build_particle_model() + state = model.state() + + # Save initial values + initial_particle_q = state.particle_q.numpy().copy() + initial_particle_qd = state.particle_qd.numpy().copy() + + # Mutate both position and velocity + junk_q = wp.array(np.full_like(initial_particle_q, 999.0), dtype=state.particle_q.dtype) + junk_qd = wp.array(np.full_like(initial_particle_qd, 999.0), dtype=state.particle_qd.dtype) + wp.copy(state.particle_q, junk_q) + wp.copy(state.particle_qd, junk_qd) + + # Seed particle_f with non-zero to verify zeroing is meaningful + junk_f = wp.array(np.full_like(state.particle_f.numpy(), 42.0), dtype=state.particle_f.dtype) + wp.copy(state.particle_f, junk_f) + self.assertTrue(np.any(state.particle_f.numpy() != 0.0)) + + # Reset + model.reset_state(state) + + # Verify particle arrays restored + np.testing.assert_array_equal(state.particle_q.numpy(), initial_particle_q) + np.testing.assert_array_equal(state.particle_qd.numpy(), initial_particle_qd) + + # Verify particle_f is zeroed + np.testing.assert_array_equal(state.particle_f.numpy(), np.zeros_like(state.particle_f.numpy())) + + def test_reset_restores_joint_state(self): + model = self._build_articulation_model() + state = model.state() + + # Save initial joint values + initial_joint_q = state.joint_q.numpy().copy() + initial_joint_qd = state.joint_qd.numpy().copy() + + # Mutate both joint_q and joint_qd + junk_q = wp.array(np.full_like(initial_joint_q, 999.0), dtype=state.joint_q.dtype) + junk_qd = wp.array(np.full_like(initial_joint_qd, 999.0), dtype=state.joint_qd.dtype) + wp.copy(state.joint_q, junk_q) + wp.copy(state.joint_qd, junk_qd) + + # Reset + model.reset_state(state) + + # Verify joint arrays restored + np.testing.assert_array_equal(state.joint_q.numpy(), initial_joint_q) + np.testing.assert_array_equal(state.joint_qd.numpy(), initial_joint_qd) + + def test_reset_with_eval_fk(self): + model = self._build_articulation_model() + state = model.state() + + # Compute FK to get expected body transforms (differs from raw model.body_q) + newton.eval_fk(model, state.joint_q, state.joint_qd, state) + expected_body_q = state.body_q.numpy().copy() + expected_body_qd = state.body_qd.numpy().copy() + + # Verify FK result differs from raw model values + self.assertFalse(np.array_equal(expected_body_q, model.body_q.numpy())) + + # Mutate body_q + junk = wp.array(np.full_like(expected_body_q, 999.0), dtype=state.body_q.dtype) + wp.copy(state.body_q, junk) + + # Reset with eval_fk=True (the default) + model.reset_state(state, eval_fk=True) + + # Verify body_q matches FK-computed values, not raw model values + np.testing.assert_allclose(state.body_q.numpy(), expected_body_q, atol=1e-5) + np.testing.assert_allclose(state.body_qd.numpy(), expected_body_qd, atol=1e-5) + + def test_reset_without_eval_fk(self): + model = self._build_articulation_model() + state = model.state() + + # Compute FK first so we know what the FK result would be + newton.eval_fk(model, state.joint_q, state.joint_qd, state) + fk_body_q = state.body_q.numpy().copy() + + # Raw model values should differ from FK + raw_body_q = model.body_q.numpy().copy() + raw_body_qd = model.body_qd.numpy().copy() + self.assertFalse(np.array_equal(raw_body_q, fk_body_q)) + + # Mutate body_q + junk = wp.array(np.full_like(raw_body_q, 999.0), dtype=state.body_q.dtype) + wp.copy(state.body_q, junk) + + # Reset with eval_fk=False + model.reset_state(state, eval_fk=False) + + # Verify body_q matches raw model values, NOT FK-computed + np.testing.assert_array_equal(state.body_q.numpy(), raw_body_q) + np.testing.assert_array_equal(state.body_qd.numpy(), raw_body_qd) + + def test_reset_does_not_reallocate(self): + model = self._build_body_model() + state = model.state() + + # Record pointer + ptr_before = state.body_q.ptr + + # Reset + model.reset_state(state) + + # Verify pointer unchanged (no reallocation) + self.assertEqual(state.body_q.ptr, ptr_before) + + def test_reset_zeroes_extended_body_buffers(self): + """Extended state attributes (body_qdd, body_parent_f) are zeroed.""" + builder = newton.ModelBuilder() + builder.add_body(mass=1.0) + builder.add_shape_sphere(body=0, radius=0.1) + builder.request_state_attributes("body_qdd", "body_parent_f") + model = builder.finalize() + state = model.state() + + # Seed extended buffers with non-zero + junk = wp.array(np.full_like(state.body_qdd.numpy(), 42.0), dtype=state.body_qdd.dtype) + wp.copy(state.body_qdd, junk) + wp.copy(state.body_parent_f, junk) + self.assertTrue(np.any(state.body_qdd.numpy() != 0.0)) + + model.reset_state(state) + + np.testing.assert_array_equal(state.body_qdd.numpy(), np.zeros_like(state.body_qdd.numpy())) + np.testing.assert_array_equal(state.body_parent_f.numpy(), np.zeros_like(state.body_parent_f.numpy())) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/newton/tests/test_viewer_reset.py b/newton/tests/test_viewer_reset.py new file mode 100644 index 000000000..95a943cb1 --- /dev/null +++ b/newton/tests/test_viewer_reset.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from newton._src.viewer.viewer import ViewerBase + + +class _StubViewer(ViewerBase): + """Minimal concrete subclass of ViewerBase for testing.""" + + def end_frame(self): + pass + + def log_mesh(self, name, vertices, indices, colors=None, smooth_shading=True): + pass + + def log_instances(self, name, mesh_name, positions, rotations, colors=None, scalings=None): + pass + + def log_lines(self, name, vertices_start, vertices_end, colors=None, radius=0.001): + pass + + def log_points(self, name, positions, colors=None, radii=None, radius=0.01): + pass + + def log_array(self, name, array): + pass + + def log_scalar(self, name, value): + pass + + def apply_forces(self, state): + pass + + def close(self): + pass + + +class TestViewerResetSignal(unittest.TestCase): + """Tests for the ViewerBase reset signal API.""" + + def test_initial_state_not_requested(self): + """Fresh viewer has no reset requested.""" + viewer = _StubViewer() + self.assertFalse(viewer.is_reset_requested()) + + def test_request_reset(self): + """request_reset() sets the flag visible via is_reset_requested.""" + viewer = _StubViewer() + viewer.request_reset() + self.assertTrue(viewer.is_reset_requested()) + + def test_clear_reset_request(self): + """clear_reset_request resets the flag to False.""" + viewer = _StubViewer() + viewer.request_reset() + viewer.clear_reset_request() + self.assertFalse(viewer.is_reset_requested()) + + def test_clear_model_resets_flag(self): + """clear_model() should clear the reset flag.""" + viewer = _StubViewer() + viewer.request_reset() + viewer.clear_model() + self.assertFalse(viewer.is_reset_requested()) + + +if __name__ == "__main__": + unittest.main()