-
Notifications
You must be signed in to change notification settings - Fork 455
Add built-in simulation reset: Model.reset_state() + viewer R key #2468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b4a0509
d7480db
14f3b01
d6c54f0
50afe72
896d803
7e9ec0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -849,6 +849,39 @@ def state(self, requires_grad: bool | None = None) -> State: | |
|
|
||
| return s | ||
|
|
||
| def reset_state(self, state: State, eval_fk: bool = True) -> None: | ||
| """ | ||
| Reset a :class:`State` to this model's initial configuration in-place. | ||
|
|
||
| Copies the model's initial position and velocity arrays into ``state`` | ||
| and zeroes all force arrays. Unlike :meth:`state`, this reuses the | ||
| existing GPU allocations -- no new arrays are created. | ||
|
|
||
| Args: | ||
| state: The state object to reset (must have been created by this model). | ||
| eval_fk: If True and the model has joints, re-evaluate forward | ||
| kinematics so that :attr:`State.body_q` and :attr:`State.body_qd` | ||
| are consistent with the restored joint coordinates. | ||
| """ | ||
| if self.particle_count: | ||
| wp.copy(state.particle_q, self.particle_q) | ||
| wp.copy(state.particle_qd, self.particle_qd) | ||
| state.particle_f.zero_() | ||
|
|
||
| if self.body_count: | ||
| wp.copy(state.body_q, self.body_q) | ||
| wp.copy(state.body_qd, self.body_qd) | ||
| state.body_f.zero_() | ||
|
|
||
| if self.joint_count: | ||
| wp.copy(state.joint_q, self.joint_q) | ||
| wp.copy(state.joint_qd, self.joint_qd) | ||
|
|
||
| if eval_fk and self.joint_count: | ||
| from .articulation import eval_fk as _eval_fk # noqa: PLC0415 | ||
|
|
||
| _eval_fk(self, self.joint_q, self.joint_qd, state) | ||
|
|
||
|
Comment on lines
+852
to
+893
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reset registered custom
🤖 Prompt for AI Agents |
||
| def control(self, requires_grad: bool | None = None, clone_variables: bool = True) -> Control: | ||
| """ | ||
| Create and return a new :class:`Control` object for this model. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,7 +187,6 @@ class _ExampleBrowser: | |
| def __init__(self, viewer): | ||
| self.viewer = viewer | ||
| self.switch_target: str | None = None | ||
| self._reset_requested = False | ||
| self.callback = None | ||
| self._tree: dict[str, list[tuple[str, str]]] = {} | ||
|
|
||
|
|
@@ -214,7 +213,7 @@ def _browser_ui(imgui): | |
| imgui.tree_pop() | ||
| imgui.separator() | ||
| if imgui.button("Reset"): | ||
| self._reset_requested = True | ||
| self.viewer._reset_requested = True | ||
|
|
||
| self.callback = _browser_ui | ||
| viewer.register_ui_callback(_browser_ui, position="panel") | ||
|
|
@@ -240,7 +239,6 @@ def switch(self, example_class): | |
|
|
||
| def reset(self, example_class): | ||
| """Reset the current example by re-creating it. Returns the new example or None.""" | ||
| self._reset_requested = False | ||
| self.viewer.clear_model() | ||
| try: | ||
| parser = getattr(example_class, "create_parser", create_parser)() | ||
|
|
@@ -279,8 +277,17 @@ def run(example, args): | |
| example, example_class = browser.switch(example_class) | ||
| continue | ||
|
|
||
| if browser is not None and browser._reset_requested: | ||
| example = browser.reset(example_class) | ||
| if viewer.is_reset_requested(): | ||
| viewer.clear_reset_request() | ||
| if hasattr(example, "reset"): | ||
| example.reset() | ||
| elif hasattr(example, "model"): | ||
| for attr in ("state_0", "state_1"): | ||
| s = getattr(example, attr, None) | ||
| if s is not None: | ||
| example.model.reset_state(s) | ||
| if hasattr(example, "sim_time"): | ||
| example.sim_time = 0.0 | ||
|
Comment on lines
+282
to
+290
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reset The generic reset only rewinds 💡 Minimal fix elif hasattr(example, "model"):
for attr in ("state_0", "state_1"):
s = getattr(example, attr, None)
if s is not None:
example.model.reset_state(s)
+ if getattr(example, "control", None) is not None:
+ example.control = example.model.control()
if hasattr(example, "sim_time"):
example.sim_time = 0.0🤖 Prompt for AI Agents |
||
| continue | ||
|
|
||
| if example is None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 The Newton Developers | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Tests for Model.reset_state().""" | ||
|
|
||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import warp as wp | ||
|
|
||
| import newton | ||
|
|
||
|
|
||
| class TestResetState(unittest.TestCase): | ||
| """Tests that Model.reset_state() restores state arrays in-place.""" | ||
|
|
||
| def _build_body_model(self): | ||
| """Build a model with one free body and a sphere shape.""" | ||
| builder = newton.ModelBuilder() | ||
| body = builder.add_body(mass=1.0) | ||
| builder.add_shape_sphere(body, radius=0.1) | ||
| return builder.finalize() | ||
|
|
||
| def _build_particle_model(self): | ||
| """Build a model with 2 particles.""" | ||
| builder = newton.ModelBuilder() | ||
| builder.add_particle(pos=(1.0, 2.0, 3.0), vel=(0.1, 0.2, 0.3), mass=1.0) | ||
| builder.add_particle(pos=(4.0, 5.0, 6.0), vel=(0.4, 0.5, 0.6), mass=1.0) | ||
| return builder.finalize() | ||
|
|
||
| def _build_articulation_model(self): | ||
| """Build a model with a revolute joint articulation.""" | ||
| builder = newton.ModelBuilder() | ||
| link0 = builder.add_link(mass=1.0) | ||
| builder.add_shape_sphere(link0, radius=0.1) | ||
| link1 = builder.add_link(mass=1.0) | ||
| builder.add_shape_sphere(link1, radius=0.1) | ||
| j0 = builder.add_joint_revolute(parent=-1, child=link0) | ||
| j1 = builder.add_joint_revolute(parent=link0, child=link1) | ||
| builder.add_articulation([j0, j1]) | ||
| return builder.finalize() | ||
|
|
||
| def test_reset_restores_body_state(self): | ||
| model = self._build_body_model() | ||
| state = model.state() | ||
|
|
||
| # Save initial values | ||
| initial_body_q = state.body_q.numpy().copy() | ||
| initial_body_qd = state.body_qd.numpy().copy() | ||
|
|
||
| # Mutate body arrays with 999.0 | ||
| junk_q = wp.array(np.full_like(initial_body_q, 999.0), dtype=state.body_q.dtype) | ||
| junk_qd = wp.array(np.full_like(initial_body_qd, 999.0), dtype=state.body_qd.dtype) | ||
| wp.copy(state.body_q, junk_q) | ||
| wp.copy(state.body_qd, junk_qd) | ||
|
|
||
| # Verify mutation took effect | ||
| np.testing.assert_array_equal(state.body_q.numpy(), junk_q.numpy()) | ||
|
|
||
| # Reset | ||
| model.reset_state(state) | ||
|
|
||
| # Verify body_q and body_qd restored | ||
| np.testing.assert_array_equal(state.body_q.numpy(), initial_body_q) | ||
| np.testing.assert_array_equal(state.body_qd.numpy(), initial_body_qd) | ||
|
|
||
| # Verify body_f is zeroed | ||
| np.testing.assert_array_equal(state.body_f.numpy(), np.zeros_like(state.body_f.numpy())) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| def test_reset_restores_particle_state(self): | ||
| model = self._build_particle_model() | ||
| state = model.state() | ||
|
|
||
| # Save initial values | ||
| initial_particle_q = state.particle_q.numpy().copy() | ||
| initial_particle_qd = state.particle_qd.numpy().copy() | ||
|
|
||
| # Mutate particle_q | ||
| junk = wp.array(np.full_like(initial_particle_q, 999.0), dtype=state.particle_q.dtype) | ||
| wp.copy(state.particle_q, junk) | ||
|
|
||
| # Reset | ||
| model.reset_state(state) | ||
|
|
||
| # Verify particle arrays restored | ||
| np.testing.assert_array_equal(state.particle_q.numpy(), initial_particle_q) | ||
| np.testing.assert_array_equal(state.particle_qd.numpy(), initial_particle_qd) | ||
|
|
||
| # Verify particle_f is zeroed | ||
| np.testing.assert_array_equal(state.particle_f.numpy(), np.zeros_like(state.particle_f.numpy())) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| def test_reset_restores_joint_state(self): | ||
| model = self._build_articulation_model() | ||
| state = model.state() | ||
|
|
||
| # Save initial joint values | ||
| initial_joint_q = state.joint_q.numpy().copy() | ||
| initial_joint_qd = state.joint_qd.numpy().copy() | ||
|
|
||
| # Mutate joint_q | ||
| junk = wp.array(np.full_like(initial_joint_q, 999.0), dtype=state.joint_q.dtype) | ||
| wp.copy(state.joint_q, junk) | ||
|
|
||
| # Reset | ||
| model.reset_state(state) | ||
|
|
||
| # Verify joint arrays restored | ||
| np.testing.assert_array_equal(state.joint_q.numpy(), initial_joint_q) | ||
| np.testing.assert_array_equal(state.joint_qd.numpy(), initial_joint_qd) | ||
|
|
||
| def test_reset_with_eval_fk(self): | ||
| model = self._build_articulation_model() | ||
| state = model.state() | ||
|
|
||
| # Run FK to get expected body_q | ||
| newton.eval_fk(model, state.joint_q, state.joint_qd, state) | ||
| expected_body_q = state.body_q.numpy().copy() | ||
|
|
||
| # Mutate body_q | ||
| junk = wp.array(np.full_like(expected_body_q, 999.0), dtype=state.body_q.dtype) | ||
| wp.copy(state.body_q, junk) | ||
|
|
||
| # Reset with eval_fk=True (the default) | ||
| model.reset_state(state, eval_fk=True) | ||
|
|
||
| # Verify body_q matches FK-computed values | ||
| np.testing.assert_allclose(state.body_q.numpy(), expected_body_q, atol=1e-5) | ||
|
|
||
| def test_reset_without_eval_fk(self): | ||
| model = self._build_articulation_model() | ||
| state = model.state() | ||
|
|
||
| # Get the raw model body_q (not FK-computed) | ||
| raw_body_q = model.body_q.numpy().copy() | ||
|
|
||
| # Mutate body_q | ||
| junk = wp.array(np.full_like(raw_body_q, 999.0), dtype=state.body_q.dtype) | ||
| wp.copy(state.body_q, junk) | ||
|
|
||
| # Reset with eval_fk=False | ||
| model.reset_state(state, eval_fk=False) | ||
|
|
||
| # Verify body_q matches raw model values, not FK-computed | ||
| np.testing.assert_array_equal(state.body_q.numpy(), raw_body_q) | ||
|
|
||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| def test_reset_does_not_reallocate(self): | ||
| model = self._build_body_model() | ||
| state = model.state() | ||
|
|
||
| # Record pointer | ||
| ptr_before = state.body_q.ptr | ||
|
|
||
| # Reset | ||
| model.reset_state(state) | ||
|
|
||
| # Verify pointer unchanged (no reallocation) | ||
| self.assertEqual(state.body_q.ptr, ptr_before) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main(verbosity=2) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: newton-physics/newton
Length of output: 504
🏁 Script executed:
Repository: newton-physics/newton
Length of output: 2281
🏁 Script executed:
Repository: newton-physics/newton
Length of output: 1430
Run
docs/generate_api.pyto include the newreset_stateexport in the API documentation.The symbol was added to the public API in
newton/__init__.pybut is missing from the generated API reference file (docs/api/newton.rst). Running the script will ensure the documentation stays in sync with the code.🤖 Prompt for AI Agents