From ee3fe4f164b9de354bc79341ac803504746cccdb Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 17:43:15 +0100 Subject: [PATCH 01/70] Refactor _validate_plan_forward to use option model directly Delegate option execution to option_model.get_next_state_and_num_actions instead of duplicating its termination logic (stuck detection, Wait atom-change checks) and directly accessing its simulator. --- .../approaches/agent_bilevel_approach.py | 79 ++++--------------- 1 file changed, 15 insertions(+), 64 deletions(-) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 7dbb7819c..17aaa8967 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -574,87 +574,38 @@ def _validate_plan_forward( Returns True if the plan reaches the goal, False otherwise. """ state = task.init - option_names = cast( # pylint: disable=protected-access - Any, self._option_model)._name_to_parameterized_option predicates = self._get_all_predicates() total_actions = 0 for i, grounded in enumerate(plan): - # Create a fresh option copy (same as the option model does). - env_param_opt = option_names.get(grounded.parent.name, - grounded.parent) - option_copy = env_param_opt.ground(grounded.objects, - grounded.params.copy()) - # Propagate Wait target atoms through re-grounding - for key in ("wait_target_atoms", "wait_target_neg_atoms"): - if key in grounded.memory: - option_copy.memory[key] = grounded.memory[key] - - if not option_copy.initiable(state): + if not grounded.initiable(state): logging.info(f"Forward validation: step {i} " - f"({option_copy.name}) not initiable.") - return False - - # Build a terminal condition that mirrors the option model: - # 1. The option's own terminal - # 2. terminate_on_repeat (stuck detection) - # 3. wait_option_terminate_on_atom_change - last_state_ref: List[Optional[State]] = [None] - abstract_fn = lambda s, _p=predicates: utils.abstract(s, _p) - - def _terminal( # pylint: disable=cell-var-from-loop - s: State, - oc: _Option = option_copy, - _abs: Callable = abstract_fn) -> bool: - if oc.terminal(s): - return True - prev = last_state_ref[0] - if prev is not None: - if (CFG.option_model_terminate_on_repeat - and prev.allclose(s)): - raise utils.OptionExecutionFailure( - f"Option '{oc.name}' got stuck.") - if (CFG.wait_option_terminate_on_atom_change - and oc.name == "Wait"): - result = utils.check_wait_target_atoms(oc, s, _abs) - if result is True: - last_state_ref[0] = s - return True - if result is None: - cur_atoms = _abs(s) - prev_atoms = _abs(prev) - if cur_atoms != prev_atoms: - last_state_ref[0] = s - return True - last_state_ref[0] = s + f"({grounded.name}) not initiable.") return False try: - sim = cast( # pylint: disable=protected-access - Any, self._option_model)._simulator - traj = utils.run_policy_with_simulator( - option_copy.policy, - sim, - state, - _terminal, - max_num_steps=CFG.max_num_steps_option_rollout) - except (utils.OptionExecutionFailure, - utils.EnvironmentFailure) as e: + next_state, num_actions = \ + self._option_model.get_next_state_and_num_actions( + state, grounded) + except utils.EnvironmentFailure as e: logging.info(f"Forward validation: step {i} " - f"({option_copy.name}) failed: {e}") + f"({grounded.name}) failed: {e}") return False - if len(traj.actions) == 0: + if num_actions == 0: + reason = cast(Any, self._option_model) \ + .last_execution_failure or \ + "produced 0 actions" logging.info(f"Forward validation: step {i} " - f"({option_copy.name}) produced 0 actions.") + f"({grounded.name}) failed: {reason}") return False - total_actions += len(traj.actions) - state = traj.states[-1] + total_actions += num_actions + state = next_state atoms = utils.abstract(state, predicates) logging.debug( f"Forward validation: step {i} " - f"({option_copy.name}) OK, {len(traj.actions)} actions. " + f"({grounded.name}) OK, {num_actions} actions. " f"Atoms: {sorted(str(a) for a in atoms)}") if not task.goal_holds(state): From 58b86cd9deeafb874f1750f93c4e06d26a15a37a Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 18:46:04 +0100 Subject: [PATCH 02/70] Unify backtracking refinement search into shared run_backtracking_refinement Extract the duplicated backtracking loop from run_low_level_search (SeSamE) and _refine_sketch (agent bilevel) into a single run_backtracking_refinement function in planning.py. Both callers now delegate to it with their own sample_fn and validate_fn callbacks, eliminating ~80 lines of duplicated loop/backtracking logic. --- .../approaches/agent_bilevel_approach.py | 166 +++----- predicators/planning.py | 368 ++++++++++-------- 2 files changed, 256 insertions(+), 278 deletions(-) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 17aaa8967..368d83c10 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -23,6 +23,7 @@ from predicators import utils from predicators.approaches import ApproachFailure from predicators.approaches.agent_planner_approach import AgentPlannerApproach +from predicators.planning import run_backtracking_refinement from predicators.settings import CFG from predicators.structs import Action, GroundAtom, Object, \ ParameterizedOption, Predicate, State, Task, _Option @@ -409,6 +410,8 @@ def _refine_sketch( Returns ``(plan, success)``. On success, ``plan`` is a list of grounded options that achieves the task goal. On failure, ``plan`` is the longest partial refinement found. + + Delegates to ``run_backtracking_refinement`` for the core loop. """ if not sketch: return [], False @@ -416,131 +419,65 @@ def _refine_sketch( rng = np.random.default_rng(CFG.seed) max_samples = CFG.agent_bilevel_max_samples_per_step check_subgoals = CFG.agent_bilevel_check_subgoals - start_time = time.perf_counter() - n = len(sketch) - cur_idx = 0 - num_tries = [0] * n max_tries = [ max_samples if step.option.params_space.shape[0] > 0 else 1 for step in sketch ] - plan: List[Optional[_Option]] = [None] * n - traj: List[Optional[State]] = [task.init] + [None] * n - - total_samples = 0 - - while cur_idx < n: - elapsed = time.perf_counter() - start_time - if elapsed > timeout: - logging.info( - f"Sketch refinement timed out after {elapsed:.1f}s " - f"at step {cur_idx}/{n}, {total_samples} total samples.") - return [p for p in plan if p is not None], False - - step = sketch[cur_idx] - num_tries[cur_idx] += 1 - total_samples += 1 - step_name = (f"{step.option.name}" - f"({', '.join(o.name for o in step.objects)})") - - # Optionally log state before sampling - cur_state = traj[cur_idx] - assert cur_state is not None, f"traj[{cur_idx}] should not be None" + predicates = self._get_all_predicates() + def sample_fn(idx: int, state: State, + rng_: np.random.Generator) -> _Option: + step = sketch[idx] if CFG.agent_bilevel_log_state: + step_name = (f"{step.option.name}" + f"({', '.join(o.name for o in step.objects)})") logging.debug(f" State before {step_name}:\n" - f"{cur_state.pretty_str()}") - - # Sample continuous parameters and ground option - params = self._sample_params(step.option, cur_state, rng) + f"{state.pretty_str()}") + params = self._sample_params(step.option, state, rng_) grounded = step.option.ground(step.objects, params) - # Inject Wait target atoms from sketch annotations if grounded.name == "Wait": if step.subgoal_atoms is not None: - grounded.memory["wait_target_atoms"] = step.subgoal_atoms + grounded.memory["wait_target_atoms"] = \ + step.subgoal_atoms if step.subgoal_neg_atoms is not None: grounded.memory["wait_target_neg_atoms"] = \ step.subgoal_neg_atoms - plan[cur_idx] = grounded - - state = cur_state - can_continue = False - fail_reason = "not initiable" - - if grounded.initiable(state): - try: - next_state, num_actions = \ - self._option_model.get_next_state_and_num_actions( - state, grounded) - except utils.EnvironmentFailure as e: - fail_reason = f"env failure: {e}" - else: - if num_actions == 0: - model = self._option_model - fail_reason = ( - getattr( # type: ignore[attr-defined] - model, "last_execution_failure", None) - or "0 actions") - else: - traj[cur_idx + 1] = next_state - # Check subgoals if specified - if (check_subgoals and step.subgoal_atoms is not None): - current_atoms = utils.abstract( - next_state, self._get_all_predicates()) - if step.subgoal_atoms.issubset(current_atoms): - can_continue = True - else: - missing = step.subgoal_atoms - current_atoms - fail_reason = ( - f"subgoal missing: " - f"{{{', '.join(str(a) for a in missing)}}}" - ) - else: - can_continue = True - # Final step: also check task goal - if can_continue and cur_idx == n - 1: - if not task.goal_holds(next_state): - can_continue = False - fail_reason = "goal not reached" - - if can_continue: - logging.info( - f" Step {cur_idx}/{n} {step_name} OK " - f"(sample {num_tries[cur_idx]}/{max_tries[cur_idx]})\n") - if CFG.agent_bilevel_log_state: - next_st = traj[cur_idx + 1] - assert next_st is not None - logging.debug(f" State after {step_name}:\n" - f"{next_st.pretty_str()}") - cur_idx += 1 - else: - logging.debug( - f" Step {cur_idx}/{n} {step_name} FAIL " - f"(sample {num_tries[cur_idx]}/{max_tries[cur_idx]})" - f": {fail_reason}") - # Backtrack: re-try current step or go back further - while num_tries[cur_idx] >= max_tries[cur_idx]: - bt_objs = ", ".join(o.name - for o in sketch[cur_idx].objects) - bt_name = (f"{sketch[cur_idx].option.name}" - f"({bt_objs})") - logging.info(f" Step {cur_idx}/{n} {bt_name} exhausted " - f"{max_tries[cur_idx]} samples, backtracking") - num_tries[cur_idx] = 0 - plan[cur_idx] = None - traj[cur_idx + 1] = None - cur_idx -= 1 - if cur_idx < 0: - logging.info(f"Sketch refinement exhausted after " - f"{total_samples} total samples.") - return [], False - - # All steps succeeded - assert all(p is not None for p in plan) - logging.info(f"Refinement complete: {total_samples} total samples " - f"for {n} steps.") - return cast(List[_Option], plan), True + return grounded + + def validate_fn(idx: int, _pre_state: State, _option: _Option, + post_state: State, + _num_actions: int) -> Tuple[bool, str]: + step = sketch[idx] + if check_subgoals and step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + missing = step.subgoal_atoms - current_atoms + return False, (f"subgoal missing: " + f"{{{', '.join(str(a) for a in missing)}}}") + if idx == n - 1: + if not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + plan, success, total_samples = run_backtracking_refinement( + init_state=task.init, + option_model=self._option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + ) + + logging.info(f"Refinement {'succeeded' if success else 'failed'}: " + f"{total_samples} samples for {n} steps.") + + filtered = [p for p in plan if p is not None] + if success: + return cast(List[_Option], filtered), True + return filtered, False def _sample_params(self, option: ParameterizedOption, _state: State, rng: np.random.Generator) -> np.ndarray: @@ -603,10 +540,9 @@ def _validate_plan_forward( total_actions += num_actions state = next_state atoms = utils.abstract(state, predicates) - logging.debug( - f"Forward validation: step {i} " - f"({grounded.name}) OK, {num_actions} actions. " - f"Atoms: {sorted(str(a) for a in atoms)}") + logging.debug(f"Forward validation: step {i} " + f"({grounded.name}) OK, {num_actions} actions. " + f"Atoms: {sorted(str(a) for a in atoms)}") if not task.goal_holds(state): atoms = utils.abstract(state, predicates) diff --git a/predicators/planning.py b/predicators/planning.py index 76e9a3906..162e69443 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -16,8 +16,8 @@ from collections import defaultdict from dataclasses import dataclass from itertools import islice -from typing import Any, Collection, Dict, FrozenSet, Iterator, List, \ - Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Collection, Dict, FrozenSet, Iterator, \ + List, Optional, Sequence, Set, Tuple, Union, cast import numpy as np @@ -26,7 +26,7 @@ from predicators.refinement_estimators import BaseRefinementEstimator from predicators.settings import CFG from predicators.structs import NSRT, AbstractPolicy, CausalProcess, \ - DefaultState, DummyOption, GroundAtom, Metrics, Object, OptionSpec, \ + DefaultState, GroundAtom, Metrics, Object, OptionSpec, \ ParameterizedOption, Predicate, State, STRIPSOperator, Task, Type, \ _GroundCausalProcess, _GroundNSRT, _GroundSTRIPSOperator, _Option from predicators.utils import EnvironmentFailure, _TaskPlanningHeuristic @@ -506,6 +506,107 @@ def _skeleton_generator( raise _SkeletonSearchTimeout +def run_backtracking_refinement( + init_state: State, + option_model: _OptionModelBase, + n_steps: int, + max_tries: List[int], + sample_fn: Callable[[int, State, np.random.Generator], _Option], + validate_fn: Callable[[int, State, _Option, State, int], Tuple[bool, str]], + rng: np.random.Generator, + timeout: float, + on_env_failure: Optional[Callable[[int, _Option, EnvironmentFailure], + None]] = None, + on_step_fail: Optional[Callable[[int, List[Optional[_Option]], str], + None]] = None, + on_exhausted: Optional[Callable[[List[Optional[_Option]]], None]] = None, + step_times: Optional[List[float]] = None, +) -> Tuple[List[Optional[_Option]], bool, int]: + """Backtracking search over continuous parameters. + + Core loop shared by SeSamE low-level search and agent bilevel + refinement. Samples options via ``sample_fn``, executes them through + ``option_model``, and validates transitions via ``validate_fn``. + Backtracks when a step exhausts its sampling budget. + + Returns ``(plan, success, total_samples)`` where plan entries are + ``None`` for unrefined steps. + + Callbacks ``on_env_failure``, ``on_step_fail``, and ``on_exhausted`` + may raise to abort the search (e.g. for failure propagation). + """ + start_time = time.perf_counter() + cur_idx = 0 + num_tries_arr = [0] * n_steps + plan: List[Optional[_Option]] = [None] * n_steps + traj: List[Optional[State]] = [init_state] + [None] * n_steps + total_samples = 0 + + while cur_idx < n_steps: + if time.perf_counter() - start_time > timeout: + logging.debug( + "Backtracking refinement timed out at step " + "%d/%d.", cur_idx, n_steps) + return plan, False, total_samples + + attempt_start = time.perf_counter() + num_tries_arr[cur_idx] += 1 + total_samples += 1 + state = traj[cur_idx] + assert state is not None + + option = sample_fn(cur_idx, state, rng) + plan[cur_idx] = option + + can_continue = False + fail_reason = "not initiable" + + if option.initiable(state): + try: + next_state, num_actions = \ + option_model.get_next_state_and_num_actions( + state, option) + except EnvironmentFailure as e: + fail_reason = f"env failure: {e}" + if on_env_failure is not None: + on_env_failure(cur_idx, option, e) + else: + if num_actions == 0: + fail_reason = (getattr(option_model, + 'last_execution_failure', None) + or "0 actions") + else: + traj[cur_idx + 1] = next_state + can_continue, fail_reason = validate_fn( + cur_idx, state, option, next_state, num_actions) + + if step_times is not None: + step_times[cur_idx] += time.perf_counter() - attempt_start + + if can_continue: + cur_idx += 1 + else: + logging.debug(" Step %d/%d FAIL (attempt %d/%d): %s", cur_idx, + n_steps, num_tries_arr[cur_idx], max_tries[cur_idx], + fail_reason) + if on_step_fail is not None: + on_step_fail(cur_idx, plan, fail_reason) + while num_tries_arr[cur_idx] >= max_tries[cur_idx]: + logging.debug( + " Step %d/%d exhausted %d samples, " + "backtracking", cur_idx, n_steps, max_tries[cur_idx]) + num_tries_arr[cur_idx] = 0 + plan[cur_idx] = None + traj[cur_idx + 1] = None + cur_idx -= 1 + if cur_idx < 0: + if on_exhausted is not None: + on_exhausted(plan) + return plan, False, total_samples + + return plan, True, total_samples + + def run_low_level_search( task: Task, option_model: _OptionModelBase, @@ -525,182 +626,123 @@ def run_low_level_search( failed refinement, where the last step did not satisfy the skeleton, but all previous steps did. Note that there are multiple low-level plans in general; we return the first one found (arbitrarily). + + Delegates to ``run_backtracking_refinement`` for the core loop. """ - start_time = time.perf_counter() - rng_sampler = np.random.default_rng(seed) + if not skeleton: + return [], True + assert CFG.sesame_propagate_failures in \ {"after_exhaust", "immediately", "never"} - cur_idx = 0 - num_tries = [0 for _ in skeleton] - # Optimization: if the params_space for the NSRT option is empty, only - # sample it once, because all samples are just empty (so equivalent). + + rng = np.random.default_rng(seed) + n = len(skeleton) max_tries = [ CFG.sesame_max_samples_per_step if nsrt.option.params_space.shape[0] > 0 else 1 for nsrt in skeleton ] - plan: List[_Option] = [DummyOption for _ in skeleton] - # If refinement_time list is passed, record the refinement time - # distributed across each step of the skeleton + + # Per-step timing if refinement_time is not None: assert len(refinement_time) == 0 for _ in skeleton: refinement_time.append(0) - # The number of actions taken by each option in the plan. This is to - # make sure that we do not exceed the task horizon. - num_actions_per_option = [0 for _ in plan] - traj: List[State] = [task.init] + [DefaultState for _ in skeleton] + + # State captured by closures + discovered_failures: List[Optional[_DiscoveredFailure]] = [None] * n longest_failed_refinement: List[_Option] = [] - # We'll use a maximum of one discovered failure per step, since - # resampling can render old discovered failures obsolete. - discovered_failures: List[Optional[_DiscoveredFailure]] = [ - None for _ in skeleton - ] - plan_found = False - while cur_idx < len(skeleton): - if time.perf_counter() - start_time > timeout: - logging.debug("Exiting low-level search due to timeout.") - return longest_failed_refinement, False - assert num_tries[cur_idx] < max_tries[cur_idx] - try_start_time = time.perf_counter() - # Good debug point #2: if you have a skeleton that you think is - # reasonable, but sampling isn't working, print num_tries here to - # see at what step the backtracking search is getting stuck. - num_tries[cur_idx] += 1 - state = traj[cur_idx] - nsrt = skeleton[cur_idx] - # Ground the NSRT's ParameterizedOption into an _Option. - # This invokes the NSRT's sampler. - option = nsrt.sample_option(state, task.goal, rng_sampler) - plan[cur_idx] = option - # Increment num_samples metric by 1 + num_actions_per_option = [0] * n + + # -- callbacks -------------------------------------------------------- + + def sample_fn(idx: int, state: State, + rng_: np.random.Generator) -> _Option: + discovered_failures[idx] = None metrics["num_samples"] += 1 - # Increment cur_idx. It will be decremented later on if we get stuck. - cur_idx += 1 - if option.initiable(state): - try: - logging.info(f"Running option {option}") - next_state, num_actions = \ - option_model.get_next_state_and_num_actions(state, option) - except EnvironmentFailure as e: - logging.debug(f"Discovered a failure: {e}") - can_continue_on = False - # Remember only the most recent failure. - discovered_failures[cur_idx - 1] = _DiscoveredFailure(e, nsrt) - else: # an EnvironmentFailure was not raised - discovered_failures[cur_idx - 1] = None - num_actions_per_option[cur_idx - 1] = num_actions - traj[cur_idx] = next_state - # Check if objects that were outside the scope had a change - # in state. - static_obj_changed = False - if CFG.sesame_check_static_object_changes: - static_objs = set(state) - set(nsrt.objects) - for obj in sorted(static_objs): - if not np.allclose( - traj[cur_idx][obj], - traj[cur_idx - 1][obj], - atol=CFG.sesame_static_object_change_tol): - static_obj_changed = True - break - if static_obj_changed: - logging.debug("Cannot continue: static object changed.") - can_continue_on = False - # Check if we have exceeded the horizon in total. - elif np.sum(num_actions_per_option[:cur_idx]) > max_horizon: - logging.debug("Cannot continue: exceeded total horizon.") - can_continue_on = False - # Check if we have exceeded the horizon individually. - elif num_actions >= CFG.max_num_steps_option_rollout: - logging.debug("Cannot continue: exceeded individual " - "horizon.") - can_continue_on = False - # Check if the option was effectively a noop. - elif num_actions == 0: - logging.debug("Cannot continue: an noop") - can_continue_on = False - elif CFG.sesame_check_expected_atoms: - # Check atoms against expected atoms_sequence constraint. - assert len(traj) == len(atoms_sequence) - # The expected atoms are ones that we definitely expect to - # be true at this point in the plan. They are not *all* the - # atoms that could be true. - expected_atoms = { - atom - for atom in atoms_sequence[cur_idx] - if atom.predicate.name != _NOT_CAUSES_FAILURE - } - # This "if all" statement is equivalent to, but faster - # than, checking whether expected_atoms is a subset of - # utils.abstract(traj[cur_idx], predicates). - if all(a.holds(traj[cur_idx]) for a in expected_atoms): - can_continue_on = True - if cur_idx == len(skeleton): - plan_found = True - else: - logging.debug("Cannot continue: expected atoms not " - "hold.") - can_continue_on = False - else: - # If we're not checking expected_atoms, we need to - # explicitly check the goal on the final timestep. - can_continue_on = True - if cur_idx == len(skeleton): - if task.goal_holds(traj[cur_idx]): - plan_found = True - else: - can_continue_on = False - else: - # The option is not initiable. - logging.debug("Cannot continue: option not initiable.") - can_continue_on = False - if refinement_time is not None: - try_end_time = time.perf_counter() - refinement_time[cur_idx - 1] += try_end_time - try_start_time - if plan_found: - return plan, True # success! - if not can_continue_on: # we got stuck, time to resample / backtrack! - # Update the longest_failed_refinement found so far. - if cur_idx > len(longest_failed_refinement): - longest_failed_refinement = list(plan[:cur_idx]) - # If we're immediately propagating failures, and we got a failure, - # raise it now. We don't do this right after catching the - # EnvironmentFailure because we want to make sure to update - # the longest_failed_refinement first. - possible_failure = discovered_failures[cur_idx - 1] - if possible_failure is not None and \ + option = skeleton[idx].sample_option(state, task.goal, rng_) + logging.info(f"Running option {option}") + return option + + def validate_fn(idx: int, pre_state: State, _option: _Option, + post_state: State, num_actions: int) -> Tuple[bool, str]: + num_actions_per_option[idx] = num_actions + nsrt = skeleton[idx] + # Static object change check. + if CFG.sesame_check_static_object_changes: + static_objs = set(pre_state) - set(nsrt.objects) + for obj in sorted(static_objs): + if not np.allclose(post_state[obj], + pre_state[obj], + atol=CFG.sesame_static_object_change_tol): + return False, "static object changed" + # Horizon checks. + total_actions = sum(num_actions_per_option[:idx]) + num_actions + if total_actions > max_horizon: + return False, "exceeded total horizon" + if num_actions >= CFG.max_num_steps_option_rollout: + return False, "exceeded individual horizon" + # Expected-atoms check. + if CFG.sesame_check_expected_atoms: + expected_atoms = { + atom + for atom in atoms_sequence[idx + 1] + if atom.predicate.name != _NOT_CAUSES_FAILURE + } + if all(a.holds(post_state) for a in expected_atoms): + return True, "" + return False, "expected atoms not hold" + # No atoms check — verify goal on final step. + if idx == n - 1: + if not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + def on_env_failure(idx: int, _option: _Option, + e: EnvironmentFailure) -> None: + logging.debug(f"Discovered a failure: {e}") + discovered_failures[idx] = _DiscoveredFailure(e, skeleton[idx]) + + def on_step_fail(idx: int, plan: List[Optional[_Option]], + _reason: str) -> None: + nonlocal longest_failed_refinement + partial = [p for p in plan[:idx + 1] if p is not None] + if len(partial) > len(longest_failed_refinement): + longest_failed_refinement = list(partial) + pf = discovered_failures[idx] + if pf is not None and \ CFG.sesame_propagate_failures == "immediately": + raise _DiscoveredFailureException( + "Discovered a failure", pf, + {"longest_failed_refinement": longest_failed_refinement}) + + def on_exhausted(_plan: List[Optional[_Option]]) -> None: + for pf in discovered_failures: + if pf is not None and \ + CFG.sesame_propagate_failures == "after_exhaust": raise _DiscoveredFailureException( - "Discovered a failure", possible_failure, + "Discovered a failure", pf, {"longest_failed_refinement": longest_failed_refinement}) - # Decrement cur_idx to re-do the step we just did. If num_tries - # is exhausted, backtrack. - cur_idx -= 1 - assert cur_idx >= 0 - while num_tries[cur_idx] == max_tries[cur_idx]: - num_tries[cur_idx] = 0 - plan[cur_idx] = DummyOption - num_actions_per_option[cur_idx] = 0 - traj[cur_idx + 1] = DefaultState - cur_idx -= 1 - if cur_idx < 0: - # Backtracking exhausted. If we're only propagating failures - # after exhaustion, and if there are any failures, - # propagate up the EARLIEST one so that high-level search - # restarts. Otherwise, return a partial refinement so that - # high-level search continues. - for possible_failure in discovered_failures: - if possible_failure is not None and \ - CFG.sesame_propagate_failures == "after_exhaust": - raise _DiscoveredFailureException( - "Discovered a failure", possible_failure, { - "longest_failed_refinement": - longest_failed_refinement - }) - return longest_failed_refinement, False - logging.debug("Option succeed!") - # Should only get here if the skeleton was empty. - assert not skeleton - return [], True + + # -- run -------------------------------------------------------------- + + plan, success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + on_env_failure=on_env_failure, + on_step_fail=on_step_fail, + on_exhausted=on_exhausted, + step_times=refinement_time, + ) + + if success: + return [cast(_Option, p) for p in plan], True + return longest_failed_refinement, False def _update_nsrts_with_failure( From e7eaf058f695500e62eb7b5842f4957b791e8f7d Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 18:53:33 +0100 Subject: [PATCH 03/70] Simplify _validate_plan_forward to use run_backtracking_refinement Replace 60 lines of manual option-model execution with a call to run_backtracking_refinement using max_tries=[1] and a sample_fn that returns the pre-grounded options. Remove unused Any import. --- .../approaches/agent_bilevel_approach.py | 77 +++++++------------ 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 368d83c10..98b8c1df8 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -16,7 +16,7 @@ import logging import re import time -from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, cast +from typing import Callable, List, Optional, Sequence, Set, Tuple, cast import numpy as np @@ -501,62 +501,37 @@ def _validate_plan_forward( task: Task, plan: List[_Option], ) -> bool: - """Re-execute the plan continuously in the option model's env. + """Re-execute the plan continuously in the option model. - Unlike refinement (which resets state between steps via - ``_reset_state``), this runs all options sequentially so that the - physics state carries forward naturally — matching how the main - env will execute during the real episode. + Runs all options sequentially so that state carries forward + naturally — matching how the real env will execute. Returns True if the plan reaches the goal, False otherwise. """ - state = task.init - predicates = self._get_all_predicates() - total_actions = 0 + n = len(plan) + if n == 0: + return task.goal_holds(task.init) - for i, grounded in enumerate(plan): - if not grounded.initiable(state): - logging.info(f"Forward validation: step {i} " - f"({grounded.name}) not initiable.") - return False + def sample_fn(i: int, _s: State, _r: np.random.Generator) -> _Option: + return plan[i] - try: - next_state, num_actions = \ - self._option_model.get_next_state_and_num_actions( - state, grounded) - except utils.EnvironmentFailure as e: - logging.info(f"Forward validation: step {i} " - f"({grounded.name}) failed: {e}") - return False - - if num_actions == 0: - reason = cast(Any, self._option_model) \ - .last_execution_failure or \ - "produced 0 actions" - logging.info(f"Forward validation: step {i} " - f"({grounded.name}) failed: {reason}") - return False - - total_actions += num_actions - state = next_state - atoms = utils.abstract(state, predicates) - logging.debug(f"Forward validation: step {i} " - f"({grounded.name}) OK, {num_actions} actions. " - f"Atoms: {sorted(str(a) for a in atoms)}") - - if not task.goal_holds(state): - atoms = utils.abstract(state, predicates) - goal_atoms = task.goal - missing = goal_atoms - atoms - logging.info( - f"Forward validation: goal not reached. " - f"Missing: {{{', '.join(str(a) for a in sorted(missing))}}}. " - f"State:\n{state.pretty_str()}") - return False - - logging.info(f"Forward validation succeeded: {total_actions} " - f"actions from {len(plan)} steps.") - return True + def validate_fn(i: int, _s: State, _o: _Option, post: State, + _n: int) -> Tuple[bool, str]: + if i == n - 1 and not task.goal_holds(post): + return False, "goal not reached" + return True, "" + + _, success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=self._option_model, + n_steps=n, + max_tries=[1] * n, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=np.random.default_rng(0), + timeout=float('inf'), + ) + return success # ------------------------------------------------------------------ # # Helpers From 43acbf2dd86c68a154df4276c953779ec4b1ea5f Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 2 Apr 2026 11:17:20 +0100 Subject: [PATCH 04/70] Refactor _current_observation/_current_state usage in pybullet_env Move the _current_observation assignment into _reset_state so callers don't need to remember the two-step pattern. Clarify the relationship between _current_observation (backing field) and _current_state (typed read accessor) in docstrings and comments. --- predicators/envs/base_env.py | 10 +++++++++- predicators/envs/pybullet_env.py | 32 +++++++++++++++----------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/predicators/envs/base_env.py b/predicators/envs/base_env.py index f62ce3e30..a88eae29e 100644 --- a/predicators/envs/base_env.py +++ b/predicators/envs/base_env.py @@ -198,7 +198,15 @@ def get_test_tasks(self) -> List[EnvironmentTask]: @property def _current_state(self) -> State: - """Default for environments where states are observations.""" + """Typed accessor for _current_observation when it is a State. + + _current_observation is the raw Observation (which may not be a + State in vision-based envs). _current_state provides a + convenience accessor with a type assertion for the common case + where observations are States. Use _current_observation for + assignment (it is the backing field); use _current_state for + reads when you need a State. + """ assert isinstance(self._current_observation, State) return self._current_observation diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index b07e31b39..e228b1712 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -340,13 +340,12 @@ def action_space(self) -> Box: return self._pybullet_robot.action_space def simulate(self, state: State, action: Action) -> State: - # Optimization: check if we're already in the right state. - # self._current_observation is None at the beginning - # state is not allclose to self._current_state when the state has been - # updated, so it first calls _reset_state to update the pybullet state + # Optimization: skip _reset_state if pybullet is already in this state. + # _current_observation is None before the first reset() call. + # Check it (not _current_state) because _current_state would fail + # its type assertion on None. if self._current_observation is None or \ not state.allclose(self._current_state): - self._current_observation = state self._reset_state(state) return self.step(action) @@ -381,6 +380,9 @@ def _reset_state(self, state: State) -> None: Used in initialization (reset(), _add_pybullet_state_to_tasks()) and bilevel planning (when creating the option model)). """ + # Keep _current_observation in sync so that step() can read it + # (e.g. for finger-delta computation). + self._current_observation = state self._objects = list(state.data) # 1) Clear old constraint if we had a held object if self._held_constraint_id is not None: @@ -694,17 +696,18 @@ def render_segmented_obj( def get_observation(self, render: bool = False) -> Observation: """Get the current observation of this environment. - Currently, this just return a copy of the state and optionally a - rendered image. + Reads the current state from pybullet, updates _current_observation + (the backing field), and returns a copy optionally with rendered images. """ - self._current_observation = self._get_state() - assert isinstance(self._current_observation, PyBulletState) - state_copy = self._current_observation.copy() + state = self._get_state() + assert isinstance(state, PyBulletState) + self._current_observation = state + obs = state.copy() if render: - state_copy.add_images_and_masks(*self.render_segmented_obj()) + obs.add_images_and_masks(*self.render_segmented_obj()) - return state_copy + return obs def step(self, action: Action, render_obs: bool = False) -> Observation: """Execute one environment step with the given action. @@ -926,11 +929,6 @@ def _add_pybullet_state_to_tasks( for task in tasks: # Reset the robot. init = task.init - # Extract the joints. - # YC: Probably need to reset_state here so I can then get an - # observation, would it work without the reset_state? - # Attempt 2: First reset it. - self._current_observation = init self._reset_state(init) # Cast _current_observation from type State to PybulletState joint_positions = self._pybullet_robot.get_joints() From 57ef4b8b80ed2e6c5759201542af98a2e5947a54 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 2 Apr 2026 12:26:19 +0100 Subject: [PATCH 05/70] Add CFG option to load plan sketch from file instead of LLM Adds agent_bilevel_plan_sketch_file setting that, when set to a file path, loads the plan sketch directly from that file, bypassing the foundation model query. Includes test data files and a unit test. --- .../approaches/agent_bilevel_approach.py | 18 +++++++----- predicators/settings.py | 1 + .../predicatorv3/approaches/agents.yaml | 1 + .../approaches/test_agent_bilevel_approach.py | 29 +++++++++++++++++++ .../approaches/test_data/boil_plan_sketch.txt | 10 +++++++ .../test_data/simple_plan_sketch.txt | 2 ++ 6 files changed, 53 insertions(+), 8 deletions(-) create mode 100644 tests/approaches/test_data/boil_plan_sketch.txt create mode 100644 tests/approaches/test_data/simple_plan_sketch.txt diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 98b8c1df8..de60d98d4 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -261,16 +261,18 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: """Query agent for a plan sketch and parse it.""" - prompt = self._build_solve_prompt(task) - responses = self._query_agent_sync(prompt) - plan_text = self._extract_option_plan_text(responses) + sketch_file = CFG.agent_bilevel_plan_sketch_file + if sketch_file: + with open(sketch_file, "r") as f: + plan_text = f.read().strip() + logging.info("Loaded plan sketch from file: %s", sketch_file) + else: + prompt = self._build_solve_prompt(task) + responses = self._query_agent_sync(prompt) + plan_text = self._extract_option_plan_text(responses) if not plan_text: - n_responses = len(responses) - types = [r.get("type") for r in responses] - raise ApproachFailure( - f"Agent returned empty plan text. " - f"Got {n_responses} responses with types: {types}") + raise ApproachFailure("Agent returned empty plan text.") cleaned_text = self._strip_code_fences(plan_text) diff --git a/predicators/settings.py b/predicators/settings.py index caefb43be..22bee6d3d 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1015,6 +1015,7 @@ class GlobalSettings: agent_bilevel_check_subgoals = True # check subgoal atoms after each step # log state pretty_str before/after each step agent_bilevel_log_state = False + agent_bilevel_plan_sketch_file = "" # load sketch from file instead of LLM @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/scripts/configs/predicatorv3/approaches/agents.yaml b/scripts/configs/predicatorv3/approaches/agents.yaml index c43ca6125..9e9d82d8a 100644 --- a/scripts/configs/predicatorv3/approaches/agents.yaml +++ b/scripts/configs/predicatorv3/approaches/agents.yaml @@ -26,6 +26,7 @@ APPROACHES: agent_planner_use_annotate_scene: True option_model_use_gui: True agent_bilevel_log_state: False + agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 4d399883d..57808f594 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -1,5 +1,6 @@ """Tests for AgentBilevelApproach -- parsing and refinement logic.""" # pylint: disable=protected-access,import-outside-toplevel +import os from unittest.mock import MagicMock, patch import numpy as np @@ -12,6 +13,8 @@ from predicators.structs import Action, GroundAtom, Object, \ ParameterizedOption, Predicate, State, Task, Type +_TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data") + # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @@ -804,6 +807,32 @@ def test_no_valid_options_raises(self): with pytest.raises(ApproachFailure, match="Parsed empty"): approach._query_agent_for_plan_sketch(task) + def test_sketch_from_file(self): + """Load sketch from a saved text file via CFG option.""" + approach, _, task = _make_approach() + sketch_path = os.path.join(_TEST_DATA_DIR, "simple_plan_sketch.txt") + + utils.reset_config({ + "env": "cover", + "approach": "agent_bilevel", + "num_train_tasks": 1, + "num_test_tasks": 1, + "seed": 42, + "agent_bilevel_plan_sketch_file": sketch_path, + }) + + sketch = approach._query_agent_for_plan_sketch(task) + + assert len(sketch) == 2 + assert sketch[0].option.name == "Pick" + assert list(sketch[0].objects) == [_block0] + assert sketch[0].subgoal_atoms is not None + assert GroundAtom(_Holding, [_block0]) in sketch[0].subgoal_atoms + assert sketch[1].option.name == "Place" + assert list(sketch[1].objects) == [_block0, _block1] + assert sketch[1].subgoal_atoms is not None + assert GroundAtom(_On, [_block0, _block1]) in sketch[1].subgoal_atoms + # --------------------------------------------------------------------------- # Tests: _sample_params diff --git a/tests/approaches/test_data/boil_plan_sketch.txt b/tests/approaches/test_data/boil_plan_sketch.txt new file mode 100644 index 000000000..8c8e9f828 --- /dev/null +++ b/tests/approaches/test_data/boil_plan_sketch.txt @@ -0,0 +1,10 @@ +PickJug(robot:robot, jug0:jug) -> {Holding(robot:robot, jug0:jug)} +Place(robot:robot) -> {JugAtFaucet(jug0:jug, faucet:faucet), NoJugAtFaucetOrAtFaucetAndFilled(jug0:jug, faucet:faucet)} +SwitchFaucetOn(robot:robot, faucet:faucet) -> {FaucetOn(faucet:faucet)} +Wait(robot:robot) -> {JugFilled(jug0:jug)} +SwitchFaucetOff(robot:robot, faucet:faucet) -> {FaucetOff(faucet:faucet)} +PickJug(robot:robot, jug0:jug) -> {Holding(robot:robot, jug0:jug)} +Place(robot:robot) -> {JugAtBurner(jug0:jug, burner0:burner)} +SwitchBurnerOn(robot:robot, burner0:burner) -> {BurnerOn(burner0:burner)} +Wait(robot:robot) -> {WaterBoiled(jug0:jug)} +SwitchBurnerOff(robot:robot, burner0:burner) -> {BurnerOff(burner0:burner)} diff --git a/tests/approaches/test_data/simple_plan_sketch.txt b/tests/approaches/test_data/simple_plan_sketch.txt new file mode 100644 index 000000000..c14ff2dd5 --- /dev/null +++ b/tests/approaches/test_data/simple_plan_sketch.txt @@ -0,0 +1,2 @@ +Pick(block0:block) -> {Holding(block0:block)} +Place(block0:block, block1:block) -> {On(block0:block, block1:block)} From d0ac199c55af1bb37f426378f0d783d143d95c32 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 2 Apr 2026 13:07:36 +0100 Subject: [PATCH 06/70] Remove redundant conditions from Place action in boil_plan_sketch --- tests/approaches/test_data/boil_plan_sketch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/approaches/test_data/boil_plan_sketch.txt b/tests/approaches/test_data/boil_plan_sketch.txt index 8c8e9f828..3553f3af8 100644 --- a/tests/approaches/test_data/boil_plan_sketch.txt +++ b/tests/approaches/test_data/boil_plan_sketch.txt @@ -1,5 +1,5 @@ PickJug(robot:robot, jug0:jug) -> {Holding(robot:robot, jug0:jug)} -Place(robot:robot) -> {JugAtFaucet(jug0:jug, faucet:faucet), NoJugAtFaucetOrAtFaucetAndFilled(jug0:jug, faucet:faucet)} +Place(robot:robot) -> {JugAtFaucet(jug0:jug, faucet:faucet)} SwitchFaucetOn(robot:robot, faucet:faucet) -> {FaucetOn(faucet:faucet)} Wait(robot:robot) -> {JugFilled(jug0:jug)} SwitchFaucetOff(robot:robot, faucet:faucet) -> {FaucetOff(faucet:faucet)} From 0cafcd8f180abad5b31256b61ddf85d5ef99ad04 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 2 Apr 2026 13:20:06 +0100 Subject: [PATCH 07/70] Scale target joint value based on switch_joint_scale in PyBulletBoilEnv --- predicators/envs/pybullet_boil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 76561aabb..e7af53342 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -927,7 +927,7 @@ def _set_switch_on(self, switch_id: int, power_on: bool) -> None: j_id, physicsClientId=self._physics_client_id) j_min, j_max = info[8], info[9] - target_val = j_max if power_on else j_min + target_val = (j_max if power_on else j_min) * self.switch_joint_scale p.resetJointState(switch_id, j_id, target_val, From 3808337144c712b6728fc4bbab97ab51af5183ae Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 7 Apr 2026 12:11:00 +0100 Subject: [PATCH 08/70] Refactor _terminal in option model to deduplicate wait-termination logic Extract repeated wait-termination check into _check_wait_termination helper and unify the three _terminal branches into a single definition with config checks inside the function body. --- predicators/option_model.py | 112 +++++++++++++++--------------------- 1 file changed, 45 insertions(+), 67 deletions(-) diff --git a/predicators/option_model.py b/predicators/option_model.py index 9af23cd51..93cba43d4 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -20,6 +20,27 @@ ParameterizedOption, State, _Option +def _check_wait_termination(option: _Option, state: State, + last_state: State, + abstract_fn: Callable[[State], Set]) -> bool: + """Check if a Wait option should terminate based on target atoms or atom + change. Returns True if it should terminate.""" + result = utils.check_wait_target_atoms(option, state, abstract_fn) + if result is True: + logging.info("Wait terminating: target atoms satisfied") + return True + if result is None: + cur_atoms = abstract_fn(state) + prev_atoms = abstract_fn(last_state) + if cur_atoms != prev_atoms: + logging.info( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + return True + return False + + def create_option_model(name: str, use_gui: Optional[bool] = None) -> _OptionModelBase: """Create an option model given its name. @@ -115,78 +136,35 @@ def get_next_state_and_num_actions(self, state: State, # if it does. This is a helpful optimization for planning with # fine-grained options over long horizons. # Note: mypy complains if this is None instead of DefaultState. - if CFG.option_model_terminate_on_repeat: - last_state = DefaultState + last_state = DefaultState - def _terminal(s: State) -> bool: - nonlocal last_state - if option_copy.terminal(s): + def _terminal(s: State) -> bool: + nonlocal last_state + if option_copy.terminal(s): + if CFG.option_model_terminate_on_repeat: logging.debug("Option reached terminal state.") - return True - if last_state is not DefaultState and last_state.allclose(s): - logging.debug("Option got stuck.") - raise utils.OptionExecutionFailure( - f"Option '{option_copy.name}' got stuck: the " - f"policy's action did not change the state. " - f"This usually means the first motion phase " - f"produced a no-op (e.g. IK returned current " - f"joints, or finger command matched current " - f"finger state).") - # Terminate Wait on target atoms or any atom change. - if (CFG.wait_option_terminate_on_atom_change - and option_copy.name == "Wait" - and last_state is not DefaultState - and self._abstract_function is not None): - result = utils.check_wait_target_atoms( - option_copy, s, self._abstract_function) - if result is True: - logging.info( - "Wait terminating: target atoms satisfied") - last_state = s - return True - if result is None: - cur_atoms = self._abstract_function(s) - prev_atoms = self._abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") - last_state = s - return True - last_state = s - return False - else: + return True + if (CFG.option_model_terminate_on_repeat + and last_state is not DefaultState + and last_state.allclose(s)): + logging.debug("Option got stuck.") + raise utils.OptionExecutionFailure( + f"Option '{option_copy.name}' got stuck: the " + f"policy's action did not change the state. " + f"This usually means the first motion phase " + f"produced a no-op (e.g. IK returned current " + f"joints, or finger command matched current " + f"finger state).") if (CFG.wait_option_terminate_on_atom_change and option_copy.name == "Wait" + and last_state is not DefaultState and self._abstract_function is not None): - last_state_ref = [DefaultState] - abstract_fn = self._abstract_function - - def _terminal(s: State) -> bool: - if option_copy.terminal(s): - return True - if last_state_ref[0] is not DefaultState: - result = utils.check_wait_target_atoms( - option_copy, s, abstract_fn) - if result is True: - logging.info( - "Wait terminating: target atoms satisfied") - return True - if result is None: - cur_atoms = abstract_fn(s) - prev_atoms = abstract_fn(last_state_ref[0]) - if cur_atoms != prev_atoms: - logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") - return True - last_state_ref[0] = s - return False - else: - # mypy complains without the lambda, pylint complains with it! - _terminal = lambda s: option_copy.terminal(s) # pylint: disable=unnecessary-lambda + if _check_wait_termination(option_copy, s, last_state, + self._abstract_function): + last_state = s + return True + last_state = s + return False try: traj = utils.run_policy_with_simulator( From 3624d01225a3d2add5fa457173796cc252649976 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 7 Apr 2026 12:43:53 +0100 Subject: [PATCH 09/70] Refactor terminal state logging in _OracleOptionModel to simplify condition checks --- predicators/option_model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/predicators/option_model.py b/predicators/option_model.py index 93cba43d4..1a3826efb 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -141,8 +141,7 @@ def get_next_state_and_num_actions(self, state: State, def _terminal(s: State) -> bool: nonlocal last_state if option_copy.terminal(s): - if CFG.option_model_terminate_on_repeat: - logging.debug("Option reached terminal state.") + logging.debug("Option reached terminal state.") return True if (CFG.option_model_terminate_on_repeat and last_state is not DefaultState @@ -158,11 +157,11 @@ def _terminal(s: State) -> bool: if (CFG.wait_option_terminate_on_atom_change and option_copy.name == "Wait" and last_state is not DefaultState - and self._abstract_function is not None): - if _check_wait_termination(option_copy, s, last_state, - self._abstract_function): - last_state = s - return True + and self._abstract_function is not None + and _check_wait_termination(option_copy, s, last_state, + self._abstract_function)): + logging.debug("Wait option terminating early.") + return True last_state = s return False From 80c81101f6225d1770b722ea0e6b390ac2fa1da9 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 7 Apr 2026 12:54:32 +0100 Subject: [PATCH 10/70] Format docstring in get_observation method for improved readability --- predicators/envs/pybullet_env.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index e228b1712..1578f3bed 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -696,8 +696,9 @@ def render_segmented_obj( def get_observation(self, render: bool = False) -> Observation: """Get the current observation of this environment. - Reads the current state from pybullet, updates _current_observation - (the backing field), and returns a copy optionally with rendered images. + Reads the current state from pybullet, updates + _current_observation (the backing field), and returns a copy + optionally with rendered images. """ state = self._get_state() assert isinstance(state, PyBulletState) From d3ad2095eeb1076cb22e655e5bba118cff7c1d6d Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 7 Apr 2026 13:49:43 +0100 Subject: [PATCH 11/70] Refactor PyBulletEnv for readability and better naming - Remove dead/commented-out code and stale self-question comments - Add _VIRTUAL_OBJECT_TYPES constant to replace hardcoded type-name skip lists in _set_state and _get_state - Move env-specific _get_robot_state_dict branches to subclass overrides in pybullet_cover and pybullet_blocks - Extract _get_camera_matrices helper to deduplicate render methods - Extract _get_object_state_dict from _get_state for per-object logic - Move create_pybullet_block/sphere to pybullet_helpers/objects.py - Merge _create_task_specific_objects into _set_domain_specific_state - Rename: _reset_state -> _set_state, _reset_custom_env_state -> _set_domain_specific_state, _extract_feature -> _get_domain_specific_feature - Add docstrings explaining where each method is called from --- predicators/agent_sdk/tools.py | 4 +- predicators/envs/mara_adapter.py | 4 +- predicators/envs/pybullet_ants.py | 13 +- predicators/envs/pybullet_balance.py | 16 +- predicators/envs/pybullet_barrier.py | 14 +- predicators/envs/pybullet_blocks.py | 24 +- predicators/envs/pybullet_boil.py | 46 +- predicators/envs/pybullet_circuit.py | 9 +- predicators/envs/pybullet_coffee.py | 29 +- predicators/envs/pybullet_cover.py | 37 +- .../components/ball_component.py | 5 +- .../components/domino_component.py | 4 +- .../components/stairs_component.py | 2 +- .../envs/pybullet_domino/composed_env.py | 9 +- predicators/envs/pybullet_env.py | 624 ++++++++---------- predicators/envs/pybullet_fan.py | 17 +- predicators/envs/pybullet_float.py | 13 +- predicators/envs/pybullet_grow.py | 16 +- predicators/envs/pybullet_laser.py | 9 +- predicators/envs/pybullet_magic_bin.py | 14 +- predicators/envs/pybullet_switch.py | 9 +- .../ground_truth_models/boil/options.py | 2 +- .../skill_factories/base.py | 2 +- predicators/pybullet_helpers/objects.py | 105 +++ scripts/run_blocks_perception.py | 2 +- tests/envs/test_pybullet_blocks.py | 2 +- tests/envs/test_pybullet_cover.py | 2 +- .../pybullet_helpers/test_motion_planning.py | 2 +- tests/test_skill_factories_integration.py | 2 +- 29 files changed, 500 insertions(+), 537 deletions(-) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 01ea16fc3..bb5f98c32 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -166,7 +166,7 @@ def _render_pybullet_image( from PIL import Image as PILImage if state is not None: - ctx.env._reset_state(state) # pylint: disable=protected-access + ctx.env._set_state(state) # pylint: disable=protected-access video = ctx.env.render() if not video: @@ -1767,7 +1767,7 @@ async def annotate_scene(args: Dict[str, Any]) -> Dict[str, Any]: render_state = ctx.visualized_state or (ctx.current_task.init if ctx.current_task else None) if render_state is not None: - ctx.env._reset_state(render_state) # pylint: disable=protected-access + ctx.env._set_state(render_state) # pylint: disable=protected-access physics_id = ctx.env._physics_client_id # pylint: disable=protected-access annotations = args.get("annotations", []) diff --git a/predicators/envs/mara_adapter.py b/predicators/envs/mara_adapter.py index 047a91502..a861923fb 100644 --- a/predicators/envs/mara_adapter.py +++ b/predicators/envs/mara_adapter.py @@ -361,7 +361,7 @@ def reset(self, train_or_test: str, task_idx: int) -> PredState: return self._current_observation.copy() def step(self, action: PredAction) -> PredState: - """Step the mara env directly, avoiding a full _reset_state.""" + """Step the mara env directly, avoiding a full _set_state.""" from mara_robosim.structs import Action as MaraAction mara_obs = self._mara_env.step(MaraAction(action.arr)) @@ -375,7 +375,7 @@ def simulate(self, state: PredState, action: PredAction) -> PredState: # Reset PyBullet from the feature vectors, then get a proper # PyBulletState observation before stepping. # pylint: disable=protected-access - self._mara_env._reset_state(mara_state) + self._mara_env._set_state(mara_state) self._mara_env._current_observation = ( self._mara_env.get_observation()) # pylint: enable=protected-access diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index 4d22d6b07..35d4f82f5 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -5,9 +5,9 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.objects import create_object, \ - sample_collision_free_2d_positions, update_object + create_pybullet_block, sample_collision_free_2d_positions, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -215,10 +215,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: # If we support robot picking up food blocks, return those IDs. return [f.id for f in self._blocks] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._food_type: if feature == "attractive": @@ -229,7 +226,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: if CFG.ants_ants_attracted_to_points: self._ant_to_xy = {} # type: ignore[no-redef] @@ -533,7 +530,7 @@ def _make_tasks( # pylint: disable=redefined-outer-name env = PyBulletAntsEnv(use_gui=True) rng = np.random.default_rng(CFG.seed) task = env._make_tasks(1, rng)[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 6b69ee4ad..76b4e6586 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -15,8 +15,9 @@ import numpy as np import pybullet as p -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, Array, ConceptPredicate, \ @@ -320,10 +321,7 @@ def get_name(cls) -> str: # ------------------------------------------------------------------------- # State Management: Get, (Re)Set, Step - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._block_type: visual_data = p.getVisualShapeData( @@ -368,10 +366,10 @@ def step( # pylint: disable=redefined-outer-name return state - def _reset_custom_env_state(self, state: State) -> None: - """Replace the old `_reset_state` environment-specific logic. + def _set_domain_specific_state(self, state: State) -> None: + """Replace the old `_set_state` environment-specific logic. - The base `_reset_state` has already handled standard features + The base `_set_state` has already handled standard features for objects that appear in _get_all_objects(), so here we just do custom domain-specific tasks: setting plates/blocks if we aren't letting the base class handle them, updating button @@ -961,7 +959,7 @@ def _table_xy_is_clear(self, x: float, y: float, CFG.num_test_tasks = 1 env = PyBulletBalanceEnv(use_gui=True) task = env._generate_test_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_barrier.py b/predicators/envs/pybullet_barrier.py index c1a7f3132..8041c6dd7 100644 --- a/predicators/envs/pybullet_barrier.py +++ b/predicators/envs/pybullet_barrier.py @@ -15,9 +15,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -217,7 +218,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (none in this env).""" return [] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._switch_type and feature == "is_on": return float(self._is_switch_on(obj)) @@ -229,10 +230,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return current_z - obj.base_z raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment state from a State object.""" # Set switch states and positions for switch in self._switches: @@ -474,7 +472,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletBarrierEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("PyBullet Barrier Environment Test") print("Barriers should animate when switches are toggled.") diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index 0aaa0afe2..b0abf0e24 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -9,8 +9,9 @@ from predicators import utils from predicators.envs.blocks import BlocksEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, Object, State @@ -93,11 +94,8 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: for blk, blk_id in zip(self._blocks, self._block_ids): blk.id = blk_id - def _create_task_specific_objects(self, state: State) -> None: - """No additional environment assets needed per-task.""" - - def _reset_custom_env_state(self, state: State) -> None: - """After the parent `_reset_state()` has reset the robot, set the block + def _set_domain_specific_state(self, state: State) -> None: + """After the parent `_set_state()` has reset the robot, set the block positions/colors and handle constraints for any 'held' block.""" block_objs = state.get_objects(self._block_type) self._block_id_to_block.clear() @@ -141,7 +139,7 @@ def _reset_custom_env_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Called by the parent class when constructing the `PyBulletState`. We read off the relevant block or robot features from PyBullet. @@ -233,6 +231,16 @@ def _extract_robot_state(self, state: State) -> np.ndarray: qx, qy, qz, qw = self.get_robot_ee_home_orn() return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) + def _get_robot_state_dict(self) -> Dict[str, float]: + rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() + fingers = self._fingers_joint_to_state(self._pybullet_robot, rf) + return { + "pose_x": rx, + "pose_y": ry, + "pose_z": rz, + "fingers": fingers, + } + def _get_object_ids_for_held_check(self) -> List[int]: """Return the IDs of blocks for which we might be checking 'held' contact.""" @@ -272,7 +280,7 @@ def _force_grasp_object(self, block: Object) -> None: """Manually create a fixed constraint for a block that is marked 'held' in the State. - Called from _reset_custom_env_state(). + Called from _set_domain_specific_state(). """ # Find block's pybullet ID block_id = None diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index e7af53342..c1485c53e 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -9,9 +9,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, DerivedPredicate, EnvironmentTask, \ @@ -212,6 +213,10 @@ def __init__(self, use_gui: bool = False) -> None: # Keep track of the spilled water block (None if no spill yet) self._spilled_water_id: Optional[int] = None + # When True, step() skips process dynamics (water filling, heating, + # happiness) so that a learned simulator can provide them instead. + self._skip_process_dynamics: bool = False + super().__init__(use_gui) # Optionally, define some relevant predicates @@ -491,11 +496,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: jug_ids = [j.id for j in self._jugs if j.id is not None] return jug_ids - def _create_task_specific_objects(self, state: State) -> None: - """If you wanted additional objects depending on a given state, add - them here.""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Map from environment object + feature name -> a float feature in the State.""" # Faucet @@ -558,8 +559,8 @@ def _extract_feature(self, obj: Object, feature: str) -> float: # Otherwise, rely on defaults (like the base PyBulletEnv) for x,y,z,... raise ValueError(f"Unknown feature {feature} for object {obj}.") - def _reset_custom_env_state(self, state: State) -> None: - """Called in _reset_state to do any environment-specific resetting. + def _set_domain_specific_state(self, state: State) -> None: + """Called in _set_state to do any environment-specific resetting. This environment only supports resetting the state at the beginning, because the state dict doesn't include all features @@ -654,23 +655,24 @@ def step(self, action: Action, render_obs: bool = False) -> State: # First let the base environment perform the usual PyBullet step next_state = super().step(action, render_obs=False) - # 1) Handle faucet filling/spillage - self._handle_faucet_logic(next_state) + if not self._skip_process_dynamics: + # 1) Handle faucet filling/spillage + self._handle_faucet_logic(next_state) - # 2) Handle burner heating - self._handle_heating_logic(next_state) + # 2) Handle burner heating + self._handle_heating_logic(next_state) - # 3) Update jug colors based on their 'heat' - self._update_jug_colors(next_state) + # 3) Update jug colors based on their 'heat' + self._update_jug_colors(next_state) - # 4) Update burner colors based on their on/off state - self._update_burner_colors(next_state) + # 4) Update burner colors based on their on/off state + self._update_burner_colors(next_state) - # 5) Update the human's happiness level - self._update_human_happiness(next_state) + # 5) Update the human's happiness level + self._update_human_happiness(next_state) - # 6) Update prev_on states for next step - self._update_prev_on_states(next_state) + # 6) Update prev_on states for next step + self._update_prev_on_states(next_state) # Re-read final state final_state = self.get_observation(render=render_obs) @@ -1445,7 +1447,7 @@ def _main() -> None: # pylint: disable=too-many-locals burner1, faucet) for task in tasks: - env._reset_state(task.init) + env._set_state(task.init) for _ in range(20000): action = Action( np.array(env._pybullet_robot.initial_joint_positions)) diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index 6c1f414cc..35c3dd695 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -297,7 +297,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of wires (assuming the robot can pick them up).""" return [self._wire1.id, self._wire2.id] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._light_type and feature == "is_on": return int(self._is_bulb_on(obj.id)) @@ -305,10 +305,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return int(self._is_switch_on()) raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: is_light_on = state.get(self._light, "is_on") if is_light_on: @@ -775,7 +772,7 @@ def _main() -> None: CFG.num_train_tasks = 1 env = PyBulletCircuitEnv(use_gui=True) task = env._generate_train_tasks()[0] - env._reset_state(task.init) + env._set_state(task.init) while True: action = Action( diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 73f429322..a447996bb 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -315,14 +315,6 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def get_name(cls) -> str: return "pybullet_coffee" - def _create_task_specific_objects(self, state: State) -> None: - """Remove/rebuild cups, liquids, and cords so each new task can have - different cups and states.""" - self._remake_jug_liquid(state) - self._remake_cup_liquids(state) - self._remake_cups(state) - self._remake_cord() - def _remake_cups(self, state: State) -> None: """Re-load cup URDFs with appropriate scaling and color for each new cup.""" @@ -403,14 +395,17 @@ def _remake_cord(self) -> None: self._physics_client_id) self._plug.id = self._cord_ids[-1] - def _reset_custom_env_state(self, state: State) -> None: - """Handles extra coffee-specific reset steps: spawning cups from - scratch, adding liquid visuals, adjusting jug fill color, toggling the - machine button, etc. - - The base `_reset_state` has already done the standard - position/orientation resets for objects in `_get_all_objects()`. + def _set_domain_specific_state(self, state: State) -> None: + """Coffee-specific state setup: rebuild task-specific objects + (cups, liquids, cords), then set visual state (button color, + liquid fills, etc.). """ + # Rebuild objects that vary per task + self._remake_jug_liquid(state) + self._remake_cup_liquids(state) + self._remake_cups(state) + self._remake_cord() + # Machine button color # Check if the machine is on and the jug is in place: if self._MachineOn_holds(state, [self._machine]) and \ @@ -439,7 +434,7 @@ def _reset_custom_env_state(self, state: State) -> None: rgbaColor=plate_color, physicsClientId=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._jug_type: if feature == "is_filled": @@ -1275,7 +1270,7 @@ def _main() -> None: env = PyBulletCoffeeEnv(use_gui=True) rng = np.random.default_rng(CFG.seed) task = env._make_tasks(1, rng)[0] # type: ignore[attr-defined] # pylint: disable=no-member - env._reset_state(task.init) + env._set_state(task.init) while True: # Robot does nothing diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index 24ea5d5d0..32f680bcf 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -13,9 +13,10 @@ from predicators import utils from predicators.envs.cover import CoverEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import update_object +from predicators.pybullet_helpers.objects import create_pybullet_block, \ + update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, Object, State @@ -61,7 +62,7 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): def __init__(self, use_gui: bool = False) -> None: super().__init__(use_gui) # Store block/target IDs (from initialize_pybullet) so that we can - # reset their positions in _reset_custom_env_state(). + # reset their positions in _set_domain_specific_state(). self._table_id: int = -1 # self._block_ids: list[int] = [] # self._target_ids: list[int] = [] @@ -151,10 +152,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: for tgt, tgt_id in zip(self._targets, pybullet_bodies["target_ids"]): tgt.id = tgt_id - def _create_task_specific_objects(self, state: State) -> None: - """No domain-specific extra creation needed here.""" - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """After the parent class has reset the robot, handle the block/target positions. @@ -299,24 +297,13 @@ def _extract_robot_state(self, state: State) -> np.ndarray: return np.array([rx, ry, rz, qx, qy, qz, qw, fingers], dtype=np.float32) - def _extract_feature(self, obj: Object, feature: str) -> float: - """Domain-specific feature extraction for blocks, targets, and the - (robot).""" - # # 1) If it's the robot - # if obj.type == self._robot_type: - # # The parent's _get_robot_state_dict() will set x,y,z,fingers - # # We can handle additional features here: - # rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - # if feature == "hand": - # # Re-normalize the y coordinate - # return (ry - self.y_lb) / (self.y_ub - self.y_lb) - # elif feature == "pose_x": - # return rx - # elif feature == "pose_z": - # return rz - # raise ValueError(f"Unknown robot feature: {feature}") - - # 2) If it's a block + def _get_robot_state_dict(self) -> Dict[str, float]: + rx, ry, rz, _, _, _, _, _rf = self._pybullet_robot.get_state() + hand = (ry - self.y_lb) / (self.y_ub - self.y_lb) + return {"hand": hand, "pose_x": rx, "pose_z": rz} + + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: + """Domain-specific feature extraction for blocks and targets.""" if obj.type == self._block_type: block_id = obj.id if feature == "is_block": diff --git a/predicators/envs/pybullet_domino/components/ball_component.py b/predicators/envs/pybullet_domino/components/ball_component.py index a3fccb2d4..9d0e44677 100644 --- a/predicators/envs/pybullet_domino/components/ball_component.py +++ b/predicators/envs/pybullet_domino/components/ball_component.py @@ -14,9 +14,8 @@ from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block, \ - create_pybullet_sphere -from predicators.pybullet_helpers.objects import update_object +from predicators.pybullet_helpers.objects import create_pybullet_block, \ + create_pybullet_sphere, update_object from predicators.settings import CFG from predicators.structs import Object, Predicate, State, Type diff --git a/predicators/envs/pybullet_domino/components/domino_component.py b/predicators/envs/pybullet_domino/components/domino_component.py index 54d9cca85..8375ffba3 100644 --- a/predicators/envs/pybullet_domino/components/domino_component.py +++ b/predicators/envs/pybullet_domino/components/domino_component.py @@ -18,9 +18,9 @@ from predicators import utils from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.settings import CFG from predicators.structs import Object, Predicate, State, Type diff --git a/predicators/envs/pybullet_domino/components/stairs_component.py b/predicators/envs/pybullet_domino/components/stairs_component.py index ff966467c..24e32cc00 100644 --- a/predicators/envs/pybullet_domino/components/stairs_component.py +++ b/predicators/envs/pybullet_domino/components/stairs_component.py @@ -12,7 +12,7 @@ from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.structs import Object, State, Type diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index a30846dba..46620b3d0 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -277,10 +277,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: ids.extend(comp.get_object_ids_for_held_check()) return ids - def _create_task_specific_objects(self, state: State) -> None: - """Create any task-specific objects (not used in current impl).""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract state feature for an object.""" # Try each component for comp in self._components: @@ -290,7 +287,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment to match the given state.""" # Update ball component's state reference for is_hit feature if self._ball_component is not None: @@ -699,7 +696,7 @@ def goal_predicates(self) -> Set[Predicate]: print(f"{'=' * 60}") # Reset to initial state - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("\nGoal atoms:") for atom in task.goal: diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 1578f3bed..5572cb091 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -12,9 +12,8 @@ - initialize_pybullet(using_gui) -> (physics_id, robot, bodies_dict) - _store_pybullet_bodies(bodies_dict) - _get_object_ids_for_held_check() -> List[int] - - _create_task_specific_objects(state) - - _reset_custom_env_state(state) - - _extract_feature(obj, feature) -> float + - _set_domain_specific_state(state) + - _get_domain_specific_feature(obj, feature) -> float """ import abc @@ -94,6 +93,11 @@ class PyBulletEnv(BaseEnv): _out_of_view_xy: ClassVar[Sequence[float]] = [10.0, 10.0] _default_orn: ClassVar[Sequence[float]] = [0.0, 0.0, 0.0, 1.0] + # Object types that have no PyBullet body — features managed + # entirely by _get_domain_specific_feature(). + _VIRTUAL_OBJECT_TYPES: ClassVar[frozenset] = frozenset( + {"loc", "angle", "human", "side", "direction"}) + # Camera parameters. _camera_distance: ClassVar[float] = 0.8 _camera_yaw: ClassVar[float] = 90.0 @@ -120,21 +124,25 @@ def __init__(self, use_gui: bool = False) -> None: self.initialize_pybullet(self.using_gui) self._store_pybullet_bodies(pybullet_bodies) - # What are they used for?? - # It's used in get_state, reset_state and labeling state. - # Should be populated at reset or reset state. + # Populated by reset() / _set_state(); used by _get_state(), + # _set_state(), and render_segmented_obj() for iteration. self._objects: List[Object] = [] def get_extra_collision_ids(self) -> Sequence[int]: """Return extra PyBullet body IDs to treat as collision obstacles. - Override in subclasses for bodies not tracked as state Objects - (e.g. liquid blocks in Grow). + Called by the motion planner (skill factories) when computing + collision-free paths. Override in subclasses for bodies not + tracked as state Objects (e.g. liquid blocks in Grow). """ return () def get_object_by_id(self, obj_id: int) -> Object: - """Get object by id.""" + """Look up an Object by its PyBullet body ID. + + Used by agent tools and skill factories to map from a PyBullet + collision/contact result back to the predicators Object. + """ for obj in self._objects: if obj.id == obj_id: return obj @@ -175,11 +183,11 @@ def initialize_pybullet( loading. - Task-specific objects that need to be loaded with different sizes or other properties should be handled in the - `_create_task_specific_objects` method, which is called during each + `_set_domain_specific_state` method, which is called during each task's reset. - Subclasses may override this method to load additional assets. In the subclass, register all object IDs here and move them out of view - in the `reset_custom_env_state` method. + in the `_set_domain_specific_state` method. """ # Skip test coverage because GUI is too expensive to use in unit tests # and cannot be used in headless mode. @@ -221,6 +229,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: @classmethod def _create_pybullet_robot( cls, physics_client_id: int) -> SingleArmPyBulletRobot: + """Instantiate the robot model. Called by initialize_pybullet().""" robot_ee_orn = cls.get_robot_ee_home_orn() ee_home = Pose((cls.robot_init_x, cls.robot_init_y, cls.robot_init_z), robot_ee_orn) @@ -235,11 +244,13 @@ def _create_pybullet_robot( base_pose) def _extract_robot_state(self, state: State) -> Array: - """Given a State, extract the robot state, to be passed into - self._pybullet_robot.reset_state(). + """State -> robot array: extract robot features for PyBullet. + + Converts the robot's features in a State into the array format + expected by self._pybullet_robot.reset_state() + (same format as self._pybullet_robot.get_state()). - This should be the same type as the return value of - self._pybullet_robot.get_state(). + Called by _set_state() to position the robot. """ # EE Position @@ -277,14 +288,20 @@ def get_pos_feature( @abc.abstractmethod def _get_object_ids_for_held_check(self) -> List[int]: - """Return a list of pybullet IDs corresponding to objects in the - simulator that should be checked when determining whether one is - held.""" + """Return PyBullet body IDs of objects that can be grasped. + + Called by _detect_held_object() (inside step()) to decide which + bodies to check for finger contact. Subclasses return only the + IDs of graspable objects (e.g. blocks, not tables). + """ raise NotImplementedError("Override me!") def _get_expected_finger_normals(self) -> Dict[int, Array]: - # Get the current state of the robot, including the orientation - # quaternion + """Compute the expected inward-facing normal for each finger. + + Called by _detect_held_object() to distinguish objects between + the fingers (valid grasp) from objects touching the outside. + """ _rx, _ry, _rz, qx, qy, qz, qw, _rf = self._pybullet_robot.get_state() # Convert the quaternion to a rotation matrix @@ -314,8 +331,11 @@ def _get_expected_finger_normals(self) -> Dict[int, Array]: @classmethod def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, finger_state: float) -> float: - """Map the fingers in the given *State* to joint values for - PyBullet.""" + """Map finger value in a State (e.g. open_fingers=0.04) to the + corresponding PyBullet joint position. + + Called by _extract_robot_state() when writing State -> PyBullet. + """ # If open_fingers is undefined, use 1.0 as the default. subs = { cls.open_fingers: pybullet_robot.open_fingers, @@ -327,7 +347,10 @@ def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, @classmethod def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, finger_joint: float) -> float: - """Inverse of _fingers_state_to_joint().""" + """Inverse of _fingers_state_to_joint(). + + Called by _get_robot_state_dict() when reading PyBullet -> State. + """ subs = { pybullet_robot.open_fingers: cls.open_fingers, pybullet_robot.closed_fingers: cls.closed_fingers, @@ -340,13 +363,24 @@ def action_space(self) -> Box: return self._pybullet_robot.action_space def simulate(self, state: State, action: Action) -> State: - # Optimization: skip _reset_state if pybullet is already in this state. - # _current_observation is None before the first reset() call. - # Check it (not _current_state) because _current_state would fail - # its type assertion on None. + """Apply an action to a state using the PyBullet simulator. + + Called by the option model during bilevel planning to forward- + simulate candidate action sequences without touching the real + environment. + + The _set_state guard handles two cases: + - Skipped (common): during a sequential rollout the option model + calls simulate(s1, a1) -> s2, then simulate(s2, a2) -> s3, etc. + After each call, _current_state already equals the next input + state, so _set_state is unnecessary. + - Taken: when the planner jumps to a different state (e.g. trying + a new skeleton or backtracking), or on the very first call + before any reset() (_current_observation is None). + """ if self._current_observation is None or \ not state.allclose(self._current_state): - self._reset_state(state) + self._set_state(state) return self.step(action) def render_state_plt( @@ -370,15 +404,21 @@ def reset(self, task_idx: int, render: bool = False) -> Observation: state = super().reset(train_or_test, task_idx) - self._reset_state(state) + self._set_state(state) observation = self.get_observation(render=render) return observation - def _reset_state(self, state: State) -> None: - """Reset the PyBullet state to match the given state. + def _set_state(self, state: State) -> None: + """State -> PyBullet: set the simulator to match a State. - Used in initialization (reset(), _add_pybullet_state_to_tasks()) - and bilevel planning (when creating the option model)). + Converts the agent-facing State representation (feature dicts + keyed by Object) into the corresponding PyBullet scene (joint + positions, body poses, grasp constraints, etc.). + + Call sites: + - reset() / _add_pybullet_state_to_tasks(): initialization + - simulate(): option-model / bilevel-planning rollouts + - external callers (skill factories, agent tools, tests) """ # Keep _current_observation in sync so that step() can read it # (e.g. for finger-delta computation). @@ -395,20 +435,15 @@ def _reset_state(self, state: State) -> None: # 2) Reset robot pose self._pybullet_robot.reset_state(self._extract_robot_state(state)) - # I want to have a step that creates task specific objects before reset - # their positions, what should I call this? - self._create_task_specific_objects(state) - # 3) Reset all known objects (position, orientation, etc.) for obj in self._objects: - if obj.type.name in [ - "robot", "loc", "angle", "human", "side", "direction" - ]: + if obj.type.name == "robot" or \ + obj.type.name in self._VIRTUAL_OBJECT_TYPES: continue self._reset_single_object(obj, state) - # 4) Let the subclass do any additional specialized resetting - self._reset_custom_env_state(state) + # 4) Let the subclass do any domain-specific state setup + self._set_domain_specific_state(state) # 5) Check for reconstruction mismatch. # Only raise for envs that override _get_state(). @@ -418,12 +453,12 @@ def _reset_state(self, state: State) -> None: raise ValueError("Could not reconstruct state.") logging.warning("Could not reconstruct state exactly in reset.") - @abc.abstractmethod - def _create_task_specific_objects(self, state: State) -> None: - raise NotImplementedError("Override me!") - def _reset_single_object(self, obj: Object, state: State) -> None: - """Shared logic for setting position/orientation and constraints.""" + """Set a single physical object's pose and grasp constraint in + PyBullet to match the given State. + + Called by _set_state() for every non-robot, non-virtual object. + """ # Skip objects without pybullet IDs (handled by subclass). if obj.id is None: return @@ -432,8 +467,6 @@ def _reset_single_object(self, obj: Object, state: State) -> None: features = obj.type.feature_names cur_x, cur_y, cur_z = p.getBasePositionAndOrientation( obj.id, physicsClientId=self._physics_client_id)[0] - # except: - # breakpoint() px = state.get(obj, "x") if "x" in obj.type.feature_names else cur_x py = state.get(obj, "y") if "y" in obj.type.feature_names else cur_y pz = state.get(obj, "z") if "z" in obj.type.feature_names else cur_z @@ -464,95 +497,44 @@ def _reset_single_object(self, obj: Object, state: State) -> None: # and stores _held_obj_to_base_link. @abc.abstractmethod - def _reset_custom_env_state(self, state: State) -> None: - """Hook for environment-specific resetting (colors, water, etc.). - - Subclasses can override or extend this if needed. + def _set_domain_specific_state(self, state: State) -> None: + """Set simulator state for features that the base class doesn't + handle — e.g. switch on/off, liquid levels, button colors, + balance beam positions. + + Called at the end of _set_state(), after the base class has + already set robot joints, object poses, and grasp constraints. + Subclasses must override. """ raise NotImplementedError("Override me!") + # Features handled by _get_object_state_dict via PyBullet queries. + _PYBULLET_FEATURES: ClassVar[frozenset] = frozenset({ + "x", "y", "z", "rot", "yaw", "roll", "pitch", "is_held", "r", "g", "b" + }) + def _get_state(self, _render_obs: bool = False) -> State: - """Reads the PyBullet scene into a `State` (PyBulletState). It takes - care of: + """PyBullet -> State: read the simulator into a PyBulletState. - * robot features [x, y, z, tilt, wrist, fingers] - * object features [x, y, z, rot, is_held] - the other feature extractors should be implemented in the subclasses via - `_extract_feature`. - """ - state_dict: Dict[Object, Dict[str, float]] = {} + Queries PyBullet for the current scene (joint positions, body + poses, visual data, etc.) and packs the values into the + agent-facing State representation. - # --- 1) Robot --- - robot_state = self._get_robot_state_dict() - state_dict[self._robot] = robot_state + Handles common features (robot pose, object x/y/z/rot/is_held, + color); subclass-specific features are delegated to + `_get_domain_specific_feature`. - # --- 2) Other Objects --- + Called by get_observation() (after reset/step) and by + _set_state() to verify reconstruction fidelity. + """ + state_dict: Dict[Object, Dict[str, float]] = {} + state_dict[self._robot] = self._get_robot_state_dict() for obj in self._objects: - if obj.type.name in ["robot"]: + if obj.type.name == "robot": continue + state_dict[obj] = self._get_object_state_dict(obj) - obj_features = obj.type.feature_names - obj_dict = {} - - if obj.type.name in ["loc", "angle", "human", "side", "direction"]: - for feature in obj_features: - obj_dict[feature] = self._extract_feature(obj, feature) - state_dict[obj] = obj_dict - continue - - # Basic features - try: - (px, py, pz), orn = p.getBasePositionAndOrientation( - obj.id, physicsClientId=self._physics_client_id) - except Exception as e: - raise RuntimeError(f"Failed to get pose for object {obj.name} " - f"(id={obj.id})") from e - if "x" in obj_features: - obj_dict["x"] = px - if "y" in obj_features: - obj_dict["y"] = py - if "z" in obj_features: - obj_dict["z"] = pz - if "rot" in obj_features or "yaw" in obj_features or \ - "roll" in obj_features or "pitch" in obj_features: - roll, pitch, yaw = p.getEulerFromQuaternion(orn) - if "rot" in obj_features: - obj_dict["rot"] = yaw - if "yaw" in obj_features: - obj_dict["yaw"] = yaw - if "roll" in obj_features: - obj_dict["roll"] = roll - if "pitch" in obj_features: - obj_dict["pitch"] = pitch - if "is_held" in obj_features: - obj_dict["is_held"] = 1.0 if obj.id == self._held_obj_id \ - else 0.0 - - if "r" in obj_features or "b" in obj_features or \ - "g" in obj_features: - # Note: also handle color_r, color_b, ... - visual_data = p.getVisualShapeData( - obj.id, physicsClientId=self._physics_client_id)[0] - (r, g, b, _a) = visual_data[7] - obj_dict["r"] = r - obj_dict["g"] = g - obj_dict["b"] = b - - # Additional features - for feature in obj_features: - if feature not in [ - "x", "y", "z", "rot", "yaw", "roll", "pitch", - "is_held", "r", "g", "b" - ]: - obj_dict[feature] = self._extract_feature(obj, feature) - - state_dict[obj] = obj_dict - - # Convert to a PyBulletState - # try: state = utils.create_state_from_dict(state_dict) - # except: - # breakpoint() joint_positions = self._pybullet_robot.get_joints() pyb_state = PyBulletState(state.data, simulator_state={ @@ -564,45 +546,100 @@ def _get_state(self, _render_obs: bool = False) -> State: }) return pyb_state + def _get_object_state_dict(self, obj: Object) -> Dict[str, float]: + """Build a feature dict for a single non-robot object. + + Virtual objects (loc, angle, etc.) delegate all features to + _get_domain_specific_feature. Physical objects get + pose/color/is_held from PyBullet; the rest are delegated. + """ + obj_features = obj.type.feature_names + obj_dict: Dict[str, float] = {} + + if obj.type.name in self._VIRTUAL_OBJECT_TYPES: + for feature in obj_features: + obj_dict[feature] = \ + self._get_domain_specific_feature(obj, feature) + return obj_dict + + # Physical object — query PyBullet for pose + try: + (px, py, pz), orn = p.getBasePositionAndOrientation( + obj.id, physicsClientId=self._physics_client_id) + except Exception as e: + raise RuntimeError(f"Failed to get pose for object {obj.name} " + f"(id={obj.id})") from e + if "x" in obj_features: + obj_dict["x"] = px + if "y" in obj_features: + obj_dict["y"] = py + if "z" in obj_features: + obj_dict["z"] = pz + + if {"rot", "yaw", "roll", "pitch"} & set(obj_features): + roll, pitch, yaw = p.getEulerFromQuaternion(orn) + if "rot" in obj_features: + obj_dict["rot"] = yaw + if "yaw" in obj_features: + obj_dict["yaw"] = yaw + if "roll" in obj_features: + obj_dict["roll"] = roll + if "pitch" in obj_features: + obj_dict["pitch"] = pitch + + if "is_held" in obj_features: + obj_dict["is_held"] = 1.0 if obj.id == self._held_obj_id else 0.0 + + if {"r", "g", "b"} & set(obj_features): + visual_data = p.getVisualShapeData( + obj.id, physicsClientId=self._physics_client_id)[0] + (r, g, b, _a) = visual_data[7] + obj_dict["r"] = r + obj_dict["g"] = g + obj_dict["b"] = b + + # Remaining features delegated to subclass + for feature in obj_features: + if feature not in self._PYBULLET_FEATURES: + obj_dict[feature] = \ + self._get_domain_specific_feature( + obj, feature) + + return obj_dict + @abc.abstractmethod - def _extract_feature(self, obj: Object, feature: str) -> float: - """Called in _get_state() to extract a feature from an object.""" + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: + """Return a single feature value for a non-robot object. + + Called by _get_object_state_dict() for: + - All features of virtual objects (those in _VIRTUAL_OBJECT_TYPES) + - Non-standard features of physical objects (anything not in + _PYBULLET_FEATURES, e.g. is_on, growth, water_height) + """ raise NotImplementedError("Override me!") def _get_robot_state_dict(self) -> Dict[str, float]: - """Get dict state of the robot.""" - r_dict = {} + """Build a feature dict for the robot from PyBullet state. + + Called by _get_state() to populate the robot entry in the State. + Subclasses with non-standard robot features (e.g. cover's + normalized hand, blocks' pose_x/y/z) should override this. + """ + rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() + r_dict: Dict[str, float] = {"x": rx, "y": ry, "z": rz, "fingers": rf} + _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) r_features = self._robot.type.feature_names - if CFG.env == "pybullet_cover": - rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - hand = (ry - self.y_lb) / (self.y_ub - self.y_lb) - r_dict.update({"hand": hand, "pose_x": rx, "pose_z": rz}) - elif CFG.env == "pybullet_blocks": - rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - fingers = self._fingers_joint_to_state(self._pybullet_robot, rf) - r_dict.update({ - "pose_x": rx, - "pose_y": ry, - "pose_z": rz, - "fingers": fingers - }) - else: - rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() - r_dict.update({"x": rx, "y": ry, "z": rz, "fingers": rf}) - _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) - if "tilt" in r_features: - r_dict["tilt"] = tilt - if "wrist" in r_features: - r_dict["wrist"] = wrist + if "tilt" in r_features: + r_dict["tilt"] = tilt + if "wrist" in r_features: + r_dict["wrist"] = wrist return r_dict - def render(self, - action: Optional[Action] = None, - caption: Optional[str] = None) -> Video: # pragma: no cover - # Skip test coverage because GUI is too expensive to use in unit tests - # and cannot be used in headless mode. - del action, caption # unused + def _get_camera_matrices(self) -> Tuple[Any, Any, int, int]: + """Return (view_matrix, proj_matrix, width, height) for rendering. + Called by render() and render_segmented_obj(). + """ view_matrix = p.computeViewMatrixFromYawPitchRoll( cameraTargetPosition=self._camera_target, distance=self._camera_distance, @@ -611,17 +648,23 @@ def render(self, roll=0, upAxisIndex=2, physicsClientId=self._physics_client_id) - width = CFG.pybullet_camera_width height = CFG.pybullet_camera_height - proj_matrix = p.computeProjectionMatrixFOV( fov=self._camera_fov, aspect=float(width / height), nearVal=0.1, farVal=100.0, physicsClientId=self._physics_client_id) + return view_matrix, proj_matrix, width, height + def render(self, + action: Optional[Action] = None, + caption: Optional[str] = None) -> Video: # pragma: no cover + # Skip test coverage because GUI is too expensive to use in unit tests + # and cannot be used in headless mode. + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() (_, _, px, _, _) = p.getCameraImage(width=width, height=height, @@ -629,7 +672,6 @@ def render(self, projectionMatrix=proj_matrix, renderer=p.ER_BULLET_HARDWARE_OPENGL, physicsClientId=self._physics_client_id) - rgb_array = np.array(px).reshape((height, width, 4)) rgb_array = rgb_array[:, :, :3] return [rgb_array] @@ -639,36 +681,13 @@ def render_segmented_obj( action: Optional[Action] = None, caption: Optional[str] = None, ) -> Tuple[Image.Image, Dict[Object, Mask]]: - """Render the scene and the segmented objects in the scene.""" - del action, caption # unused - # if not self.using_gui: - # raise Exception( - # "Rendering only works with GUI on. See " - # "https://github.com/bulletphysics/bullet3/issues/1157") - - view_matrix = p.computeViewMatrixFromYawPitchRoll( - cameraTargetPosition=self._camera_target, - distance=self._camera_distance, - yaw=self._camera_yaw, - pitch=self._camera_pitch, - roll=0, - upAxisIndex=2, - physicsClientId=self._physics_client_id) - - width = CFG.pybullet_camera_width - height = CFG.pybullet_camera_height - - proj_matrix = p.computeProjectionMatrixFOV( - fov=60, - aspect=float(width / height), - nearVal=0.1, - farVal=100.0, - physicsClientId=self._physics_client_id) - - # Initialize an empty dictionary - mask_dict: Dict[Object, Mask] = {} + """Render the scene and return per-object segmentation masks. - # Get the original image and segmentation mask + Called by get_observation(render=True) to attach RGB images and + masks to the observation (used for VLM predicate grounding). + """ + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() (_, _, rgbImg, _, segImg) = p.getCameraImage(width=width, height=height, @@ -676,21 +695,14 @@ def render_segmented_obj( projectionMatrix=proj_matrix, renderer=p.ER_BULLET_HARDWARE_OPENGL, physicsClientId=self._physics_client_id) - - # Convert to numpy arrays original_image: np.ndarray = np.array(rgbImg, dtype=np.uint8).reshape( (height, width, 4)) seg_image = np.array(segImg).reshape((height, width)) - state_img = Image.fromarray( # type: ignore[no-untyped-call] original_image[:, :, :3]) - - # Iterate over all bodies to be labeled + mask_dict: Dict[Object, Mask] = {} for obj in self._objects: - body_id = obj.id - mask = seg_image == body_id - mask_dict[obj] = mask - + mask_dict[obj] = (seg_image == obj.id) return state_img, mask_dict def get_observation(self, render: bool = False) -> Observation: @@ -768,11 +780,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: # If not currently holding something, and fingers are closing, check # for a new grasp. if self._held_constraint_id is None and self._fingers_closing(action): - # logging.debug("Finger closing") - # Detect if an object is held. If so, create a grasp constraint. self._held_obj_id = self._detect_held_object() - # logging.debug(f"Detected held object: {self._held_obj_id}") - # breakpoint() if self._held_obj_id is not None: self._create_grasp_constraint() @@ -782,7 +790,6 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: p.removeConstraint(self._held_constraint_id, physicsClientId=self._physics_client_id) self._held_constraint_id = None - # logging.debug("Finger opening") self._held_obj_id = None # Depending on the observation mode, either return object-centric state @@ -793,10 +800,13 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: return observation def _detect_held_object(self) -> Optional[int]: - """Return the PyBullet object ID of the held object if one exists. + """Return the PyBullet body ID of the grasped object, or None. - If multiple objects are within the grasp tolerance, return the - one that is closest. + Called by step() when fingers are closing and no object is + currently held. Checks contact between each finger and every + graspable body (from _get_object_ids_for_held_check()), using + contact-normal alignment to reject touches on the outside of + the gripper. If multiple objects qualify, returns the closest. """ expected_finger_normals = self._get_expected_finger_normals() closest_held_obj = None @@ -840,6 +850,12 @@ def _detect_held_object(self) -> Optional[int]: return closest_held_obj def _create_grasp_constraint(self) -> None: + """Create a fixed PyBullet constraint between the end-effector + and _held_obj_id so the object moves with the gripper. + + Called by step() after _detect_held_object() finds a grasp, + and by _reset_single_object() when restoring a held state. + """ assert self._held_obj_id is not None base_link_to_world = np.r_[p.invertTransform( *p.getLinkState(self._pybullet_robot.robot_id, @@ -864,32 +880,48 @@ def _create_grasp_constraint(self) -> None: physicsClientId=self._physics_client_id) def _fingers_closing(self, action: Action) -> bool: - """Check whether this action is working toward closing the fingers.""" + """True if this action's finger target is below current position. + + Called by step() to decide whether to check for a new grasp. + """ f_delta = self._action_to_finger_delta(action) return f_delta < -self._finger_action_tol def _fingers_opening(self, action: Action) -> bool: - """Check whether this action is working toward opening the fingers.""" + """True if this action's finger target is above current position. + + Called by step() to decide whether to release a held object. + """ f_delta = self._action_to_finger_delta(action) - # logging.debug(f"Finger delta: {f_delta}") return f_delta > self._finger_action_tol def _get_finger_position(self, state: State) -> float: - # Arbitrarily use the left finger as reference. + """Return the current left-finger joint position from state. + + Called by _action_to_finger_delta() to compute the delta + between current and target finger positions. + """ state = cast(utils.PyBulletState, state) finger_joint_idx = self._pybullet_robot.left_finger_joint_idx return state.joint_positions[finger_joint_idx] def _action_to_finger_delta(self, action: Action) -> float: + """Compute (target - current) finger joint position. + + Called by _fingers_closing() and _fingers_opening(). + """ assert isinstance(self._current_observation, State) finger_position = self._get_finger_position(self._current_observation) joint_positions, _ = self._split_action(action) target = joint_positions[self._pybullet_robot.left_finger_joint_idx] - # logging.debug(f"Finger position: {finger_position}, target: {target}") return target - finger_position def _split_action(self, action: Action) -> Tuple[np.ndarray, np.ndarray]: - """Split an action into joint targets and an optional base delta.""" + """Split an action into (arm_joint_targets, base_delta). + + Called by step() and _action_to_finger_delta(). For robots + without a mobile base, base_delta is an empty array. + """ action_arr = action.arr base_dim = int(getattr(self._pybullet_robot, "base_action_dim", 0)) if base_dim > 0: @@ -905,7 +937,10 @@ def _split_action(self, action: Action) -> Tuple[np.ndarray, np.ndarray]: return action_arr, np.zeros(0, dtype=action_arr.dtype) def _apply_base_delta(self, base_delta: np.ndarray) -> None: - """Apply a delta (dx, dy, dtheta) to the robot base if supported.""" + """Apply a delta (dx, dy, dtheta) to the robot base. + + Called by step() for mobile robots (e.g. mobile_fetch). + """ robot = self._pybullet_robot assert hasattr(robot, 'get_base_pose'), \ "Robot does not support base pose operations" @@ -922,29 +957,23 @@ def _apply_base_delta(self, base_delta: np.ndarray) -> None: def _add_pybullet_state_to_tasks( self, tasks: List[EnvironmentTask]) -> List[EnvironmentTask]: - """Converts the task initial states into PyBulletStates. + """Convert plain-State tasks into PyBulletState tasks. - This is used in generating tasks. + Called by _generate_train/test_tasks() in subclasses. Sets up + the simulator for each task's init state so that joint positions + and (optionally) rendered images are captured into the task. """ pybullet_tasks = [] for task in tasks: # Reset the robot. init = task.init - self._reset_state(init) + self._set_state(init) # Cast _current_observation from type State to PybulletState joint_positions = self._pybullet_robot.get_joints() self._current_observation = utils.PyBulletState( init.data.copy(), simulator_state=joint_positions) - # Attempt 1: Let's try to get a rendering directly first pybullet_init = self.get_observation(render=CFG.render_init_state) - pybullet_init.option_history = [ - ] # useful for vlm predicate grounding - # # + pybullet_init.option_history = [] pybullet_task = EnvironmentTask(pybullet_init, task.goal, goal_nl=task.goal_nl) @@ -953,143 +982,10 @@ def _add_pybullet_state_to_tasks( @classmethod def get_robot_ee_home_orn(cls) -> Quaternion: - """Public for use by oracle options.""" + """Return the default end-effector orientation for this env. + + Used by initialize_pybullet() to set the robot's home pose, + and by oracle options to compute motion-planning targets. + """ robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] return robot_ee_orns[CFG.pybullet_robot] - - -def create_pybullet_block( - color: Tuple[float, float, float, float], - half_extents: Tuple[float, float, float], - mass: float, - friction: float, - position: Pose3D = (0.0, 0.0, 0.0), - orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), - physics_client_id: int = 0, - add_top_triangle: bool = False, -) -> int: - """A generic utility for creating a new block. - - Returns the PyBullet ID of the newly created block. - """ - # The poses here are not important because they are overwritten by - - # Create the collision shape. - collision_id = p.createCollisionShape(p.GEOM_BOX, - halfExtents=half_extents, - physicsClientId=physics_client_id) - - # Create the visual_shape. - visual_id = p.createVisualShape(p.GEOM_BOX, - halfExtents=half_extents, - rgbaColor=color, - physicsClientId=physics_client_id) - - # Create the body. - block_id = p.createMultiBody(baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - physicsClientId=physics_client_id) - p.changeDynamics( - block_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - rollingFriction=friction, - physicsClientId=physics_client_id) - - if add_top_triangle: - # 1. Create the triangle's visual shape - triangle_size = min(half_extents[0], half_extents[1]) - triangle_vertices = [ - [triangle_size, 0, 0], # Tip pointing in +X - [-triangle_size, triangle_size, 0], # Back left - [-triangle_size, -triangle_size, 0] # Back right - ] - triangle_visual_id = p.createVisualShape( - p.GEOM_MESH, - vertices=triangle_vertices, - indices=[0, 1, 2], # <-- FIX: Added this line - rgbaColor=[1, 1, 0, - 1], # <-- CHANGE: Set to yellow (R=1, G=1, B=0, A=1) - physicsClientId=physics_client_id) - - # 2. Re-create the body, but this time WITH a link for the triangle - p.removeBody( - block_id, - physicsClientId=physics_client_id) # Remove the old simple block - - block_id = p.createMultiBody( - baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - # --- Link Parameters for the Triangle --- - linkMasses=[0], # Massless link - linkCollisionShapeIndices=[-1], # No collision for the link - linkVisualShapeIndices=[triangle_visual_id - ], # Visual shape for the link - # Position the link's origin on top of the block's base - linkPositions=[[0, 0, half_extents[2] + 0.001]], - linkOrientations=[[0, 0, 0, 1]], # No relative rotation - linkInertialFramePositions=[[0, 0, 0]], - linkInertialFrameOrientations=[[0, 0, 0, 1]], - linkParentIndices=[0], # Link is attached to the base (index 0) - linkJointTypes=[p.JOINT_FIXED], # Link is fixed to the base - linkJointAxis=[[0, 0, - 1]], # Axis for the joint (not relevant for fixed) - physicsClientId=physics_client_id) - - # Re-apply dynamics to the new multi-body object - p.changeDynamics( - block_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - physicsClientId=physics_client_id) - - return block_id - - -def create_pybullet_sphere( - color: Tuple[float, float, float, float], - radius: float, - mass: float, - friction: float, - position: Pose3D = (0.0, 0.0, 0.0), - orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), - physics_client_id: int = 0, -) -> int: - """A generic utility for creating a new sphere. - - Returns the PyBullet ID of the newly created sphere. - """ - # Create the collision shape. - collision_id = p.createCollisionShape(p.GEOM_SPHERE, - radius=radius, - physicsClientId=physics_client_id) - - # Create the visual shape. - visual_id = p.createVisualShape(p.GEOM_SPHERE, - radius=radius, - rgbaColor=color, - physicsClientId=physics_client_id) - - # Create the body. - sphere_id = p.createMultiBody(baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - physicsClientId=physics_client_id) - p.changeDynamics( - sphere_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - physicsClientId=physics_client_id) - - return sphere_id diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index d4acbdfec..4059c9122 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -6,10 +6,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block, \ - create_pybullet_sphere +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, create_pybullet_sphere, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -610,7 +610,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: self._target.id = pybullet_bodies["target_id"] # Initialize boundary wall IDs list (will be populated - # in _reset_custom_env_state) + # in _set_domain_specific_state) # pylint: disable=attribute-defined-outside-init self._boundary_wall_ids: List[int] = [] @@ -620,10 +620,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def _get_object_ids_for_held_check(self) -> List[int]: return [] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: for switch_obj in self._switches: is_on_val = state.get(switch_obj, "is_on") self._set_switch_on(switch_obj.id, bool(is_on_val > 0.5)) @@ -838,7 +835,7 @@ def _position_fans_on_sides(self) -> None: orientation=p.getQuaternionFromEuler(rot), physics_client_id=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._fan_type: if feature == "facing_side": @@ -1633,7 +1630,7 @@ def _has_valid_path(self, start_pos: Tuple[int, CFG.fan_train_num_walls_per_task, _rng) for _task in _tasks: - env._reset_state(_task.init) # pylint: disable=protected-access + env._set_state(_task.init) # pylint: disable=protected-access for _ in range(5000): _action = Action( np.array(env._pybullet_robot # pylint: disable=protected-access diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index fef0830d3..4b95df4f6 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -13,10 +13,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ - sample_collision_free_2d_positions, update_object + create_pybullet_block, sample_collision_free_2d_positions, update_object from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type @@ -229,10 +229,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def _get_object_ids_for_held_check(self) -> List[int]: return [block_obj.id for block_obj in self._blocks] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._block_type: # if feature == "is_light": @@ -255,7 +252,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return self._current_water_height raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: # Initialize water level self._current_water_height = state.get(self._vessel, "water_height") @@ -617,7 +614,7 @@ def _make_tasks(self, num_tasks: int, CFG.pybullet_sim_steps_per_action = 1 env = PyBulletFloatEnv(use_gui=True) task = env._make_tasks(1, np.random.default_rng(0))[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: action = Action(np.array(env._pybullet_robot.initial_joint_positions)) # pylint: disable=protected-access diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index e1bc394a0..395e10428 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -14,9 +14,10 @@ from predicators import utils from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -265,10 +266,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: jug_ids = [jug.id for jug in self._jugs if jug.id is not None] return jug_ids - def _create_task_specific_objects(self, state: State) -> None: - """No extra objects to create beyond cups and jugs.""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" # For growth, we look up the height of the liquid body if obj.type == self._cup_type and feature == "growth": @@ -285,8 +283,8 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: - """Called in _reset_state to handle any custom resetting.""" + def _set_domain_specific_state(self, state: State) -> None: + """Called in _set_state to handle any custom resetting.""" # Remove existing "liquid bodies" for liquid_id in self._cup_to_liquid_id.values(): if liquid_id is not None: @@ -724,7 +722,7 @@ def _create_pybullet_liquid_for_cup( _rng = np.random.default_rng(CFG.seed) _task = env._get_tasks( # pylint: disable=protected-access 1, CFG.grow_num_cups_test, CFG.grow_num_jugs_test, _rng)[0] - env._reset_state(_task.init) # pylint: disable=protected-access + env._set_state(_task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index 6a71cda18..9b4e58c09 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -282,14 +282,11 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: # ------------------------------------------------------------------------- # State Reading/Writing # ------------------------------------------------------------------------- - def _create_task_specific_objects(self, state: State) -> None: - pass - def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of wires (assuming the robot can pick them up).""" return [m.id for m in self._normal_mirrors + self._split_mirrors] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._station_type: if feature == "is_on": @@ -302,7 +299,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return 1.0 if self._is_target_hit(obj) else 0.0 raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: oov_x, oov_y = self._out_of_view_xy lasers_copy = _laser_ids.copy() @@ -822,7 +819,7 @@ def create_laser_cylinder(start: Any, CFG.laser_zero_reflection_angle = True env = PyBulletLaserEnv(use_gui=True) task = env._make_tasks(1, np.random.default_rng(CFG.seed), True)[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index 583fe1294..2c6d8bfd6 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -16,9 +16,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -235,7 +236,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (blocks).""" return [block.id for block in self._blocks] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._switch_type and feature == "is_on": return float(self._is_switch_on()) @@ -246,10 +247,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return float(pos[0] > 5.0) # Out of view if x > 5 raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment state from a State object.""" # Set switch state switch_on = state.get(self._switch, "is_on") > 0.5 @@ -481,7 +479,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletMagicBinEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("PyBullet Magic Bin Environment Test") print("Blocks should vanish when in bin with switch ON.") diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index bd5ac59d1..ed4bb858b 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -223,7 +223,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (none in this env).""" return [] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._light_type and feature == "is_on": return float(self._is_power_switch_on()) @@ -236,10 +236,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return float(self._is_switch_on(self._color_switch)) raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment state from a State object.""" # Set power switch state power_on = state.get(self._power_switch, "is_on") > 0.5 @@ -465,7 +462,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletSwitchEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: _joints = env._pybullet_robot.initial_joint_positions # pylint: disable=protected-access diff --git a/predicators/ground_truth_models/boil/options.py b/predicators/ground_truth_models/boil/options.py index 769edbcbd..59b2ccd48 100644 --- a/predicators/ground_truth_models/boil/options.py +++ b/predicators/ground_truth_models/boil/options.py @@ -88,7 +88,7 @@ def _get_options_skill_factories( # --------------------------------------------------------------- # Helper: find the switch object associated with a faucet/burner. - # The env sets obj.switch_id in _reset_state. + # The env sets obj.switch_id in _set_state. # --------------------------------------------------------------- def _get_switch_pose( state: State, diff --git a/predicators/ground_truth_models/skill_factories/base.py b/predicators/ground_truth_models/skill_factories/base.py index 8cdf73f48..64ef19541 100644 --- a/predicators/ground_truth_models/skill_factories/base.py +++ b/predicators/ground_truth_models/skill_factories/base.py @@ -543,7 +543,7 @@ def _plan_with_simulator( new_state_data, simulator_state=pb_state.simulator_state) # 3. Reset simulator to current state - sim._reset_state(remapped_state) # pylint: disable=protected-access + sim._set_state(remapped_state) # pylint: disable=protected-access # 4. Collect collision body IDs (exclude held objects and # non-physical types) and find the held object. diff --git a/predicators/pybullet_helpers/objects.py b/predicators/pybullet_helpers/objects.py index 42941e9c1..6b226deac 100644 --- a/predicators/pybullet_helpers/objects.py +++ b/predicators/pybullet_helpers/objects.py @@ -157,3 +157,108 @@ def create_geom(px: float, py: float) -> _Geom2D: else: # We successfully placed all shapes return positions + + +def create_pybullet_block( + color: Tuple[float, float, float, float], + half_extents: Tuple[float, float, float], + mass: float, + friction: float, + position: Pose3D = (0.0, 0.0, 0.0), + orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), + physics_client_id: int = 0, + add_top_triangle: bool = False, +) -> int: + """Create a box-shaped PyBullet body and return its ID.""" + collision_id = p.createCollisionShape(p.GEOM_BOX, + halfExtents=half_extents, + physicsClientId=physics_client_id) + visual_id = p.createVisualShape(p.GEOM_BOX, + halfExtents=half_extents, + rgbaColor=color, + physicsClientId=physics_client_id) + block_id = p.createMultiBody(baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + physicsClientId=physics_client_id) + p.changeDynamics(block_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + rollingFriction=friction, + physicsClientId=physics_client_id) + + if add_top_triangle: + triangle_size = min(half_extents[0], half_extents[1]) + triangle_vertices = [ + [triangle_size, 0, 0], + [-triangle_size, triangle_size, 0], + [-triangle_size, -triangle_size, 0], + ] + triangle_visual_id = p.createVisualShape( + p.GEOM_MESH, + vertices=triangle_vertices, + indices=[0, 1, 2], + rgbaColor=[1, 1, 0, 1], + physicsClientId=physics_client_id) + + p.removeBody(block_id, physicsClientId=physics_client_id) + + block_id = p.createMultiBody( + baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + linkMasses=[0], + linkCollisionShapeIndices=[-1], + linkVisualShapeIndices=[triangle_visual_id], + linkPositions=[[0, 0, half_extents[2] + 0.001]], + linkOrientations=[[0, 0, 0, 1]], + linkInertialFramePositions=[[0, 0, 0]], + linkInertialFrameOrientations=[[0, 0, 0, 1]], + linkParentIndices=[0], + linkJointTypes=[p.JOINT_FIXED], + linkJointAxis=[[0, 0, 1]], + physicsClientId=physics_client_id) + + p.changeDynamics(block_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + physicsClientId=physics_client_id) + + return block_id + + +def create_pybullet_sphere( + color: Tuple[float, float, float, float], + radius: float, + mass: float, + friction: float, + position: Pose3D = (0.0, 0.0, 0.0), + orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), + physics_client_id: int = 0, +) -> int: + """Create a sphere-shaped PyBullet body and return its ID.""" + collision_id = p.createCollisionShape(p.GEOM_SPHERE, + radius=radius, + physicsClientId=physics_client_id) + visual_id = p.createVisualShape(p.GEOM_SPHERE, + radius=radius, + rgbaColor=color, + physicsClientId=physics_client_id) + sphere_id = p.createMultiBody(baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + physicsClientId=physics_client_id) + p.changeDynamics(sphere_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + physicsClientId=physics_client_id) + return sphere_id diff --git a/scripts/run_blocks_perception.py b/scripts/run_blocks_perception.py index 82b8e2693..585d4d067 100644 --- a/scripts/run_blocks_perception.py +++ b/scripts/run_blocks_perception.py @@ -98,9 +98,9 @@ from predicators import utils from predicators.envs.pybullet_blocks import PyBulletBlocksEnv -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose3D +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import \ create_single_arm_pybullet_robot from predicators.settings import CFG diff --git a/tests/envs/test_pybullet_blocks.py b/tests/envs/test_pybullet_blocks.py index 40922fed6..39512c703 100644 --- a/tests/envs/test_pybullet_blocks.py +++ b/tests/envs/test_pybullet_blocks.py @@ -70,7 +70,7 @@ def set_state(self, state): simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) + self._set_state(state_with_sim) def get_state(self): """Expose get_state().""" diff --git a/tests/envs/test_pybullet_cover.py b/tests/envs/test_pybullet_cover.py index fe012bd94..376b88d71 100644 --- a/tests/envs/test_pybullet_cover.py +++ b/tests/envs/test_pybullet_cover.py @@ -43,7 +43,7 @@ def set_state(self, state): simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) + self._set_state(state_with_sim) def get_state(self): """Expose get_state().""" diff --git a/tests/pybullet_helpers/test_motion_planning.py b/tests/pybullet_helpers/test_motion_planning.py index f471ff83d..7eb04e37f 100644 --- a/tests/pybullet_helpers/test_motion_planning.py +++ b/tests/pybullet_helpers/test_motion_planning.py @@ -6,12 +6,12 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose from predicators.pybullet_helpers.joint import JointPositions from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.motion_planning import run_motion_planning +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import \ create_single_arm_pybullet_robot diff --git a/tests/test_skill_factories_integration.py b/tests/test_skill_factories_integration.py index 40a685fec..54f56cde9 100644 --- a/tests/test_skill_factories_integration.py +++ b/tests/test_skill_factories_integration.py @@ -78,7 +78,7 @@ def set_state(self, state: Any) -> None: simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) # type: ignore[attr-defined] + self._set_state(state_with_sim) # type: ignore[attr-defined] def get_state(self) -> Any: """Get state.""" From 5bf6af3f1946391ffe8e31a82f5db3185619cf87 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 7 Apr 2026 15:23:16 +0100 Subject: [PATCH 12/70] Regroup PyBulletEnv methods by responsibility and update docstring Reorganize methods into labeled sections (Setup, Public API, Core Loop, State Write/Read, Grasp Management, Action Helpers, Rendering, Utilities) so related functions are adjacent. Update module docstring to document the main public API and state synchronization methods. --- predicators/envs/pybullet_env.py | 689 ++++++++++++++++--------------- 1 file changed, 360 insertions(+), 329 deletions(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 5572cb091..56e8d887e 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -7,7 +7,20 @@ For a comprehensive guide on creating new PyBullet environments, see: docs/pybullet_env_guide.md -Quick reference - required methods to implement: +Main public API: + reset(train_or_test, task_idx) — reset env to a task, returns observation + simulate(state, action) — forward-simulate without touching real env + step(action) — execute action, manage grasps, return observation + get_observation() — read PyBullet state, optionally attach images/masks + +State synchronization: + _set_state(state) — write a State into PyBullet (robot pose, object + poses, grasp constraints). Delegates domain-specific setup to + _set_domain_specific_state(). + _get_state() — read PyBullet into a PyBulletState. Delegates + domain-specific features to _get_domain_specific_feature(). + +Required overrides in subclasses: - get_name() -> str - initialize_pybullet(using_gui) -> (physics_id, robot, bodies_dict) - _store_pybullet_bodies(bodies_dict) @@ -128,25 +141,7 @@ def __init__(self, use_gui: bool = False) -> None: # _set_state(), and render_segmented_obj() for iteration. self._objects: List[Object] = [] - def get_extra_collision_ids(self) -> Sequence[int]: - """Return extra PyBullet body IDs to treat as collision obstacles. - - Called by the motion planner (skill factories) when computing - collision-free paths. Override in subclasses for bodies not - tracked as state Objects (e.g. liquid blocks in Grow). - """ - return () - - def get_object_by_id(self, obj_id: int) -> Object: - """Look up an Object by its PyBullet body ID. - - Used by agent tools and skill factories to map from a PyBullet - collision/contact result back to the predicators Object. - """ - for obj in self._objects: - if obj.id == obj_id: - return obj - raise ValueError(f"Object with ID {obj_id} not found") + # ── Setup & Initialization ────────────────────────────────── @classmethod def initialize_pybullet( @@ -243,124 +238,52 @@ def _create_pybullet_robot( physics_client_id, ee_home, base_pose) - def _extract_robot_state(self, state: State) -> Array: - """State -> robot array: extract robot features for PyBullet. - - Converts the robot's features in a State into the array format - expected by self._pybullet_robot.reset_state() - (same format as self._pybullet_robot.get_state()). + @classmethod + def get_robot_ee_home_orn(cls) -> Quaternion: + """Return the default end-effector orientation for this env. - Called by _set_state() to position the robot. + Used by initialize_pybullet() to set the robot's home pose, + and by oracle options to compute motion-planning targets. """ + robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] + return robot_ee_orns[CFG.pybullet_robot] - # EE Position - def get_pos_feature( - state: State, - feature_name: str) -> float: # type: ignore[no-untyped-def] - if feature_name in self._robot.type.feature_names: - return state.get(self._robot, feature_name) - if f"pose_{feature_name}" in self._robot.type.feature_names: - return state.get(self._robot, f"pose_{feature_name}") - raise ValueError(f"Cannot find robot pos '{feature_name}'") - - rx = get_pos_feature(state, "x") - ry = get_pos_feature(state, "y") - rz = get_pos_feature(state, "z") - - # EE Orientation - _, default_tilt, default_wrist = p.getEulerFromQuaternion( - self.get_robot_ee_home_orn()) - if "tilt" in self._robot.type.feature_names: - tilt = state.get(self._robot, "tilt") - else: - tilt = default_tilt - if "wrist" in self._robot.type.feature_names: - wrist = state.get(self._robot, "wrist") - else: - wrist = default_wrist - qx, qy, qz, qw = p.getQuaternionFromEuler([0.0, tilt, wrist]) - - # Fingers - f = state.get(self._robot, "fingers") - f = self._fingers_state_to_joint(self._pybullet_robot, f) - - return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) - - @abc.abstractmethod - def _get_object_ids_for_held_check(self) -> List[int]: - """Return PyBullet body IDs of objects that can be grasped. + # ── Public API & Properties ───────────────────────────────── - Called by _detect_held_object() (inside step()) to decide which - bodies to check for finger contact. Subclasses return only the - IDs of graspable objects (e.g. blocks, not tables). - """ - raise NotImplementedError("Override me!") + @property + def action_space(self) -> Box: + return self._pybullet_robot.action_space - def _get_expected_finger_normals(self) -> Dict[int, Array]: - """Compute the expected inward-facing normal for each finger. + def get_extra_collision_ids(self) -> Sequence[int]: + """Return extra PyBullet body IDs to treat as collision obstacles. - Called by _detect_held_object() to distinguish objects between - the fingers (valid grasp) from objects touching the outside. + Called by the motion planner (skill factories) when computing + collision-free paths. Override in subclasses for bodies not + tracked as state Objects (e.g. liquid blocks in Grow). """ - _rx, _ry, _rz, qx, qy, qz, qw, _rf = self._pybullet_robot.get_state() - - # Convert the quaternion to a rotation matrix - rotation_matrix = p.getMatrixFromQuaternion([qx, qy, qz, qw]) - rotation_matrix = np.array(rotation_matrix).reshape(3, 3) - - # Define the initial normal vectors for the fingers - if CFG.pybullet_robot == "panda": - # gripper rotated 90deg so parallel to x-axis - normal = np.array([1., 0., 0.], dtype=np.float32) - elif CFG.pybullet_robot in {"fetch", "mobile_fetch"}: - # gripper parallel to y-axis - normal = np.array([0., 1., 0.], dtype=np.float32) - else: # pragma: no cover - # Shouldn't happen unless we introduce a new robot. - raise ValueError(f"Unknown robot {CFG.pybullet_robot}") - - # Transform the normal vectors using the rotation matrix - transformed_normal = rotation_matrix.dot(normal) - transformed_normal_neg = rotation_matrix.dot(-1 * normal) - - return { - self._pybullet_robot.left_finger_id: transformed_normal, - self._pybullet_robot.right_finger_id: transformed_normal_neg, - } + return () - @classmethod - def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, - finger_state: float) -> float: - """Map finger value in a State (e.g. open_fingers=0.04) to the - corresponding PyBullet joint position. + def get_object_by_id(self, obj_id: int) -> Object: + """Look up an Object by its PyBullet body ID. - Called by _extract_robot_state() when writing State -> PyBullet. + Used by agent tools and skill factories to map from a PyBullet + collision/contact result back to the predicators Object. """ - # If open_fingers is undefined, use 1.0 as the default. - subs = { - cls.open_fingers: pybullet_robot.open_fingers, - cls.closed_fingers: pybullet_robot.closed_fingers, - } - match = min(subs, key=lambda k: abs(k - finger_state)) - return subs[match] - - @classmethod - def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, - finger_joint: float) -> float: - """Inverse of _fingers_state_to_joint(). + for obj in self._objects: + if obj.id == obj_id: + return obj + raise ValueError(f"Object with ID {obj_id} not found") - Called by _get_robot_state_dict() when reading PyBullet -> State. - """ - subs = { - pybullet_robot.open_fingers: cls.open_fingers, - pybullet_robot.closed_fingers: cls.closed_fingers, - } - match = min(subs, key=lambda k: abs(k - finger_joint)) - return subs[match] + # ── Core Loop (Reset / Simulate / Step) ───────────────────── - @property - def action_space(self) -> Box: - return self._pybullet_robot.action_space + def reset(self, + train_or_test: str, + task_idx: int, + render: bool = False) -> Observation: + state = super().reset(train_or_test, task_idx) + self._set_state(state) + observation = self.get_observation(render=render) + return observation def simulate(self, state: State, action: Action) -> State: """Apply an action to a state using the PyBullet simulator. @@ -383,31 +306,85 @@ def simulate(self, state: State, action: Action) -> State: self._set_state(state) return self.step(action) - def render_state_plt( - self, - state: State, - task: EnvironmentTask, - action: Optional[Action] = None, - caption: Optional[str] = None) -> matplotlib.figure.Figure: - raise NotImplementedError("This env does not use Matplotlib") + def step(self, action: Action, render_obs: bool = False) -> Observation: + """Execute one environment step with the given action. - def render_state(self, - state: State, - task: EnvironmentTask, - action: Optional[Action] = None, - caption: Optional[str] = None) -> Video: - raise NotImplementedError("A PyBullet environment cannot render " - "arbitrary states.") + This method handles: + 1. Robot joint control by converting action to target positions + 2. Management of held objects and grasping constraints + 3. Physics simulation stepping + 4. Object grasp detection and constraint creation/removal + 5. `self._current_observation` update + + Args: + action (Action): The action to execute, containing target joint + positions + render_obs (bool, optional): Whether to include RGB observation. + Defaults to False. + + Returns: + Observation: Updated environment observation after executing the + action. May include an image if render_obs=True or + CFG.rgb_observation=True. + """ + # Send the action to the robot. + target_joint_positions, base_delta = self._split_action(action) + if base_delta.size: + self._apply_base_delta(base_delta) + self._pybullet_robot.set_motors(target_joint_positions.tolist()) + + # If we are setting the robot joints directly, and if there is a held + # object, we need to reset the pose of the held object directly. This + # is because the PyBullet constraints don't seem to play nicely with + # resetJointState (the robot will sometimes drop the object). + if CFG.pybullet_control_mode == "reset" and \ + self._held_obj_id is not None: + world_to_base_link = get_link_state( + self._pybullet_robot.robot_id, + self._pybullet_robot.end_effector_id, + physics_client_id=self._physics_client_id).com_pose + base_link_to_held_obj = p.invertTransform( + *self._held_obj_to_base_link) + world_to_held_obj = p.multiplyTransforms(world_to_base_link[0], + world_to_base_link[1], + base_link_to_held_obj[0], + base_link_to_held_obj[1]) + p.resetBasePositionAndOrientation( + self._held_obj_id, + world_to_held_obj[0], + world_to_held_obj[1], + physicsClientId=self._physics_client_id) + + # Step the simulation here before adding or removing constraints + # because detect_held_object() should use the updated state. + if CFG.pybullet_control_mode != "reset": + for _ in range(CFG.pybullet_sim_steps_per_action): + p.stepSimulation(physicsClientId=self._physics_client_id) + + # If not currently holding something, and fingers are closing, check + # for a new grasp. + if self._held_constraint_id is None and self._fingers_closing(action): + self._held_obj_id = self._detect_held_object() + if self._held_obj_id is not None: + self._create_grasp_constraint() + + # If placing, remove the grasp constraint. + if self._held_constraint_id is not None and \ + self._fingers_opening(action): + p.removeConstraint(self._held_constraint_id, + physicsClientId=self._physics_client_id) + self._held_constraint_id = None + self._held_obj_id = None + + # Depending on the observation mode, either return object-centric state + # or object_centric + rgb observation + observation = self.get_observation(render=CFG.rgb_observation or\ + render_obs) - def reset(self, - train_or_test: str, - task_idx: int, - render: bool = False) -> Observation: - state = super().reset(train_or_test, task_idx) - self._set_state(state) - observation = self.get_observation(render=render) return observation + # ── State Write (State → PyBullet) ────────────────────────── + def _set_state(self, state: State) -> None: """State -> PyBullet: set the simulator to match a State. @@ -482,7 +459,7 @@ def _reset_single_object(self, obj: Object, state: State) -> None: else: orn = self._default_orn # e.g. (0,0,0,1) - # 2) Update the object’s position/orientation in PyBullet + # 2) Update the object's position/orientation in PyBullet update_object(obj.id, (px, py, pz), orn, physics_client_id=self._physics_client_id) @@ -508,10 +485,71 @@ def _set_domain_specific_state(self, state: State) -> None: """ raise NotImplementedError("Override me!") - # Features handled by _get_object_state_dict via PyBullet queries. - _PYBULLET_FEATURES: ClassVar[frozenset] = frozenset({ - "x", "y", "z", "rot", "yaw", "roll", "pitch", "is_held", "r", "g", "b" - }) + def _extract_robot_state(self, state: State) -> Array: + """State -> robot array: extract robot features for PyBullet. + + Converts the robot's features in a State into the array format + expected by self._pybullet_robot.reset_state() + (same format as self._pybullet_robot.get_state()). + + Called by _set_state() to position the robot. + """ + + # EE Position + def get_pos_feature( + state: State, + feature_name: str) -> float: # type: ignore[no-untyped-def] + if feature_name in self._robot.type.feature_names: + return state.get(self._robot, feature_name) + if f"pose_{feature_name}" in self._robot.type.feature_names: + return state.get(self._robot, f"pose_{feature_name}") + raise ValueError(f"Cannot find robot pos '{feature_name}'") + + rx = get_pos_feature(state, "x") + ry = get_pos_feature(state, "y") + rz = get_pos_feature(state, "z") + + # EE Orientation + _, default_tilt, default_wrist = p.getEulerFromQuaternion( + self.get_robot_ee_home_orn()) + if "tilt" in self._robot.type.feature_names: + tilt = state.get(self._robot, "tilt") + else: + tilt = default_tilt + if "wrist" in self._robot.type.feature_names: + wrist = state.get(self._robot, "wrist") + else: + wrist = default_wrist + qx, qy, qz, qw = p.getQuaternionFromEuler([0.0, tilt, wrist]) + + # Fingers + f = state.get(self._robot, "fingers") + f = self._fingers_state_to_joint(self._pybullet_robot, f) + + return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) + + @classmethod + def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, + finger_state: float) -> float: + """Map finger value in a State (e.g. open_fingers=0.04) to the + corresponding PyBullet joint position. + + Called by _extract_robot_state() when writing State -> PyBullet. + """ + # If open_fingers is undefined, use 1.0 as the default. + subs = { + cls.open_fingers: pybullet_robot.open_fingers, + cls.closed_fingers: pybullet_robot.closed_fingers, + } + match = min(subs, key=lambda k: abs(k - finger_state)) + return subs[match] + + # ── State Read (PyBullet → State) ─────────────────────────── + + # Features handled by _get_object_state_dict via PyBullet queries. + _PYBULLET_FEATURES: ClassVar[frozenset] = frozenset({ + "x", "y", "z", "rot", "yaw", "roll", "pitch", "is_held", "r", "g", "b" + }) def _get_state(self, _render_obs: bool = False) -> State: """PyBullet -> State: read the simulator into a PyBulletState. @@ -546,6 +584,23 @@ def _get_state(self, _render_obs: bool = False) -> State: }) return pyb_state + def _get_robot_state_dict(self) -> Dict[str, float]: + """Build a feature dict for the robot from PyBullet state. + + Called by _get_state() to populate the robot entry in the State. + Subclasses with non-standard robot features (e.g. cover's + normalized hand, blocks' pose_x/y/z) should override this. + """ + rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() + r_dict: Dict[str, float] = {"x": rx, "y": ry, "z": rz, "fingers": rf} + _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) + r_features = self._robot.type.feature_names + if "tilt" in r_features: + r_dict["tilt"] = tilt + if "wrist" in r_features: + r_dict["wrist"] = wrist + return r_dict + def _get_object_state_dict(self, obj: Object) -> Dict[str, float]: """Build a feature dict for a single non-robot object. @@ -618,186 +673,63 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """ raise NotImplementedError("Override me!") - def _get_robot_state_dict(self) -> Dict[str, float]: - """Build a feature dict for the robot from PyBullet state. - - Called by _get_state() to populate the robot entry in the State. - Subclasses with non-standard robot features (e.g. cover's - normalized hand, blocks' pose_x/y/z) should override this. - """ - rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() - r_dict: Dict[str, float] = {"x": rx, "y": ry, "z": rz, "fingers": rf} - _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) - r_features = self._robot.type.feature_names - if "tilt" in r_features: - r_dict["tilt"] = tilt - if "wrist" in r_features: - r_dict["wrist"] = wrist - return r_dict - - def _get_camera_matrices(self) -> Tuple[Any, Any, int, int]: - """Return (view_matrix, proj_matrix, width, height) for rendering. + @classmethod + def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, + finger_joint: float) -> float: + """Inverse of _fingers_state_to_joint(). - Called by render() and render_segmented_obj(). + Called by _get_robot_state_dict() when reading PyBullet -> State. """ - view_matrix = p.computeViewMatrixFromYawPitchRoll( - cameraTargetPosition=self._camera_target, - distance=self._camera_distance, - yaw=self._camera_yaw, - pitch=self._camera_pitch, - roll=0, - upAxisIndex=2, - physicsClientId=self._physics_client_id) - width = CFG.pybullet_camera_width - height = CFG.pybullet_camera_height - proj_matrix = p.computeProjectionMatrixFOV( - fov=self._camera_fov, - aspect=float(width / height), - nearVal=0.1, - farVal=100.0, - physicsClientId=self._physics_client_id) - return view_matrix, proj_matrix, width, height - - def render(self, - action: Optional[Action] = None, - caption: Optional[str] = None) -> Video: # pragma: no cover - # Skip test coverage because GUI is too expensive to use in unit tests - # and cannot be used in headless mode. - del action, caption # unused - view_matrix, proj_matrix, width, height = self._get_camera_matrices() - (_, _, px, _, - _) = p.getCameraImage(width=width, - height=height, - viewMatrix=view_matrix, - projectionMatrix=proj_matrix, - renderer=p.ER_BULLET_HARDWARE_OPENGL, - physicsClientId=self._physics_client_id) - rgb_array = np.array(px).reshape((height, width, 4)) - rgb_array = rgb_array[:, :, :3] - return [rgb_array] - - def render_segmented_obj( - self, - action: Optional[Action] = None, - caption: Optional[str] = None, - ) -> Tuple[Image.Image, Dict[Object, Mask]]: - """Render the scene and return per-object segmentation masks. + subs = { + pybullet_robot.open_fingers: cls.open_fingers, + pybullet_robot.closed_fingers: cls.closed_fingers, + } + match = min(subs, key=lambda k: abs(k - finger_joint)) + return subs[match] - Called by get_observation(render=True) to attach RGB images and - masks to the observation (used for VLM predicate grounding). - """ - del action, caption # unused - view_matrix, proj_matrix, width, height = self._get_camera_matrices() - (_, _, rgbImg, _, - segImg) = p.getCameraImage(width=width, - height=height, - viewMatrix=view_matrix, - projectionMatrix=proj_matrix, - renderer=p.ER_BULLET_HARDWARE_OPENGL, - physicsClientId=self._physics_client_id) - original_image: np.ndarray = np.array(rgbImg, dtype=np.uint8).reshape( - (height, width, 4)) - seg_image = np.array(segImg).reshape((height, width)) - state_img = Image.fromarray( # type: ignore[no-untyped-call] - original_image[:, :, :3]) - mask_dict: Dict[Object, Mask] = {} - for obj in self._objects: - mask_dict[obj] = (seg_image == obj.id) - return state_img, mask_dict + # ── Grasp Detection & Constraint Management ───────────────── - def get_observation(self, render: bool = False) -> Observation: - """Get the current observation of this environment. + @abc.abstractmethod + def _get_object_ids_for_held_check(self) -> List[int]: + """Return PyBullet body IDs of objects that can be grasped. - Reads the current state from pybullet, updates - _current_observation (the backing field), and returns a copy - optionally with rendered images. + Called by _detect_held_object() (inside step()) to decide which + bodies to check for finger contact. Subclasses return only the + IDs of graspable objects (e.g. blocks, not tables). """ - state = self._get_state() - assert isinstance(state, PyBulletState) - self._current_observation = state - obs = state.copy() - - if render: - obs.add_images_and_masks(*self.render_segmented_obj()) - - return obs - - def step(self, action: Action, render_obs: bool = False) -> Observation: - """Execute one environment step with the given action. - - This method handles: - 1. Robot joint control by converting action to target positions - 2. Management of held objects and grasping constraints - 3. Physics simulation stepping - 4. Object grasp detection and constraint creation/removal - 5. `self._current_observation` update + raise NotImplementedError("Override me!") - Args: - action (Action): The action to execute, containing target joint - positions - render_obs (bool, optional): Whether to include RGB observation. - Defaults to False. + def _get_expected_finger_normals(self) -> Dict[int, Array]: + """Compute the expected inward-facing normal for each finger. - Returns: - Observation: Updated environment observation after executing the - action. May include an image if render_obs=True or - CFG.rgb_observation=True. + Called by _detect_held_object() to distinguish objects between + the fingers (valid grasp) from objects touching the outside. """ - # Send the action to the robot. - target_joint_positions, base_delta = self._split_action(action) - if base_delta.size: - self._apply_base_delta(base_delta) - self._pybullet_robot.set_motors(target_joint_positions.tolist()) - - # If we are setting the robot joints directly, and if there is a held - # object, we need to reset the pose of the held object directly. This - # is because the PyBullet constraints don't seem to play nicely with - # resetJointState (the robot will sometimes drop the object). - if CFG.pybullet_control_mode == "reset" and \ - self._held_obj_id is not None: - world_to_base_link = get_link_state( - self._pybullet_robot.robot_id, - self._pybullet_robot.end_effector_id, - physics_client_id=self._physics_client_id).com_pose - base_link_to_held_obj = p.invertTransform( - *self._held_obj_to_base_link) - world_to_held_obj = p.multiplyTransforms(world_to_base_link[0], - world_to_base_link[1], - base_link_to_held_obj[0], - base_link_to_held_obj[1]) - p.resetBasePositionAndOrientation( - self._held_obj_id, - world_to_held_obj[0], - world_to_held_obj[1], - physicsClientId=self._physics_client_id) - - # Step the simulation here before adding or removing constraints - # because detect_held_object() should use the updated state. - if CFG.pybullet_control_mode != "reset": - for _ in range(CFG.pybullet_sim_steps_per_action): - p.stepSimulation(physicsClientId=self._physics_client_id) + _rx, _ry, _rz, qx, qy, qz, qw, _rf = self._pybullet_robot.get_state() - # If not currently holding something, and fingers are closing, check - # for a new grasp. - if self._held_constraint_id is None and self._fingers_closing(action): - self._held_obj_id = self._detect_held_object() - if self._held_obj_id is not None: - self._create_grasp_constraint() + # Convert the quaternion to a rotation matrix + rotation_matrix = p.getMatrixFromQuaternion([qx, qy, qz, qw]) + rotation_matrix = np.array(rotation_matrix).reshape(3, 3) - # If placing, remove the grasp constraint. - if self._held_constraint_id is not None and \ - self._fingers_opening(action): - p.removeConstraint(self._held_constraint_id, - physicsClientId=self._physics_client_id) - self._held_constraint_id = None - self._held_obj_id = None + # Define the initial normal vectors for the fingers + if CFG.pybullet_robot == "panda": + # gripper rotated 90deg so parallel to x-axis + normal = np.array([1., 0., 0.], dtype=np.float32) + elif CFG.pybullet_robot in {"fetch", "mobile_fetch"}: + # gripper parallel to y-axis + normal = np.array([0., 1., 0.], dtype=np.float32) + else: # pragma: no cover + # Shouldn't happen unless we introduce a new robot. + raise ValueError(f"Unknown robot {CFG.pybullet_robot}") - # Depending on the observation mode, either return object-centric state - # or object_centric + rgb observation - observation = self.get_observation(render=CFG.rgb_observation or\ - render_obs) + # Transform the normal vectors using the rotation matrix + transformed_normal = rotation_matrix.dot(normal) + transformed_normal_neg = rotation_matrix.dot(-1 * normal) - return observation + return { + self._pybullet_robot.left_finger_id: transformed_normal, + self._pybullet_robot.right_finger_id: transformed_normal_neg, + } def _detect_held_object(self) -> Optional[int]: """Return the PyBullet body ID of the grasped object, or None. @@ -916,6 +848,8 @@ def _action_to_finger_delta(self, action: Action) -> float: target = joint_positions[self._pybullet_robot.left_finger_joint_idx] return target - finger_position + # ── Action Helpers ────────────────────────────────────────── + def _split_action(self, action: Action) -> Tuple[np.ndarray, np.ndarray]: """Split an action into (arm_joint_targets, base_delta). @@ -955,6 +889,113 @@ def _apply_base_delta(self, base_delta: np.ndarray) -> None: ) robot.set_base_pose(new_pose) # type: ignore[attr-defined] + # ── Rendering & Observation ───────────────────────────────── + + def _get_camera_matrices(self) -> Tuple[Any, Any, int, int]: + """Return (view_matrix, proj_matrix, width, height) for rendering. + + Called by render() and render_segmented_obj(). + """ + view_matrix = p.computeViewMatrixFromYawPitchRoll( + cameraTargetPosition=self._camera_target, + distance=self._camera_distance, + yaw=self._camera_yaw, + pitch=self._camera_pitch, + roll=0, + upAxisIndex=2, + physicsClientId=self._physics_client_id) + width = CFG.pybullet_camera_width + height = CFG.pybullet_camera_height + proj_matrix = p.computeProjectionMatrixFOV( + fov=self._camera_fov, + aspect=float(width / height), + nearVal=0.1, + farVal=100.0, + physicsClientId=self._physics_client_id) + return view_matrix, proj_matrix, width, height + + def render(self, + action: Optional[Action] = None, + caption: Optional[str] = None) -> Video: # pragma: no cover + # Skip test coverage because GUI is too expensive to use in unit tests + # and cannot be used in headless mode. + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() + (_, _, px, _, + _) = p.getCameraImage(width=width, + height=height, + viewMatrix=view_matrix, + projectionMatrix=proj_matrix, + renderer=p.ER_BULLET_HARDWARE_OPENGL, + physicsClientId=self._physics_client_id) + rgb_array = np.array(px).reshape((height, width, 4)) + rgb_array = rgb_array[:, :, :3] + return [rgb_array] + + def render_segmented_obj( + self, + action: Optional[Action] = None, + caption: Optional[str] = None, + ) -> Tuple[Image.Image, Dict[Object, Mask]]: + """Render the scene and return per-object segmentation masks. + + Called by get_observation(render=True) to attach RGB images and + masks to the observation (used for VLM predicate grounding). + """ + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() + (_, _, rgbImg, _, + segImg) = p.getCameraImage(width=width, + height=height, + viewMatrix=view_matrix, + projectionMatrix=proj_matrix, + renderer=p.ER_BULLET_HARDWARE_OPENGL, + physicsClientId=self._physics_client_id) + original_image: np.ndarray = np.array(rgbImg, dtype=np.uint8).reshape( + (height, width, 4)) + seg_image = np.array(segImg).reshape((height, width)) + state_img = Image.fromarray( # type: ignore[no-untyped-call] + original_image[:, :, :3]) + mask_dict: Dict[Object, Mask] = {} + for obj in self._objects: + mask_dict[obj] = (seg_image == obj.id) + return state_img, mask_dict + + def render_state_plt( + self, + state: State, + task: EnvironmentTask, + action: Optional[Action] = None, + caption: Optional[str] = None) -> matplotlib.figure.Figure: + raise NotImplementedError("This env does not use Matplotlib") + + def render_state(self, + state: State, + task: EnvironmentTask, + action: Optional[Action] = None, + caption: Optional[str] = None) -> Video: + raise NotImplementedError("A PyBullet environment cannot render " + "arbitrary states.") + + def get_observation(self, render: bool = False) -> Observation: + """Get the current observation of this environment. + + Reads the current state from pybullet, updates + _current_observation (the backing field), and returns a copy + optionally with rendered images. + """ + state = self._get_state() + assert isinstance(state, PyBulletState) + self._current_observation = state + obs = state.copy() + + if render: + obs.add_images_and_masks(*self.render_segmented_obj()) + + return obs + + # ── Task Utilities ────────────────────────────────────────── + def _add_pybullet_state_to_tasks( self, tasks: List[EnvironmentTask]) -> List[EnvironmentTask]: """Convert plain-State tasks into PyBulletState tasks. @@ -979,13 +1020,3 @@ def _add_pybullet_state_to_tasks( goal_nl=task.goal_nl) pybullet_tasks.append(pybullet_task) return pybullet_tasks - - @classmethod - def get_robot_ee_home_orn(cls) -> Quaternion: - """Return the default end-effector orientation for this env. - - Used by initialize_pybullet() to set the robot's home pose, - and by oracle options to compute motion-planning targets. - """ - robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] - return robot_ee_orns[CFG.pybullet_robot] From 59aac0140f9a2e362d73e604debcfc86d9776bb4 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 8 Apr 2026 12:12:27 +0100 Subject: [PATCH 13/70] Refactor PyBulletEnv: extract _domain_specific_step from step() Add _step_base() and _domain_specific_step() to PyBulletEnv base class. step() now calls _step_base (robot control, physics, grasp) then _domain_specific_step (water filling, heating, etc.), gated by _skip_domain_specific_dynamics flag for kinematics-only mode. Migrate all 15 domain envs to override _domain_specific_step() instead of step(). Envs with pre-step logic (coffee, switch, blocks, cover) still override step() for the pre-step part only. --- predicators/envs/pybullet_ants.py | 50 ++++------- predicators/envs/pybullet_balance.py | 55 ++++-------- predicators/envs/pybullet_blocks.py | 45 +++++----- predicators/envs/pybullet_boil.py | 50 +++-------- predicators/envs/pybullet_circuit.py | 30 +++---- predicators/envs/pybullet_coffee.py | 26 +++--- predicators/envs/pybullet_cover.py | 14 ++-- .../envs/pybullet_domino/composed_env.py | 23 ++--- predicators/envs/pybullet_env.py | 83 ++++++++++--------- predicators/envs/pybullet_fan.py | 15 +--- predicators/envs/pybullet_float.py | 31 ++----- predicators/envs/pybullet_grow.py | 57 ++++++------- predicators/envs/pybullet_laser.py | 31 +++---- predicators/envs/pybullet_magic_bin.py | 13 +-- predicators/envs/pybullet_switch.py | 24 ++---- 15 files changed, 202 insertions(+), 345 deletions(-) diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index 35d4f82f5..a8ba2f162 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -174,7 +174,7 @@ def initialize_pybullet( food_ids = [] for _ in range(cls.num_food): fid = create_pybullet_block( - color=(0.5, 0.5, 0.5, 1.0), # We’ll override color later + color=(0.5, 0.5, 0.5, 1.0), # We'll override color later half_extents=cls.food_half_extents, mass=cls.food_mass, friction=0.5, @@ -227,31 +227,29 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - - if CFG.ants_ants_attracted_to_points: - self._ant_to_xy = {} # type: ignore[no-redef] - for ant_obj in state.get_objects(self._ant_type): - self._ants_to_xy[ant_obj] = (self._train_rng.uniform( - self.one_third_x, self.two_third_x), - self._train_rng.uniform( - self.y_lb, self.y_ub)) - - # Hide irrelevant objects + """Hide unused objects, set attraction points, food colors, and + ant target references.""" oov_x, oov_y = self._out_of_view_xy block_objs = state.get_objects(self._food_type) for i in range(len(block_objs), len(self._blocks)): - # Hide the remaining blocks update_object(self._blocks[i].id, position=(oov_x, oov_y, self.z_lb), physics_client_id=self._physics_client_id) ant_objs = state.get_objects(self._ant_type) for i in range(len(ant_objs), len(self._ants)): - # Hide the remaining ants update_object(self._ants[i].id, position=(oov_x, oov_y, self.z_lb), physics_client_id=self._physics_client_id) + if CFG.ants_ants_attracted_to_points: + self._ant_to_xy = {} # type: ignore[no-redef] + for ant_obj in state.get_objects(self._ant_type): + self._ants_to_xy[ant_obj] = (self._train_rng.uniform( + self.one_third_x, self.two_third_x), + self._train_rng.uniform( + self.y_lb, self.y_ub)) + for food in state.get_objects(self._food_type): r = state.get(food, "r") g = state.get(food, "g") @@ -262,7 +260,6 @@ def _set_domain_specific_state(self, state: State) -> None: physics_client_id=self._physics_client_id) food.attractive = attractive - # Set ant's attractive food for ant_obj in state.get_objects(self._ant_type): food_id = state.get(ant_obj, "target_food") for food_obj in state.get_objects(self._food_type): @@ -270,25 +267,10 @@ def _set_domain_specific_state(self, state: State) -> None: ant_obj.target_food = food_obj break - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - """Override to (1) do usual robot step, (2) move ants toward attracted - food with noise, and then (3) return the final state.""" - # Step the robot normally - next_state = super().step(action, render_obs=render_obs) - - # Move ants. For each ant, find a target food - # object that is “attractive.” If there’s more - # than one attractive block, pick the one it’s - # “assigned” to, or the first in the list. Then - # move a small step toward it with noise. - self._update_ant_positions(next_state) - - final_state = self._get_state() - self._current_observation = final_state - return final_state + def _domain_specific_step(self) -> None: + """Move ants toward attracted food with noise.""" + state = self._get_state() + self._update_ant_positions(state) def _update_ant_positions(self, state: State) -> None: """For each ant, move it a small step toward its assigned attractive @@ -301,7 +283,7 @@ def _update_ant_positions(self, state: State) -> None: if CFG.ants_ants_attracted_to_points: fx, fy = self._ants_to_xy[ant_obj] else: - # Retrieve this ant’s assigned food + # Retrieve this ant's assigned food target_food_obj = None for food_obj in state.get_objects(self._food_type): if food_obj.id == state.get(ant_obj, "target_food"): diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 76b4e6586..197da0174 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -346,14 +346,9 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - state = super().step(action, render_obs=render_obs) - + def _domain_specific_step(self) -> None: + state = self._get_state() self._update_balance_beam(state) - # Turn machine on if self._PressingButton_holds(state, [self._robot, self._machine]): if self._Balanced_holds(state, [self._plate1, self._plate3]): @@ -361,29 +356,28 @@ def step( # pylint: disable=redefined-outer-name -1, rgbaColor=self._button_color_on, physicsClientId=self._physics_client_id) - self._current_observation = self._get_state() - state = self._current_observation.copy() - - return state def _set_domain_specific_state(self, state: State) -> None: - """Replace the old `_set_state` environment-specific logic. - - The base `_set_state` has already handled standard features - for objects that appear in _get_all_objects(), so here we just - do custom domain-specific tasks: setting plates/blocks if we - aren't letting the base class handle them, updating button - color, and running the beam-balancing update. - """ - # block objs in the state + """Set block placement, balance beam, block colors, ID mapping, and + button color.""" block_objs = state.get_objects(self._block_type) + + # Put unused blocks out of view + h = self._block_size + oov_x, oov_y = self._out_of_view_xy + for i in range(len(block_objs), len(self._blocks)): + p.resetBasePositionAndOrientation( + self._blocks[i].id, [oov_x, oov_y, i * h], + self._default_orn, + physicsClientId=self._physics_client_id) + + self._prev_diff = 0 + self._update_balance_beam(state) + self._block_id_to_block.clear() - # Suppose we want to manually update each block's color or remove them - # if not used. For example: for i, block_obj in enumerate(block_objs): self._block_id_to_block[block_obj.id] = block_obj - # Manually set color if needed: r = state.get(block_obj, "color_r") g = state.get(block_obj, "color_g") b = state.get(block_obj, "color_b") @@ -392,20 +386,7 @@ def _set_domain_specific_state(self, state: State) -> None: rgbaColor=(r, g, b, 1.0), physicsClientId=self._physics_client_id) - # For blocks beyond the number actually in the state, put them out of - # view: - h = self._block_size - oov_x, oov_y = self._out_of_view_xy - for i in range(len(block_objs), len(self._blocks)): - p.resetBasePositionAndOrientation( - self._blocks[i].id, [oov_x, oov_y, i * h], - self._default_orn, - physicsClientId=self._physics_client_id) - - self._prev_diff = 0 # reset difference - self._update_balance_beam(state) - - # Update button color for whether the machine is on + # Update button color if self._MachineOn_holds(state, [self._machine, self._robot]): button_color = self._button_color_on else: diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index b0abf0e24..b3d2d55d6 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -95,17 +95,13 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: blk.id = blk_id def _set_domain_specific_state(self, state: State) -> None: - """After the parent `_set_state()` has reset the robot, set the block - positions/colors and handle constraints for any 'held' block.""" + """Set block positions, grasp constraints, out-of-view placement, + ID mapping, and block colors.""" block_objs = state.get_objects(self._block_type) - self._block_id_to_block.clear() # Place the relevant blocks for i, block_obj in enumerate(block_objs): - block_id = self._block_ids[i] # re-use the i-th block ID - self._block_id_to_block[block_id] = block_obj - - # Position/orientation from the state's block features + block_id = self._block_ids[i] bx = state.get(block_obj, "pose_x") by = state.get(block_obj, "pose_y") bz = state.get(block_obj, "pose_z") @@ -114,19 +110,9 @@ def _set_domain_specific_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - # Update color - r = state.get(block_obj, "color_r") - g = state.get(block_obj, "color_g") - b = state.get(block_obj, "color_b") - p.changeVisualShape(block_id, - linkIndex=-1, - rgbaColor=(r, g, b, 1.0), - physicsClientId=self._physics_client_id) - # If there is a held block, create the constraint held_block = self._get_held_block(state) if held_block is not None: - # Force grasp the relevant block self._force_grasp_object(held_block) # Teleport any leftover blocks out of view @@ -139,6 +125,19 @@ def _set_domain_specific_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) + self._block_id_to_block.clear() + + for i, block_obj in enumerate(block_objs): + block_id = self._block_ids[i] + self._block_id_to_block[block_id] = block_obj + r = state.get(block_obj, "color_r") + g = state.get(block_obj, "color_g") + b = state.get(block_obj, "color_b") + p.changeVisualShape(block_id, + linkIndex=-1, + rgbaColor=(r, g, b, 1.0), + physicsClientId=self._physics_client_id) + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Called by the parent class when constructing the `PyBulletState`. @@ -202,17 +201,13 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: f"{feature}") def step(self, action: Action, render_obs: bool = False) -> State: - self._prev_held_obj_id = self._held_obj_id - # Otherwise, proceed with normal PyBullet step - next_state = super().step(action, render_obs=render_obs) + return super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: if CFG.blocks_high_towers_are_unstable: - self._apply_force_to_high_towers(next_state) - next_state = self._get_state() - self._current_observation = next_state - - return next_state + state = self._get_state() + self._apply_force_to_high_towers(state) def _extract_robot_state(self, state: State) -> np.ndarray: """As needed, parse from the robot's `pose_x`, `pose_y`, `pose_z`, diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index c1485c53e..9957013c5 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -213,10 +213,6 @@ def __init__(self, use_gui: bool = False) -> None: # Keep track of the spilled water block (None if no spill yet) self._spilled_water_id: Optional[int] = None - # When True, step() skips process dynamics (water filling, heating, - # happiness) so that a learned simulator can provide them instead. - self._skip_process_dynamics: bool = False - super().__init__(use_gui) # Optionally, define some relevant predicates @@ -571,7 +567,7 @@ def _set_domain_specific_state(self, state: State) -> None: for i, burner_obj in enumerate(burners): on_val = state.get(burner_obj, "is_on") burner_obj.switch_id = self._burner_switches[i].id - burner_obj.prev_on = 0.0 # Initialize prev_on to 0 + burner_obj.prev_on = 0.0 self._set_switch_on(self._burner_switches[i].id, bool(on_val > 0.5)) @@ -601,7 +597,7 @@ def _set_domain_specific_state(self, state: State) -> None: # Faucet on/off self._faucet.switch_id = self._faucet_switch.id - self._faucet.prev_on = 0.0 # Initialize prev_on to 0 + self._faucet.prev_on = 0.0 f_on = state.get(self._faucet, "is_on") self._set_switch_on(self._faucet_switch.id, bool(f_on > 0.5)) @@ -616,7 +612,6 @@ def _set_domain_specific_state(self, state: State) -> None: self._faucet._spilled_level = -self.water_fill_speed * 20 spilled_level = max(0.0, self._faucet._spilled_level) # pylint: enable=protected-access - # If there's already some spillage in the state, recreate a block if spilled_level > 0.0: self._spilled_water_id = self._create_spilled_water_block( spilled_level, state) @@ -628,17 +623,14 @@ def _set_domain_specific_state(self, state: State) -> None: # Move irrelevant jugs and burners out of the way oov_x, oov_y = self._out_of_view_xy - jugs = state.get_objects(self._jug_type) for i in range(len(jugs), len(self._jugs)): update_object(self._jugs[i].id, position=(oov_x, oov_y, 0.0), physics_client_id=self._physics_client_id) - burners = state.get_objects(self._burner_type) for i in range(len(burners), len(self._burners)): update_object(self._burners[i].id, position=(oov_x, oov_y, 0.0), physics_client_id=self._physics_client_id) - # Also move the corresponding switch update_object(self._burner_switches[i].id, position=(oov_x, oov_y, self.switch_height), physics_client_id=self._physics_client_id) @@ -649,35 +641,15 @@ def _set_domain_specific_state(self, state: State) -> None: # ------------------------------------------------------------------------- # Step Logic # ------------------------------------------------------------------------- - def step(self, action: Action, render_obs: bool = False) -> State: - """Execute a low-level action (robot controls), then handle water - filling/spillage and heating.""" - # First let the base environment perform the usual PyBullet step - next_state = super().step(action, render_obs=False) - - if not self._skip_process_dynamics: - # 1) Handle faucet filling/spillage - self._handle_faucet_logic(next_state) - - # 2) Handle burner heating - self._handle_heating_logic(next_state) - - # 3) Update jug colors based on their 'heat' - self._update_jug_colors(next_state) - - # 4) Update burner colors based on their on/off state - self._update_burner_colors(next_state) - - # 5) Update the human's happiness level - self._update_human_happiness(next_state) - - # 6) Update prev_on states for next step - self._update_prev_on_states(next_state) - - # Re-read final state - final_state = self.get_observation(render=render_obs) - self._current_observation = final_state - return final_state + def _domain_specific_step(self) -> None: + """Handle water filling/spillage, heating, and happiness.""" + state = self._get_state() + self._handle_faucet_logic(state) + self._handle_heating_logic(state) + self._update_jug_colors(state) + self._update_burner_colors(state) + self._update_human_happiness(state) + self._update_prev_on_states(state) def _handle_faucet_logic(self, state: State) -> None: """If faucet is on, fill any jug that is properly aligned; otherwise, diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index 35c3dd695..e43e594eb 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -306,6 +306,9 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: + """Set switch position and bulb on/off state.""" + is_switch_on = state.get(self._battery, "is_on") + self._set_switch_on(self._battery, is_switch_on) is_light_on = state.get(self._light, "is_on") if is_light_on: @@ -313,29 +316,23 @@ def _set_domain_specific_state(self, state: State) -> None: else: self._turn_bulb_off() - is_switch_on = state.get(self._battery, "is_on") - self._set_switch_on(self._battery, is_switch_on) - - def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step. - - If the battery is connected to the light, turn the bulb on. - """ - next_state = super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: + """If the battery is connected to the light, turn the bulb on.""" + state = self._get_state() # Check basic conditions for turning on the bulb - switch_on = self._SwitchedOn_holds(next_state, [self._battery]) + switch_on = self._SwitchedOn_holds(state, [self._battery]) basic_conditions = switch_on and ( CFG.circuit_light_doesnt_need_battery or self._CircuitClosed_holds( - next_state, [self._light, self._battery])) + state, [self._light, self._battery])) # Additional condition: if not using battery_in_box mode, # both C batteries must be in the battery box if not CFG.circuit_battery_in_box and self._c_battery1 is not None \ and self._c_battery2 is not None: both_batteries_in_box = ( - self._InBatteryBox_holds(next_state, [self._c_battery1]) - and self._InBatteryBox_holds(next_state, [self._c_battery2])) + self._InBatteryBox_holds(state, [self._c_battery1]) + and self._InBatteryBox_holds(state, [self._c_battery2])) can_turn_on = basic_conditions and both_batteries_in_box else: can_turn_on = basic_conditions @@ -345,13 +342,8 @@ def step(self, action: Action, render_obs: bool = False) -> State: else: self._turn_bulb_off() - final_state = self._get_state() - # Draw debug lines to visualize battery box region - self._draw_battery_box_debug_lines(final_state) - - self._current_observation = final_state - return final_state + self._draw_battery_box_debug_lines(state) # ------------------------------------------------------------------------- # Predicates diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index a447996bb..5f5474f05 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -396,18 +396,13 @@ def _remake_cord(self) -> None: self._plug.id = self._cord_ids[-1] def _set_domain_specific_state(self, state: State) -> None: - """Coffee-specific state setup: rebuild task-specific objects - (cups, liquids, cords), then set visual state (button color, - liquid fills, etc.). - """ - # Rebuild objects that vary per task + """Reset liquid visuals, cup geometry, cord, and button colors.""" self._remake_jug_liquid(state) self._remake_cup_liquids(state) self._remake_cups(state) self._remake_cord() # Machine button color - # Check if the machine is on and the jug is in place: if self._MachineOn_holds(state, [self._machine]) and \ self._JugInMachine_holds(state, [self._jug, self._machine]): button_color = self.button_color_on @@ -475,21 +470,20 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def step(self, action: Action, render_obs: bool = False) -> State: - # Save current end-effector roll-pitch-yaw for later comparison - current_ee_rpy = self._pybullet_robot.forward_kinematics( + # Save pre-kinematics state for _domain_specific_step. + self._pre_step_ee_rpy = self._pybullet_robot.forward_kinematics( self._pybullet_robot.get_joints()).rpy - state = super().step(action, render_obs=render_obs) - # self._update_jug_liquid_position() + self._last_action = action + return super().step(action, render_obs=render_obs) + + def _domain_specific_step(self) -> None: + state = self._get_state() if CFG.coffee_machine_has_plug: self._check_and_apply_plug_in_constraint(state) self._handle_machine_on_and_jug_filling(state) self._handle_pouring(state) - self._handle_twisting(state, current_ee_rpy, action) - # Refresh current observation - self._current_observation = self._get_state(_render_obs=False) - state = self._current_observation.copy() - - return state + self._handle_twisting(state, self._pre_step_ee_rpy, + self._last_action) def _update_jug_liquid_position(self) -> None: """If the jug is filled, move its liquid to match the jug's pose. diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index 32f680bcf..31dbdd715 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -370,19 +370,15 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: # Step logic (unchanged except for removing direct calls to _get_state()) # ----------------------------------------------------------------------- def step(self, action: Action, render_obs: bool = False) -> State: - """Override to handle the Cover domain's 'hand region' constraint - before calling the parent's step().""" - # Check if the pick/place position satisfies the hand constraints + """Check hand region constraint before kinematics.""" if not self._satisfies_hand_contraints(action): - # Constraint violated => no-op return self._current_state.copy() + return super().step(action, render_obs=render_obs) - # Otherwise, proceed with normal PyBullet step - next_state = super().step(action, render_obs=render_obs) - + def _domain_specific_step(self) -> None: if CFG.cover_blocks_change_color_when_cover: - self._change_block_color_when_cover(next_state) - return next_state + state = self._get_state() + self._change_block_color_when_cover(state) def _change_block_color_when_cover(self, state: State) -> None: """If a block is now covering a target, change it's color to diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index 46620b3d0..4e82718e9 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -288,31 +288,22 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - """Reset environment to match the given state.""" - # Update ball component's state reference for is_hit feature - if self._ball_component is not None: - self._ball_component.set_current_state(state) - - # Reset each component + """Reset each component and update ball state reference.""" for comp in self._components: comp.reset_state(state) - def step(self, action: Action, render_obs: bool = False) -> State: - """Execute action and run component physics updates.""" - super().step(action, render_obs=render_obs) + if self._ball_component is not None: + self._ball_component.set_current_state(state) - # Run component step functions (e.g., fan wind simulation) + def _domain_specific_step(self) -> None: + """Run component physics updates (e.g., fan wind simulation).""" for comp in self._components: comp.step() - final_state = self._get_state() - self._current_observation = final_state - # Update ball component's state reference if self._ball_component is not None: - self._ball_component.set_current_state(final_state) - - return final_state + state = self._get_state() + self._ball_component.set_current_state(state) # ========================================================================= # PREDICATE HOLD FUNCTIONS diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 56e8d887e..25053ba2d 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -132,6 +132,10 @@ def __init__(self, use_gui: bool = False) -> None: self._held_obj_to_base_link: Optional[Any] = None self._held_obj_id: Optional[int] = None + # When True, _domain_specific_step() is skipped in step(). + # Used by sim-learning to create kinematics-only envs. + self._skip_domain_specific_dynamics: bool = False + # Set up all the static PyBullet content. self._physics_client_id, self._pybullet_robot, pybullet_bodies = \ self.initialize_pybullet(self.using_gui) @@ -224,7 +228,10 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: @classmethod def _create_pybullet_robot( cls, physics_client_id: int) -> SingleArmPyBulletRobot: - """Instantiate the robot model. Called by initialize_pybullet().""" + """Instantiate the robot model. + + Called by initialize_pybullet(). + """ robot_ee_orn = cls.get_robot_ee_home_orn() ee_home = Pose((cls.robot_init_x, cls.robot_init_y, cls.robot_init_z), robot_ee_orn) @@ -242,8 +249,8 @@ def _create_pybullet_robot( def get_robot_ee_home_orn(cls) -> Quaternion: """Return the default end-effector orientation for this env. - Used by initialize_pybullet() to set the robot's home pose, - and by oracle options to compute motion-planning targets. + Used by initialize_pybullet() to set the robot's home pose, and + by oracle options to compute motion-planning targets. """ robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] return robot_ee_orns[CFG.pybullet_robot] @@ -309,24 +316,20 @@ def simulate(self, state: State, action: Action) -> State: def step(self, action: Action, render_obs: bool = False) -> Observation: """Execute one environment step with the given action. - This method handles: - 1. Robot joint control by converting action to target positions - 2. Management of held objects and grasping constraints - 3. Physics simulation stepping - 4. Object grasp detection and constraint creation/removal - 5. `self._current_observation` update - - Args: - action (Action): The action to execute, containing target joint - positions - render_obs (bool, optional): Whether to include RGB observation. - Defaults to False. - - Returns: - Observation: Updated environment observation after executing the - action. May include an image if render_obs=True or - CFG.rgb_observation=True. + Flow: kinematics → domain-specific dynamics → observation. + Subclasses override ``_domain_specific_step`` (not this method) + to add post-kinematics dynamics (water filling, heating, etc.). """ + self._step_base(action) + if not self._skip_domain_specific_dynamics: + self._domain_specific_step() + observation = self.get_observation( + render=CFG.rgb_observation or render_obs) + self._current_observation = observation + return observation + + def _step_base(self, action: Action) -> None: + """Run robot control, physics stepping, and grasp management.""" # Send the action to the robot. target_joint_positions, base_delta = self._split_action(action) if base_delta.size: @@ -376,12 +379,13 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: self._held_constraint_id = None self._held_obj_id = None - # Depending on the observation mode, either return object-centric state - # or object_centric + rgb observation - observation = self.get_observation(render=CFG.rgb_observation or\ - render_obs) + def _domain_specific_step(self) -> None: + """Apply domain-specific dynamics after kinematics. - return observation + Override in subclasses to add post-kinematics effects + (water filling, heating, balance beam physics, etc.). + Skipped when ``_skip_domain_specific_dynamics`` is True. + """ # ── State Write (State → PyBullet) ────────────────────────── @@ -431,8 +435,8 @@ def _set_state(self, state: State) -> None: logging.warning("Could not reconstruct state exactly in reset.") def _reset_single_object(self, obj: Object, state: State) -> None: - """Set a single physical object's pose and grasp constraint in - PyBullet to match the given State. + """Set a single physical object's pose and grasp constraint in PyBullet + to match the given State. Called by _set_state() for every non-robot, non-virtual object. """ @@ -475,9 +479,9 @@ def _reset_single_object(self, obj: Object, state: State) -> None: @abc.abstractmethod def _set_domain_specific_state(self, state: State) -> None: - """Set simulator state for features that the base class doesn't - handle — e.g. switch on/off, liquid levels, button colors, - balance beam positions. + """Set simulator state for features that the base class doesn't handle + — e.g. switch on/off, liquid levels, button colors, balance beam + positions. Called at the end of _set_state(), after the base class has already set robot joints, object poses, and grasp constraints. @@ -678,7 +682,8 @@ def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, finger_joint: float) -> float: """Inverse of _fingers_state_to_joint(). - Called by _get_robot_state_dict() when reading PyBullet -> State. + Called by _get_robot_state_dict() when reading PyBullet -> + State. """ subs = { pybullet_robot.open_fingers: cls.open_fingers, @@ -737,8 +742,8 @@ def _detect_held_object(self) -> Optional[int]: Called by step() when fingers are closing and no object is currently held. Checks contact between each finger and every graspable body (from _get_object_ids_for_held_check()), using - contact-normal alignment to reject touches on the outside of - the gripper. If multiple objects qualify, returns the closest. + contact-normal alignment to reject touches on the outside of the + gripper. If multiple objects qualify, returns the closest. """ expected_finger_normals = self._get_expected_finger_normals() closest_held_obj = None @@ -782,11 +787,11 @@ def _detect_held_object(self) -> Optional[int]: return closest_held_obj def _create_grasp_constraint(self) -> None: - """Create a fixed PyBullet constraint between the end-effector - and _held_obj_id so the object moves with the gripper. + """Create a fixed PyBullet constraint between the end-effector and + _held_obj_id so the object moves with the gripper. - Called by step() after _detect_held_object() finds a grasp, - and by _reset_single_object() when restoring a held state. + Called by step() after _detect_held_object() finds a grasp, and + by _reset_single_object() when restoring a held state. """ assert self._held_obj_id is not None base_link_to_world = np.r_[p.invertTransform( @@ -830,8 +835,8 @@ def _fingers_opening(self, action: Action) -> bool: def _get_finger_position(self, state: State) -> float: """Return the current left-finger joint position from state. - Called by _action_to_finger_delta() to compute the delta - between current and target finger positions. + Called by _action_to_finger_delta() to compute the delta between + current and target finger positions. """ state = cast(utils.PyBulletState, state) finger_joint_idx = self._pybullet_robot.left_finger_joint_idx diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index 4059c9122..bc6f41fdc 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -872,18 +872,12 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: # ------------------------------------------------------------------------- # Step # ------------------------------------------------------------------------- - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - """Execute a low-level action, then spin fans & blow the ball.""" - super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: + """Spin fans & blow the ball.""" self._simulate_fans() - final_state = self._get_state() - self._current_observation = final_state + state = self._get_state() # Draw a debug line at the ball's position - bx, by = final_state.get(self._ball, - "x"), final_state.get(self._ball, "y") + bx, by = state.get(self._ball, "x"), state.get(self._ball, "y") p.addUserDebugLine( [bx, by, self.table_height], [bx, by, self.table_height + self.debug_line_height], @@ -891,7 +885,6 @@ def step( # pylint: disable=redefined-outer-name lifeTime=self. debug_line_lifetime, # short lifetime so each step refreshes physicsClientId=self._physics_client_id) - return final_state # ------------------------------------------------------------------------- # Fan Simulation diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index 4b95df4f6..907b78339 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -253,9 +253,10 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - - # Initialize water level + """Set water height and redraw water bodies, block colors, and + displacement tracking.""" self._current_water_height = state.get(self._vessel, "water_height") + # Clear old water for wid in self._water_ids.values(): if wid is not None: @@ -264,17 +265,9 @@ def _set_domain_specific_state(self, state: State) -> None: # Reset blocks for blk in self._blocks: - # Set block's color based on is_light - # update_object(blk.id, - # color=PyBulletFloatEnv.block_color_light \ - # if state.get(blk, "is_light") > 0.5 - # else PyBulletFloatEnv.block_color_heavy, - # physics_client_id=self._physics_client_id) - # Set block's color randomly update_object(blk.id, color=self._train_rng.choice(self._obj_colors), physics_client_id=self._physics_client_id) - # Re-initialize displacing to False self._block_is_displacing[blk] = False # Re-draw water @@ -290,21 +283,13 @@ def _set_domain_specific_state(self, state: State) -> None: color=[0.5, 0.5, 1, 0.5], physics_client_id=self._physics_client_id) - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - next_state = super().step(action, render_obs=render_obs) - # Check if blocks entering/exiting water changed its level - changed = self._update_water_level_if_needed(next_state) + def _domain_specific_step(self) -> None: + """Update water level and float light blocks.""" + state = self._get_state() + changed = self._update_water_level_if_needed(state) if changed: self._create_or_update_water(force_redraw=True) - # Keep light blocks floating on water surface - self._float_light_blocks(next_state) - - final_state = self._get_state() - self._current_observation = final_state - return final_state + self._float_light_blocks(state) def _float_light_blocks(self, state: State) -> None: """Force each light, unheld block in a container compartment to float diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index 395e10428..205f5d6de 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -284,8 +284,28 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - """Called in _set_state to handle any custom resetting.""" - # Remove existing "liquid bodies" + """Set out-of-view positioning, jug init positions, liquid bodies, and + cup/jug colors.""" + cups = state.get_objects(self._cup_type) + jugs = state.get_objects(self._jug_type) + + # Store jug initial positions + for jug in jugs: + jug.init_x = state.get(jug, "x") + jug.init_y = state.get(jug, "y") + jug.init_z = state.get(jug, "z") + + oov_x, oov_y = self._out_of_view_xy + for i in range(len(cups), len(self._cups)): + update_object(self._cups[i].id, + position=(oov_x, oov_y, 0.0), + physics_client_id=self._physics_client_id) + for i in range(len(jugs), len(self._jugs)): + update_object(self._jugs[i].id, + position=(oov_x, oov_y, 0.0), + physics_client_id=self._physics_client_id) + + # Remove existing liquid bodies for liquid_id in self._cup_to_liquid_id.values(): if liquid_id is not None: p.removeBody(liquid_id, @@ -293,13 +313,11 @@ def _set_domain_specific_state(self, state: State) -> None: self._cup_to_liquid_id.clear() # Recreate the liquid bodies as needed - cups = state.get_objects(self._cup_type) for cup in cups: liquid_id = self._create_pybullet_liquid_for_cup(cup, state) self._cup_to_liquid_id[cup] = liquid_id - # Also update the PyBullet color on each cup/jug to match the (r,g,b) in - # the state + # Update colors for cup in cups: if cup.id is not None: r = state.get(cup, "r") @@ -308,7 +326,6 @@ def _set_domain_specific_state(self, state: State) -> None: update_object(cup.id, color=(r, g, b, 1.0), physics_client_id=self._physics_client_id) - jugs = state.get_objects(self._jug_type) for jug in jugs: if jug.id is not None: r = state.get(jug, "r") @@ -317,34 +334,14 @@ def _set_domain_specific_state(self, state: State) -> None: update_object(jug.id, color=(r, g, b, 1.0), physics_client_id=self._physics_client_id) - # set the sim_feature position to the initial position - jug.init_x = state.get(jug, "x") - jug.init_y = state.get(jug, "y") - jug.init_z = state.get(jug, "z") - - oov_x, oov_y = self._out_of_view_xy - for i in range(len(cups), len(self._cups)): - update_object(self._cups[i].id, - position=(oov_x, oov_y, 0.0), - physics_client_id=self._physics_client_id) - for i in range(len(jugs), len(self._jugs)): - update_object(self._jugs[i].id, - position=(oov_x, oov_y, 0.0), - physics_client_id=self._physics_client_id) # ------------------------------------------------------------------------- # Pouring logic - def step(self, action: Action, render_obs: bool = False) -> State: - """Let parent handle the robot stepping, then apply custom pouring - logic.""" - next_state = super().step(action, render_obs=render_obs) - - self._handle_pouring(next_state) - - final_state = self._get_state() - self._current_observation = final_state.copy() - return final_state + def _domain_specific_step(self) -> None: + """Apply custom pouring logic.""" + state = self._get_state() + self._handle_pouring(state) def _handle_pouring(self, state: State) -> None: if self._held_obj_id is None: diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index 9b4e58c09..9815a1c56 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -300,17 +300,10 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: + """Set target/mirror positioning, station switch, and remove old laser + beams.""" oov_x, oov_y = self._out_of_view_xy - lasers_copy = _laser_ids.copy() - for beam_id, creation_time, client_id in lasers_copy: - p.removeBody(beam_id, physicsClientId=client_id) - # Remove the beam from the list - _laser_ids.remove((beam_id, creation_time, client_id)) - logging.debug(f"[reset] removing beam_id: {beam_id} " - f"in sim{client_id}, remaining beams " - f"{[bid for bid, _, _ in _laser_ids]}") - # Move targets out of view if needed target_objs = state.get_objects(self._target_type) for i in range(len(target_objs), len(self._targets)): @@ -341,27 +334,29 @@ def _set_domain_specific_state(self, state: State) -> None: switch_on = state.get(self._station, "is_on") > 0.5 self._set_station_powered_on(switch_on) + lasers_copy = _laser_ids.copy() + for beam_id, creation_time, client_id in lasers_copy: + p.removeBody(beam_id, physicsClientId=client_id) + _laser_ids.remove((beam_id, creation_time, client_id)) + logging.debug(f"[reset] removing beam_id: {beam_id} " + f"in sim{client_id}, remaining beams " + f"{[bid for bid, _, _ in _laser_ids]}") + # ------------------------------------------------------------------------- # Step # ------------------------------------------------------------------------- - def step(self, action: Action, render_obs: bool = False) -> State: - next_state = super().step(action, render_obs=render_obs) - - # After any motion, we simulate the laser - self._simulate_laser(next_state) + def _domain_specific_step(self) -> None: + state = self._get_state() + self._simulate_laser(state) lasers_copy = _laser_ids.copy() for beam_id, creation_time, client_id in lasers_copy: if time.time() - creation_time > self._laser_life_time: p.removeBody(beam_id, physicsClientId=client_id) - # Remove the beam from the list _laser_ids.remove((beam_id, creation_time, client_id)) logging.debug(f"[step] removing beam_id: {beam_id} " f"in sim{client_id}, remaining beams " f"{[bid for bid, _, _ in _laser_ids]}") - final_state = self._get_state() - self._current_observation = final_state - return final_state # ------------------------------------------------------------------------- # Laser Simulation diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index 2c6d8bfd6..6bae8a02b 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -265,12 +265,8 @@ def _set_domain_specific_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step.""" - # Execute the action - super().step(action, render_obs=render_obs) - - # Check magic bin logic: if switch is on and block is in bin, vanish it + def _domain_specific_step(self) -> None: + """If switch is on and block is in bin, vanish it.""" if self._is_switch_on(): bin_pos, _ = p.getBasePositionAndOrientation( self._bin.id, physicsClientId=self._physics_client_id) @@ -301,11 +297,6 @@ def step(self, action: Action, render_obs: bool = False) -> State: self._default_orn, physicsClientId=self._physics_client_id) - # Get updated state - final_state = self._get_state() - self._current_observation = final_state - return final_state - # ------------------------------------------------------------------------- # Switch helpers def _is_switch_on(self) -> bool: diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index ed4bb858b..e2f1be09e 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -237,38 +237,31 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - """Reset environment state from a State object.""" - # Set power switch state + """Set switch positions, tracking vars, color count, and light visual.""" power_on = state.get(self._power_switch, "is_on") > 0.5 self._set_switch_state(self._power_switch, power_on) - # Set color switch state color_switch_on = state.get(self._color_switch, "is_on") > 0.5 self._set_switch_state(self._color_switch, color_switch_on) - # Track previous color switch state for edge detection self._prev_color_switch_on = color_switch_on - # Initialize color_count from light's color_index color_index = int(state.get(self._light, "color_index")) self._color_switch.color_count = color_index - # Update light visual self._update_light_visual(power_on, color_index) def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step.""" - # Get current color_count from sim_feature - prev_color_count = self._color_switch.color_count - - # Execute the action - super().step(action, render_obs=render_obs) + """Save pre-step color count before kinematics.""" + self._pre_step_color_count = self._color_switch.color_count + return super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: # Detect color switch toggle (OFF -> ON transition) curr_color_switch_on = self._is_switch_on(self._color_switch) if not self._prev_color_switch_on and curr_color_switch_on: # Rising edge detected - increment color count - self._color_switch.color_count = prev_color_count + 1 + self._color_switch.color_count = self._pre_step_color_count + 1 self._prev_color_switch_on = curr_color_switch_on @@ -282,11 +275,6 @@ def step(self, action: Action, render_obs: bool = False) -> State: # Update light visual self._update_light_visual(power_on, color_index) - # Get updated state with correct light values - final_state = self._get_state() - self._current_observation = final_state - return final_state - # ------------------------------------------------------------------------- # Switch helpers def _is_switch_on(self, switch_obj: Object) -> bool: From f86c0ea5235e36d5430e41436cb7896a1188064b Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 8 Apr 2026 12:13:51 +0100 Subject: [PATCH 14/70] Update PyBulletEnv module docstring for step() refactoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Document the step_base → domain_specific_step → get_observation flow, _skip_domain_specific_dynamics flag, and _domain_specific_step as an optional override. --- predicators/envs/pybullet_env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 25053ba2d..cf5d46b2e 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -10,7 +10,10 @@ Main public API: reset(train_or_test, task_idx) — reset env to a task, returns observation simulate(state, action) — forward-simulate without touching real env - step(action) — execute action, manage grasps, return observation + step(action) — _step_base (robot control, physics, grasps) + → _domain_specific_step (water filling, heating, etc.) + → get_observation. Domain dynamics are skipped when + _skip_domain_specific_dynamics is True (kinematics-only mode). get_observation() — read PyBullet state, optionally attach images/masks State synchronization: @@ -27,6 +30,7 @@ - _get_object_ids_for_held_check() -> List[int] - _set_domain_specific_state(state) - _get_domain_specific_feature(obj, feature) -> float + - _domain_specific_step() (optional, default no-op) """ import abc From 9cddb03497e4baa615938951b2ef1eb36823316b Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 8 Apr 2026 20:15:58 +0100 Subject: [PATCH 15/70] Add skip_process_dynamics constructor param to PyBulletEnv Replace direct access to private _skip_domain_specific_dynamics attribute with a public constructor parameter, so callers declare kinematics-only mode at creation time instead of mutating internal state after construction. --- predicators/envs/__init__.py | 7 ++++-- predicators/envs/pybullet_ants.py | 5 +++-- predicators/envs/pybullet_balance.py | 4 ++-- predicators/envs/pybullet_barrier.py | 4 ++-- predicators/envs/pybullet_blocks.py | 4 ++-- predicators/envs/pybullet_boil.py | 4 ++-- predicators/envs/pybullet_circuit.py | 4 ++-- predicators/envs/pybullet_coffee.py | 4 ++-- predicators/envs/pybullet_cover.py | 4 ++-- .../envs/pybullet_domino/composed_env.py | 22 ++++++++++--------- predicators/envs/pybullet_env.py | 10 +++++---- predicators/envs/pybullet_fan.py | 4 ++-- predicators/envs/pybullet_float.py | 4 ++-- predicators/envs/pybullet_grow.py | 4 ++-- predicators/envs/pybullet_laser.py | 4 ++-- predicators/envs/pybullet_magic_bin.py | 4 ++-- predicators/envs/pybullet_switch.py | 4 ++-- 17 files changed, 52 insertions(+), 44 deletions(-) diff --git a/predicators/envs/__init__.py b/predicators/envs/__init__.py index 66a497845..a986a0628 100644 --- a/predicators/envs/__init__.py +++ b/predicators/envs/__init__.py @@ -2,6 +2,8 @@ import logging +from typing import Any + from predicators import utils from predicators.envs.base_env import BaseEnv @@ -14,7 +16,8 @@ def create_new_env(name: str, do_cache: bool = True, - use_gui: bool = False) -> BaseEnv: + use_gui: bool = False, + **kwargs: Any) -> BaseEnv: """Create a new instance of an environment from its name. If do_cache is True, then cache this env instance so that it can @@ -22,7 +25,7 @@ def create_new_env(name: str, """ for cls in utils.get_all_subclasses(BaseEnv): if not cls.__abstractmethods__ and cls.get_name() == name: - env = cls(use_gui) + env = cls(use_gui, **kwargs) break else: raise NotImplementedError(f"Unknown env: {name}") diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index a8ba2f162..9d68ec92a 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -91,7 +91,8 @@ class PyBulletAntsEnv(PyBulletEnv): def __init__(self, use_gui: bool = False, - debug_layout: bool = True) -> None: + debug_layout: bool = True, + **kwargs) -> None: # Create single robot self._robot = Object("robot", self._robot_type) @@ -113,7 +114,7 @@ def __init__(self, if CFG.ants_ants_attracted_to_points: self._ants_to_xy: Dict[Object, Tuple[float, float]] = {} - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) self._debug_layout = debug_layout # Define predicates if needed (some are placeholders) diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 197da0174..07b1aad06 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -88,7 +88,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _num_blocks_train = CFG.balance_num_blocks_train _num_blocks_test = CFG.balance_num_blocks_test - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Types # bbox_features = ["bbox_left", "bbox_right", # "bbox_upper", "bbox_lower"] @@ -116,7 +116,7 @@ def __init__(self, use_gui: bool = False) -> None: self._prev_diff = 0 - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._DirectlyOn = Predicate( diff --git a/predicators/envs/pybullet_barrier.py b/predicators/envs/pybullet_barrier.py index 8041c6dd7..9a64714e5 100644 --- a/predicators/envs/pybullet_barrier.py +++ b/predicators/envs/pybullet_barrier.py @@ -91,7 +91,7 @@ class PyBulletBarrierEnv(PyBulletEnv): _barrier_type = Type("barrier", ["x", "y", "rot", "height"], sim_features=["id", "base_z"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Objects self._robot = Object("robot", self._robot_type) self._switches: List[Object] = [ @@ -103,7 +103,7 @@ def __init__(self, use_gui: bool = False) -> None: for i in range(self.num_barriers) ] - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._SwitchOn = Predicate("SwitchOn", [self._switch_type], diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index b3d2d55d6..d3ebfb1bb 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -27,8 +27,8 @@ class PyBulletBlocksEnv(PyBulletEnv, BlocksEnv): _table_pose: ClassVar[Pose3D] = (1.35, 0.75, table_height / 2) _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) - def __init__(self, use_gui: bool = False) -> None: - super().__init__(use_gui) + def __init__(self, use_gui: bool = False, **kwargs) -> None: + super().__init__(use_gui, **kwargs) # Store references self._table_id: int = -1 # self._block_ids: List[int] = [] diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 9957013c5..3bbf2a2b9 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -174,7 +174,7 @@ def water_fill_speed(self) -> float: _human_type = Type("human", ["happiness_level"], sim_features=["id", "happiness_level"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Create the robot as an Object self._robot = Object("robot", self._robot_type) @@ -213,7 +213,7 @@ def __init__(self, use_gui: bool = False) -> None: # Keep track of the spilled water block (None if no spill yet) self._spilled_water_id: Optional[int] = None - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Optionally, define some relevant predicates self._JugFilled = Predicate("JugFilled", [self._jug_type], diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index e43e594eb..e1fec79bb 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -104,7 +104,7 @@ class PyBulletCircuitEnv(PyBulletEnv): _c_battery_type = Type("c_battery", ["x", "y", "z", "yaw", "pitch", "roll"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Objects self._robot = Object("robot", self._robot_type) @@ -120,7 +120,7 @@ def __init__(self, use_gui: bool = False) -> None: self._c_battery1 = Object("c_battery1", self._c_battery_type) self._c_battery2 = Object("c_battery2", self._c_battery_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._Holding = Predicate("Holding", diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 5f5474f05..4d5c221f0 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -217,7 +217,7 @@ def pour_z_offset(cls) -> float: _camera_pitch: ClassVar[float] _camera_target: ClassVar[Pose3D] - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: if CFG.coffee_render_grid_world: # Camera parameters for grid world PyBulletCoffeeEnv._camera_distance = 3 @@ -238,7 +238,7 @@ def __init__(self, use_gui: bool = False) -> None: # PyBulletCoffeeEnv._camera_pitch = 0 # even lower PyBulletCoffeeEnv._camera_target = (0.75, 1.25, 0.42) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Create the cups lazily because they can change size and color. # self._cup_id_to_cup: Dict[int, Object] = {} diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index 31dbdd715..ec6e63501 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -59,8 +59,8 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): float]]] = [(0, 0, 0, 1.), (1, 1, 1, 1.)] - def __init__(self, use_gui: bool = False) -> None: - super().__init__(use_gui) + def __init__(self, use_gui: bool = False, **kwargs) -> None: + super().__init__(use_gui, **kwargs) # Store block/target IDs (from initialize_pybullet) so that we can # reset their positions in _set_domain_specific_state(). self._table_id: int = -1 diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index 4e82718e9..04f0de983 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -102,7 +102,8 @@ class PyBulletDominoComposedEnv(PyBulletEnv): def __init__(self, components: List[DominoEnvComponent], - use_gui: bool = False) -> None: + use_gui: bool = False, + **kwargs: Any) -> None: """Initialize the composed domino environment. Args: @@ -134,7 +135,7 @@ def __init__(self, # Wire up fan -> ball wind connection if both present # (done after PyBullet init in _store_pybullet_bodies) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) def _create_robot_predicates(self) -> None: """Create robot-specific predicates.""" @@ -404,7 +405,7 @@ def _make_tasks(self, class PyBulletDominoEnvNew(PyBulletDominoComposedEnv): """Backward-compatible domino environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -426,7 +427,8 @@ def __init__(self, use_gui: bool = False) -> None: num_pivots_max=max_pivots, workspace_bounds=workspace_bounds) - super().__init__(components=[domino_comp], use_gui=use_gui) + super().__init__(components=[domino_comp], use_gui=use_gui, + **kwargs) @classmethod def get_name(cls) -> str: @@ -436,7 +438,7 @@ def get_name(cls) -> str: class PyBulletDominoFanEnvNew(PyBulletDominoComposedEnv): """Backward-compatible domino + fan + ball environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -466,7 +468,7 @@ def __init__(self, use_gui: bool = False) -> None: table_height=self.table_height) super().__init__(components=[domino_comp, fan_comp, ball_comp], - use_gui=use_gui) + use_gui=use_gui, **kwargs) @classmethod def get_name(cls) -> str: @@ -492,7 +494,7 @@ def goal_predicates(self) -> Set[Predicate]: class PyBulletDominoFanRampEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -527,7 +529,7 @@ def __init__(self, use_gui: bool = False) -> None: super().__init__( components=[domino_comp, fan_comp, ball_comp, ramp_comp], - use_gui=use_gui) + use_gui=use_gui, **kwargs) @classmethod def get_name(cls) -> str: @@ -553,7 +555,7 @@ def goal_predicates(self) -> Set[Predicate]: class PyBulletDominoFanRampStairsEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp + stairs environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -595,7 +597,7 @@ def __init__(self, use_gui: bool = False) -> None: super().__init__(components=[ domino_comp, fan_comp, ball_comp, ramp_comp, stairs_comp ], - use_gui=use_gui) + use_gui=use_gui, **kwargs) # Store reference to stairs component self._stairs_component = stairs_comp diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index cf5d46b2e..6f30b7895 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -13,7 +13,7 @@ step(action) — _step_base (robot control, physics, grasps) → _domain_specific_step (water filling, heating, etc.) → get_observation. Domain dynamics are skipped when - _skip_domain_specific_dynamics is True (kinematics-only mode). + skip_process_dynamics=True is passed to the constructor. get_observation() — read PyBullet state, optionally attach images/masks State synchronization: @@ -123,7 +123,9 @@ class PyBulletEnv(BaseEnv): _camera_fov: ClassVar[float] = 60 _debug_text_position: ClassVar[Pose3D] = (1.65, 0.25, 0.75) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, + use_gui: bool = False, + skip_process_dynamics: bool = False) -> None: super().__init__(use_gui) # Forward declaration: subclasses must define _robot @@ -138,7 +140,7 @@ def __init__(self, use_gui: bool = False) -> None: # When True, _domain_specific_step() is skipped in step(). # Used by sim-learning to create kinematics-only envs. - self._skip_domain_specific_dynamics: bool = False + self._skip_domain_specific_dynamics: bool = skip_process_dynamics # Set up all the static PyBullet content. self._physics_client_id, self._pybullet_robot, pybullet_bodies = \ @@ -388,7 +390,7 @@ def _domain_specific_step(self) -> None: Override in subclasses to add post-kinematics effects (water filling, heating, balance beam physics, etc.). - Skipped when ``_skip_domain_specific_dynamics`` is True. + Skipped when ``skip_process_dynamics=True`` is passed to the constructor. """ # ── State Write (State → PyBullet) ────────────────────────── diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index bc6f41fdc..5c45eed48 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -257,7 +257,7 @@ def get_configuration_dict(cls) -> Dict[str, Any]: # ------------------------------------------------------------------------- # Environment initialization # ------------------------------------------------------------------------- - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: self._robot = Object("robot", self._robot_type) # Fans - create one fan object per side instead of multiple @@ -300,7 +300,7 @@ def __init__(self, use_gui: bool = False) -> None: # Target self._target = Object("target", self._target_type) - super().__init__(use_gui=use_gui) + super().__init__(use_gui=use_gui, **kwargs) # Define new predicates if desired self._FanOn = Predicate( diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index 907b78339..fcad5973a 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -120,7 +120,7 @@ class PyBulletFloatEnv(PyBulletEnv): _block_type = Type("block", ["x", "y", "z", "in_water", "is_held"], sim_features=["id", "is_light"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: self._robot = Object("robot", self._robot_type) self._vessel = Object("vessel", self._vessel_type) self._block0 = Object("block0", self._block_type) @@ -128,7 +128,7 @@ def __init__(self, use_gui: bool = False) -> None: self._block2 = Object("block2", self._block_type) self._blocks = [self._block0, self._block1, self._block2] - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) self._InWater = Predicate("InWater", [self._block_type], self._InWater_holds) diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index 205f5d6de..9187ac6cc 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -110,7 +110,7 @@ class PyBulletGrowEnv(PyBulletEnv): _jug_type = Type("jug", ["x", "y", "z", "rot", "is_held", "r", "g", "b"], sim_features=["id", "init_x", "init_y", "init_z"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Create the single robot Object self._robot = Object("robot", self._robot_type) @@ -133,7 +133,7 @@ def __init__(self, use_gui: bool = False) -> None: # For tracking the "liquid bodies" we create for each cup self._cup_to_liquid_id: Dict[Object, Optional[int]] = {} - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Define Predicates self._Grown = Predicate("Grown", [self._cup_type], self._Grown_holds) diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index 9815a1c56..a9ee740a2 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -121,7 +121,7 @@ class PyBulletLaserEnv(PyBulletEnv): ["x", "y", "z", "rot", "split_mirror", "is_held"]) _target_type = Type("target", ["x", "y", "z", "rot", "is_hit"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Create environment objects (logic-level) self._robot = Object("robot", self._robot_type) self._station = Object("station", self._station_type) @@ -140,7 +140,7 @@ def __init__(self, use_gui: bool = False) -> None: ] # Initialize PyBullet - super().__init__(use_gui=use_gui) + super().__init__(use_gui=use_gui, **kwargs) # Define predicates # Example: "StationOn" checks whether the station is toggled on diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index 6bae8a02b..dc755286c 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -86,7 +86,7 @@ class PyBulletMagicBinEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale"]) _bin_type = Type("bin", ["x", "y", "z", "rot"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Objects self._robot = Object("robot", self._robot_type) self._blocks: List[Object] = [ @@ -96,7 +96,7 @@ def __init__(self, use_gui: bool = False) -> None: self._switch = Object("switch", self._switch_type) self._bin = Object("bin", self._bin_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._HandEmpty = Predicate("HandEmpty", [self._robot_type], diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index e2f1be09e..8fec02ccc 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -89,14 +89,14 @@ class PyBulletSwitchEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale", "color_count"]) _light_type = Type("light", ["x", "y", "z", "rot", "is_on", "color_index"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs) -> None: # Objects self._robot = Object("robot", self._robot_type) self._power_switch = Object("power_switch", self._power_switch_type) self._color_switch = Object("color_switch", self._color_switch_type) self._light = Object("light", self._light_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Track previous switch states for edge detection self._prev_color_switch_on: bool = False From 989cf4e4a70134c10204a35c748b861890ccbc0f Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Mon, 13 Apr 2026 20:42:27 +0100 Subject: [PATCH 16/70] Extract run_query_sync helper to remove duplicated async-to-sync bridging Both AgentSessionMixin and AgentExplorer had near-identical wrappers that ran session.query() synchronously via nest_asyncio or asyncio.run. Move that logic into a module-level run_query_sync helper in session_manager and have both callers delegate to it. --- predicators/agent_sdk/session_manager.py | 18 +++++++++++++++++ predicators/approaches/agent_session_mixin.py | 14 +++---------- predicators/explorers/agent_explorer.py | 20 +++---------------- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/predicators/agent_sdk/session_manager.py b/predicators/agent_sdk/session_manager.py index f56063a25..84c6ce880 100644 --- a/predicators/agent_sdk/session_manager.py +++ b/predicators/agent_sdk/session_manager.py @@ -1,4 +1,5 @@ """Agent session lifecycle management for Claude SDK.""" +import asyncio import datetime import json import logging @@ -211,3 +212,20 @@ def save_session_info(self) -> None: with open(path, "w", encoding="utf-8") as f: json.dump(info, f, indent=2) logging.info("Saved session info to %s", path) + + +def run_query_sync(session: Any, message: str) -> List[Dict[str, Any]]: + """Synchronously run ``session.query(message)``. + + Reuses a running event loop via nest_asyncio when one is active, + otherwise falls back to ``asyncio.run``. + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import nest_asyncio # type: ignore[import-untyped,import-not-found] # pylint: disable=import-outside-toplevel + nest_asyncio.apply() + return loop.run_until_complete(session.query(message)) + return loop.run_until_complete(session.query(message)) + except RuntimeError: + return asyncio.run(session.query(message)) diff --git a/predicators/approaches/agent_session_mixin.py b/predicators/approaches/agent_session_mixin.py index f90697340..fd41f9531 100644 --- a/predicators/approaches/agent_session_mixin.py +++ b/predicators/approaches/agent_session_mixin.py @@ -8,7 +8,8 @@ import os from typing import Any, Dict, List, Optional, Set, Union -from predicators.agent_sdk.session_manager import AgentSessionManager +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync from predicators.agent_sdk.tools import ToolContext, create_mcp_tools, \ get_allowed_tool_list from predicators.explorers import create_explorer @@ -179,16 +180,7 @@ def _query_agent_sync(self, message: str) -> List[Dict[str, Any]]: """Synchronous wrapper for async agent query.""" self._ensure_agent_session() assert self._agent_session is not None - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - import nest_asyncio # type: ignore[import-untyped,import-not-found] # pylint: disable=import-outside-toplevel - nest_asyncio.apply() - return loop.run_until_complete( - self._agent_session.query(message)) - return loop.run_until_complete(self._agent_session.query(message)) - except RuntimeError: - return asyncio.run(self._agent_session.query(message)) + return run_query_sync(self._agent_session, message) def _create_agent_explorer( self, diff --git a/predicators/explorers/agent_explorer.py b/predicators/explorers/agent_explorer.py index 31b675ab4..014e78cf7 100644 --- a/predicators/explorers/agent_explorer.py +++ b/predicators/explorers/agent_explorer.py @@ -1,6 +1,5 @@ """An explorer that queries a Claude agent to generate option plans.""" -import asyncio import logging from typing import Any, Dict, List, Set @@ -8,7 +7,8 @@ from gym.spaces import Box from predicators import utils -from predicators.agent_sdk.session_manager import AgentSessionManager +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync from predicators.agent_sdk.tools import ToolContext from predicators.explorers.base_explorer import BaseExplorer from predicators.settings import CFG @@ -38,7 +38,7 @@ def _get_exploration_strategy(self, train_task_idx: int, task = self._train_tasks[train_task_idx] try: prompt = self._build_exploration_prompt(train_task_idx) - responses = self._query_agent_sync(prompt) + responses = run_query_sync(self._agent_session, prompt) plan_text = self._extract_option_plan_text(responses) if plan_text: option_plan = self._parse_and_ground_plan(plan_text, task) @@ -185,20 +185,6 @@ def _build_trajectory_summary(self) -> str: return "\n".join(lines) - def _query_agent_sync(self, message: str) -> List[Dict[str, Any]]: - """Synchronous wrapper for async agent query.""" - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # pylint: disable-next=import-outside-toplevel - import nest_asyncio # type: ignore[import-untyped] - nest_asyncio.apply() - return loop.run_until_complete( - self._agent_session.query(message)) - return loop.run_until_complete(self._agent_session.query(message)) - except RuntimeError: - return asyncio.run(self._agent_session.query(message)) - def _extract_option_plan_text(self, responses: List[Dict[str, Any]]) -> str: """Extract plan text from the last assistant text response. From 87bbe1c1b3279bfdd80f2a0e06878048a6ab1e07 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 14 Apr 2026 17:04:51 +0100 Subject: [PATCH 17/70] Refactor main function: extract and modularize setup logic for clarity and maintainability --- predicators/main.py | 90 ++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/predicators/main.py b/predicators/main.py index 0fd55c6e3..a50591fd4 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -65,6 +65,53 @@ "Please add `export PYTHONHASHSEED=0` to your bash profile!" +def main() -> None: + """Main entry point for running approaches in environments.""" + script_start = time.perf_counter() + + # Parse & validate args + args = utils.parse_args() + utils.update_config(args) + str_args = " ".join(sys.argv) + + # Setup logging and directories + utils.configure_logging() + os.makedirs(CFG.results_dir, exist_ok=True) + os.makedirs(CFG.eval_trajectories_dir, exist_ok=True) + + # Log initial info + utils.log_initial_info(str_args) + + # Setup environment and tasks + env, approach_train_tasks, train_tasks = setup_environment() + + # Setup predicates + included_preds, excluded_preds = utils.parse_config_excluded_predicates( + env) + preds = utils.replace_goals_with_agent_specific_goals( + included_preds, excluded_preds, + env) if CFG.approach != "oracle" else included_preds + + # Create approach + approach = setup_approach(env, preds, approach_train_tasks) + + # Create dataset and cognitive manager + offline_dataset = create_offline_dataset(env, train_tasks, preds, approach) + execution_monitor = create_execution_monitor(CFG.execution_monitor) + cogman = CogMan(approach, create_perceiver(CFG.perceiver), + execution_monitor) + + # Run pipeline + _run_pipeline(env, cogman, approach_train_tasks, offline_dataset) + + # Log completion + script_time = time.perf_counter() - script_start + logging.info(f"\n\nMain script terminated in {script_time:.5f} seconds") + + +# ── Setup helpers ──────────────────────────────────────────────── + + def setup_environment() -> Tuple[BaseEnv, List[Task], List[Task]]: """Create and setup the environment and tasks. @@ -141,48 +188,7 @@ def create_offline_dataset(env: BaseEnv, train_tasks: List[Task], preds: set, return None -def main() -> None: - """Main entry point for running approaches in environments.""" - script_start = time.perf_counter() - - # Parse & validate args - args = utils.parse_args() - utils.update_config(args) - str_args = " ".join(sys.argv) - - # Setup logging and directories - utils.configure_logging() - os.makedirs(CFG.results_dir, exist_ok=True) - os.makedirs(CFG.eval_trajectories_dir, exist_ok=True) - - # Log initial info - utils.log_initial_info(str_args) - - # Setup environment and tasks - env, approach_train_tasks, train_tasks = setup_environment() - - # Setup predicates - included_preds, excluded_preds = utils.parse_config_excluded_predicates( - env) - preds = utils.replace_goals_with_agent_specific_goals( - included_preds, excluded_preds, - env) if CFG.approach != "oracle" else included_preds - - # Create approach - approach = setup_approach(env, preds, approach_train_tasks) - - # Create dataset and cognitive manager - offline_dataset = create_offline_dataset(env, train_tasks, preds, approach) - execution_monitor = create_execution_monitor(CFG.execution_monitor) - cogman = CogMan(approach, create_perceiver(CFG.perceiver), - execution_monitor) - - # Run pipeline - _run_pipeline(env, cogman, approach_train_tasks, offline_dataset) - - # Log completion - script_time = time.perf_counter() - script_start - logging.info(f"\n\nMain script terminated in {script_time:.5f} seconds") +# ── Pipeline ───────────────────────────────────────────────────── def _run_pipeline(env: BaseEnv, From 10f010bf41b02ec24604ce9c14ad111edb1b5729 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 14 Apr 2026 17:21:20 +0100 Subject: [PATCH 18/70] Rename agent explorer to agent_plan for clearer naming Distinguishes the grounded-plan explorer from upcoming bilevel variants. AgentExplorer -> AgentPlanExplorer, get_name() 'agent' -> 'agent_plan', file moved to agent_plan_explorer.py, and all callers / docstrings / YAML config examples updated accordingly. --- .../agent_abstraction_learning_approach.py | 2 +- predicators/approaches/agent_bilevel_approach.py | 2 +- .../approaches/agent_closed_loop_approach.py | 2 +- predicators/approaches/agent_planner_approach.py | 8 ++++---- predicators/approaches/agent_session_mixin.py | 2 +- predicators/explorers/__init__.py | 2 +- .../{agent_explorer.py => agent_plan_explorer.py} | 14 ++++++++++---- .../configs/predicatorv3/approaches/agents.yaml | 6 +++--- scripts/configs/predicatorv3/predicator_v3.yaml | 8 ++++---- 9 files changed, 26 insertions(+), 20 deletions(-) rename predicators/explorers/{agent_explorer.py => agent_plan_explorer.py} (95%) diff --git a/predicators/approaches/agent_abstraction_learning_approach.py b/predicators/approaches/agent_abstraction_learning_approach.py index b76fcf7de..f0df3c966 100644 --- a/predicators/approaches/agent_abstraction_learning_approach.py +++ b/predicators/approaches/agent_abstraction_learning_approach.py @@ -477,7 +477,7 @@ def _build_solve_prompt(self, task: Task) -> str: def _create_explorer(self) -> BaseExplorer: """Create explorer, passing agent context if using agent explorer.""" - if CFG.explorer == "agent": + if CFG.explorer == "agent_plan": all_trajs = (self._offline_dataset.trajectories + self._online_dataset.trajectories) self._sync_tool_context(all_trajs) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index de60d98d4..3b75f082f 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -10,7 +10,7 @@ python predicators/main.py --env pybullet_domino \ --approach agent_bilevel --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ import dataclasses import logging diff --git a/predicators/approaches/agent_closed_loop_approach.py b/predicators/approaches/agent_closed_loop_approach.py index 8e38ebf33..1bf7805b1 100644 --- a/predicators/approaches/agent_closed_loop_approach.py +++ b/predicators/approaches/agent_closed_loop_approach.py @@ -9,7 +9,7 @@ python predicators/main.py --env pybullet_domino \ --approach agent_closed_loop --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ import logging from typing import Callable, List diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index f178fc76b..cf918adc2 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -1,6 +1,6 @@ """Agent planner approach: fixed-vocabulary open-loop planning. -Combines online trajectory collection (via AgentExplorer) with open-loop +Combines online trajectory collection (via AgentPlanExplorer) with open-loop option plan generation (via Claude Agent SDK). No predicate/process/type invention — just stores trajectories and generates plans. @@ -8,7 +8,7 @@ python predicators/main.py --env pybullet_domino \ --approach agent_planner --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ import datetime import inspect as _inspect @@ -37,7 +37,7 @@ class AgentPlannerApproach(AgentSessionMixin, BaseApproach): """Fixed-vocabulary open-loop planning via Claude Agent SDK. - - Collects trajectories online using AgentExplorer + - Collects trajectories online using AgentPlanExplorer - At solve time, queries the agent for an option plan - No predicate/process/type invention """ @@ -705,7 +705,7 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: def _create_explorer(self) -> BaseExplorer: """Create explorer for interaction requests.""" - if CFG.explorer == "agent": + if CFG.explorer == "agent_plan": self._sync_tool_context() return self._create_agent_explorer(self._get_all_predicates(), self._get_all_options()) diff --git a/predicators/approaches/agent_session_mixin.py b/predicators/approaches/agent_session_mixin.py index fd41f9531..f3578db5a 100644 --- a/predicators/approaches/agent_session_mixin.py +++ b/predicators/approaches/agent_session_mixin.py @@ -190,7 +190,7 @@ def _create_agent_explorer( """Create an agent explorer with tool_context and agent_session.""" self._ensure_agent_session() return create_explorer( - "agent", + "agent_plan", predicates, options, self._types, # type: ignore[attr-defined] diff --git a/predicators/explorers/__init__.py b/predicators/explorers/__init__.py index 560c840d6..191a39cf9 100644 --- a/predicators/explorers/__init__.py +++ b/predicators/explorers/__init__.py @@ -109,7 +109,7 @@ def create_explorer( action_space, train_tasks, max_steps_before_termination, nsrts, maple_q_function) - elif name == "agent": + elif name == "agent_plan": assert tool_context is not None assert agent_session is not None explorer = cls(initial_predicates, initial_options, types, diff --git a/predicators/explorers/agent_explorer.py b/predicators/explorers/agent_plan_explorer.py similarity index 95% rename from predicators/explorers/agent_explorer.py rename to predicators/explorers/agent_plan_explorer.py index 014e78cf7..2de8a404a 100644 --- a/predicators/explorers/agent_explorer.py +++ b/predicators/explorers/agent_plan_explorer.py @@ -1,4 +1,10 @@ -"""An explorer that queries a Claude agent to generate option plans.""" +"""Agent plan explorer: Claude agent generates grounded option plans. + +Produces fully-grounded option plans (including continuous parameters) and +rolls them out in the real environment. The agent is expected to provide +complete parameters itself; this explorer does not run backtracking +refinement against a learned option model. +""" import logging from typing import Any, Dict, List, Set @@ -16,8 +22,8 @@ ParameterizedOption, Predicate, State, Task, Type -class AgentExplorer(BaseExplorer): - """Queries a Claude agent to produce option plans for exploration.""" +class AgentPlanExplorer(BaseExplorer): + """Queries a Claude agent to produce grounded option plans.""" def __init__(self, predicates: Set[Predicate], options: Set[ParameterizedOption], types: Set[Type], @@ -31,7 +37,7 @@ def __init__(self, predicates: Set[Predicate], @classmethod def get_name(cls) -> str: - return "agent" + return "agent_plan" def _get_exploration_strategy(self, train_task_idx: int, timeout: int) -> ExplorationStrategy: diff --git a/scripts/configs/predicatorv3/approaches/agents.yaml b/scripts/configs/predicatorv3/approaches/agents.yaml index 9e9d82d8a..946a30713 100644 --- a/scripts/configs/predicatorv3/approaches/agents.yaml +++ b/scripts/configs/predicatorv3/approaches/agents.yaml @@ -2,7 +2,7 @@ APPROACHES: # agent_planner: # NAME: "agent_planner" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True @@ -15,7 +15,7 @@ APPROACHES: agent_bilevel: NAME: "agent_bilevel" FLAGS: - explorer: "agent" + explorer: "agent_plan" demonstrator: "oracle_process_planning" terminate_on_goal_reached_and_option_terminated: True agent_sdk_use_local_sandbox: True @@ -30,7 +30,7 @@ APPROACHES: # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # option_learner: "agent" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True diff --git a/scripts/configs/predicatorv3/predicator_v3.yaml b/scripts/configs/predicatorv3/predicator_v3.yaml index 29f0a5398..9678225af 100644 --- a/scripts/configs/predicatorv3/predicator_v3.yaml +++ b/scripts/configs/predicatorv3/predicator_v3.yaml @@ -18,7 +18,7 @@ APPROACHES: # agent_planner: # NAME: "agent_planner" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # # agent_sdk_use_docker_sandbox: True @@ -32,7 +32,7 @@ APPROACHES: # agent_bilevel: # NAME: "agent_bilevel" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # # agent_sdk_use_docker_sandbox: True @@ -46,7 +46,7 @@ APPROACHES: # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # option_learner: "agent" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True @@ -60,7 +60,7 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # bilevel_plan_without_sim: True # max_initial_demos: 0 - # explorer: "agent" + # explorer: "agent_plan" # num_online_learning_cycles: 4 # online_nsrt_learning_requests_per_cycle: 1 ENVS: From 4076abd2201cb9eeea49f7498cff2d5a1367a022 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 14 Apr 2026 17:50:10 +0100 Subject: [PATCH 19/70] Move AgentSessionMixin into agent_sdk package The mixin is pure agent-session plumbing (session creation, lifecycle, explorer factory) and has no approach-specific logic, so it belongs next to session_manager.py, tools.py, and the sandbox managers rather than in approaches/. --- .../{approaches => agent_sdk}/agent_session_mixin.py | 9 +++++++-- .../approaches/agent_abstraction_learning_approach.py | 2 +- predicators/approaches/agent_planner_approach.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) rename predicators/{approaches => agent_sdk}/agent_session_mixin.py (96%) diff --git a/predicators/approaches/agent_session_mixin.py b/predicators/agent_sdk/agent_session_mixin.py similarity index 96% rename from predicators/approaches/agent_session_mixin.py rename to predicators/agent_sdk/agent_session_mixin.py index f3578db5a..1f518e356 100644 --- a/predicators/approaches/agent_session_mixin.py +++ b/predicators/agent_sdk/agent_session_mixin.py @@ -128,12 +128,16 @@ def _ensure_agent_session(self) -> None: tools=tools, ) + extra_names = [ + getattr(t, "name", "") for t in + self._tool_context.extra_mcp_tools] self._agent_session = AgentSessionManager( system_prompt=self._get_agent_system_prompt(), mcp_server=mcp_server, log_dir=self._get_log_dir(), model_name=CFG.agent_sdk_model_name, - allowed_tools=get_allowed_tool_list(tool_names), + allowed_tools=get_allowed_tool_list( + tool_names, extra_names=extra_names or None), ) if self._agent_session_id is not None: @@ -186,11 +190,12 @@ def _create_agent_explorer( self, predicates: Set[Predicate], options: Set[ParameterizedOption], + name: str = "agent_plan", ) -> BaseExplorer: """Create an agent explorer with tool_context and agent_session.""" self._ensure_agent_session() return create_explorer( - "agent_plan", + name, predicates, options, self._types, # type: ignore[attr-defined] diff --git a/predicators/approaches/agent_abstraction_learning_approach.py b/predicators/approaches/agent_abstraction_learning_approach.py index f0df3c966..96e4ab11f 100644 --- a/predicators/approaches/agent_abstraction_learning_approach.py +++ b/predicators/approaches/agent_abstraction_learning_approach.py @@ -16,7 +16,7 @@ from predicators.agent_sdk.proposal_parser import ProposalBundle, \ build_exec_context, exec_code_safely from predicators.approaches.agent_planner_approach import AgentPlannerApproach -from predicators.approaches.agent_session_mixin import AgentSessionMixin +from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.approaches.pp_online_process_learning_approach import \ OnlineProcessLearningAndPlanningApproach from predicators.approaches.pp_predicate_invention_approach import \ diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index cf918adc2..1e8c5d5c1 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -23,7 +23,7 @@ from predicators import utils from predicators.approaches import ApproachFailure -from predicators.approaches.agent_session_mixin import AgentSessionMixin +from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.approaches.base_approach import BaseApproach from predicators.explorers import create_explorer from predicators.explorers.base_explorer import BaseExplorer From b26429153d94828fd75c4e7cafbd25aeb99b2a24 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Tue, 14 Apr 2026 20:58:37 +0100 Subject: [PATCH 20/70] Add AgentBilevelExplorer for sim-learning experiments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The explorer asks a Claude agent for a plan sketch, refines it against the approach's current (possibly learned) option model, and rolls the refined plan out in the real env. When the mental model disagrees with reality — e.g. the sketch expects JugFilled after a Wait but the mental model's process dynamics can't produce it — the explorer truncates the plan at the deepest unsatisfiable subgoal (inclusive) so the real-env rollout ends exactly where the disagreement occurs, maximising signal per experiment. Key pieces: - predicators/agent_sdk/bilevel_sketch.py: extracted the sketch build / parse / refine helpers from AgentBilevelApproach as module-level functions so both the approach (solve path) and the new explorer (exploration path) can share them. refine_sketch gains truncate_on_subgoal_fail: the on_step_fail callback snapshots the deepest subgoal failure seen during backtracking, and on exhaustion the captured prefix is returned as the experiment plan. - predicators/explorers/agent_bilevel_explorer.py: new explorer. Reads option_model from tool_context (synced by the approach), builds the sketch prompt via bilevel_sketch, runs refine_sketch with check_subgoals=True, check_final_goal=False, truncate_on_subgoal_fail =True, wraps the result in an option_plan_to_policy that converts OptionExecutionFailure into RequestActPolicyFailure so the episode cleanly terminates at the point of real-env divergence. Stashes the sketch subgoals/options on ToolContext for downstream diffing by the learning approach. - predicators/approaches/agent_bilevel_approach.py: shim methods over bilevel_sketch; behaviour unchanged. - predicators/approaches/agent_planner_approach.py: _create_explorer dispatches both "agent_plan" and "agent_bilevel" through the agent factory path and forwards CFG.explorer as the name. - predicators/explorers/__init__.py: factory branch merged for the two agent-session-backed explorers. - predicators/agent_sdk/tools.py: ToolContext gains last_sketch_subgoals / last_sketch_options fields, populated by the explorer and marked TODO for the learning approach to consume. - tests/explorers/test_agent_bilevel_explorer.py: happy-path, fallback, wait-memory-injection, and deepest-subgoal-failure truncation tests. --- predicators/agent_sdk/bilevel_sketch.py | 427 ++++++++++++++++++ predicators/agent_sdk/tools.py | 5 + .../approaches/agent_bilevel_approach.py | 347 ++------------ .../approaches/agent_planner_approach.py | 9 +- predicators/explorers/__init__.py | 2 +- .../explorers/agent_bilevel_explorer.py | 223 +++++++++ predicators/explorers/agent_plan_explorer.py | 6 +- .../explorers/test_agent_bilevel_explorer.py | 330 ++++++++++++++ 8 files changed, 1037 insertions(+), 312 deletions(-) create mode 100644 predicators/agent_sdk/bilevel_sketch.py create mode 100644 predicators/explorers/agent_bilevel_explorer.py create mode 100644 tests/explorers/test_agent_bilevel_explorer.py diff --git a/predicators/agent_sdk/bilevel_sketch.py b/predicators/agent_sdk/bilevel_sketch.py new file mode 100644 index 000000000..f088ee0b5 --- /dev/null +++ b/predicators/agent_sdk/bilevel_sketch.py @@ -0,0 +1,427 @@ +"""Shared helpers for bilevel plan-sketch construction and refinement. + +Extracted from ``AgentBilevelApproach`` so both the approach (at solve +time) and ``AgentBilevelExplorer`` (at exploration time) can build plan +sketches, parse subgoal annotations, and run backtracking refinement +against an arbitrary ``_OptionModelBase``. + +The helpers are pure module-level functions — they take their +dependencies (option_model, predicates, rng, settings) explicitly so +neither approaches nor explorers need to subclass one another. +""" +import dataclasses +import logging +import re +from typing import Callable, List, Optional, Sequence, Set, Tuple, cast + +import numpy as np + +from predicators import utils +from predicators.option_model import _OptionModelBase +from predicators.planning import run_backtracking_refinement +from predicators.structs import GroundAtom, Object, ParameterizedOption, \ + Predicate, State, Task, Type, _Option + + +@dataclasses.dataclass +class SketchStep: + """One step in an agent-produced plan sketch. + + ``subgoal_atoms`` / ``subgoal_neg_atoms`` are optional: ``None`` + means "no subgoal constraint at this step"; an empty set means "the + annotation was present but contained no atoms of that polarity". + """ + option: ParameterizedOption + objects: Sequence[Object] + subgoal_atoms: Optional[Set[GroundAtom]] + subgoal_neg_atoms: Optional[Set[GroundAtom]] = None + + +def strip_code_fences(text: str) -> str: + """Strip markdown code fences wrapping plan text.""" + lines = text.split('\n') + while lines and lines[0].strip().startswith('```'): + lines.pop(0) + while lines and lines[-1].strip().startswith('```'): + lines.pop() + return '\n'.join(lines) + + +def sample_params(option: ParameterizedOption, + rng: np.random.Generator) -> np.ndarray: + """Sample continuous parameters uniformly from the option's box.""" + if option.params_space.shape[0] == 0: + return np.array([], dtype=np.float32) + low = option.params_space.low + high = option.params_space.high + return rng.uniform(low, high).astype(np.float32) + + +def build_solve_prompt( + task: Task, + *, + all_predicates: Set[Predicate], + all_options: Set[ParameterizedOption], + trajectory_summary: str = "", + tool_names: Optional[Sequence[str]] = None, +) -> str: + """Build the bilevel solve/explore prompt asking for a plan sketch. + + Mirrors ``AgentBilevelApproach._build_solve_prompt`` but takes + dependencies explicitly so explorers can reuse it. + """ + init_state = task.init + objects = list(init_state) + + obj_strs = [] + for obj in sorted(objects, key=lambda o: o.name): + obj_strs.append(f" {obj.name}: {obj.type.name}") + + goal_strs = [str(a) for a in sorted(task.goal, key=str)] + + option_strs = [] + for opt in sorted(all_options, key=lambda o: o.name): + type_sig = ", ".join(t.name for t in opt.types) + params_dim = opt.params_space.shape[0] + if params_dim > 0: + low = opt.params_space.low.tolist() + high = opt.params_space.high.tolist() + if opt.params_description: + desc = ", ".join(opt.params_description) + param_info = (f" [auto-searched params: {desc}, " + f"range {low} to {high}]") + else: + param_info = (f" [auto-searched: {params_dim}d, " + f"range {low} to {high}]") + else: + param_info = "" + option_strs.append(f" {opt.name}({type_sig}){param_info}") + + atoms = utils.abstract(init_state, all_predicates) + atom_strs = [str(a) for a in sorted(atoms, key=str)] + + state_str = init_state.dict_str(indent=2) + + tools_str = "" + if tool_names: + tool_list = "\n".join(f" - {t}" for t in tool_names) + tools_str = f"\n## Available Tools\n{tool_list}\n" + + goal_nl_section = "" + if task.goal_nl: + goal_nl_section = f"\n## Goal Description\n{task.goal_nl}\n" + + pred_strs = [] + for pred in sorted(all_predicates, key=lambda p: p.name): + type_sig = ", ".join(t.name for t in pred.types) + pred_strs.append(f" {pred.name}({type_sig})") + + prompt = f"""You are solving a task. \ +Generate a plan sketch to achieve the goal. +{goal_nl_section} +## Goal Atoms +{chr(10).join(goal_strs)} + +## Initial State Atoms +{chr(10).join(atom_strs)} + +## Initial State Features +{state_str} + +## Objects +{chr(10).join(obj_strs)} + +## Available Options +{chr(10).join(option_strs)} + +## Available Predicates (for subgoal annotations) +{chr(10).join(pred_strs)} +{trajectory_summary}{tools_str} +## Instructions +Use your available tools to inspect the environment before producing the plan. + +Generate a plan SKETCH — the sequence of options with object arguments, but \ +WITHOUT continuous parameters. Continuous parameters will be found \ +automatically by a backtracking search procedure. + +Optionally annotate subgoal atoms that should hold after each step. This \ +helps the search verify progress. Use `-> {{atoms}}` after each step. + +After any action whose desired subgoal depends on a delayed process (e.g. \ +water filling, dominoes cascading, heating), insert a Wait action. For Wait \ +steps, annotate with the atoms the process should produce — this tells the \ +system exactly when the Wait should end rather than terminating on any \ +incidental atom change. Use `NOT Pred(...)` for atoms that should become false. + +Output the plan sketch with one option per line in this format: + OptionName(obj1:type1, obj2:type2) -> \ +{{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} + Wait(robot:Robot) -> {{Boiled(water:water_type)}} + Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} + +Always use typed references (obj:type) in both option arguments AND subgoal \ +atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ +only check that the option executed successfully (non-zero actions). + +Output ONLY the plan sketch lines at the end, after any analysis.""" + + return prompt + + +def parse_subgoal_annotations( + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + option_names: Set[str], +) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. + + Returns a list parallel to the option lines in ``text``. Each entry + is ``None`` for a line with no annotation, or + ``(positive_atoms, negative_atoms)`` otherwise. + """ + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + + for line in text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + + sg_match = subgoal_re.search(stripped) + if not sg_match: + results.append(None) + continue + + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] + for n in atom_match.group(3).split(',') + ] + + if pred_name not in pred_map: + logging.warning(f"Unknown predicate in subgoal: {pred_name}") + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError as e: + logging.warning(f"Unknown object in subgoal: {e}") + continue + if len(objs) != len(pred.types): + logging.warning( + f"Arity mismatch for {pred_name}: expected " + f"{len(pred.types)}, got {len(objs)}") + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + + if pos_atoms or neg_atoms: + results.append((pos_atoms, neg_atoms)) + else: + results.append(None) + + return results + + +def parse_sketch_from_text( + plan_text: str, + task: Task, + *, + predicates: Set[Predicate], + options: Set[ParameterizedOption], + types: Set[Type], +) -> List[SketchStep]: + """Parse plan-sketch text into ``SketchStep``s. + + Applies ``strip_code_fences`` first, then delegates option-plan + parsing to ``utils.parse_model_output_into_option_plan`` and subgoal + annotation parsing to ``parse_subgoal_annotations``. + """ + cleaned_text = strip_code_fences(plan_text) + objects = list(task.init) + option_names = {o.name for o in options} + + parsed = utils.parse_model_output_into_option_plan( + cleaned_text, + objects, + types, + options, + parse_continuous_params=False) + + if not parsed: + return [] + + subgoals = parse_subgoal_annotations(cleaned_text, predicates, objects, + option_names) + + sketch: List[SketchStep] = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + SketchStep(option=option, + objects=objs, + subgoal_atoms=None)) + return sketch + + +def refine_sketch( + task: Task, + sketch: List[SketchStep], + option_model: _OptionModelBase, + *, + predicates: Set[Predicate], + timeout: float, + rng: np.random.Generator, + max_samples_per_step: int, + check_subgoals: bool, + check_final_goal: bool = True, + truncate_on_subgoal_fail: bool = False, + log_state: bool = False, + run_id: str = "bilevel", + on_step_fail: Optional[Callable[[int, List[Optional[_Option]], str], + None]] = None, +) -> Tuple[List[_Option], bool, int]: + """Backtracking search over continuous parameters for a plan sketch. + + Returns ``(refined_plan, success, total_samples)``. On success the + plan is fully refined; on failure it is the longest prefix of + refined options (``None`` entries dropped). + + ``check_subgoals`` gates per-step subgoal-atom validation. + ``check_final_goal`` gates the task-goal check on the final step. + ``truncate_on_subgoal_fail`` (explorer mode) lets backtracking run + to exhaustion with subgoal checks enabled, then — if the search + fails — returns the consistent plan prefix captured at the deepest + subgoal failure seen during backtracking (inclusive of the failing + step). Use this to build *experiment* plans that probe a single + mental-model disagreement: upstream steps get their standard + backtracking retries, but once the deepest unresolvable subgoal is + identified, subsequent sketch steps are dropped (they would be + built on a false mental-model state). + + Wait steps inject ``wait_target_atoms`` / ``wait_target_neg_atoms`` + from the sketch's subgoal annotations into ``grounded.memory`` so + that ``WaitOption`` terminates on the intended atom change rather + than the first incidental one. + """ + if not sketch: + return [], False, 0 + + n = len(sketch) + max_tries = [ + max_samples_per_step if step.option.params_space.shape[0] > 0 else 1 + for step in sketch + ] + # Snapshot of the deepest subgoal failure seen during backtracking. + # Tracks (idx, plan_prefix_snapshot). Updated whenever on_step_fail + # reports a subgoal failure at a strictly deeper index than before. + # The snapshot is taken at the moment of failure, so it is a + # *consistent* trajectory: run_backtracking_refinement has already + # written plan[idx] for that attempt and the prefix plan[:idx+1] + # reflects the exact grounded options that led to this failure. + deepest_subgoal_fail_idx: List[int] = [-1] + deepest_subgoal_fail_prefix: List[List[Optional[_Option]]] = [[]] + + def sample_fn(idx: int, state: State, + rng_: np.random.Generator) -> _Option: + step = sketch[idx] + if log_state: + step_name = (f"{step.option.name}" + f"({', '.join(o.name for o in step.objects)})") + logging.debug(f"[{run_id}] State before {step_name}:\n" + f"{state.pretty_str()}") + params = sample_params(step.option, rng_) + grounded = step.option.ground(list(step.objects), params) + if grounded.name == "Wait": + if step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + if step.subgoal_neg_atoms is not None: + grounded.memory["wait_target_neg_atoms"] = \ + step.subgoal_neg_atoms + return grounded + + def validate_fn(idx: int, _pre_state: State, _option: _Option, + post_state: State, + _num_actions: int) -> Tuple[bool, str]: + step = sketch[idx] + if check_subgoals and step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + missing = step.subgoal_atoms - current_atoms + return False, (f"subgoal missing: " + f"{{{', '.join(str(a) for a in missing)}}}") + if check_final_goal and idx == n - 1: + if not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + def wrapped_on_step_fail(idx: int, cur_plan: List[Optional[_Option]], + fail_reason: str) -> None: + # run_backtracking_refinement calls this BEFORE clearing + # plan[idx] (planning.py lines 592-599), so cur_plan[0..idx] is + # still populated with the grounded options that produced this + # exact failure trajectory. Record the deepest subgoal failure + # seen so far along with a consistent snapshot of the prefix. + if (truncate_on_subgoal_fail + and fail_reason.startswith("subgoal missing") + and idx > deepest_subgoal_fail_idx[0]): + deepest_subgoal_fail_idx[0] = idx + deepest_subgoal_fail_prefix[0] = list(cur_plan[:idx + 1]) + if on_step_fail is not None: + on_step_fail(idx, cur_plan, fail_reason) + + plan, success, total_samples = run_backtracking_refinement( + init_state=task.init, + option_model=option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + on_step_fail=wrapped_on_step_fail, + ) + + logging.info( + f"[{run_id}] Refinement {'succeeded' if success else 'failed'}: " + f"{total_samples} samples for {n} steps.") + + if (truncate_on_subgoal_fail and not success + and deepest_subgoal_fail_idx[0] >= 0): + snapshot = deepest_subgoal_fail_prefix[0] + refined = [p for p in snapshot if p is not None] + logging.info( + f"[{run_id}] Truncating at deepest subgoal failure " + f"(step {deepest_subgoal_fail_idx[0]}): " + f"{len(refined)}/{n} steps in experiment plan.") + return cast(List[_Option], refined), False, total_samples + + refined = [p for p in plan if p is not None] + if success: + return cast(List[_Option], refined), True, total_samples + return refined, False, total_samples diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index bb5f98c32..583a537c3 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -114,6 +114,11 @@ class ToolContext: turn_id: int = 0 # current query/turn within the session test_call_id: int = 0 # incremented per test_option_plan call visualized_state: Optional[State] = None # last state from visualize_state + # Populated by AgentBilevelExplorer so learning approaches can diff + # mental-model subgoals against real trajectories. + # TODO(sim-learning): consume these in learn_from_interaction_results. + last_sketch_subgoals: Optional[Any] = None + last_sketch_options: Optional[Any] = None def _text_result(text: str) -> Dict[str, Any]: diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 3b75f082f..6461bea60 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -12,15 +12,15 @@ --num_train_tasks 1 --num_test_tasks 1 \ --num_online_learning_cycles 1 --explorer agent_plan """ -import dataclasses import logging -import re import time -from typing import Callable, List, Optional, Sequence, Set, Tuple, cast +from typing import Callable, List, Optional, Sequence, Set, Tuple import numpy as np from predicators import utils +from predicators.agent_sdk import bilevel_sketch +from predicators.agent_sdk.bilevel_sketch import SketchStep as _SketchStep from predicators.approaches import ApproachFailure from predicators.approaches.agent_planner_approach import AgentPlannerApproach from predicators.planning import run_backtracking_refinement @@ -29,16 +29,6 @@ ParameterizedOption, Predicate, State, Task, _Option -@dataclasses.dataclass -class _SketchStep: - """One step in an agent-produced plan sketch.""" - option: ParameterizedOption - objects: Sequence[Object] - subgoal_atoms: Optional[Set[GroundAtom]] # None = no subgoal constraint - # Atoms that must be FALSE after this step. - subgoal_neg_atoms: Optional[Set[GroundAtom]] = None - - class AgentBilevelApproach(AgentPlannerApproach): """Bilevel planning: agent proposes discrete skeleton, search refines continuous parameters. @@ -90,114 +80,13 @@ def _get_agent_system_prompt(self) -> str: def _build_solve_prompt(self, task: Task) -> str: """Build prompt asking for a plan sketch without continuous params.""" - init_state = task.init - objects = list(init_state) - - # Objects - obj_strs = [] - for obj in sorted(objects, key=lambda o: o.name): - obj_strs.append(f" {obj.name}: {obj.type.name}") - - # Goal - goal_strs = [str(a) for a in sorted(task.goal, key=str)] - - # Options (show params_space info so agent understands what's tunable) - option_strs = [] - for opt in sorted(self._get_all_options(), key=lambda o: o.name): - type_sig = ", ".join(t.name for t in opt.types) - params_dim = opt.params_space.shape[0] - if params_dim > 0: - low = opt.params_space.low.tolist() - high = opt.params_space.high.tolist() - if opt.params_description: - desc = ", ".join(opt.params_description) - param_info = (f" [auto-searched params: {desc}, " - f"range {low} to {high}]") - else: - param_info = (f" [auto-searched: {params_dim}d, " - f"range {low} to {high}]") - else: - param_info = "" - option_strs.append(f" {opt.name}({type_sig}){param_info}") - - # Current atoms - atoms = utils.abstract(init_state, self._get_all_predicates()) - atom_strs = [str(a) for a in sorted(atoms, key=str)] - - # Trajectory summary - traj_summary = self._build_trajectory_summary() - - # State features - state_str = init_state.dict_str(indent=2) - - # Available tools - tool_names = self._get_agent_tool_names() - tools_str = "" - if tool_names: - tool_list = "\n".join(f" - {t}" for t in tool_names) - tools_str = f"\n## Available Tools\n{tool_list}\n" - - # Natural language goal - goal_nl_section = "" - if task.goal_nl: - goal_nl_section = f"\n## Goal Description\n{task.goal_nl}\n" - - # Available predicates for subgoal annotations - pred_strs = [] - for pred in sorted(self._get_all_predicates(), key=lambda p: p.name): - type_sig = ", ".join(t.name for t in pred.types) - pred_strs.append(f" {pred.name}({type_sig})") - - prompt = f"""You are solving a task. \ -Generate a plan sketch to achieve the goal. -{goal_nl_section} -## Goal Atoms -{chr(10).join(goal_strs)} - -## Initial State Atoms -{chr(10).join(atom_strs)} - -## Initial State Features -{state_str} - -## Objects -{chr(10).join(obj_strs)} - -## Available Options -{chr(10).join(option_strs)} - -## Available Predicates (for subgoal annotations) -{chr(10).join(pred_strs)} -{traj_summary}{tools_str} -## Instructions -Use your available tools to inspect the environment before producing the plan. - -Generate a plan SKETCH — the sequence of options with object arguments, but \ -WITHOUT continuous parameters. Continuous parameters will be found \ -automatically by a backtracking search procedure. - -Optionally annotate subgoal atoms that should hold after each step. This \ -helps the search verify progress. Use `-> {{atoms}}` after each step. - -After any action whose desired subgoal depends on a delayed process (e.g. \ -water filling, dominoes cascading, heating), insert a Wait action. For Wait \ -steps, annotate with the atoms the process should produce — this tells the \ -system exactly when the Wait should end rather than terminating on any \ -incidental atom change. Use `NOT Pred(...)` for atoms that should become false. - -Output the plan sketch with one option per line in this format: - OptionName(obj1:type1, obj2:type2) -> \ -{{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} - Wait(robot:Robot) -> {{Boiled(water:water_type)}} - Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} - -Always use typed references (obj:type) in both option arguments AND subgoal \ -atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ -only check that the option executed successfully (non-zero actions). - -Output ONLY the plan sketch lines at the end, after any analysis.""" - - return prompt + return bilevel_sketch.build_solve_prompt( + task, + all_predicates=self._get_all_predicates(), + all_options=self._get_all_options(), + trajectory_summary=self._build_trajectory_summary(), + tool_names=self._get_agent_tool_names(), + ) # ------------------------------------------------------------------ # # Solving @@ -274,129 +163,26 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: if not plan_text: raise ApproachFailure("Agent returned empty plan text.") - cleaned_text = self._strip_code_fences(plan_text) - - # Phase 1: parse options + objects (no continuous params) - objects = list(task.init) - parsed = utils.parse_model_output_into_option_plan( - cleaned_text, - objects, - self._types, - self._get_all_options(), - parse_continuous_params=False) + sketch = bilevel_sketch.parse_sketch_from_text( + plan_text, + task, + predicates=self._get_all_predicates(), + options=self._get_all_options(), + types=self._types, + ) - if not parsed: + if not sketch: option_names = sorted(o.name for o in self._get_all_options()) raise ApproachFailure(f"Parsed empty plan sketch from agent.\n" f" Plan text:\n{plan_text}\n" f" Available option names: {option_names}") - # Phase 2: parse subgoal annotations from raw text - subgoals = self._parse_subgoal_annotations(cleaned_text, - self._get_all_predicates(), - objects) - - # Zip into sketch steps - sketch = [] - for i, (option, objs, _) in enumerate(parsed): - sg = subgoals[i] if i < len(subgoals) else None - if sg is not None: - pos, neg = sg - sketch.append( - _SketchStep(option=option, - objects=objs, - subgoal_atoms=pos if pos else None, - subgoal_neg_atoms=neg if neg else None)) - else: - sketch.append( - _SketchStep(option=option, - objects=objs, - subgoal_atoms=None)) - logging.info(f"[{self._run_id}] Agent produced sketch with " f"{len(sketch)} steps, " f"{sum(1 for s in sketch if s.subgoal_atoms)} " f"with subgoals.") return sketch - def _parse_subgoal_annotations( - self, - text: str, - predicates: Set[Predicate], - objects: Sequence[Object], - ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: - """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. - - Returns a list parallel to the option lines. Entries are None - for lines without annotations. Each non-None entry is - ``(positive_atoms, negative_atoms)``. - """ - pred_map = {p.name: p for p in predicates} - obj_map = {o.name: o for o in objects} - - # Regex: match -> { ... } after the option line - subgoal_re = re.compile(r'->\s*\{([^}]*)\}') - # Regex: match individual atoms, optionally prefixed with NOT - atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') - - results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] - option_names = {o.name for o in self._get_all_options()} - - for line in text.split('\n'): - stripped = line.strip() - if not stripped: - continue - # Check if this line starts with a valid option name - first_token = stripped.split('(')[0] - if first_token not in option_names: - continue - - # This is an option line — check for subgoal annotation - sg_match = subgoal_re.search(stripped) - if not sg_match: - results.append(None) - continue - - atoms_text = sg_match.group(1) - pos_atoms: Set[GroundAtom] = set() - neg_atoms: Set[GroundAtom] = set() - for atom_match in atom_re.finditer(atoms_text): - is_neg = atom_match.group(1) is not None - pred_name = atom_match.group(2) - # Handle both "obj" and "obj:type" formats - obj_names = [ - n.strip().split(':')[0] - for n in atom_match.group(3).split(',') - ] - - if pred_name not in pred_map: - logging.warning(f"Unknown predicate in subgoal: " - f"{pred_name}") - continue - pred = pred_map[pred_name] - try: - objs = [obj_map[n] for n in obj_names] - except KeyError as e: - logging.warning(f"Unknown object in subgoal: {e}") - continue - if len(objs) != len(pred.types): - logging.warning( - f"Arity mismatch for {pred_name}: expected " - f"{len(pred.types)}, got {len(objs)}") - continue - atom = GroundAtom(pred, objs) - if is_neg: - neg_atoms.add(atom) - else: - pos_atoms.add(atom) - - if pos_atoms or neg_atoms: - results.append((pos_atoms, neg_atoms)) - else: - results.append(None) - - return results - # ------------------------------------------------------------------ # # Backtracking refinement # ------------------------------------------------------------------ # @@ -413,86 +199,37 @@ def _refine_sketch( grounded options that achieves the task goal. On failure, ``plan`` is the longest partial refinement found. - Delegates to ``run_backtracking_refinement`` for the core loop. + Delegates to ``bilevel_sketch.refine_sketch``. """ - if not sketch: - return [], False - - rng = np.random.default_rng(CFG.seed) - max_samples = CFG.agent_bilevel_max_samples_per_step - check_subgoals = CFG.agent_bilevel_check_subgoals - n = len(sketch) - max_tries = [ - max_samples if step.option.params_space.shape[0] > 0 else 1 - for step in sketch - ] - predicates = self._get_all_predicates() - - def sample_fn(idx: int, state: State, - rng_: np.random.Generator) -> _Option: - step = sketch[idx] - if CFG.agent_bilevel_log_state: - step_name = (f"{step.option.name}" - f"({', '.join(o.name for o in step.objects)})") - logging.debug(f" State before {step_name}:\n" - f"{state.pretty_str()}") - params = self._sample_params(step.option, state, rng_) - grounded = step.option.ground(step.objects, params) - if grounded.name == "Wait": - if step.subgoal_atoms is not None: - grounded.memory["wait_target_atoms"] = \ - step.subgoal_atoms - if step.subgoal_neg_atoms is not None: - grounded.memory["wait_target_neg_atoms"] = \ - step.subgoal_neg_atoms - return grounded - - def validate_fn(idx: int, _pre_state: State, _option: _Option, - post_state: State, - _num_actions: int) -> Tuple[bool, str]: - step = sketch[idx] - if check_subgoals and step.subgoal_atoms is not None: - current_atoms = utils.abstract(post_state, predicates) - if not step.subgoal_atoms.issubset(current_atoms): - missing = step.subgoal_atoms - current_atoms - return False, (f"subgoal missing: " - f"{{{', '.join(str(a) for a in missing)}}}") - if idx == n - 1: - if not task.goal_holds(post_state): - return False, "goal not reached" - return True, "" - - plan, success, total_samples = run_backtracking_refinement( - init_state=task.init, - option_model=self._option_model, - n_steps=n, - max_tries=max_tries, - sample_fn=sample_fn, - validate_fn=validate_fn, - rng=rng, + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + self._option_model, + predicates=self._get_all_predicates(), timeout=timeout, + rng=np.random.default_rng(CFG.seed), + max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, + check_subgoals=CFG.agent_bilevel_check_subgoals, + log_state=CFG.agent_bilevel_log_state, + run_id=self._run_id, ) - - logging.info(f"Refinement {'succeeded' if success else 'failed'}: " - f"{total_samples} samples for {n} steps.") - - filtered = [p for p in plan if p is not None] - if success: - return cast(List[_Option], filtered), True - return filtered, False + return plan, success def _sample_params(self, option: ParameterizedOption, _state: State, rng: np.random.Generator) -> np.ndarray: - """Sample continuous parameters for an option. + """Sample continuous parameters for an option.""" + return bilevel_sketch.sample_params(option, rng) - Currently uniform random; hook point for future learned - samplers. - """ - if option.params_space.shape[0] == 0: - return np.array([], dtype=np.float32) - low = option.params_space.low - high = option.params_space.high - return rng.uniform(low, high).astype(np.float32) + def _parse_subgoal_annotations( + self, + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Shim over ``bilevel_sketch.parse_subgoal_annotations``.""" + option_names = {o.name for o in self._get_all_options()} + return bilevel_sketch.parse_subgoal_annotations( + text, predicates, objects, option_names) # ------------------------------------------------------------------ # # Forward validation diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 1e8c5d5c1..88d4a4698 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -705,10 +705,13 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: def _create_explorer(self) -> BaseExplorer: """Create explorer for interaction requests.""" - if CFG.explorer == "agent_plan": + if CFG.explorer in ("agent_plan", "agent_bilevel"): self._sync_tool_context() - return self._create_agent_explorer(self._get_all_predicates(), - self._get_all_options()) + return self._create_agent_explorer( + self._get_all_predicates(), + self._get_all_options(), + name=CFG.explorer, + ) return create_explorer( CFG.explorer, self._get_all_predicates(), diff --git a/predicators/explorers/__init__.py b/predicators/explorers/__init__.py index 191a39cf9..644138648 100644 --- a/predicators/explorers/__init__.py +++ b/predicators/explorers/__init__.py @@ -109,7 +109,7 @@ def create_explorer( action_space, train_tasks, max_steps_before_termination, nsrts, maple_q_function) - elif name == "agent_plan": + elif name in ("agent_plan", "agent_bilevel"): assert tool_context is not None assert agent_session is not None explorer = cls(initial_predicates, initial_options, types, diff --git a/predicators/explorers/agent_bilevel_explorer.py b/predicators/explorers/agent_bilevel_explorer.py new file mode 100644 index 000000000..0b2adf8e6 --- /dev/null +++ b/predicators/explorers/agent_bilevel_explorer.py @@ -0,0 +1,223 @@ +"""Agent bilevel explorer: sketch → refine against mental model → execute real. + +Produces a plan *sketch* via a Claude agent, runs backtracking refinement +against the approach's currently-learned option model (read from +``tool_context.option_model``), then rolls the refined plan out in the +real environment. When the mental model disagrees with reality (e.g. a +subgoal atom the mental model expected after a Wait doesn't actually +hold), the resulting trajectory provides a targeted learning signal for +online simulator synthesis. + +Parallels ``AgentPlanExplorer`` for session plumbing and +``AgentBilevelApproach`` for the sketch/refine workflow. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +import numpy as np +from gym.spaces import Box + +from predicators import utils +from predicators.agent_sdk import bilevel_sketch +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync +from predicators.agent_sdk.tools import ToolContext +from predicators.explorers.base_explorer import BaseExplorer +from predicators.settings import CFG +from predicators.structs import Action, ExplorationStrategy, \ + ParameterizedOption, Predicate, State, Task, Type + + +class AgentBilevelExplorer(BaseExplorer): + """Queries a Claude agent for a plan sketch, refines it, and executes.""" + + def __init__(self, predicates: Set[Predicate], + options: Set[ParameterizedOption], types: Set[Type], + action_space: Box, train_tasks: List[Task], + max_steps_before_termination: int, tool_context: ToolContext, + agent_session: AgentSessionManager) -> None: + super().__init__(predicates, options, types, action_space, train_tasks, + max_steps_before_termination) + self._tool_context = tool_context + self._agent_session = agent_session + + @classmethod + def get_name(cls) -> str: + return "agent_bilevel" + + # ------------------------------------------------------------------ # + # Exploration strategy + # ------------------------------------------------------------------ # + + def _get_exploration_strategy(self, train_task_idx: int, + timeout: int) -> ExplorationStrategy: + task = self._train_tasks[train_task_idx] + # The approach syncs tool_context.option_model right before + # constructing this explorer, so reading here picks up the most + # recently learned model. + option_model = self._tool_context.option_model + assert option_model is not None, \ + "agent_bilevel explorer needs a synced option_model" + + try: + prompt = bilevel_sketch.build_solve_prompt( + task, + all_predicates=self._predicates, + all_options=self._options, + trajectory_summary=self._build_trajectory_summary(), + tool_names=self._agent_tool_names(), + ) + responses = run_query_sync(self._agent_session, prompt) + plan_text = self._extract_option_plan_text(responses) + if not plan_text: + raise ValueError("agent returned empty plan text") + + sketch = bilevel_sketch.parse_sketch_from_text( + plan_text, + task, + predicates=self._predicates, + options=self._options, + types=self._types, + ) + if not sketch: + raise ValueError("parsed empty plan sketch") + + self._tool_context.last_sketch_subgoals = [ + (s.subgoal_atoms, s.subgoal_neg_atoms) for s in sketch + ] + self._tool_context.last_sketch_options = [ + (s.option.name, [o.name for o in s.objects]) for s in sketch + ] + + # Explorer mode: keep subgoal validation ON so the mental + # model can tell us which step it can't predict, but when + # that happens, truncate the plan at that step (inclusive) + # instead of backtracking. Steps beyond the first + # disagreement are built on a false mental-model state, so + # executing them in the real env adds noise rather than + # signal. The truncated plan — Pick → ... → first failing + # step — is the experiment we want to run. Final-goal check + # is also off: the explorer isn't trying to solve the task + # in the mental model. + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + option_model, + predicates=self._predicates, + timeout=float(timeout), + rng=np.random.default_rng(CFG.seed), + max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, + check_subgoals=True, + check_final_goal=False, + truncate_on_subgoal_fail=True, + log_state=CFG.agent_bilevel_log_state, + run_id="agent_bilevel_explorer", + ) + logging.info( + f"agent_bilevel explorer: sketch has {len(sketch)} steps, " + f"refined {len(plan)} " + f"({'success' if success else 'partial'}).") + + if plan: + policy = utils.option_plan_to_policy( + plan, + abstract_function=lambda s: utils.abstract( + s, self._predicates)) + return self._wrap_policy(policy), lambda _: False + + logging.info("agent_bilevel explorer: refinement produced zero " + "steps, falling back to random.") + except Exception as e: # pylint: disable=broad-except + logging.warning(f"agent_bilevel explorer failed: {e}. " + "Falling back to random options.") + + if not CFG.agent_explorer_fallback_to_random: + raise utils.RequestActPolicyFailure( + "agent_bilevel explorer failed and fallback disabled.") + return self._random_options_fallback() + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + + def _wrap_policy( + self, policy: Callable[[State], Action] + ) -> Callable[[State], Action]: + """Convert OptionExecutionFailure into RequestActPolicyFailure. + + This lets the main loop cleanly terminate the episode when the + refined plan finishes or fails mid-execution (which is exactly + the disagreement signal we want to collect). + """ + + def _wrapped(state: State) -> Action: + try: + return policy(state) + except utils.OptionExecutionFailure as e: + raise utils.RequestActPolicyFailure(e.args[0], e.info) from e + + return _wrapped + + def _random_options_fallback(self) -> ExplorationStrategy: + """Fall back to random option sampling.""" + + def fallback_policy(state: State) -> Action: + del state + raise utils.RequestActPolicyFailure( + "Random option sampling failed!") + + policy = utils.create_random_option_policy(self._options, self._rng, + fallback_policy) + return policy, lambda _: False + + def _agent_tool_names(self) -> Optional[List[str]]: + """Return tool names exposed by the current session, if any.""" + return getattr(self._agent_session, "tool_names", None) + + def _build_trajectory_summary(self) -> str: + """Summarize trajectory data for the agent.""" + all_trajs = (self._tool_context.offline_trajectories + + self._tool_context.online_trajectories) + if not all_trajs: + return "" + + max_trajs = CFG.agent_sdk_max_trajectories_in_context + recent = all_trajs[-max_trajs:] + lines = [ + f"\n## Trajectory Summary ({len(all_trajs)} total, " + f"showing last {len(recent)})" + ] + + for i, traj in enumerate(recent): + n_steps = len(traj.actions) + init_atoms = utils.abstract(traj.states[0], self._predicates) + final_atoms = utils.abstract(traj.states[-1], self._predicates) + new_atoms = final_atoms - init_atoms + lost_atoms = init_atoms - final_atoms + lines.append(f"\nTrajectory {i}: {n_steps} steps") + if new_atoms: + lines.append( + " Gained: " + + f"{', '.join(str(a) for a in sorted(new_atoms, key=str))}") + if lost_atoms: + lines.append( + " Lost: " + + f"{', '.join(str(a) for a in sorted(lost_atoms, key=str))}" + ) + + return "\n".join(lines) + + def _extract_option_plan_text( + self, responses: List[Dict[str, Any]]) -> str: + """Extract plan text from the last assistant text response.""" + last_text_parts: List[str] = [] + for resp in responses: + if resp.get("type") == "assistant": + parts = [ + block.get("text", "") for block in resp.get("content", []) + if isinstance(block, dict) and block.get("type") == "text" + ] + if parts: + last_text_parts = parts + return "\n".join(last_text_parts) diff --git a/predicators/explorers/agent_plan_explorer.py b/predicators/explorers/agent_plan_explorer.py index 2de8a404a..f693c273f 100644 --- a/predicators/explorers/agent_plan_explorer.py +++ b/predicators/explorers/agent_plan_explorer.py @@ -1,9 +1,9 @@ """Agent plan explorer: Claude agent generates grounded option plans. Produces fully-grounded option plans (including continuous parameters) and -rolls them out in the real environment. The agent is expected to provide -complete parameters itself; this explorer does not run backtracking -refinement against a learned option model. +rolls them out in the real environment. Unlike ``AgentBilevelExplorer``, it +does not run backtracking refinement against a learned option model — the +agent is expected to provide complete parameters itself. """ import logging diff --git a/tests/explorers/test_agent_bilevel_explorer.py b/tests/explorers/test_agent_bilevel_explorer.py new file mode 100644 index 000000000..33a651cad --- /dev/null +++ b/tests/explorers/test_agent_bilevel_explorer.py @@ -0,0 +1,330 @@ +"""Tests for AgentBilevelExplorer.""" +# pylint: disable=protected-access + +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest +from gym.spaces import Box + +from predicators import utils +from predicators.agent_sdk.tools import ToolContext +from predicators.explorers import create_explorer +from predicators.explorers.agent_bilevel_explorer import AgentBilevelExplorer +from predicators.explorers.base_explorer import BaseExplorer +from predicators.structs import Action, GroundAtom, Object, \ + ParameterizedOption, Predicate, State, Task, Type + +# --------------------------------------------------------------------------- +# Fixtures (parallel the bilevel approach tests) +# --------------------------------------------------------------------------- + +_block_type = Type("block", ["x", "y", "held"]) +_robot_type = Type("robot", ["x", "y"]) + +_block0 = Object("block0", _block_type) +_block1 = Object("block1", _block_type) +_robot = Object("robot0", _robot_type) + +_Holding = Predicate("Holding", [_block_type], + lambda s, o: s.get(o[0], "held") > 0.5) +_On = Predicate("On", [_block_type, _block_type], + lambda s, o: abs(s.get(o[0], "x") - s.get(o[1], "x")) < 0.1) +_HandEmpty = Predicate("HandEmpty", [_robot_type], lambda s, o: True) + +_ALL_PREDICATES = {_Holding, _On, _HandEmpty} +_ALL_TYPES = {_block_type, _robot_type} + + +def _noop_policy(_s, _m, _o, _p): + return Action(np.zeros(1, dtype=np.float32)) + + +def _always_true(_s, _m, _o, _p): + return True + + +def _always_false(_s, _m, _o, _p): + return False + + +_Pick = ParameterizedOption( + "Pick", + types=[_block_type], + params_space=Box(low=np.array([0.0], dtype=np.float32), + high=np.array([1.0], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_Place = ParameterizedOption( + "Place", + types=[_block_type, _block_type], + params_space=Box(low=np.array([0.0, 0.0], dtype=np.float32), + high=np.array([1.0, 1.0], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_Wait = ParameterizedOption( + "Wait", + types=[_robot_type], + params_space=Box(low=np.array([], dtype=np.float32), + high=np.array([], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_ALL_OPTIONS = {_Pick, _Place, _Wait} + + +def _make_state(overrides=None): + data = { + _block0: np.array([0.1, 0.2, 0.0], dtype=np.float32), + _block1: np.array([0.5, 0.6, 0.0], dtype=np.float32), + _robot: np.array([0.0, 0.0], dtype=np.float32), + } + if overrides: + for obj, vals in overrides.items(): + data[obj] = np.array(vals, dtype=np.float32) + return State(data) + + +def _make_task(): + state = _make_state() + goal = {GroundAtom(_On, [_block0, _block1])} + return Task(state, goal) + + +def _assistant_response(text: str): + return [{ + "type": "assistant", + "content": [{ + "type": "text", + "text": text + }], + }] + + +def _make_explorer(option_model, query_impl): + """Build an AgentBilevelExplorer with stubbed session + tool_context.""" + tool_context = ToolContext( + types=_ALL_TYPES, + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + train_tasks=[_make_task()], + option_model=option_model, + ) + agent_session = MagicMock() + agent_session.query = query_impl + agent_session.tool_names = None + explorer = AgentBilevelExplorer( + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + types=_ALL_TYPES, + action_space=Box(low=-1, high=1, shape=(1, )), + train_tasks=[_make_task()], + max_steps_before_termination=50, + tool_context=tool_context, + agent_session=agent_session, + ) + return explorer, tool_context + + +def _reset_config(**overrides): + base = { + "env": "cover", + "approach": "agent_bilevel", + "num_train_tasks": 1, + "num_test_tasks": 1, + "seed": 42, + "agent_bilevel_max_samples_per_step": 5, + "agent_bilevel_max_retries": 0, + "agent_bilevel_check_subgoals": True, + "agent_bilevel_log_state": False, + "agent_explorer_fallback_to_random": True, + "agent_sdk_max_trajectories_in_context": 5, + } + base.update(overrides) + utils.reset_config(base) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_factory_registration(): + """AgentBilevelExplorer is reachable through create_explorer.""" + _reset_config() + tool_context = ToolContext( + types=_ALL_TYPES, + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + train_tasks=[_make_task()], + option_model=MagicMock(), + ) + agent_session = MagicMock() + explorer = create_explorer( + "agent_bilevel", + _ALL_PREDICATES, + _ALL_OPTIONS, + _ALL_TYPES, + Box(low=-1, high=1, shape=(1, )), + [_make_task()], + tool_context=tool_context, + agent_session=agent_session, + ) + assert isinstance(explorer, BaseExplorer) + assert isinstance(explorer, AgentBilevelExplorer) + + +def test_happy_path_returns_policy_and_stashes_subgoals(): + """Canned sketch → refined plan → policy and stashed subgoals.""" + _reset_config() + + goal_state = _make_state({_block0: [0.5, 0.6, 0.0]}) + option_model = MagicMock() + option_model.get_next_state_and_num_actions.return_value = (goal_state, 3) + + plan_text = ("Pick(block0:block)\n" + "Place(block0:block, block1:block) -> " + "{On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + + explorer, tool_context = _make_explorer(option_model, query) + policy, term_fn = explorer._get_exploration_strategy(0, timeout=5) + + assert callable(policy) + assert term_fn(_make_state()) is False + assert tool_context.last_sketch_subgoals is not None + assert len(tool_context.last_sketch_subgoals) == 2 + # Second step's positive subgoal should be {On(block0, block1)}. + pos2, _neg2 = tool_context.last_sketch_subgoals[1] + assert pos2 == {GroundAtom(_On, [_block0, _block1])} + assert tool_context.last_sketch_options == [ + ("Pick", ["block0"]), + ("Place", ["block0", "block1"]), + ] + assert query.await_count == 1 + + +def test_wait_memory_injection_on_refine(): + """Wait step with subgoal should have wait_target_atoms injected.""" + _reset_config() + + captured: list = [] + + def side_effect(_state, option): + captured.append(option) + return (_make_state({_block0: [0.5, 0.6, 0.0]}), 3) + + option_model = MagicMock() + option_model.get_next_state_and_num_actions.side_effect = side_effect + + plan_text = ("Wait(robot0:robot) -> {On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + explorer, _ = _make_explorer(option_model, query) + + explorer._get_exploration_strategy(0, timeout=5) + assert captured, "option_model was not invoked" + wait_opt = captured[0] + assert wait_opt.name == "Wait" + assert "wait_target_atoms" in wait_opt.memory + assert wait_opt.memory["wait_target_atoms"] == { + GroundAtom(_On, [_block0, _block1]) + } + + +def test_plan_truncates_at_deepest_subgoal_failure_after_backtracking(): + """Regression: explorer returns the prefix up to (and including) the + deepest step whose subgoal backtracking couldn't satisfy. + + Reproduces the boil-task bug: the agent sketches ``Pick → Wait(Holding) + → Place`` and the mental model's Wait does NOT produce ``Holding``. + Backtracking runs normally — it retries Pick with different params + and re-runs Wait each time — but since the mental model simply can't + produce Holding under any params, Wait's subgoal keeps failing. + After exhaustion, the explorer returns ``[Pick, Wait]`` with the last + grounded attempts. Place is NEVER executed because refinement never + gets past Wait. + """ + _reset_config() + + # Mental model post-state: Holding(block0) NEVER holds (held=0). + no_holding_state = _make_state({_block0: [0.1, 0.2, 0.0]}) + option_model = MagicMock() + option_model.get_next_state_and_num_actions.return_value = ( + no_holding_state, 3) + + plan_text = ("Pick(block0:block)\n" + "Wait(robot0:robot) -> {Holding(block0:block)}\n" + "Place(block0:block, block1:block) -> " + "{On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + explorer, tool_context = _make_explorer(option_model, query) + + policy, _ = explorer._get_exploration_strategy(0, timeout=5) + assert callable(policy) + + # All three sketch steps recorded in metadata — the SKETCH is the full + # agent output; the TRUNCATION only applies to the refined plan. + assert tool_context.last_sketch_options == [ + ("Pick", ["block0"]), + ("Wait", ["robot0"]), + ("Place", ["block0", "block1"]), + ] + + executed_names = [ + call.args[1].name + for call in option_model.get_next_state_and_num_actions.call_args_list + ] + # Pick and Wait were each executed at least once (backtracking likely + # retried Pick multiple times). + assert "Pick" in executed_names + assert "Wait" in executed_names + # Place must NEVER be executed in the mental model: backtracking never + # got past the Wait subgoal failure, so Place never reached sample_fn. + assert "Place" not in executed_names, ( + "Place must not be executed in the mental model — refinement " + f"should have stalled at Wait's unsatisfiable subgoal, got " + f"{executed_names}") + # Pick has params (5 max_samples_per_step in test config), Wait has none. + # Each backtracking cycle runs Pick + Wait once, so we expect roughly + # 2 * max_samples_per_step mental-model calls — confirm backtracking + # actually exercised the upstream retries (at least 2 Picks). + assert executed_names.count("Pick") >= 2, ( + "Backtracking should have retried Pick at least twice before " + f"giving up, got {executed_names}") + + +def test_fallback_when_query_fails_and_flag_on(): + """Agent raises → random options fallback when flag enabled.""" + _reset_config(agent_explorer_fallback_to_random=True) + + option_model = MagicMock() + + async def failing_query(_msg): + raise RuntimeError("boom") + + explorer, _ = _make_explorer(option_model, failing_query) + policy, term_fn = explorer._get_exploration_strategy(0, timeout=5) + assert callable(policy) + assert term_fn(_make_state()) is False + + +def test_fallback_disabled_raises(): + """Agent raises → RequestActPolicyFailure when fallback flag off.""" + _reset_config(agent_explorer_fallback_to_random=False) + + option_model = MagicMock() + + async def failing_query(_msg): + raise RuntimeError("boom") + + explorer, _ = _make_explorer(option_model, failing_query) + with pytest.raises(utils.RequestActPolicyFailure): + explorer._get_exploration_strategy(0, timeout=5) From ee0a2b70974ecec937b5521d3a84751a4d086899 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 10:16:53 +0100 Subject: [PATCH 21/70] Add explorer-specific sample budget and experiment-plan logging - New setting agent_bilevel_explorer_max_samples_per_step (default 50), separate from the solve-path budget, so the explorer's backtracking cost is independently tunable. - Log the actual experiment plan (option names, objects, params) after refinement so the explorer's output is visible alongside the existing sketch/truncation log lines. - Test config updated to set both budgets explicitly. --- predicators/explorers/agent_bilevel_explorer.py | 13 ++++++++++++- predicators/settings.py | 5 +++++ tests/explorers/test_agent_bilevel_explorer.py | 1 + 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/predicators/explorers/agent_bilevel_explorer.py b/predicators/explorers/agent_bilevel_explorer.py index 0b2adf8e6..d71344693 100644 --- a/predicators/explorers/agent_bilevel_explorer.py +++ b/predicators/explorers/agent_bilevel_explorer.py @@ -107,7 +107,8 @@ def _get_exploration_strategy(self, train_task_idx: int, predicates=self._predicates, timeout=float(timeout), rng=np.random.default_rng(CFG.seed), - max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, + max_samples_per_step=CFG. + agent_bilevel_explorer_max_samples_per_step, check_subgoals=True, check_final_goal=False, truncate_on_subgoal_fail=True, @@ -118,6 +119,16 @@ def _get_exploration_strategy(self, train_task_idx: int, f"agent_bilevel explorer: sketch has {len(sketch)} steps, " f"refined {len(plan)} " f"({'success' if success else 'partial'}).") + if plan: + plan_strs = [] + for i, opt in enumerate(plan): + obj_s = ", ".join(o.name for o in opt.objects) + par_s = ", ".join(f"{p:.4f}" for p in opt.params) + plan_strs.append( + f" {i}: {opt.name}({obj_s})[{par_s}]") + logging.info( + "agent_bilevel explorer: experiment plan:\n" + + "\n".join(plan_strs)) if plan: policy = utils.option_plan_to_policy( diff --git a/predicators/settings.py b/predicators/settings.py index 22bee6d3d..c1b23423a 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1016,6 +1016,11 @@ class GlobalSettings: # log state pretty_str before/after each step agent_bilevel_log_state = False agent_bilevel_plan_sketch_file = "" # load sketch from file instead of LLM + # Agent bilevel explorer settings. Separate from the solve-path budget + # above because the explorer runs full backtracking while looking for + # the deepest subgoal-failure to truncate at, and each exhausted + # upstream step multiplies the cost. + agent_bilevel_explorer_max_samples_per_step = 50 @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/explorers/test_agent_bilevel_explorer.py b/tests/explorers/test_agent_bilevel_explorer.py index 33a651cad..0db0dc237 100644 --- a/tests/explorers/test_agent_bilevel_explorer.py +++ b/tests/explorers/test_agent_bilevel_explorer.py @@ -142,6 +142,7 @@ def _reset_config(**overrides): "num_test_tasks": 1, "seed": 42, "agent_bilevel_max_samples_per_step": 5, + "agent_bilevel_explorer_max_samples_per_step": 5, "agent_bilevel_max_retries": 0, "agent_bilevel_check_subgoals": True, "agent_bilevel_log_state": False, From a8fb2dd94cf471fee7c431162b7f224308f4a302 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 10:17:37 +0100 Subject: [PATCH 22/70] Add sim-learning approach and synthesis tooling AgentSimLearningApproach extends AgentBilevelApproach to learn process dynamics online. Each cycle: the agent synthesizes parameterized process rules via Claude (using run_python / evaluate_simulator / test_simulator MCP tools), parameters are fitted via emcee MCMC, and the learned dynamics are composed with a kinematics-only PyBullet oracle into a combined option model for plan refinement. Key pieces: - predicators/approaches/agent_sim_learning_approach.py: the approach. Initialises with a kinematics-only option model (so AgentBilevelExplorer sees disagreements at process-dynamic subgoals like JugFilled/Boiled), and replaces it with the kin+learned model after each successful synthesis cycle. - predicators/agent_sdk/tools.py: create_synthesis_tools() builds the three MCP tools the synthesis agent uses; extra_mcp_tools field and get_allowed_tool_list(extra_names=) plumbing lets the approach inject them into the session. - predicators/code_sim_learning/: ParamSpec, fit_params (emcee MCMC), compute_mse, LearnedSimulator. - predicators/ground_truth_models/boil/gt_simulator.py: ground-truth process-dynamics simulator for the boil environment. - tests/: approach and param-fitting tests. --- predicators/agent_sdk/tools.py | 231 +++++++- .../approaches/agent_sim_learning_approach.py | 550 ++++++++++++++++++ predicators/code_sim_learning/__init__.py | 1 + predicators/code_sim_learning/training.py | 156 +++++ predicators/code_sim_learning/utils.py | 38 ++ .../ground_truth_models/boil/gt_simulator.py | 165 ++++++ .../test_agent_sim_learning_approach.py | 365 ++++++++++++ tests/code_sim_learning/test_param_fitting.py | 321 ++++++++++ 8 files changed, 1824 insertions(+), 3 deletions(-) create mode 100644 predicators/approaches/agent_sim_learning_approach.py create mode 100644 predicators/code_sim_learning/__init__.py create mode 100644 predicators/code_sim_learning/training.py create mode 100644 predicators/code_sim_learning/utils.py create mode 100644 predicators/ground_truth_models/boil/gt_simulator.py create mode 100644 tests/approaches/test_agent_sim_learning_approach.py create mode 100644 tests/code_sim_learning/test_param_fitting.py diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 583a537c3..b375e0580 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -72,7 +72,10 @@ PLANNING_TOOL_NAMES + SCENE_TOOL_NAMES) -def get_allowed_tool_list(tool_names: Optional[List[str]] = None) -> List[str]: +def get_allowed_tool_list( + tool_names: Optional[List[str]] = None, + extra_names: Optional[List[str]] = None, +) -> List[str]: """Compute the allowed_tools list for the agent SDK. Args: @@ -82,6 +85,8 @@ def get_allowed_tool_list(tool_names: Optional[List[str]] = None) -> List[str]: prefix = f"mcp__{MCP_SERVER_NAME}__" names = ALL_TOOL_NAMES if tool_names is None else \ [n for n in tool_names if n in set(ALL_TOOL_NAMES)] + if extra_names: + names = list(names) + list(extra_names) return [f"{prefix}{n}" for n in names] @@ -114,6 +119,7 @@ class ToolContext: turn_id: int = 0 # current query/turn within the session test_call_id: int = 0 # incremented per test_option_plan call visualized_state: Optional[State] = None # last state from visualize_state + extra_mcp_tools: list = field(default_factory=list) # injected by subclass # Populated by AgentBilevelExplorer so learning approaches can diff # mental-model subgoals against real trajectories. # TODO(sim-learning): consume these in learn_from_interaction_results. @@ -1950,5 +1956,224 @@ async def visualize_state(args: Dict[str, Any]) -> Dict[str, Any]: "visualize_state": visualize_state, } if tool_names is None: - return list(_all.values()) - return [_all[n] for n in tool_names if n in _all] + tools = list(_all.values()) + else: + tools = [_all[n] for n in tool_names if n in _all] + tools.extend(ctx.extra_mcp_tools) + return tools + + +# ── Sim-learning tools ─────────────────────────────────────────── + + +def create_synthesis_tools( + exec_ns: Dict[str, Any], + step_transitions: list, + process_features: Dict[str, List[str]], + kin_env: Any = None, + save_dir: Optional[str] = None, +) -> list: + """Create MCP tools for the sim-learning synthesis agent. + + Returns ``[run_python, evaluate_simulator, test_simulator]``. + + * ``run_python`` — executes arbitrary Python in a persistent + namespace pre-loaded with trajectory data. + * ``evaluate_simulator`` — fits parameters via MCMC on + ``PROCESS_RULES`` / ``PARAM_SPECS`` defined in the namespace. + * ``test_simulator`` — tests predictions vs observations. + + Args: + exec_ns: Persistent namespace for ``run_python``. Should + contain ``trajectories``, ``np``, ``ParamSpec``. + step_transitions: ``(State, Action, State)`` triples. + process_features: ``{type_name: [feat_names]}`` for MSE. + kin_env: Kinematics-only environment. When provided, + evaluate/test tools run kinematics before learned rules. + save_dir: Directory to save simulator source code to. + Each ``run_python`` call appends code to + ``save_dir/simulator_code.py``. + """ + import io # pylint: disable=import-outside-toplevel + import sys # pylint: disable=import-outside-toplevel + import traceback # pylint: disable=import-outside-toplevel + + from claude_agent_sdk import \ + tool # pylint: disable=import-outside-toplevel + + from predicators.approaches.agent_sim_learning_approach import ( # pylint: disable=import-outside-toplevel + AgentSimLearningApproach) + + _run_count = [0] # mutable counter in closure + + def _text(msg: str) -> Dict[str, Any]: + return {"type": "text", "text": msg} + + # ── run_python ────────────────────────────────────────── + + @tool( + "run_python", + "Execute Python code with trajectory data in scope. " + "Available variables: trajectories (List[LowLevelTrajectory])," + " np, ParamSpec. print() output is returned. " + "The namespace persists across calls.", + { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute.", + } + }, + "required": ["code"], + }, + ) + async def run_python(args: Dict[str, Any]) -> Dict[str, Any]: + code = args["code"] + old_stdout = sys.stdout + sys.stdout = captured = io.StringIO() + try: + exec(code, exec_ns) # pylint: disable=exec-used + except Exception: # pylint: disable=broad-except + tb = traceback.format_exc() + return _text(f"Error:\n{tb}") + finally: + sys.stdout = old_stdout + + # Save each successful run_python call as a versioned file; + # _load_simulator_from_file replays these in order. + if save_dir is not None: + _run_count[0] += 1 + os.makedirs(save_dir, exist_ok=True) + filename = f"{_run_count[0]:03d}_run_python.py" + filepath = os.path.join(save_dir, filename) + with open(filepath, "w", encoding="utf-8") as f: + f.write(code) + + output = captured.getvalue() + return _text(output or "(no output)") + + # ── evaluate_simulator ────────────────────────────────── + + @tool( + "evaluate_simulator", + "Fit parameters using PROCESS_RULES and PARAM_SPECS " + "from the run_python namespace. Reports MSE and fitted " + "parameter values.", + {"type": "object", "properties": {}}, + ) + async def evaluate_simulator( + args: Dict[str, Any]) -> Dict[str, Any]: + rules = exec_ns.get("PROCESS_RULES") + specs = exec_ns.get("PARAM_SPECS") + if not isinstance(rules, list) or not rules: + return _text( + "Error: PROCESS_RULES not defined. Use " + "run_python to define it first.") + if not isinstance(specs, list) or not specs: + return _text( + "Error: PARAM_SPECS not defined. Use " + "run_python to define it first.") + + try: + fitted_params, mse = ( + AgentSimLearningApproach._fit_parameters( + rules, specs, step_transitions, process_features, + kin_env)) + except Exception as e: # pylint: disable=broad-except + return _text(f"Error: fit_params failed:\n{e}") + + lines = [ + f"MSE: {mse:.6f} on " + f"{len(step_transitions)} step transitions.", + "", "Fitted parameters:", + ] + for name, val in fitted_params.items(): + lines.append(f" {name}: {val:.6f}") + + return _text("\n".join(lines)) + + # ── test_simulator ────────────────────────────────────── + + @tool( + "test_simulator", + "Test PROCESS_RULES predictions vs observations on " + "step transitions. Shows mismatches.", + { + "type": "object", + "properties": { + "max_transitions": { + "type": "integer", + "description": + "Max transitions to test (default 100).", + }, + "tolerance": { + "type": "number", + "description": + "Absolute tolerance for mismatch " + "(default 1e-4).", + }, + }, + }, + ) + async def test_simulator( + args: Dict[str, Any]) -> Dict[str, Any]: + rules = exec_ns.get("PROCESS_RULES") + specs = exec_ns.get("PARAM_SPECS") + if not isinstance(rules, list) or not rules: + return _text("Error: PROCESS_RULES not defined.") + + max_n = args.get("max_transitions", 100) + tol = args.get("tolerance", 1e-4) + pairs = step_transitions[:max_n] + + # Use init params if not yet fitted. + if specs: + t_params = {s.name: s.init_value for s in specs} + else: + t_params = {} + + lines: list = [] + n_tested = 0 + n_mismatch = 0 + + for s_t, action, s_next_obs in pairs: + # Run kinematics first so rules see post-kin state. + kin_state = (kin_env.simulate(s_t, action) + if kin_env is not None else s_t) + updates: Dict = {} + for rule in rules: + updates = rule(kin_state, updates, t_params) + + entry: list = [] + for obj in s_t: + type_name = obj.type.name + for feat in process_features.get(type_name, []): + if obj in updates and feat in updates[obj]: + pred = updates[obj][feat] + pred = (pred.item() + if hasattr(pred, "item") + else float(pred)) + else: + pred = s_t.get(obj, feat) + obs = s_next_obs.get(obj, feat) + err = abs(pred - obs) + if err > tol: + entry.append( + f" {obj.name}.{feat}: " + f"pred={pred:.6f} obs={obs:.6f} " + f"err={err:.6f}") + + n_tested += 1 + if entry: + n_mismatch += 1 + lines.append(f"Step {n_tested}:") + lines.extend(entry) + lines.append("") + + lines.append( + f"Tested {n_tested} steps: {n_mismatch} mismatches, " + f"{n_tested - n_mismatch} correct.") + return _text("\n".join(lines)) + + return [run_python, evaluate_simulator, test_simulator] diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py new file mode 100644 index 000000000..a7a656dd0 --- /dev/null +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -0,0 +1,550 @@ +"""Agent sim-learning approach: learns a simulator program online. + +Extends AgentBilevelApproach to learn process dynamics via an +agent-synthesized step-level simulator with parameterized process +rules. Parameters are fitted via emcee ensemble MCMC (training.py). + +The approach creates a kinematics-only oracle (PyBullet with process +dynamics disabled) and composes it with the learned step-level +dynamics into a single simulator function, plugged into a standard +_OracleOptionModel for true per-step interleaving. + +Example command:: + + python predicators/main.py --env pybullet_boil \ + --approach agent_sim_learning --seed 0 \ + --num_train_tasks 10 --num_test_tasks 5 \ + --num_online_learning_cycles 5 --explorer agent_plan +""" + +import inspect +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple + +import numpy as np +from gym.spaces import Box + +from predicators import utils +from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach +from predicators.agent_sdk.tools import create_synthesis_tools +from predicators.code_sim_learning.training import (ParamSpec, compute_mse, + fit_params) +from predicators.code_sim_learning.utils import LearnedSimulator +from predicators.envs import create_new_env +from predicators.option_model import _OptionModelBase, _OracleOptionModel +from predicators.settings import CFG +from predicators.structs import (Action, InteractionResult, + LowLevelTrajectory, ParameterizedOption, + Predicate, State, Task, Type) + +logger = logging.getLogger(__name__) + + +# ── Helpers ─────────────────────────────────────────────────────── + + +def _build_fitted_step_fn( + process_rules: List, + fitted_params: Dict[str, float], +) -> Callable[[State], Dict]: + """Create a step function from fitted process rules + parameters.""" + + def step_fn(state: State) -> Dict: + updates: Dict = {} + for rule in process_rules: + updates = rule(state, updates, fitted_params) + result: Dict = {} + for obj, feat_dict in updates.items(): + result[obj] = {} + for feat, val in feat_dict.items(): + result[obj][feat] = float(val) + return result + + return step_fn + + +def merge_process_updates( + base_state: State, + updates: Dict, + process_features: Dict[str, List[str]], +) -> State: + """Apply learned process updates on top of a base state. + + Args: + base_state: The state to merge into (e.g. from kinematics). + updates: {Object: {feat_name: new_value}} from learned dynamics. + process_features: {type_name: [feat_names]} identifying which + features to overwrite. + + Returns: + A copy of base_state with process features overwritten. + """ + if not updates: + return base_state + + new_data = {} + for obj in base_state: + arr = base_state[obj].copy() + type_name = obj.type.name + process_feats = set(process_features.get(type_name, [])) + + if obj in updates: + for feat_name, new_val in updates[obj].items(): + if feat_name in process_feats: + idx = obj.type.feature_names.index(feat_name) + arr[idx] = new_val + + new_data[obj] = arr + + merged = base_state.copy() + merged.data = new_data + return merged + + +# ── Approach ───────────────────────────────────────────────────── + + +class AgentSimLearningApproach(AgentBilevelApproach): + """Bilevel planning with a learned step-level simulator. + + During online learning: + 1. Collect trajectories (inherited from AgentBilevelApproach) + 2. Segment into option-level transitions + 3. Synthesize parameterized process rules via Claude agent + 4. Fit rule parameters via emcee ensemble MCMC + 5. Compose with kinematics-only oracle into a combined simulator + 6. Build _OracleOptionModel with the combined simulator + + During solving: + - Uses the learned model for plan validation in backtracking + refinement. + """ + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + *args: Any, + option_model: Optional[_OptionModelBase] = None, + **kwargs: Any) -> None: + # Build the kinematics-only env BEFORE super().__init__ and pass + # the resulting option model in via option_model=. This stops + # AgentPlannerApproach.__init__ from spinning up its own full- + # process env (which would conflict with this one over PyBullet + # GUI connections) and is the only env this approach holds. + # learn_from_interaction_results later wraps a kin+learned + # combined simulator around the same env. + self._base_env = create_new_env(CFG.env, do_cache=False, + use_gui=CFG.option_model_use_gui, + skip_process_dynamics=True) + if option_model is None: + # Use initial_options directly rather than get_gt_options(CFG.env) + # — the latter calls get_or_create_env which would create a + # second cached env (without GUI, with full dynamics) and the + # two PyBullet connections then fight over the physics server, + # producing "Not connected to physics server" mid-rollout. + option_model = _OracleOptionModel(initial_options, + self._base_env.simulate) + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + *args, + option_model=option_model, + **kwargs) + self._types = types + self._simulator: Optional[LearnedSimulator] = None + # Persistent state across learning cycles. + self._process_rules: Optional[List] = None + self._fitted_params: Optional[Dict[str, float]] = None + self._fit_mse: float = float("inf") + # True during simulator synthesis (learning); False during + # plan generation (decision-making). + self._learning_mode: bool = False + + @classmethod + def get_name(cls) -> str: + return "agent_sim_learning" + + # ── Agent session hooks ────────────────────────────────────── + + def _get_agent_system_prompt(self) -> str: + if self._learning_mode: + return self._build_synthesis_system_prompt() + return super()._get_agent_system_prompt() + + # ── Online learning ────────────────────────────────────────── + + def learn_from_interaction_results( + self, results: Sequence[InteractionResult]) -> None: + super().learn_from_interaction_results(results) + + if not self._online_trajectories: + logger.warning("No transitions, skipping.") + return + + logger.info("Sim-learning cycle %d: %d total trajectories.", + self._online_learning_cycle, + len(self._online_trajectories)) + + # Include all features so the agent can synthesize rules for any + # feature, not just pre-identified "process" features. + process_features: Dict[str, List[str]] = {} + for t in self._types: + if t.feature_names: + process_features[t.name] = list(t.feature_names) + + # synthesize via agent. + self._synthesize_with_agent(process_features) + + # Build simulator from fitted rules. + if self._process_rules is not None and self._fitted_params is not None: + step_fn = _build_fitted_step_fn( + self._process_rules, self._fitted_params) + self._simulator = LearnedSimulator( + step_fn=step_fn, + name="agent_synthesized") + elif self._simulator is None: + logger.warning("Synthesis produced no simulator, skipping.") + return + + # Build combined simulator: kinematics → learned dynamics. + combined_sim = self._build_combined_simulator( + self._base_env, self._simulator, process_features) + + # Wrap in an option model with interleaved per-step simulation. + self._option_model = self._build_option_model(combined_sim) + logger.info("Built learned option model (MSE: %.6f).", + self._fit_mse) + + def _build_option_model( + self, + simulator_fn: Callable[[State, Action], State], + ) -> _OracleOptionModel: + """Wrap a simulator function in an OracleOptionModel. + + Plumbs ``_abstract_function`` for Wait-target atom-change + termination so the model behaves identically whether it's + wrapping the bare kin-only simulator (init) or the learned + kin+process combined simulator (post learn_from_interaction). + Uses ``self._get_all_options()`` rather than + ``get_gt_options(CFG.env)`` to avoid spawning a second cached + PyBullet env via ``get_or_create_env``. + """ + model = _OracleOptionModel(self._get_all_options(), simulator_fn) + if CFG.wait_option_terminate_on_atom_change: + preds = self._get_all_predicates() + model._abstract_function = ( # pylint: disable=protected-access + lambda s, _p=preds: utils.abstract(s, _p)) + return model + + # ── Agent-based synthesis ──────────────────────────────────── + + def _synthesize_with_agent( + self, + process_features: Dict[str, List[str]], + ) -> None: + """Synthesize parameterized process rules via a Claude agent. + + Provides ``run_python``, ``evaluate_simulator``, and + ``test_simulator`` tools. The agent explores trajectory data + via ``run_python`` (which has a persistent namespace with + ``trajectories`` pre-loaded), then defines ``PROCESS_RULES`` + and ``PARAM_SPECS``. Each ``run_python`` call appends code + to a saved file; after the session we reload from that file. + """ + step_transitions = self._extract_step_transitions( + self._online_trajectories) + + # Directory for saving simulator source code. + base = self._tool_context.sandbox_dir or self._get_log_dir() + save_dir = os.path.join(base, "simulator_code") + + # Persistent exec namespace — the agent's "scratch-pad". + exec_ns: Dict[str, Any] = { + "trajectories": self._online_trajectories, + "np": np, + "ParamSpec": ParamSpec, + } + + # Build synthesis tools (run_python, evaluate, test). + tools = create_synthesis_tools( + exec_ns, step_transitions, process_features, self._base_env, + save_dir=save_dir) + self._tool_context.extra_mcp_tools = tools + self._learning_mode = True + + # Force a fresh session so the synthesis system prompt and + # tool set take effect. + self._close_agent_session() + self._ensure_agent_session() + + # Write data-structure reference for the agent to Read. + structs_ref = self._write_structs_reference() + + n_trajs = len(self._online_trajectories) + message = f"""\ +Synthesize a process dynamics simulator for this environment. \ +There are {n_trajs} trajectories ({len(step_transitions)} step \ +transitions) available. + +Data-structure source code is at: {structs_ref} +Read that file first, then explore the trajectory data with \ +`run_python` and define PROCESS_RULES and PARAM_SPECS.""" + + try: + self._query_agent_sync(message) + finally: + self._tool_context.extra_mcp_tools = [] + self._learning_mode = False + self._close_agent_session() + + # Load results from saved versioned files. + rules, specs = self._load_simulator_from_file( + save_dir, self._online_trajectories) + if rules is None or specs is None: + return + + self._process_rules = rules + + # Fit parameters via MCMC. + self._fitted_params, self._fit_mse = self._fit_parameters( + rules, specs, step_transitions, process_features, + self._base_env) + logger.info( + "Agent synthesized %d rules, %d params (MSE: %.6f).", + len(rules), len(specs), self._fit_mse) + + # ── Parameter fitting ──────────────────────────────────────── + + @staticmethod + def _fit_parameters( + rules: List, + specs: List[ParamSpec], + step_transitions: List[Tuple[State, Action, State]], + process_features: Dict[str, List[str]], + kin_env: Any = None, + ) -> Tuple[Dict[str, float], float]: + """Fit parameters for the synthesized rules via MCMC. + + Args: + kin_env: Kinematics-only environment. When provided the + simulator runs kinematics first so learned rules see + the post-kinematics state (consistent with inference). + + Returns: + (fitted_params, mse) tuple. + """ + + def sim_fn(state: State, action: Action, + params: Dict[str, float]) -> Dict: + if kin_env is not None: + state = kin_env.simulate(state, action) + updates: Dict = {} + for rule in rules: + updates = rule(state, updates, params) + return updates + + result = fit_params( + simulator_fn=sim_fn, + transitions=step_transitions, + param_specs=specs, + process_features=process_features, + ) + + mse = compute_mse( + sim_fn, step_transitions, result.point_estimate, process_features) + return result.point_estimate, mse + + @staticmethod + def _load_simulator_from_file( + save_dir: str, + trajectories: Optional[List[LowLevelTrajectory]] = None, + ) -> Tuple[Optional[List], Optional[List[ParamSpec]]]: + """Load PROCESS_RULES and PARAM_SPECS from versioned code files. + + Executes all ``NNN_run_python.py`` files in ``save_dir`` in + order, accumulating into a single namespace. + + Returns (rules, specs), either of which may be None on failure. + """ + if not os.path.isdir(save_dir): + logger.warning("No simulator code dir at %s.", save_dir) + return None, None + + files = sorted( + f for f in os.listdir(save_dir) + if f.endswith(".py") and f[0].isdigit()) + if not files: + logger.warning("No code files in %s.", save_dir) + return None, None + + ns: Dict[str, Any] = { + "np": np, + "ParamSpec": ParamSpec, + "trajectories": trajectories or [], + } + for fname in files: + fpath = os.path.join(save_dir, fname) + with open(fpath, "r", encoding="utf-8") as f: + code = f.read() + try: + exec(code, ns) # pylint: disable=exec-used + except Exception: + logger.warning("Failed to exec %s, skipping.", fpath, + exc_info=True) + + rules = ns.get("PROCESS_RULES") + specs = ns.get("PARAM_SPECS") + if not isinstance(rules, list) or not rules: + logger.warning("Saved code did not define PROCESS_RULES.") + return None, None + if not isinstance(specs, list) or not specs: + logger.warning("Saved code did not define PARAM_SPECS.") + return None, None + + logger.info("Loaded %d rules, %d param specs from %d files in %s.", + len(rules), len(specs), len(files), save_dir) + return rules, specs + + # ── Static helpers ─────────────────────────────────────────── + + def _write_structs_reference(self) -> str: + """Write extracted source of key structs to the sandbox. + + Returns the path the agent should Read. + """ + from predicators.structs import ( # pylint: disable=import-outside-toplevel + Action as _Action, LowLevelTrajectory as _LLT, + Object as _Object, State as _State, Type as _Type) + + source = "\n\n".join( + inspect.getsource(cls) + for cls in [_Type, _Object, _State, _Action, _LLT]) + + # Write into sandbox reference dir if available, else log dir. + base = self._tool_context.sandbox_dir or self._get_log_dir() + ref_dir = os.path.join(base, "reference") + os.makedirs(ref_dir, exist_ok=True) + ref_path = os.path.join(ref_dir, "structs.py") + with open(ref_path, "w", encoding="utf-8") as f: + f.write(source) + + # In Docker sandbox the agent sees /sandbox/reference/structs.py. + if self._tool_context.sandbox_dir: + return "/sandbox/reference/structs.py" + return ref_path + + @staticmethod + def _extract_step_transitions( + trajectories: List[LowLevelTrajectory], + ) -> List[Tuple[State, Action, State]]: + """Extract consecutive (s_t, action_t, s_{t+1}) triples.""" + triples: List[Tuple[State, Action, State]] = [] + for traj in trajectories: + for i in range(len(traj.actions)): + triples.append( + (traj.states[i], traj.actions[i], traj.states[i + 1])) + return triples + + @staticmethod + def _build_combined_simulator( + kin_env: Any, + simulator: LearnedSimulator, + process_features: Dict[str, List[str]], + ) -> Callable[[State, Action], State]: + """Compose kinematics-only env with learned step-level dynamics.""" + + def combined_simulate(state: State, action: Action) -> State: + kin_state = kin_env.simulate(state, action) + updates = simulator.predict_step(kin_state) + if not updates: + return kin_state + return merge_process_updates(kin_state, updates, process_features) + + return combined_simulate + + @staticmethod + def _build_synthesis_system_prompt() -> str: + """Build the system prompt for the synthesis agent.""" + return """\ +You are synthesizing a parameterized process dynamics simulator for a \ +robotic manipulation environment. + +A separate physics engine (PyBullet) handles kinematics (robot movement, \ +grasping, rigid body physics). Your simulator handles **process dynamics**: \ +non-kinematic features that change due to ongoing physical or causal processes. + +## Tools + +- `run_python(code)` — execute Python in a persistent namespace. `print()` \ +output is returned. The namespace persists across calls. +- `evaluate_simulator` — fit parameters using PROCESS_RULES and PARAM_SPECS \ +from the namespace. Reports MSE. +- `test_simulator` — test predictions vs observations on step transitions. \ +Shows mismatches. + +### Pre-loaded variables + +- `trajectories`: List[LowLevelTrajectory] — the collected trajectory data +- `np`, `ParamSpec` — standard imports + +### Data structures + +The trajectory data uses classes from `predicators.structs` (Type, Object, \ +State, Action, LowLevelTrajectory). Their source code is provided as a \ +reference file — Read the path given in the first message. + +## Goal + +Define two variables in the `run_python` namespace: + +- `PROCESS_RULES`: list of rule functions +- `PARAM_SPECS`: list of ParamSpec objects + +Parameters are fitted automatically after the session ends. + +### Process rule signature + +```python +def rule(state, updates, params): + \"\"\"Apply one process for a single simulation step. + + Args: + state: Current env state. + updates: Dict[Object, Dict[str, value]] accumulated from prior rules. + params: Dict[str, float] of learned parameters. + + Returns: + The (possibly modified) updates dict. + \"\"\" +``` + +### ParamSpec + +```python +ParamSpec(name: str, init_value: float) +``` + +## Workflow + +1. Explore the trajectory data with `run_python`: types, features, \ +state changes over time +2. Identify which features change due to process dynamics (not kinematics) +3. Define `PROCESS_RULES` and `PARAM_SPECS` in the namespace via `run_python` +4. Call `evaluate_simulator` to fit parameters and check MSE +5. Call `test_simulator` to see prediction mismatches +6. Iterate if needed + +## Tips + +- Each trajectory is a sequence of states from one episode. Compare \ +consecutive states to see per-step changes. +- Group objects by type: \ +`groups = {}; for o in state: groups.setdefault(o.type.name, []).append(o)` +- Accumulate updates: `updates.setdefault(obj, {})[feat] = new_value` +""" diff --git a/predicators/code_sim_learning/__init__.py b/predicators/code_sim_learning/__init__.py new file mode 100644 index 000000000..685d11353 --- /dev/null +++ b/predicators/code_sim_learning/__init__.py @@ -0,0 +1 @@ +"""Compositional world modeling via code""" diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py new file mode 100644 index 000000000..bffb8dd8c --- /dev/null +++ b/predicators/code_sim_learning/training.py @@ -0,0 +1,156 @@ +"""Training utilities for the sim-learning approach. + +Parameter fitting via emcee (affine-invariant ensemble MCMC). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple + +import numpy as np + +from predicators.structs import Action, State + +logger = logging.getLogger(__name__) + + +# Step-level simulator: (State, Action, params_dict) -> {Object: {feat: val}} +StepSimulatorFn = Callable[[State, Action, Dict[str, float]], Dict] + + +@dataclass +class ParamSpec: + """Specification for a single learnable parameter.""" + + name: str + init_value: float + + +@dataclass +class FitResult: + """Result of parameter fitting.""" + + names: List[str] + samples: np.ndarray # (num_samples, num_params) + log_probs: np.ndarray # (num_samples,) + + @property + def point_estimate(self) -> Dict[str, float]: + """Posterior mean.""" + mean = self.samples.mean(axis=0) + return {n: float(mean[i]) for i, n in enumerate(self.names)} + + +def compute_mse( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], +) -> float: + """Compute MSE between predicted and observed process features.""" + total_se = 0.0 + count = 0 + + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + + for obj, feat_dict in updates.items(): + type_name = obj.type.name + allowed_feats = process_features.get(type_name, []) + for feat_name, pred_val in feat_dict.items(): + if feat_name not in allowed_feats: + continue + v = pred_val.item() if hasattr(pred_val, 'item') else pred_val + obs_val = float(s_next_obs.get(obj, feat_name)) + total_se += (v - obs_val) ** 2 + count += 1 + + # Penalize unpredicted features (model predicts no change). + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + continue + pred_val = float(s_t.get(obj, feat_name)) + obs_val = float(s_next_obs.get(obj, feat_name)) + total_se += (pred_val - obs_val) ** 2 + count += 1 + + if count == 0: + return 0.0 + return total_se / count + + +def fit_params( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + param_specs: List[ParamSpec], + process_features: Dict[str, List[str]], + num_walkers: int = 32, + num_steps: int = 500, + burn_in: int = 200, + noise_sigma: float = 0.05, + prior_sigma_scale: float = 2.0, +) -> FitResult: + """Fit simulator parameters via emcee ensemble MCMC. + + Gradient-free — handles all parameter types (rates, thresholds, + capacities) uniformly. Returns full posterior with uncertainty. + + Args: + simulator_fn: Simulator(state, action, params_dict) -> updates. + Should run kinematics internally if needed. + transitions: List of (s_t, action, s_{t+1}_obs) triples. + param_specs: Parameter specifications (name, init_value). + process_features: {type_name: [feat_names]} to fit. + num_walkers: Number of ensemble walkers (>= 2*ndim). + num_steps: Total MCMC steps per walker. + burn_in: Steps to discard as burn-in. + noise_sigma: Observation noise std dev for likelihood. + prior_sigma_scale: Prior width as multiple of init_value. + + Returns: + FitResult with posterior samples and log-probabilities. + """ + import emcee # pylint: disable=import-outside-toplevel + + names = [s.name for s in param_specs] + init_values = np.array([s.init_value for s in param_specs]) + ndim = len(param_specs) + num_walkers = max(num_walkers, 2 * ndim + 2) + prior_sigma = init_values * prior_sigma_scale + + def log_posterior(theta: np.ndarray) -> float: + # Reject negative values + if np.any(theta <= 0): + return -np.inf + params = {n: float(theta[i]) for i, n in enumerate(names)} + # Broad Gaussian prior centered on init values + log_prior = -0.5 * np.sum( + ((theta - init_values) / prior_sigma) ** 2) + # Likelihood + mse = compute_mse(simulator_fn, transitions, + params, process_features) + return log_prior + (-0.5 * mse / (noise_sigma ** 2)) + + # Initialize walkers in a small ball around init values. + p0 = init_values * (1.0 + 0.01 * np.random.randn(num_walkers, ndim)) + + sampler = emcee.EnsembleSampler(num_walkers, ndim, log_posterior) + + logger.info("Running emcee: %d walkers, %d steps, %d burn-in.", + num_walkers, num_steps, burn_in) + sampler.run_mcmc(p0, num_steps, progress=False) + + # Discard burn-in, flatten chains. + samples = sampler.get_chain(discard=burn_in, flat=True) + log_probs = sampler.get_log_prob(discard=burn_in, flat=True) + + result = FitResult(names=names, samples=samples, log_probs=log_probs) + + logger.info("emcee done. Posterior mean: %s", + {k: f"{v:.4f}" for k, v in result.point_estimate.items()}) + + return result diff --git a/predicators/code_sim_learning/utils.py b/predicators/code_sim_learning/utils.py new file mode 100644 index 000000000..0d541f6f2 --- /dev/null +++ b/predicators/code_sim_learning/utils.py @@ -0,0 +1,38 @@ +"""Utilities for the code sim-learning module.""" + +from __future__ import annotations + +import logging +from typing import Callable, Dict + +from predicators.structs import Object, State + +logger = logging.getLogger(__name__) + +# Type alias: {Object: {feature_name: new_value}} +ProcessUpdate = Dict[Object, Dict[str, float]] + + +class LearnedSimulator: + """Wraps a step-level simulator function (handwritten or LLM-synthesized). + + The function predicts process dynamics — features like water_volume, + heat_level, spilled_level that aren't captured by rigid body + physics. + """ + + StepFn = Callable[[State], ProcessUpdate] + + def __init__(self, + step_fn: StepFn, + name: str = "learned_simulator") -> None: + self._step_fn = step_fn + self.name = name + + def predict_step(self, state: State) -> ProcessUpdate: + """Predict process feature updates for a single timestep.""" + try: + return self._step_fn(state) + except Exception as e: # pylint: disable=broad-except + logger.warning("Simulator '%s' step raised: %s", self.name, e) + return {} diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py new file mode 100644 index 000000000..9e3c46054 --- /dev/null +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -0,0 +1,165 @@ +"""Ground-truth simulator program for pybullet_boil process dynamics. + +Reproduces the custom step logic from pybullet_boil.py as composable +process rules using plain numpy/float arithmetic. +""" + +from __future__ import annotations + +from typing import Dict, List + +import numpy as np + +from predicators.code_sim_learning.training import ParamSpec +from predicators.code_sim_learning.utils import ProcessUpdate +from predicators.structs import Object, State + +# Constants matching pybullet_boil.py exactly. +WATER_FILL_SPEED = 0.02 # 0.002 * water_height_to_level_ratio(10) +HEATING_SPEED = 0.03 +HAPPINESS_SPEED = 0.05 +MAX_JUG_WATER_CAPACITY = 1.3 +WATER_FILLED_HEIGHT = 0.8 +MAX_WATER_SPILL_WIDTH = 0.3 +FAUCET_ALIGN_THRESHOLD = 0.1 +BURNER_ALIGN_THRESHOLD = 0.05 +FAUCET_X_LEN = 0.15 + +# Parameter specs for fitting. +BOIL_PARAM_SPECS: List[ParamSpec] = [ + ParamSpec("water_fill_speed", WATER_FILL_SPEED), + ParamSpec("heating_speed", HEATING_SPEED), + ParamSpec("happiness_speed", HAPPINESS_SPEED), + ParamSpec("max_jug_water_capacity", MAX_JUG_WATER_CAPACITY), + ParamSpec("water_filled_height", WATER_FILLED_HEIGHT), + ParamSpec("max_water_spill_width", MAX_WATER_SPILL_WIDTH), + ParamSpec("faucet_x_len", FAUCET_X_LEN), + ParamSpec("faucet_align_threshold", FAUCET_ALIGN_THRESHOLD), + ParamSpec("burner_align_threshold", BURNER_ALIGN_THRESHOLD), +] + +Params = Dict[str, float] + + +def _objs_by_type(state: State) -> Dict[str, List[Object]]: + """Group state objects by type name.""" + groups: Dict[str, List[Object]] = {} + for o in state: + groups.setdefault(o.type.name, []).append(o) + return groups + + +def _water_filling(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Faucet on + jug aligned → fill jug; otherwise spill.""" + objs = _objs_by_type(state) + for faucet in objs.get("faucet", []): + if state.get(faucet, "is_on") <= 0.5: + continue + + fx = float(state.get(faucet, "x")) + fy = float(state.get(faucet, "y")) + frot = float(state.get(faucet, "rot")) + out_x = fx + params["faucet_x_len"] * np.cos(frot) + out_y = fy - params["faucet_x_len"] * np.sin(frot) + + jug_catching = False + for jug in objs.get("jug", []): + if state.get(jug, "is_held") > 0.5: + continue + jx = float(state.get(jug, "x")) + jy = float(state.get(jug, "y")) + dist = float(np.hypot(out_x - jx, out_y - jy)) + + if dist < params["faucet_align_threshold"]: + water = float(state.get(jug, "water_volume")) + if water < params["max_jug_water_capacity"]: + new_water = min(params["max_jug_water_capacity"], + water + params["water_fill_speed"]) + updates.setdefault(jug, {})["water_volume"] = new_water + jug_catching = True + else: + spill = float(state.get(faucet, "spilled_level")) + new_spill = min(params["max_water_spill_width"], + spill + params["water_fill_speed"]) + updates.setdefault( + faucet, {})["spilled_level"] = new_spill + break + + if not jug_catching: + spill = float(state.get(faucet, "spilled_level")) + new_spill = min(params["max_water_spill_width"], + spill + params["water_fill_speed"]) + updates.setdefault(faucet, {})["spilled_level"] = new_spill + + return updates + + +def _heating(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Burner on + jug with water aligned → heat jug.""" + objs = _objs_by_type(state) + for burner in objs.get("burner", []): + if state.get(burner, "is_on") <= 0.5: + continue + bx = float(state.get(burner, "x")) + by = float(state.get(burner, "y")) + + for jug in objs.get("jug", []): + if state.get(jug, "is_held") > 0.5: + continue + if state.get(jug, "water_volume") <= 0.0: + continue + jx = float(state.get(jug, "x")) + jy = float(state.get(jug, "y")) + dist = float(np.hypot(bx - jx, by - jy)) + + if dist < params["burner_align_threshold"]: + heat = float(state.get(jug, "heat_level")) + new_heat = min(1.0, heat + params["heating_speed"]) + updates.setdefault(jug, {})["heat_level"] = new_heat + + return updates + + +def _happiness(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Jug filled + boiled + no spill + burner off → human happy.""" + objs = _objs_by_type(state) + faucets = objs.get("faucet", []) + burners = objs.get("burner", []) + + def _get_val(obj: Object, feat: str) -> float: + val = updates.get(obj, {}).get(feat, None) + if val is not None: + return float(val) if hasattr(val, 'item') else val + return float(state.get(obj, feat)) + + any_spill = any(_get_val(f, "spilled_level") > 0 for f in faucets) + any_burner_on = any(state.get(b, "is_on") > 0.5 for b in burners) + + if any_spill or any_burner_on: + return updates + + for jug in objs.get("jug", []): + water = _get_val(jug, "water_volume") + heat = _get_val(jug, "heat_level") + if water >= params["water_filled_height"] and heat >= 1.0: + for human in objs.get("human", []): + h = float(state.get(human, "happiness_level")) + new_h = min(1.0, h + params["happiness_speed"]) + updates.setdefault(human, {})["happiness_level"] = new_h + + return updates + + +PROCESS_RULES = [_water_filling, _heating, _happiness] + + +def get_gt_process_features() -> Dict[str, List[str]]: + """Process features handled by the simulator (not PyBullet).""" + return { + "jug": ["water_volume", "heat_level"], + "faucet": ["spilled_level"], + "human": ["happiness_level"], + } diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py new file mode 100644 index 000000000..4e1367fa5 --- /dev/null +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -0,0 +1,365 @@ +"""Integration test: GT simulator + backtracking refinement solves boil. + +Verifies that given a correct plan sketch (from a real agent run) and a +ground-truth simulator program, the hybrid learned option model +(PyBullet + learned process dynamics) can find continuous parameters +that solve a pybullet_boil task. +""" +# pylint: disable=protected-access +import logging +import os +import re +from typing import List, Optional, Sequence, Set, Tuple + +import numpy as np +import pytest + +from predicators import utils +from predicators.approaches.agent_bilevel_approach import _SketchStep +from predicators.approaches.agent_sim_learning_approach import \ + merge_process_updates +from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_options +from predicators.ground_truth_models.boil.gt_simulator import \ + BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features +from predicators.option_model import _OracleOptionModel +from predicators.planning import run_backtracking_refinement +from predicators.structs import GroundAtom, Object, ParameterizedOption, \ + Predicate + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _setup_env(): + """Create boil env and return (env, task, options_dict, objects_dict).""" + utils.reset_config({ + "env": "pybullet_boil", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "boil_goal": "simple", + "boil_num_jugs_train": [1], + "boil_num_jugs_test": [1], + "boil_num_burner_train": [1], + "boil_num_burner_test": [1], + "option_model_use_gui": False, + "wait_option_terminate_on_atom_change": True, + }) + env = create_new_env("pybullet_boil", do_cache=False, use_gui=False) + task = [t.task for t in env.get_train_tasks()][0] + options = get_gt_options(env.get_name()) + options_dict = {o.name: o for o in options} + objects_dict = {obj.name: obj for obj in task.init} + return env, task, options_dict, objects_dict + + +def _build_oracle_model(env): + """Build an oracle option model.""" + options = get_gt_options(env.get_name()) + oracle = _OracleOptionModel(options, env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) + return oracle + + +def _build_kinematics_only_oracle(env): + """Build an oracle that only handles kinematics (no process dynamics). + + Creates a separate env instance with process dynamics disabled, so + that water filling, heating, and happiness are not simulated. + """ + kin_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, + skip_process_dynamics=True) + options = get_gt_options(kin_env.get_name()) + oracle = _OracleOptionModel(options, kin_env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) + return oracle + + +def _build_combined_model(env): + """Build a combined model: kinematics-only env + GT step-level dynamics. + + This mirrors the approach's design: compose a kinematics-only + env.simulate with a step-level dynamics function into a single + simulator, then plug into a standard _OracleOptionModel. + """ + kin_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, + skip_process_dynamics=True) + process_features = get_gt_process_features() + gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} + + def combined_simulate(state, action): + kin_state = kin_env.simulate(state, action) + updates = {} + for rule in PROCESS_RULES: + updates = rule(kin_state, updates, gt_params) + if not updates: + return kin_state + return merge_process_updates(kin_state, updates, process_features) + + options = get_gt_options(env.get_name()) + model = _OracleOptionModel(options, combined_simulate) + preds = env.predicates + model._abstract_function = lambda s: utils.abstract(s, preds) + return model + + +def _parse_sketch_from_file( + sketch_file: str, + options: Set[ParameterizedOption], + types: Set, + predicates: Set[Predicate], + objects: Sequence[Object], +) -> List[_SketchStep]: + """Parse a plan sketch from a text file, same as agent_bilevel_approach.""" + with open(sketch_file, "r") as f: + plan_text = f.read().strip() + + # Phase 1: parse options + objects (no continuous params) + parsed = utils.parse_model_output_into_option_plan( + plan_text, objects, types, options, parse_continuous_params=False) + assert parsed, f"Parsed empty plan sketch from {sketch_file}" + + # Phase 2: parse subgoal annotations + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + option_names = {o.name for o in options} + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + subgoals: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + for line in plan_text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + sg_match = subgoal_re.search(stripped) + if not sg_match: + subgoals.append(None) + continue + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] for n in atom_match.group(3).split(',') + ] + if pred_name not in pred_map: + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError: + continue + if len(objs) != len(pred.types): + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + if pos_atoms or neg_atoms: + subgoals.append((pos_atoms, neg_atoms)) + else: + subgoals.append(None) + + # Zip into sketch steps + sketch = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + _SketchStep(option=option, objects=objs, subgoal_atoms=None)) + return sketch + + +def _informed_place_params(pre_state, sketch, step_idx, rng, n): + """Sample Place params biased toward the contextual target.""" + step = sketch[step_idx] + low = step.option.params_space.low + high = step.option.params_space.high + eps = 1e-4 + + next_step = sketch[step_idx + 1] if step_idx + 1 < n else None + + if next_step and "Faucet" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "faucet": + fx = pre_state.get(obj, "x") + fy = pre_state.get(obj, "y") + frot = pre_state.get(obj, "rot") + # The jug has a physics offset after drop, so target + # slightly past the faucet output to compensate. + out_x = fx + 0.15 * np.cos(frot) + out_y = fy - 0.15 * np.sin(frot) + # Target near faucet output x but lower y (IK-reachable). + x = np.clip(out_x + rng.normal(0, 0.02), low[0] + eps, + high[0] - eps) + y = np.clip(out_y - 0.05 + rng.normal(0, 0.03), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + # Negative yaw helps place jug closer to faucet output. + yaw = np.clip(rng.normal(-0.3, 0.5), low[3] + eps, + high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + if next_step and "Burner" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "burner": + bx = pre_state.get(obj, "x") + by = pre_state.get(obj, "y") + x = np.clip(bx + rng.normal(0, 0.05), low[0] + eps, + high[0] - eps) + y = np.clip(by + rng.normal(0, 0.05), low[1] + eps, + high[1] - eps) + # Bias z toward low end for reliable IK. + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = rng.uniform(low[3] + eps, high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + return rng.uniform(low + eps, high - eps).astype(np.float32) + + +def _refine(task, + sketch, + option_model, + predicates, + seed=0, + max_samples=200, + timeout=600.0): + """Run backtracking refinement with informed Place sampling.""" + rng = np.random.default_rng(seed) + n = len(sketch) + max_tries = [ + max_samples if step.option.params_space.shape[0] > 0 else 1 + for step in sketch + ] + + def sample_fn(idx, state, rng_): + step = sketch[idx] + if step.option.params_space.shape[0] == 0: + params = np.array([], dtype=np.float32) + elif step.option.name == "Place": + params = _informed_place_params(state, sketch, idx, rng_, n) + else: + low = step.option.params_space.low + high = step.option.params_space.high + params = rng_.uniform(low, high).astype(np.float32) + grounded = step.option.ground(step.objects, params) + if grounded.name == "Wait" and step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + return grounded + + def validate_fn(idx, _pre, _opt, post_state, _n_acts): + step = sketch[idx] + if step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + missing = step.subgoal_atoms - current_atoms + return False, f"subgoal missing: {missing}" + if idx == n - 1 and not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + plan, success, total_samples = run_backtracking_refinement( + init_state=task.init, + option_model=option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + ) + logger.info("Refinement: %s, %d total samples", + "success" if success else "failed", total_samples) + return [p for p in plan if p is not None], success + + +SKETCH_FILE = os.path.join(os.path.dirname(__file__), "test_data", + "boil_plan_sketch.txt") + + +@pytest.mark.parametrize("model_type", ["oracle", "combined"]) +def test_boil_sketch_refinement(model_type): + """Test that backtracking refinement solves a boil task.""" + env, task, options_dict, objects_dict = _setup_env() + predicates = env.predicates + options = get_gt_options(env.get_name()) + + if model_type == "oracle": + option_model = _build_oracle_model(env) + else: + option_model = _build_combined_model(env) + + sketch = _parse_sketch_from_file(SKETCH_FILE, options, env.types, + predicates, list(task.init)) + plan, success = _refine(task, + sketch, + option_model, + predicates, + max_samples=500, + timeout=1200.0) + + logger.info("Model=%s, success=%s, plan_len=%d", model_type, success, + len(plan)) + if success: + for i, opt in enumerate(plan): + objs = ", ".join(o.name for o in opt.objects) + params = ", ".join(f"{p:.3f}" for p in opt.params) + logger.info(" %d: %s(%s)[%s]", i, opt.name, objs, params) + + assert success, (f"Refinement failed with {model_type} model. " + f"Partial plan: {len(plan)} steps.") + + # Forward validation: re-execute the plan in the oracle model (full + # env dynamics) to verify the plan actually solves the task. + # Always uses the oracle regardless of which model found the plan. + oracle_model = _build_oracle_model(env) + n = len(plan) + + def fwd_sample_fn(i, _s, _r): + return plan[i] + + def fwd_validate_fn(i, _s, _o, post, _n): + if i == n - 1 and not task.goal_holds(post): + return False, "goal not reached" + return True, "" + + _, fwd_success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=oracle_model, + n_steps=n, + max_tries=[1] * n, + sample_fn=fwd_sample_fn, + validate_fn=fwd_validate_fn, + rng=np.random.default_rng(0), + timeout=600.0, + ) + if fwd_success: + logger.info("Forward validation passed for %s model.", model_type) + else: + logger.warning( + "Forward validation failed for %s model " + "(PyBullet state reconstruction is imperfect).", model_type) + + +if __name__ == "__main__": + import sys + model = sys.argv[1] if len(sys.argv) > 1 else "oracle" + test_boil_sketch_refinement(model) diff --git a/tests/code_sim_learning/test_param_fitting.py b/tests/code_sim_learning/test_param_fitting.py new file mode 100644 index 000000000..82853b9ce --- /dev/null +++ b/tests/code_sim_learning/test_param_fitting.py @@ -0,0 +1,321 @@ +"""Test parameter fitting recovers GT simulator parameters. + +Uses step-level transitions from a real oracle trajectory (boil env), +then fits from perturbed initial values via emcee. +""" + +import logging +import os +import re +from typing import Dict, List, Optional, Sequence, Set, Tuple + +import predicators.approaches # noqa: F401 (bootstrap circular import) +import numpy as np + +from predicators import utils +from predicators.approaches.agent_bilevel_approach import _SketchStep +from predicators.code_sim_learning.training import ParamSpec, fit_params +from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_options +from predicators.ground_truth_models.boil.gt_simulator import ( + BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features) +from predicators.option_model import _OracleOptionModel +from predicators.planning import run_backtracking_refinement +from predicators.structs import Action, GroundAtom, Object, \ + ParameterizedOption, Predicate, State + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Ground-truth parameter values (from BOIL_PARAM_SPECS). +GT_PARAMS = {s.name: s.init_value for s in BOIL_PARAM_SPECS} + +SKETCH_FILE = os.path.join(os.path.dirname(__file__), "..", "approaches", + "test_data", "boil_plan_sketch.txt") + + +def _setup_env(): + """Create boil env and return (env, task, options, predicates).""" + utils.reset_config({ + "env": "pybullet_boil", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "boil_goal": "simple", + "boil_num_jugs_train": [1], + "boil_num_jugs_test": [1], + "boil_num_burner_train": [1], + "boil_num_burner_test": [1], + "option_model_use_gui": False, + "wait_option_terminate_on_atom_change": True, + }) + env = create_new_env("pybullet_boil", do_cache=False, use_gui=False) + task = [t.task for t in env.get_train_tasks()][0] + options = get_gt_options(env.get_name()) + return env, task, options + + +def _build_oracle_model(env): + """Build an oracle option model.""" + options = get_gt_options(env.get_name()) + oracle = _OracleOptionModel(options, env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) + return oracle + + +def _parse_sketch_from_file( + sketch_file: str, + options: Set[ParameterizedOption], + types: Set, + predicates: Set[Predicate], + objects: Sequence[Object], +) -> List[_SketchStep]: + """Parse a plan sketch from a text file.""" + with open(sketch_file, "r") as f: + plan_text = f.read().strip() + + parsed = utils.parse_model_output_into_option_plan( + plan_text, objects, types, options, parse_continuous_params=False) + assert parsed, f"Parsed empty plan sketch from {sketch_file}" + + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + option_names = {o.name for o in options} + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + subgoals: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + for line in plan_text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + sg_match = subgoal_re.search(stripped) + if not sg_match: + subgoals.append(None) + continue + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] for n in atom_match.group(3).split(',') + ] + if pred_name not in pred_map: + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError: + continue + if len(objs) != len(pred.types): + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + if pos_atoms or neg_atoms: + subgoals.append((pos_atoms, neg_atoms)) + else: + subgoals.append(None) + + sketch = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + _SketchStep(option=option, objects=objs, subgoal_atoms=None)) + return sketch + + +def _informed_place_params(pre_state, sketch, step_idx, rng, n): + """Sample Place params biased toward the contextual target.""" + step = sketch[step_idx] + low = step.option.params_space.low + high = step.option.params_space.high + eps = 1e-4 + + next_step = sketch[step_idx + 1] if step_idx + 1 < n else None + + if next_step and "Faucet" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "faucet": + fx = pre_state.get(obj, "x") + fy = pre_state.get(obj, "y") + frot = pre_state.get(obj, "rot") + out_x = fx + 0.15 * np.cos(frot) + out_y = fy - 0.15 * np.sin(frot) + x = np.clip(out_x + rng.normal(0, 0.02), low[0] + eps, + high[0] - eps) + y = np.clip(out_y - 0.05 + rng.normal(0, 0.03), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = np.clip(rng.normal(-0.3, 0.5), low[3] + eps, + high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + if next_step and "Burner" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "burner": + bx = pre_state.get(obj, "x") + by = pre_state.get(obj, "y") + x = np.clip(bx + rng.normal(0, 0.05), low[0] + eps, + high[0] - eps) + y = np.clip(by + rng.normal(0, 0.05), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = rng.uniform(low[3] + eps, high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + return rng.uniform(low + eps, high - eps).astype(np.float32) + + +def _generate_oracle_transitions( + env, task, options, oracle, +) -> List[Tuple[State, Action, State]]: + """Generate (s, a, s') triples by running the oracle on the boil task. + + Parses the plan sketch, runs backtracking refinement to find + continuous parameters, then replays the plan through the oracle + model to collect step-level transitions with real actions. + """ + predicates = env.predicates + sketch = _parse_sketch_from_file(SKETCH_FILE, options, env.types, + predicates, list(task.init)) + n = len(sketch) + rng = np.random.default_rng(0) + max_tries = [ + 500 if step.option.params_space.shape[0] > 0 else 1 + for step in sketch + ] + + def sample_fn(idx, state, rng_): + step = sketch[idx] + if step.option.params_space.shape[0] == 0: + params = np.array([], dtype=np.float32) + elif step.option.name == "Place": + params = _informed_place_params(state, sketch, idx, rng_, n) + else: + low = step.option.params_space.low + high = step.option.params_space.high + params = rng_.uniform(low, high).astype(np.float32) + grounded = step.option.ground(step.objects, params) + if grounded.name == "Wait" and step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + return grounded + + def validate_fn(idx, _pre, _opt, post_state, _n_acts): + step = sketch[idx] + if step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + return False, "subgoal missing" + if idx == n - 1 and not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + # Collect trajectories during refinement (not replay, since + # PyBullet state reconstruction is imperfect). + step_trajectories: Dict[int, object] = {} + + orig_validate = validate_fn + + def collecting_validate_fn(idx, pre, opt, post_state, n_acts): + ok, reason = orig_validate(idx, pre, opt, post_state, n_acts) + if ok and oracle.last_trajectory is not None: + step_trajectories[idx] = oracle.last_trajectory + return ok, reason + + plan, success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=oracle, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=collecting_validate_fn, + rng=rng, + timeout=1200.0, + ) + assert success, "Need a successful plan to generate transitions" + + # Extract step-level transitions from collected trajectories. + transitions: List[Tuple[State, Action, State]] = [] + for idx in sorted(step_trajectories.keys()): + traj = step_trajectories[idx] + for i in range(len(traj.actions)): + transitions.append( + (traj.states[i], traj.actions[i], traj.states[i + 1])) + + logger.info("Collected %d step-level transitions from oracle.", + len(transitions)) + return transitions + + +def test_emcee_recovers_rate_params(): + """Fit perturbed rate params from oracle-generated data.""" + np.random.seed(42) + env, task, options = _setup_env() + oracle = _build_oracle_model(env) + transitions = _generate_oracle_transitions(env, task, options, oracle) + process_features = get_gt_process_features() + + logger.info("Generated %d oracle transitions.", len(transitions)) + + def simulator_fn(state, action, params): + updates = {} + for rule in PROCESS_RULES: + updates = rule(state, updates, params) + return updates + + # Perturb rate params (50%), keep others at true. + param_specs = [] + for s in BOIL_PARAM_SPECS: + if s.name in ("water_fill_speed", "heating_speed", + "happiness_speed"): + param_specs.append(ParamSpec(s.name, s.init_value * 0.5)) + else: + param_specs.append(s) + + result = fit_params( + simulator_fn=simulator_fn, + transitions=transitions, + param_specs=param_specs, + process_features=process_features, + num_walkers=32, + num_steps=500, + burn_in=200, + noise_sigma=0.05, + ) + + fitted = result.point_estimate + logger.info("Fitted params (posterior mean):") + for name, val in fitted.items(): + true_val = GT_PARAMS[name] + rel_err = abs(val - true_val) / max(true_val, 1e-8) + logger.info(" %s: fitted=%.4f, true=%.4f, rel_err=%.1f%%", + name, val, true_val, rel_err * 100) + + for name in ["water_fill_speed", "heating_speed", "happiness_speed"]: + true_val = GT_PARAMS[name] + fitted_val = fitted[name] + rel_err = abs(fitted_val - true_val) / true_val + assert rel_err < 0.3, ( + f"{name}: fitted={fitted_val:.4f}, true={true_val:.4f}, " + f"rel_err={rel_err:.1%}") + + logger.info("All rate parameter recovery checks passed.") From f392458f73fc3c0588a26b941cfa7b307c135be0 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 10:17:46 +0100 Subject: [PATCH 23/70] Update experiment configs for sim-learning - agents.yaml: comment out agent_bilevel preset, add agent_sim_learning with explorer=agent_bilevel and skip_test_until_last_ite_or_early_stopping. - common.yaml: disable failure/test video recording, set num_online_learning_cycles=1 for faster iteration. --- .../predicatorv3/approaches/agents.yaml | 22 ++++++++++++++++--- scripts/configs/predicatorv3/common.yaml | 6 ++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/scripts/configs/predicatorv3/approaches/agents.yaml b/scripts/configs/predicatorv3/approaches/agents.yaml index 946a30713..52e0f3958 100644 --- a/scripts/configs/predicatorv3/approaches/agents.yaml +++ b/scripts/configs/predicatorv3/approaches/agents.yaml @@ -12,10 +12,25 @@ APPROACHES: # agent_planner_use_visualize_state: True # agent_planner_use_annotate_scene: True # option_model_use_gui: True - agent_bilevel: - NAME: "agent_bilevel" + # agent_bilevel: + # NAME: "agent_bilevel" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_sdk_max_agent_turns_per_iteration: 50 + # agent_planner_use_scratchpad: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: True + # agent_bilevel_log_state: False + # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + agent_sim_learning: + NAME: "agent_sim_learning" FLAGS: - explorer: "agent_plan" + explorer: "agent_bilevel" demonstrator: "oracle_process_planning" terminate_on_goal_reached_and_option_terminated: True agent_sdk_use_local_sandbox: True @@ -27,6 +42,7 @@ APPROACHES: option_model_use_gui: True agent_bilevel_log_state: False agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + skip_test_until_last_ite_or_early_stopping: True # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: diff --git a/scripts/configs/predicatorv3/common.yaml b/scripts/configs/predicatorv3/common.yaml index c4d2a9ab4..cbb09dc4c 100644 --- a/scripts/configs/predicatorv3/common.yaml +++ b/scripts/configs/predicatorv3/common.yaml @@ -1,8 +1,8 @@ ARGS: - "debug" # - "use_gui" - - "make_failure_videos" - - "make_test_videos" + # - "make_failure_videos" + # - "make_test_videos" # - "make_demo_videos" # - "make_demo_images" # support images # - "make_failure_images" # query images @@ -10,7 +10,7 @@ ARGS: # - "save_atoms" FLAGS: max_initial_demos: 0 - num_online_learning_cycles: 0 + num_online_learning_cycles: 1 online_nsrt_learning_requests_per_cycle: 1 skill_phase_use_motion_planning: True max_num_steps_interaction_request: 300 From 7663d05ba09724d2b5650718b2345ac059a8c896 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 10:47:31 +0100 Subject: [PATCH 24/70] Refactor sim-learning: extract primitives, add GT simulator factory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simulation primitives (code_sim_learning/utils.py): - apply_rules(state, rules, params) → ProcessUpdate - merge_updates(base_state, updates, process_features) → State - simulate_step(state, action, base_env, rules, params, features) → State These replace _build_fitted_step_fn, merge_process_updates, _sim_fn_from_rules, and the body of _build_combined_simulator. GT simulator factory (ground_truth_models): - GroundTruthSimulatorFactory ABC + get_gt_simulator(env_name) discovery, following the existing get_gt_options / get_gt_nsrts pattern. - PyBulletBoilGroundTruthSimulatorFactory registered in boil/. - Replaces the hardcoded _load_oracle_simulator in the approach. Oracle ablation flags (settings.py): - agent_sim_learn_oracle_sim_program: load GT rules, skip synthesis. - agent_sim_learn_oracle_sim_params: use GT param values, skip MCMC. Also: kin_env → base_env rename throughout, redundant self._types assignment removed, process_features computed once in __init__. --- predicators/agent_sdk/tools.py | 10 +- .../approaches/agent_sim_learning_approach.py | 247 +++++++----------- predicators/code_sim_learning/utils.py | 90 ++++++- predicators/ground_truth_models/__init__.py | 37 +++ .../ground_truth_models/boil/__init__.py | 4 +- .../ground_truth_models/boil/gt_simulator.py | 17 ++ predicators/settings.py | 7 + .../test_agent_sim_learning_approach.py | 15 +- 8 files changed, 261 insertions(+), 166 deletions(-) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index b375e0580..02c493329 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -1970,7 +1970,7 @@ def create_synthesis_tools( exec_ns: Dict[str, Any], step_transitions: list, process_features: Dict[str, List[str]], - kin_env: Any = None, + base_env: Any = None, save_dir: Optional[str] = None, ) -> list: """Create MCP tools for the sim-learning synthesis agent. @@ -1988,7 +1988,7 @@ def create_synthesis_tools( contain ``trajectories``, ``np``, ``ParamSpec``. step_transitions: ``(State, Action, State)`` triples. process_features: ``{type_name: [feat_names]}`` for MSE. - kin_env: Kinematics-only environment. When provided, + base_env: Kinematics-only environment. When provided, evaluate/test tools run kinematics before learned rules. save_dir: Directory to save simulator source code to. Each ``run_python`` call appends code to @@ -2079,7 +2079,7 @@ async def evaluate_simulator( fitted_params, mse = ( AgentSimLearningApproach._fit_parameters( rules, specs, step_transitions, process_features, - kin_env)) + base_env)) except Exception as e: # pylint: disable=broad-except return _text(f"Error: fit_params failed:\n{e}") @@ -2139,8 +2139,8 @@ async def test_simulator( for s_t, action, s_next_obs in pairs: # Run kinematics first so rules see post-kin state. - kin_state = (kin_env.simulate(s_t, action) - if kin_env is not None else s_t) + kin_state = (base_env.simulate(s_t, action) + if base_env is not None else s_t) updates: Dict = {} for rule in rules: updates = rule(kin_state, updates, t_params) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index a7a656dd0..695019c76 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -30,8 +30,10 @@ from predicators.agent_sdk.tools import create_synthesis_tools from predicators.code_sim_learning.training import (ParamSpec, compute_mse, fit_params) -from predicators.code_sim_learning.utils import LearnedSimulator +from predicators.code_sim_learning.utils import (LearnedSimulator, + apply_rules, merge_updates) from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_simulator from predicators.option_model import _OptionModelBase, _OracleOptionModel from predicators.settings import CFG from predicators.structs import (Action, InteractionResult, @@ -41,67 +43,6 @@ logger = logging.getLogger(__name__) -# ── Helpers ─────────────────────────────────────────────────────── - - -def _build_fitted_step_fn( - process_rules: List, - fitted_params: Dict[str, float], -) -> Callable[[State], Dict]: - """Create a step function from fitted process rules + parameters.""" - - def step_fn(state: State) -> Dict: - updates: Dict = {} - for rule in process_rules: - updates = rule(state, updates, fitted_params) - result: Dict = {} - for obj, feat_dict in updates.items(): - result[obj] = {} - for feat, val in feat_dict.items(): - result[obj][feat] = float(val) - return result - - return step_fn - - -def merge_process_updates( - base_state: State, - updates: Dict, - process_features: Dict[str, List[str]], -) -> State: - """Apply learned process updates on top of a base state. - - Args: - base_state: The state to merge into (e.g. from kinematics). - updates: {Object: {feat_name: new_value}} from learned dynamics. - process_features: {type_name: [feat_names]} identifying which - features to overwrite. - - Returns: - A copy of base_state with process features overwritten. - """ - if not updates: - return base_state - - new_data = {} - for obj in base_state: - arr = base_state[obj].copy() - type_name = obj.type.name - process_feats = set(process_features.get(type_name, [])) - - if obj in updates: - for feat_name, new_val in updates[obj].items(): - if feat_name in process_feats: - idx = obj.type.feature_names.index(feat_name) - arr[idx] = new_val - - new_data[obj] = arr - - merged = base_state.copy() - merged.data = new_data - return merged - - # ── Approach ───────────────────────────────────────────────────── @@ -141,11 +82,6 @@ def __init__(self, use_gui=CFG.option_model_use_gui, skip_process_dynamics=True) if option_model is None: - # Use initial_options directly rather than get_gt_options(CFG.env) - # — the latter calls get_or_create_env which would create a - # second cached env (without GUI, with full dynamics) and the - # two PyBullet connections then fight over the physics server, - # producing "Not connected to physics server" mid-rollout. option_model = _OracleOptionModel(initial_options, self._base_env.simulate) super().__init__(initial_predicates, @@ -156,8 +92,11 @@ def __init__(self, *args, option_model=option_model, **kwargs) - self._types = types self._simulator: Optional[LearnedSimulator] = None + self._process_features: Dict[str, List[str]] = { + t.name: list(t.feature_names) + for t in types if t.feature_names + } # Persistent state across learning cycles. self._process_rules: Optional[List] = None self._fitted_params: Optional[Dict[str, float]] = None @@ -183,43 +122,25 @@ def learn_from_interaction_results( self, results: Sequence[InteractionResult]) -> None: super().learn_from_interaction_results(results) - if not self._online_trajectories: - logger.warning("No transitions, skipping.") - return - - logger.info("Sim-learning cycle %d: %d total trajectories.", - self._online_learning_cycle, - len(self._online_trajectories)) - - # Include all features so the agent can synthesize rules for any - # feature, not just pre-identified "process" features. - process_features: Dict[str, List[str]] = {} - for t in self._types: - if t.feature_names: - process_features[t.name] = list(t.feature_names) + self._synthesize_with_agent(self._process_features) - # synthesize via agent. - self._synthesize_with_agent(process_features) - - # Build simulator from fitted rules. + # Build learned simulator. if self._process_rules is not None and self._fitted_params is not None: - step_fn = _build_fitted_step_fn( - self._process_rules, self._fitted_params) + rules, params = self._process_rules, self._fitted_params self._simulator = LearnedSimulator( - step_fn=step_fn, + step_fn=lambda s, _r=rules, _p=params: apply_rules(s, _r, _p), name="agent_synthesized") elif self._simulator is None: logger.warning("Synthesis produced no simulator, skipping.") return - # Build combined simulator: kinematics → learned dynamics. + # Build combined simulator. combined_sim = self._build_combined_simulator( - self._base_env, self._simulator, process_features) + self._base_env, self._simulator, self._process_features) - # Wrap in an option model with interleaved per-step simulation. + # Build learned option model self._option_model = self._build_option_model(combined_sim) - logger.info("Built learned option model (MSE: %.6f).", - self._fit_mse) + logger.info("Built learned option model (MSE: %.6f).", self._fit_mse) def _build_option_model( self, @@ -256,38 +177,59 @@ def _synthesize_with_agent( ``trajectories`` pre-loaded), then defines ``PROCESS_RULES`` and ``PARAM_SPECS``. Each ``run_python`` call appends code to a saved file; after the session we reload from that file. + + Behaviour is modified by two CFG flags: + + - ``agent_sim_learn_oracle_sim_program``: skip agent synthesis + and load GT rules/specs instead (init_values perturbed so + MCMC has non-trivial work). + - ``agent_sim_learn_oracle_sim_params``: skip MCMC fitting and + use the GT parameter values directly. """ step_transitions = self._extract_step_transitions( self._online_trajectories) - # Directory for saving simulator source code. - base = self._tool_context.sandbox_dir or self._get_log_dir() - save_dir = os.path.join(base, "simulator_code") - - # Persistent exec namespace — the agent's "scratch-pad". - exec_ns: Dict[str, Any] = { - "trajectories": self._online_trajectories, - "np": np, - "ParamSpec": ParamSpec, - } - - # Build synthesis tools (run_python, evaluate, test). - tools = create_synthesis_tools( - exec_ns, step_transitions, process_features, self._base_env, - save_dir=save_dir) - self._tool_context.extra_mcp_tools = tools - self._learning_mode = True - - # Force a fresh session so the synthesis system prompt and - # tool set take effect. - self._close_agent_session() - self._ensure_agent_session() + # ── Obtain rules + specs ──────────────────────────────── + if CFG.agent_sim_learn_oracle_sim_program: + rules, specs = get_gt_simulator(CFG.env) + if not CFG.agent_sim_learn_oracle_sim_params: + rng = np.random.default_rng(CFG.seed) + specs = [ + ParamSpec(s.name, s.init_value + rng.normal( + 0, max(abs(s.init_value) * 0.2, 1e-4))) + for s in specs + ] + logger.info("Loaded oracle sim program (%d rules, %d params).", + len(rules), len(specs)) + else: + # Directory for saving simulator source code. + base = self._tool_context.sandbox_dir or self._get_log_dir() + save_dir = os.path.join(base, "simulator_code") + + # Persistent exec namespace — the agent's "scratch-pad". + exec_ns: Dict[str, Any] = { + "trajectories": self._online_trajectories, + "np": np, + "ParamSpec": ParamSpec, + } + + # Build synthesis tools (run_python, evaluate, test). + tools = create_synthesis_tools( + exec_ns, step_transitions, process_features, self._base_env, + save_dir=save_dir) + self._tool_context.extra_mcp_tools = tools + self._learning_mode = True + + # Force a fresh session so the synthesis system prompt and + # tool set take effect. + self._close_agent_session() + self._ensure_agent_session() - # Write data-structure reference for the agent to Read. - structs_ref = self._write_structs_reference() + # Write data-structure reference for the agent to Read. + structs_ref = self._write_structs_reference() - n_trajs = len(self._online_trajectories) - message = f"""\ + n_trajs = len(self._online_trajectories) + message = f"""\ Synthesize a process dynamics simulator for this environment. \ There are {n_trajs} trajectories ({len(step_transitions)} step \ transitions) available. @@ -296,28 +238,38 @@ def _synthesize_with_agent( Read that file first, then explore the trajectory data with \ `run_python` and define PROCESS_RULES and PARAM_SPECS.""" - try: - self._query_agent_sync(message) - finally: - self._tool_context.extra_mcp_tools = [] - self._learning_mode = False - self._close_agent_session() + try: + self._query_agent_sync(message) + finally: + self._tool_context.extra_mcp_tools = [] + self._learning_mode = False + self._close_agent_session() - # Load results from saved versioned files. - rules, specs = self._load_simulator_from_file( - save_dir, self._online_trajectories) - if rules is None or specs is None: - return + # Load results from saved versioned files. + rules, specs = self._load_simulator_from_file( + save_dir, self._online_trajectories) + if rules is None or specs is None: + return + + logger.info("Agent synthesized %d rules, %d params.", + len(rules), len(specs)) self._process_rules = rules - # Fit parameters via MCMC. - self._fitted_params, self._fit_mse = self._fit_parameters( - rules, specs, step_transitions, process_features, - self._base_env) - logger.info( - "Agent synthesized %d rules, %d params (MSE: %.6f).", - len(rules), len(specs), self._fit_mse) + # ── Obtain fitted parameters ──────────────────────────── + base = self._base_env + if CFG.agent_sim_learn_oracle_sim_params: + self._fitted_params = {s.name: s.init_value for s in specs} + self._fit_mse = compute_mse( + lambda s, a, p: apply_rules(base.simulate(s, a), rules, p), + step_transitions, self._fitted_params, process_features) + logger.info("Using oracle params (MSE: %.6f).", self._fit_mse) + else: + self._fitted_params, self._fit_mse = self._fit_parameters( + rules, specs, step_transitions, process_features, + base) + logger.info("Fitted %d params (MSE: %.6f).", + len(specs), self._fit_mse) # ── Parameter fitting ──────────────────────────────────────── @@ -327,12 +279,12 @@ def _fit_parameters( specs: List[ParamSpec], step_transitions: List[Tuple[State, Action, State]], process_features: Dict[str, List[str]], - kin_env: Any = None, + base_env: Any = None, ) -> Tuple[Dict[str, float], float]: """Fit parameters for the synthesized rules via MCMC. Args: - kin_env: Kinematics-only environment. When provided the + base_env: Kinematics-only environment. When provided the simulator runs kinematics first so learned rules see the post-kinematics state (consistent with inference). @@ -342,12 +294,9 @@ def _fit_parameters( def sim_fn(state: State, action: Action, params: Dict[str, float]) -> Dict: - if kin_env is not None: - state = kin_env.simulate(state, action) - updates: Dict = {} - for rule in rules: - updates = rule(state, updates, params) - return updates + if base_env is not None: + state = base_env.simulate(state, action) + return apply_rules(state, rules, params) result = fit_params( simulator_fn=sim_fn, @@ -453,18 +402,18 @@ def _extract_step_transitions( @staticmethod def _build_combined_simulator( - kin_env: Any, + base_env: Any, simulator: LearnedSimulator, process_features: Dict[str, List[str]], ) -> Callable[[State, Action], State]: """Compose kinematics-only env with learned step-level dynamics.""" def combined_simulate(state: State, action: Action) -> State: - kin_state = kin_env.simulate(state, action) + kin_state = base_env.simulate(state, action) updates = simulator.predict_step(kin_state) if not updates: return kin_state - return merge_process_updates(kin_state, updates, process_features) + return merge_updates(kin_state, updates, process_features) return combined_simulate diff --git a/predicators/code_sim_learning/utils.py b/predicators/code_sim_learning/utils.py index 0d541f6f2..e0cdeea36 100644 --- a/predicators/code_sim_learning/utils.py +++ b/predicators/code_sim_learning/utils.py @@ -1,11 +1,20 @@ -"""Utilities for the code sim-learning module.""" +"""Utilities for the code sim-learning module. + +Core primitives for process-dynamics simulation: + +* ``apply_rules`` — run a list of rule functions on a state, return + feature updates (``ProcessUpdate``). +* ``merge_updates`` — overwrite process features in a ``State`` with + values from a ``ProcessUpdate``. +* ``simulate_step`` — full pipeline: kinematics → rules → merge. +""" from __future__ import annotations import logging -from typing import Callable, Dict +from typing import Any, Callable, Dict, List -from predicators.structs import Object, State +from predicators.structs import Action, Object, State logger = logging.getLogger(__name__) @@ -13,6 +22,81 @@ ProcessUpdate = Dict[Object, Dict[str, float]] +# ── Primitives ──────────────────────────────────────────────────── + + +def apply_rules(state: State, rules: List, + params: Dict[str, float]) -> ProcessUpdate: + """Apply process rules sequentially and return feature updates. + + Each rule has signature ``rule(state, updates, params) -> updates``. + Values are normalised to plain floats (rules may return numpy + scalars). + """ + updates: ProcessUpdate = {} + for rule in rules: + updates = rule(state, updates, params) + return { + obj: {feat: float(val) for feat, val in feat_dict.items()} + for obj, feat_dict in updates.items() + } + + +def merge_updates( + base_state: State, + updates: ProcessUpdate, + process_features: Dict[str, List[str]], +) -> State: + """Overwrite process features in *base_state* with *updates*. + + Only features listed in ``process_features[type_name]`` are + overwritten; all other features are preserved from *base_state*. + """ + if not updates: + return base_state + + new_data = {} + for obj in base_state: + arr = base_state[obj].copy() + type_name = obj.type.name + process_feats = set(process_features.get(type_name, [])) + + if obj in updates: + for feat_name, new_val in updates[obj].items(): + if feat_name in process_feats: + idx = obj.type.feature_names.index(feat_name) + arr[idx] = new_val + + new_data[obj] = arr + + merged = base_state.copy() + merged.data = new_data + return merged + + +def simulate_step( + state: State, + action: Action, + base_env: Any, + rules: List, + params: Dict[str, float], + process_features: Dict[str, List[str]], +) -> State: + """Full simulation pipeline: kinematics → rules → merge. + + Runs ``base_env.simulate`` for kinematics, ``apply_rules`` for + process dynamics, and ``merge_updates`` to combine them. + """ + kin_state = base_env.simulate(state, action) + updates = apply_rules(kin_state, rules, params) + if not updates: + return kin_state + return merge_updates(kin_state, updates, process_features) + + +# ── LearnedSimulator ────────────────────────────────────────────── + + class LearnedSimulator: """Wraps a step-level simulator function (handwritten or LLM-synthesized). diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index 2aa01dff4..3ca0f04ca 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -68,6 +68,28 @@ def get_processes( raise NotImplementedError("Override me!") +class GroundTruthSimulatorFactory(abc.ABC): + """Parent class for ground-truth process-dynamics simulator programs.""" + + @classmethod + @abc.abstractmethod + def get_env_names(cls) -> Set[str]: + """Get the env names that this factory builds simulators for.""" + raise NotImplementedError("Override me!") + + @classmethod + @abc.abstractmethod + def get_rules(cls) -> list: + """Return the list of process rule functions.""" + raise NotImplementedError("Override me!") + + @classmethod + @abc.abstractmethod + def get_param_specs(cls) -> list: + """Return the list of ParamSpec objects.""" + raise NotImplementedError("Override me!") + + class GroundTruthLDLBridgePolicyFactory(abc.ABC): """Ground-truth policies implemented with LDLs saved in text files.""" @@ -251,6 +273,21 @@ def get_gt_processes(env_name: str, return final_processes +def get_gt_simulator(env_name: str) -> tuple: + """Load ground-truth process rules and param specs for an env. + + Returns ``(rules, param_specs)`` where *rules* is a list of + process rule functions and *param_specs* is a list of + ``ParamSpec`` objects whose ``init_value`` is the GT value. + """ + gt_name = _normalize_env_name_for_gt(env_name) + for cls in utils.get_all_subclasses(GroundTruthSimulatorFactory): + if not cls.__abstractmethods__ and gt_name in cls.get_env_names(): + return cls.get_rules(), cls.get_param_specs() + raise NotImplementedError("Ground-truth simulator not implemented for " + f"env: {env_name}") + + def get_gt_ldl_bridge_policy(env_name: str, types: Set[Type], predicates: Set[Predicate], options: Set[ParameterizedOption], diff --git a/predicators/ground_truth_models/boil/__init__.py b/predicators/ground_truth_models/boil/__init__.py index cde72a21a..12fb982f8 100644 --- a/predicators/ground_truth_models/boil/__init__.py +++ b/predicators/ground_truth_models/boil/__init__.py @@ -1,5 +1,6 @@ """Ground-truth models for coffee environment and variants.""" +from .gt_simulator import PyBulletBoilGroundTruthSimulatorFactory from .nsrts import PyBulletBoilGroundTruthNSRTFactory from .options import PyBulletBoilGroundTruthOptionFactory from .processes import PyBulletBoilGroundTruthProcessFactory @@ -7,5 +8,6 @@ __all__ = [ "PyBulletBoilGroundTruthNSRTFactory", "PyBulletBoilGroundTruthOptionFactory", - "PyBulletBoilGroundTruthProcessFactory" + "PyBulletBoilGroundTruthProcessFactory", + "PyBulletBoilGroundTruthSimulatorFactory", ] diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index 9e3c46054..22573a5ab 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -12,6 +12,7 @@ from predicators.code_sim_learning.training import ParamSpec from predicators.code_sim_learning.utils import ProcessUpdate +from predicators.ground_truth_models import GroundTruthSimulatorFactory from predicators.structs import Object, State # Constants matching pybullet_boil.py exactly. @@ -156,6 +157,22 @@ def _get_val(obj: Object, feat: str) -> float: PROCESS_RULES = [_water_filling, _heating, _happiness] +class PyBulletBoilGroundTruthSimulatorFactory(GroundTruthSimulatorFactory): + """GT process-dynamics simulator for pybullet_boil.""" + + @classmethod + def get_env_names(cls): + return {"pybullet_boil"} + + @classmethod + def get_rules(cls): + return list(PROCESS_RULES) + + @classmethod + def get_param_specs(cls): + return list(BOIL_PARAM_SPECS) + + def get_gt_process_features() -> Dict[str, List[str]]: """Process features handled by the simulator (not PyBullet).""" return { diff --git a/predicators/settings.py b/predicators/settings.py index c1b23423a..ef898e028 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1022,6 +1022,13 @@ class GlobalSettings: # upstream step multiplies the cost. agent_bilevel_explorer_max_samples_per_step = 50 + # Sim-learning oracle flags (for ablation / debugging). + # When True, load GT process rules instead of running agent synthesis. + # Parameters init_values are perturbed so MCMC still has work to do. + agent_sim_learn_oracle_sim_program = False + # When True, use GT parameter values directly, skipping MCMC fitting. + agent_sim_learn_oracle_sim_params = False + @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: """A workaround for global settings that are derived from the diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index 4e1367fa5..55d68fbf3 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -16,8 +16,7 @@ from predicators import utils from predicators.approaches.agent_bilevel_approach import _SketchStep -from predicators.approaches.agent_sim_learning_approach import \ - merge_process_updates +from predicators.code_sim_learning.utils import merge_updates from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_options from predicators.ground_truth_models.boil.gt_simulator import \ @@ -69,10 +68,10 @@ def _build_kinematics_only_oracle(env): Creates a separate env instance with process dynamics disabled, so that water filling, heating, and happiness are not simulated. """ - kin_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, + base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, skip_process_dynamics=True) - options = get_gt_options(kin_env.get_name()) - oracle = _OracleOptionModel(options, kin_env.simulate) + options = get_gt_options(base_env.get_name()) + oracle = _OracleOptionModel(options, base_env.simulate) preds = env.predicates oracle._abstract_function = lambda s: utils.abstract(s, preds) return oracle @@ -85,19 +84,19 @@ def _build_combined_model(env): env.simulate with a step-level dynamics function into a single simulator, then plug into a standard _OracleOptionModel. """ - kin_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, + base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, skip_process_dynamics=True) process_features = get_gt_process_features() gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} def combined_simulate(state, action): - kin_state = kin_env.simulate(state, action) + kin_state = base_env.simulate(state, action) updates = {} for rule in PROCESS_RULES: updates = rule(kin_state, updates, gt_params) if not updates: return kin_state - return merge_process_updates(kin_state, updates, process_features) + return merge_updates(kin_state, updates, process_features) options = get_gt_options(env.get_name()) model = _OracleOptionModel(options, combined_simulate) From 9970dd48f69dcb02eaa4a9d6576f73be1f6dac2e Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 11:15:16 +0100 Subject: [PATCH 25/70] Fix formatting, pylint, and mypy issues for CI compliance - yapf + isort autoformatting applied to all touched files. - pylint: fix logging-not-lazy in agent_bilevel_explorer, add broad-except and reimported disables in agent_sim_learning_approach. - mypy: fix base/env variable name collision, add type: ignore on lambda inference, add return type annotations to GT factory methods. --- predicators/agent_sdk/agent_session_mixin.py | 10 ++- predicators/agent_sdk/bilevel_sketch.py | 28 +++---- predicators/agent_sdk/tools.py | 59 +++++++------- .../approaches/agent_sim_learning_approach.py | 79 ++++++++++--------- predicators/code_sim_learning/utils.py | 4 +- .../explorers/agent_bilevel_explorer.py | 18 ++--- .../ground_truth_models/boil/gt_simulator.py | 9 +-- .../test_agent_sim_learning_approach.py | 12 ++- 8 files changed, 106 insertions(+), 113 deletions(-) diff --git a/predicators/agent_sdk/agent_session_mixin.py b/predicators/agent_sdk/agent_session_mixin.py index 1f518e356..325974882 100644 --- a/predicators/agent_sdk/agent_session_mixin.py +++ b/predicators/agent_sdk/agent_session_mixin.py @@ -129,15 +129,17 @@ def _ensure_agent_session(self) -> None: ) extra_names = [ - getattr(t, "name", "") for t in - self._tool_context.extra_mcp_tools] + getattr(t, "name", "") + for t in self._tool_context.extra_mcp_tools + ] self._agent_session = AgentSessionManager( system_prompt=self._get_agent_system_prompt(), mcp_server=mcp_server, log_dir=self._get_log_dir(), model_name=CFG.agent_sdk_model_name, - allowed_tools=get_allowed_tool_list( - tool_names, extra_names=extra_names or None), + allowed_tools=get_allowed_tool_list(tool_names, + extra_names=extra_names + or None), ) if self._agent_session_id is not None: diff --git a/predicators/agent_sdk/bilevel_sketch.py b/predicators/agent_sdk/bilevel_sketch.py index f088ee0b5..672f1bbd7 100644 --- a/predicators/agent_sdk/bilevel_sketch.py +++ b/predicators/agent_sdk/bilevel_sketch.py @@ -208,8 +208,7 @@ def parse_subgoal_annotations( is_neg = atom_match.group(1) is not None pred_name = atom_match.group(2) obj_names = [ - n.strip().split(':')[0] - for n in atom_match.group(3).split(',') + n.strip().split(':')[0] for n in atom_match.group(3).split(',') ] if pred_name not in pred_map: @@ -222,9 +221,8 @@ def parse_subgoal_annotations( logging.warning(f"Unknown object in subgoal: {e}") continue if len(objs) != len(pred.types): - logging.warning( - f"Arity mismatch for {pred_name}: expected " - f"{len(pred.types)}, got {len(objs)}") + logging.warning(f"Arity mismatch for {pred_name}: expected " + f"{len(pred.types)}, got {len(objs)}") continue atom = GroundAtom(pred, objs) if is_neg: @@ -259,11 +257,7 @@ def parse_sketch_from_text( option_names = {o.name for o in options} parsed = utils.parse_model_output_into_option_plan( - cleaned_text, - objects, - types, - options, - parse_continuous_params=False) + cleaned_text, objects, types, options, parse_continuous_params=False) if not parsed: return [] @@ -283,9 +277,7 @@ def parse_sketch_from_text( subgoal_neg_atoms=neg if neg else None)) else: sketch.append( - SketchStep(option=option, - objects=objs, - subgoal_atoms=None)) + SketchStep(option=option, objects=objs, subgoal_atoms=None)) return sketch @@ -366,8 +358,7 @@ def sample_fn(idx: int, state: State, return grounded def validate_fn(idx: int, _pre_state: State, _option: _Option, - post_state: State, - _num_actions: int) -> Tuple[bool, str]: + post_state: State, _num_actions: int) -> Tuple[bool, str]: step = sketch[idx] if check_subgoals and step.subgoal_atoms is not None: current_atoms = utils.abstract(post_state, predicates) @@ -415,10 +406,9 @@ def wrapped_on_step_fail(idx: int, cur_plan: List[Optional[_Option]], and deepest_subgoal_fail_idx[0] >= 0): snapshot = deepest_subgoal_fail_prefix[0] refined = [p for p in snapshot if p is not None] - logging.info( - f"[{run_id}] Truncating at deepest subgoal failure " - f"(step {deepest_subgoal_fail_idx[0]}): " - f"{len(refined)}/{n} steps in experiment plan.") + logging.info(f"[{run_id}] Truncating at deepest subgoal failure " + f"(step {deepest_subgoal_fail_idx[0]}): " + f"{len(refined)}/{n} steps in experiment plan.") return cast(List[_Option], refined), False, total_samples refined = [p for p in plan if p is not None] diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 02c493329..e56812599 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -2001,8 +2001,8 @@ def create_synthesis_tools( from claude_agent_sdk import \ tool # pylint: disable=import-outside-toplevel - from predicators.approaches.agent_sim_learning_approach import ( # pylint: disable=import-outside-toplevel - AgentSimLearningApproach) + from predicators.approaches.agent_sim_learning_approach import \ + AgentSimLearningApproach # pylint: disable=import-outside-toplevel _run_count = [0] # mutable counter in closure @@ -2060,33 +2060,32 @@ async def run_python(args: Dict[str, Any]) -> Dict[str, Any]: "Fit parameters using PROCESS_RULES and PARAM_SPECS " "from the run_python namespace. Reports MSE and fitted " "parameter values.", - {"type": "object", "properties": {}}, + { + "type": "object", + "properties": {} + }, ) - async def evaluate_simulator( - args: Dict[str, Any]) -> Dict[str, Any]: + async def evaluate_simulator(args: Dict[str, Any]) -> Dict[str, Any]: rules = exec_ns.get("PROCESS_RULES") specs = exec_ns.get("PARAM_SPECS") if not isinstance(rules, list) or not rules: - return _text( - "Error: PROCESS_RULES not defined. Use " - "run_python to define it first.") + return _text("Error: PROCESS_RULES not defined. Use " + "run_python to define it first.") if not isinstance(specs, list) or not specs: - return _text( - "Error: PARAM_SPECS not defined. Use " - "run_python to define it first.") + return _text("Error: PARAM_SPECS not defined. Use " + "run_python to define it first.") try: - fitted_params, mse = ( - AgentSimLearningApproach._fit_parameters( - rules, specs, step_transitions, process_features, - base_env)) + fitted_params, mse = (AgentSimLearningApproach._fit_parameters( + rules, specs, step_transitions, process_features, base_env)) except Exception as e: # pylint: disable=broad-except return _text(f"Error: fit_params failed:\n{e}") lines = [ f"MSE: {mse:.6f} on " f"{len(step_transitions)} step transitions.", - "", "Fitted parameters:", + "", + "Fitted parameters:", ] for name, val in fitted_params.items(): lines.append(f" {name}: {val:.6f}") @@ -2104,20 +2103,19 @@ async def evaluate_simulator( "properties": { "max_transitions": { "type": "integer", - "description": - "Max transitions to test (default 100).", + "description": "Max transitions to test (default 100).", }, "tolerance": { - "type": "number", + "type": + "number", "description": - "Absolute tolerance for mismatch " - "(default 1e-4).", + "Absolute tolerance for mismatch " + "(default 1e-4).", }, }, }, ) - async def test_simulator( - args: Dict[str, Any]) -> Dict[str, Any]: + async def test_simulator(args: Dict[str, Any]) -> Dict[str, Any]: rules = exec_ns.get("PROCESS_RULES") specs = exec_ns.get("PARAM_SPECS") if not isinstance(rules, list) or not rules: @@ -2152,17 +2150,15 @@ async def test_simulator( if obj in updates and feat in updates[obj]: pred = updates[obj][feat] pred = (pred.item() - if hasattr(pred, "item") - else float(pred)) + if hasattr(pred, "item") else float(pred)) else: pred = s_t.get(obj, feat) obs = s_next_obs.get(obj, feat) err = abs(pred - obs) if err > tol: - entry.append( - f" {obj.name}.{feat}: " - f"pred={pred:.6f} obs={obs:.6f} " - f"err={err:.6f}") + entry.append(f" {obj.name}.{feat}: " + f"pred={pred:.6f} obs={obs:.6f} " + f"err={err:.6f}") n_tested += 1 if entry: @@ -2171,9 +2167,8 @@ async def test_simulator( lines.extend(entry) lines.append("") - lines.append( - f"Tested {n_tested} steps: {n_mismatch} mismatches, " - f"{n_tested - n_mismatch} correct.") + lines.append(f"Tested {n_tested} steps: {n_mismatch} mismatches, " + f"{n_tested - n_mismatch} correct.") return _text("\n".join(lines)) return [run_python, evaluate_simulator, test_simulator] diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index 695019c76..95e730b7a 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -26,23 +26,21 @@ from gym.spaces import Box from predicators import utils -from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach from predicators.agent_sdk.tools import create_synthesis_tools -from predicators.code_sim_learning.training import (ParamSpec, compute_mse, - fit_params) -from predicators.code_sim_learning.utils import (LearnedSimulator, - apply_rules, merge_updates) +from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach +from predicators.code_sim_learning.training import ParamSpec, compute_mse, \ + fit_params +from predicators.code_sim_learning.utils import LearnedSimulator, \ + apply_rules, merge_updates from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_simulator from predicators.option_model import _OptionModelBase, _OracleOptionModel from predicators.settings import CFG -from predicators.structs import (Action, InteractionResult, - LowLevelTrajectory, ParameterizedOption, - Predicate, State, Task, Type) +from predicators.structs import Action, InteractionResult, \ + LowLevelTrajectory, ParameterizedOption, Predicate, State, Task, Type logger = logging.getLogger(__name__) - # ── Approach ───────────────────────────────────────────────────── @@ -78,9 +76,10 @@ def __init__(self, # GUI connections) and is the only env this approach holds. # learn_from_interaction_results later wraps a kin+learned # combined simulator around the same env. - self._base_env = create_new_env(CFG.env, do_cache=False, - use_gui=CFG.option_model_use_gui, - skip_process_dynamics=True) + self._base_env = create_new_env(CFG.env, + do_cache=False, + use_gui=CFG.option_model_use_gui, + skip_process_dynamics=True) if option_model is None: option_model = _OracleOptionModel(initial_options, self._base_env.simulate) @@ -128,15 +127,17 @@ def learn_from_interaction_results( if self._process_rules is not None and self._fitted_params is not None: rules, params = self._process_rules, self._fitted_params self._simulator = LearnedSimulator( - step_fn=lambda s, _r=rules, _p=params: apply_rules(s, _r, _p), + step_fn=lambda s, _r=rules, _p=params: # type: ignore[misc] + apply_rules(s, _r, _p), name="agent_synthesized") elif self._simulator is None: logger.warning("Synthesis produced no simulator, skipping.") return # Build combined simulator. - combined_sim = self._build_combined_simulator( - self._base_env, self._simulator, self._process_features) + combined_sim = self._build_combined_simulator(self._base_env, + self._simulator, + self._process_features) # Build learned option model self._option_model = self._build_option_model(combined_sim) @@ -195,8 +196,9 @@ def _synthesize_with_agent( if not CFG.agent_sim_learn_oracle_sim_params: rng = np.random.default_rng(CFG.seed) specs = [ - ParamSpec(s.name, s.init_value + rng.normal( - 0, max(abs(s.init_value) * 0.2, 1e-4))) + ParamSpec( + s.name, s.init_value + + rng.normal(0, max(abs(s.init_value) * 0.2, 1e-4))) for s in specs ] logger.info("Loaded oracle sim program (%d rules, %d params).", @@ -214,9 +216,11 @@ def _synthesize_with_agent( } # Build synthesis tools (run_python, evaluate, test). - tools = create_synthesis_tools( - exec_ns, step_transitions, process_features, self._base_env, - save_dir=save_dir) + tools = create_synthesis_tools(exec_ns, + step_transitions, + process_features, + self._base_env, + save_dir=save_dir) self._tool_context.extra_mcp_tools = tools self._learning_mode = True @@ -251,25 +255,26 @@ def _synthesize_with_agent( if rules is None or specs is None: return - logger.info("Agent synthesized %d rules, %d params.", - len(rules), len(specs)) + logger.info("Agent synthesized %d rules, %d params.", len(rules), + len(specs)) self._process_rules = rules # ── Obtain fitted parameters ──────────────────────────── - base = self._base_env if CFG.agent_sim_learn_oracle_sim_params: self._fitted_params = {s.name: s.init_value for s in specs} + env = self._base_env self._fit_mse = compute_mse( - lambda s, a, p: apply_rules(base.simulate(s, a), rules, p), + lambda s, a, p: apply_rules( # type: ignore[misc] + env.simulate(s, a), rules, p), step_transitions, self._fitted_params, process_features) logger.info("Using oracle params (MSE: %.6f).", self._fit_mse) else: self._fitted_params, self._fit_mse = self._fit_parameters( rules, specs, step_transitions, process_features, - base) - logger.info("Fitted %d params (MSE: %.6f).", - len(specs), self._fit_mse) + self._base_env) + logger.info("Fitted %d params (MSE: %.6f).", len(specs), + self._fit_mse) # ── Parameter fitting ──────────────────────────────────────── @@ -292,8 +297,8 @@ def _fit_parameters( (fitted_params, mse) tuple. """ - def sim_fn(state: State, action: Action, - params: Dict[str, float]) -> Dict: + def sim_fn(state: State, action: Action, params: Dict[str, + float]) -> Dict: if base_env is not None: state = base_env.simulate(state, action) return apply_rules(state, rules, params) @@ -305,8 +310,8 @@ def sim_fn(state: State, action: Action, process_features=process_features, ) - mse = compute_mse( - sim_fn, step_transitions, result.point_estimate, process_features) + mse = compute_mse(sim_fn, step_transitions, result.point_estimate, + process_features) return result.point_estimate, mse @staticmethod @@ -325,9 +330,8 @@ def _load_simulator_from_file( logger.warning("No simulator code dir at %s.", save_dir) return None, None - files = sorted( - f for f in os.listdir(save_dir) - if f.endswith(".py") and f[0].isdigit()) + files = sorted(f for f in os.listdir(save_dir) + if f.endswith(".py") and f[0].isdigit()) if not files: logger.warning("No code files in %s.", save_dir) return None, None @@ -343,8 +347,9 @@ def _load_simulator_from_file( code = f.read() try: exec(code, ns) # pylint: disable=exec-used - except Exception: - logger.warning("Failed to exec %s, skipping.", fpath, + except Exception: # pylint: disable=broad-except + logger.warning("Failed to exec %s, skipping.", + fpath, exc_info=True) rules = ns.get("PROCESS_RULES") @@ -367,7 +372,7 @@ def _write_structs_reference(self) -> str: Returns the path the agent should Read. """ - from predicators.structs import ( # pylint: disable=import-outside-toplevel + from predicators.structs import ( # pylint: disable=import-outside-toplevel,reimported Action as _Action, LowLevelTrajectory as _LLT, Object as _Object, State as _State, Type as _Type) diff --git a/predicators/code_sim_learning/utils.py b/predicators/code_sim_learning/utils.py index e0cdeea36..5436a36e8 100644 --- a/predicators/code_sim_learning/utils.py +++ b/predicators/code_sim_learning/utils.py @@ -21,7 +21,6 @@ # Type alias: {Object: {feature_name: new_value}} ProcessUpdate = Dict[Object, Dict[str, float]] - # ── Primitives ──────────────────────────────────────────────────── @@ -37,7 +36,8 @@ def apply_rules(state: State, rules: List, for rule in rules: updates = rule(state, updates, params) return { - obj: {feat: float(val) for feat, val in feat_dict.items()} + obj: {feat: float(val) + for feat, val in feat_dict.items()} for obj, feat_dict in updates.items() } diff --git a/predicators/explorers/agent_bilevel_explorer.py b/predicators/explorers/agent_bilevel_explorer.py index d71344693..8c50db54c 100644 --- a/predicators/explorers/agent_bilevel_explorer.py +++ b/predicators/explorers/agent_bilevel_explorer.py @@ -108,7 +108,7 @@ def _get_exploration_strategy(self, train_task_idx: int, timeout=float(timeout), rng=np.random.default_rng(CFG.seed), max_samples_per_step=CFG. - agent_bilevel_explorer_max_samples_per_step, + agent_bilevel_explorer_max_samples_per_step, check_subgoals=True, check_final_goal=False, truncate_on_subgoal_fail=True, @@ -124,11 +124,9 @@ def _get_exploration_strategy(self, train_task_idx: int, for i, opt in enumerate(plan): obj_s = ", ".join(o.name for o in opt.objects) par_s = ", ".join(f"{p:.4f}" for p in opt.params) - plan_strs.append( - f" {i}: {opt.name}({obj_s})[{par_s}]") - logging.info( - "agent_bilevel explorer: experiment plan:\n" + - "\n".join(plan_strs)) + plan_strs.append(f" {i}: {opt.name}({obj_s})[{par_s}]") + logging.info("agent_bilevel explorer: experiment plan:\n%s", + "\n".join(plan_strs)) if plan: policy = utils.option_plan_to_policy( @@ -153,8 +151,8 @@ def _get_exploration_strategy(self, train_task_idx: int, # ------------------------------------------------------------------ # def _wrap_policy( - self, policy: Callable[[State], Action] - ) -> Callable[[State], Action]: + self, policy: Callable[[State], + Action]) -> Callable[[State], Action]: """Convert OptionExecutionFailure into RequestActPolicyFailure. This lets the main loop cleanly terminate the episode when the @@ -219,8 +217,8 @@ def _build_trajectory_summary(self) -> str: return "\n".join(lines) - def _extract_option_plan_text( - self, responses: List[Dict[str, Any]]) -> str: + def _extract_option_plan_text(self, responses: List[Dict[str, + Any]]) -> str: """Extract plan text from the last assistant text response.""" last_text_parts: List[str] = [] for resp in responses: diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index 22573a5ab..03daa230b 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -83,8 +83,7 @@ def _water_filling(state: State, updates: ProcessUpdate, spill = float(state.get(faucet, "spilled_level")) new_spill = min(params["max_water_spill_width"], spill + params["water_fill_speed"]) - updates.setdefault( - faucet, {})["spilled_level"] = new_spill + updates.setdefault(faucet, {})["spilled_level"] = new_spill break if not jug_catching: @@ -161,15 +160,15 @@ class PyBulletBoilGroundTruthSimulatorFactory(GroundTruthSimulatorFactory): """GT process-dynamics simulator for pybullet_boil.""" @classmethod - def get_env_names(cls): + def get_env_names(cls) -> set: return {"pybullet_boil"} @classmethod - def get_rules(cls): + def get_rules(cls) -> list: return list(PROCESS_RULES) @classmethod - def get_param_specs(cls): + def get_param_specs(cls) -> list: return list(BOIL_PARAM_SPECS) diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index 55d68fbf3..ecfbcebaa 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -68,8 +68,10 @@ def _build_kinematics_only_oracle(env): Creates a separate env instance with process dynamics disabled, so that water filling, heating, and happiness are not simulated. """ - base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, - skip_process_dynamics=True) + base_env = create_new_env("pybullet_boil", + do_cache=False, + use_gui=False, + skip_process_dynamics=True) options = get_gt_options(base_env.get_name()) oracle = _OracleOptionModel(options, base_env.simulate) preds = env.predicates @@ -84,8 +86,10 @@ def _build_combined_model(env): env.simulate with a step-level dynamics function into a single simulator, then plug into a standard _OracleOptionModel. """ - base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, - skip_process_dynamics=True) + base_env = create_new_env("pybullet_boil", + do_cache=False, + use_gui=False, + skip_process_dynamics=True) process_features = get_gt_process_features() gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} From 8ff80a4c107415e88a0de47cab5e1c33fb524149 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 12:39:55 +0100 Subject: [PATCH 26/70] Update test setup to use test tasks for boil environment and refine test description --- tests/approaches/test_agent_sim_learning_approach.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index ecfbcebaa..74b9bef4b 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -46,7 +46,7 @@ def _setup_env(): "wait_option_terminate_on_atom_change": True, }) env = create_new_env("pybullet_boil", do_cache=False, use_gui=False) - task = [t.task for t in env.get_train_tasks()][0] + task = [t.task for t in env.get_test_tasks()][0] options = get_gt_options(env.get_name()) options_dict = {o.name: o for o in options} objects_dict = {obj.name: obj for obj in task.init} @@ -300,7 +300,7 @@ def validate_fn(idx, _pre, _opt, post_state, _n_acts): @pytest.mark.parametrize("model_type", ["oracle", "combined"]) def test_boil_sketch_refinement(model_type): - """Test that backtracking refinement solves a boil task.""" + """Test that backtracking refinement solves the first test task.""" env, task, options_dict, objects_dict = _setup_env() predicates = env.predicates options = get_gt_options(env.get_name()) From 54002dd0ac4784be1cef7ffc682df831b864f592 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Thu, 16 Apr 2026 12:52:29 +0100 Subject: [PATCH 27/70] Refactor combined model in GT simulator --- .../test_agent_sim_learning_approach.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index 74b9bef4b..31528aa69 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -16,11 +16,12 @@ from predicators import utils from predicators.approaches.agent_bilevel_approach import _SketchStep -from predicators.code_sim_learning.utils import merge_updates +from predicators.code_sim_learning.utils import LearnedSimulator, \ + apply_rules, merge_updates from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_options from predicators.ground_truth_models.boil.gt_simulator import \ - BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features + BOIL_PARAM_SPECS, PROCESS_RULES from predicators.option_model import _OracleOptionModel from predicators.planning import run_backtracking_refinement from predicators.structs import GroundAtom, Object, ParameterizedOption, \ @@ -82,22 +83,29 @@ def _build_kinematics_only_oracle(env): def _build_combined_model(env): """Build a combined model: kinematics-only env + GT step-level dynamics. - This mirrors the approach's design: compose a kinematics-only - env.simulate with a step-level dynamics function into a single - simulator, then plug into a standard _OracleOptionModel. + Uses the same construction as AgentSimLearningApproach: wraps GT + rules in a LearnedSimulator via apply_rules, composes with a + kinematics-only env, and derives process_features from env.types + (all features, not just GT process features). """ base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, skip_process_dynamics=True) - process_features = get_gt_process_features() + process_features = { + t.name: list(t.feature_names) + for t in env.types if t.feature_names + } gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} + rules = PROCESS_RULES + + simulator = LearnedSimulator( + step_fn=lambda s, _r=rules, _p=gt_params: apply_rules(s, _r, _p), + name="gt_combined") def combined_simulate(state, action): kin_state = base_env.simulate(state, action) - updates = {} - for rule in PROCESS_RULES: - updates = rule(kin_state, updates, gt_params) + updates = simulator.predict_step(kin_state) if not updates: return kin_state return merge_updates(kin_state, updates, process_features) From cb405d9f4e8d7efc1a4804c0e9e8abbfa1f6260b Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:13:52 +0100 Subject: [PATCH 28/70] Fix expected-atoms check to support DerivedPredicates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use utils.abstract to evaluate expected atoms in low-level search so that DerivedPredicates — which require a Set[GroundAtom] rather than a State — are handled correctly alongside regular predicates. --- predicators/planning.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/predicators/planning.py b/predicators/planning.py index 162e69443..14f3889a6 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -688,7 +688,17 @@ def validate_fn(idx: int, pre_state: State, _option: _Option, for atom in atoms_sequence[idx + 1] if atom.predicate.name != _NOT_CAUSES_FAILURE } - if all(a.holds(post_state) for a in expected_atoms): + # Use utils.abstract to evaluate atoms so that + # DerivedPredicates (which need a Set[GroundAtom], not a + # State) are handled correctly. + preds: Set[Predicate] = set() + for a in expected_atoms: + preds.add(a.predicate) + aux = getattr(a.predicate, "auxiliary_predicates", None) + if aux: + preds.update(aux) + current_atoms = utils.abstract(post_state, preds) + if expected_atoms.issubset(current_atoms): return True, "" return False, "expected atoms not hold" # No atoms check — verify goal on final step. From 6c925724e8339711c8cc20a31eab2dc959029126 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:13:59 +0100 Subject: [PATCH 29/70] Skip kinematic reset in PyBullet when only non-kinematic state changed When sequential simulate calls differ only in process features (as in the combined kinematic+learned simulator), reapplying joint positions and tearing down/recreating grasp constraints causes visible arm jitter. Compare robot poses first and skip the kinematic reset path when they already match. --- predicators/envs/pybullet_env.py | 80 ++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 6f30b7895..910fb3f68 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -388,9 +388,9 @@ def _step_base(self, action: Action) -> None: def _domain_specific_step(self) -> None: """Apply domain-specific dynamics after kinematics. - Override in subclasses to add post-kinematics effects - (water filling, heating, balance beam physics, etc.). - Skipped when ``skip_process_dynamics=True`` is passed to the constructor. + Override in subclasses to add post-kinematics effects (water + filling, heating, balance beam physics, etc.). Skipped when + ``skip_process_dynamics=True`` is passed to the constructor. """ # ── State Write (State → PyBullet) ────────────────────────── @@ -402,43 +402,74 @@ def _set_state(self, state: State) -> None: keyed by Object) into the corresponding PyBullet scene (joint positions, body poses, grasp constraints, etc.). + When robot and object poses already match (e.g. sequential + simulate calls where only process features changed), the + kinematic reset is skipped to avoid discontinuous joint resets + and grasp constraint teardown/recreation that cause visible + jitter. + Call sites: - reset() / _add_pybullet_state_to_tasks(): initialization - simulate(): option-model / bilevel-planning rollouts - external callers (skill factories, agent tools, tests) """ + # Check if kinematics already match before overwriting + # _current_observation. When only process features differ + # (e.g. combined kin+learned simulator), we can skip the + # expensive kinematic reset that causes robot arm jitter. + skip_kin = self._kinematics_match(state) + # Keep _current_observation in sync so that step() can read it # (e.g. for finger-delta computation). self._current_observation = state self._objects = list(state.data) - # 1) Clear old constraint if we had a held object - if self._held_constraint_id is not None: - p.removeConstraint(self._held_constraint_id, - physicsClientId=self._physics_client_id) - self._held_constraint_id = None - self._held_obj_to_base_link = None - self._held_obj_id = None - # 2) Reset robot pose - self._pybullet_robot.reset_state(self._extract_robot_state(state)) + if not skip_kin: + # 1) Clear old constraint if we had a held object + if self._held_constraint_id is not None: + p.removeConstraint(self._held_constraint_id, + physicsClientId=self._physics_client_id) + self._held_constraint_id = None + self._held_obj_to_base_link = None + self._held_obj_id = None + + # 2) Reset robot pose + self._pybullet_robot.reset_state(self._extract_robot_state(state)) - # 3) Reset all known objects (position, orientation, etc.) - for obj in self._objects: - if obj.type.name == "robot" or \ - obj.type.name in self._VIRTUAL_OBJECT_TYPES: - continue - self._reset_single_object(obj, state) + # 3) Reset all known objects (position, orientation, etc.) + for obj in self._objects: + if obj.type.name == "robot" or \ + obj.type.name in self._VIRTUAL_OBJECT_TYPES: + continue + self._reset_single_object(obj, state) # 4) Let the subclass do any domain-specific state setup self._set_domain_specific_state(state) # 5) Check for reconstruction mismatch. # Only raise for envs that override _get_state(). - reconstructed = self._get_state() - if not reconstructed.allclose(state): - if type(self)._get_state is not PyBulletEnv._get_state: - raise ValueError("Could not reconstruct state.") - logging.warning("Could not reconstruct state exactly in reset.") + if not skip_kin: + reconstructed = self._get_state() + if not reconstructed.allclose(state): + if type(self)._get_state is not PyBulletEnv._get_state: + raise ValueError("Could not reconstruct state.") + logging.warning( + "Could not reconstruct state exactly in reset.") + + def _kinematics_match(self, state: State) -> bool: + """Check if robot pose in *state* matches the current PyBullet state. + + Used by ``_set_state`` to skip the kinematic reset when only + non-kinematic features (process dynamics) have changed. + """ + if self._current_observation is None: + return False + try: + new_robot = self._extract_robot_state(state) + cur_robot = self._extract_robot_state(self._current_observation) + return bool(np.allclose(new_robot, cur_robot, atol=1e-3)) + except (KeyError, ValueError): + return False def _reset_single_object(self, obj: Object, state: State) -> None: """Set a single physical object's pose and grasp constraint in PyBullet @@ -485,7 +516,8 @@ def _reset_single_object(self, obj: Object, state: State) -> None: @abc.abstractmethod def _set_domain_specific_state(self, state: State) -> None: - """Set simulator state for features that the base class doesn't handle + """Set simulator state for features that the base class doesn't handle. + — e.g. switch on/off, liquid levels, button colors, balance beam positions. From c9723f24a4fdca33276811defe6731be0ae0851a Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:14:03 +0100 Subject: [PATCH 30/70] Support offline dataset learning in AgentSimLearningApproach Factor simulator synthesis into a shared _learn_simulator helper so that both learn_from_offline_dataset and learn_from_interaction_results can trigger it on their respective trajectory sources. Also create a separate headless env for parameter fitting so MCMC's thousands of _set_state calls don't thrash the GUI env during training. --- .../approaches/agent_sim_learning_approach.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index 95e730b7a..c415cc4b2 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -36,7 +36,7 @@ from predicators.ground_truth_models import get_gt_simulator from predicators.option_model import _OptionModelBase, _OracleOptionModel from predicators.settings import CFG -from predicators.structs import Action, InteractionResult, \ +from predicators.structs import Action, Dataset, InteractionResult, \ LowLevelTrajectory, ParameterizedOption, Predicate, State, Task, Type logger = logging.getLogger(__name__) @@ -115,13 +115,24 @@ def _get_agent_system_prompt(self) -> str: return self._build_synthesis_system_prompt() return super()._get_agent_system_prompt() - # ── Online learning ────────────────────────────────────────── + # ── Learning ──────────────────────────────────────────────── + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + super().learn_from_offline_dataset(dataset) + self._learn_simulator(dataset.trajectories) def learn_from_interaction_results( self, results: Sequence[InteractionResult]) -> None: super().learn_from_interaction_results(results) + self._learn_simulator(self._online_trajectories) + + def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: + """Synthesize rules, fit parameters, and build the option model. - self._synthesize_with_agent(self._process_features) + Shared by ``learn_from_offline_dataset`` and + ``learn_from_interaction_results``. + """ + self._synthesize_with_agent(self._process_features, trajectories) # Build learned simulator. if self._process_rules is not None and self._fitted_params is not None: @@ -169,6 +180,7 @@ def _build_option_model( def _synthesize_with_agent( self, process_features: Dict[str, List[str]], + trajectories: List[LowLevelTrajectory], ) -> None: """Synthesize parameterized process rules via a Claude agent. @@ -187,8 +199,7 @@ def _synthesize_with_agent( - ``agent_sim_learn_oracle_sim_params``: skip MCMC fitting and use the GT parameter values directly. """ - step_transitions = self._extract_step_transitions( - self._online_trajectories) + step_transitions = self._extract_step_transitions(trajectories) # ── Obtain rules + specs ──────────────────────────────── if CFG.agent_sim_learn_oracle_sim_program: @@ -210,7 +221,7 @@ def _synthesize_with_agent( # Persistent exec namespace — the agent's "scratch-pad". exec_ns: Dict[str, Any] = { - "trajectories": self._online_trajectories, + "trajectories": trajectories, "np": np, "ParamSpec": ParamSpec, } @@ -232,7 +243,7 @@ def _synthesize_with_agent( # Write data-structure reference for the agent to Read. structs_ref = self._write_structs_reference() - n_trajs = len(self._online_trajectories) + n_trajs = len(trajectories) message = f"""\ Synthesize a process dynamics simulator for this environment. \ There are {n_trajs} trajectories ({len(step_transitions)} step \ @@ -251,7 +262,7 @@ def _synthesize_with_agent( # Load results from saved versioned files. rules, specs = self._load_simulator_from_file( - save_dir, self._online_trajectories) + save_dir, trajectories) if rules is None or specs is None: return @@ -261,18 +272,24 @@ def _synthesize_with_agent( self._process_rules = rules # ── Obtain fitted parameters ──────────────────────────── + # Use a headless env for fitting so the GUI env isn't + # thrashed by thousands of _set_state calls during MCMC. + fit_env = create_new_env(CFG.env, + do_cache=False, + use_gui=False, + skip_process_dynamics=True) if CFG.agent_sim_learn_oracle_sim_params: self._fitted_params = {s.name: s.init_value for s in specs} - env = self._base_env self._fit_mse = compute_mse( lambda s, a, p: apply_rules( # type: ignore[misc] - env.simulate(s, a), rules, p), - step_transitions, self._fitted_params, process_features) + fit_env.simulate(s, a), rules, p), + step_transitions, + self._fitted_params, + process_features) logger.info("Using oracle params (MSE: %.6f).", self._fit_mse) else: self._fitted_params, self._fit_mse = self._fit_parameters( - rules, specs, step_transitions, process_features, - self._base_env) + rules, specs, step_transitions, process_features, fit_env) logger.info("Fitted %d params (MSE: %.6f).", len(specs), self._fit_mse) @@ -372,9 +389,12 @@ def _write_structs_reference(self) -> str: Returns the path the agent should Read. """ - from predicators.structs import ( # pylint: disable=import-outside-toplevel,reimported - Action as _Action, LowLevelTrajectory as _LLT, - Object as _Object, State as _State, Type as _Type) + # pylint: disable=import-outside-toplevel,reimported + from predicators.structs import Action as _Action + from predicators.structs import LowLevelTrajectory as _LLT + from predicators.structs import Object as _Object + from predicators.structs import State as _State + from predicators.structs import Type as _Type source = "\n\n".join( inspect.getsource(cls) From cccb7e229066a9faf9f9b10c0919dcd9001a30ee Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:14:07 +0100 Subject: [PATCH 31/70] Log periodic progress during MCMC parameter fitting Replace the silent run_mcmc call with a manual sample loop that logs step count and best log-probability roughly five times per run, and flushes handlers so the updates appear promptly under buffered logging. --- predicators/code_sim_learning/training.py | 28 +++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index bffb8dd8c..4383aa64f 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -15,7 +15,6 @@ logger = logging.getLogger(__name__) - # Step-level simulator: (State, Action, params_dict) -> {Object: {feat: val}} StepSimulatorFn = Callable[[State, Action, Dict[str, float]], Dict] @@ -64,7 +63,7 @@ def compute_mse( continue v = pred_val.item() if hasattr(pred_val, 'item') else pred_val obs_val = float(s_next_obs.get(obj, feat_name)) - total_se += (v - obs_val) ** 2 + total_se += (v - obs_val)**2 count += 1 # Penalize unpredicted features (model predicts no change). @@ -75,7 +74,7 @@ def compute_mse( continue pred_val = float(s_t.get(obj, feat_name)) obs_val = float(s_next_obs.get(obj, feat_name)) - total_se += (pred_val - obs_val) ** 2 + total_se += (pred_val - obs_val)**2 count += 1 if count == 0: @@ -128,12 +127,10 @@ def log_posterior(theta: np.ndarray) -> float: return -np.inf params = {n: float(theta[i]) for i, n in enumerate(names)} # Broad Gaussian prior centered on init values - log_prior = -0.5 * np.sum( - ((theta - init_values) / prior_sigma) ** 2) + log_prior = -0.5 * np.sum(((theta - init_values) / prior_sigma)**2) # Likelihood - mse = compute_mse(simulator_fn, transitions, - params, process_features) - return log_prior + (-0.5 * mse / (noise_sigma ** 2)) + mse = compute_mse(simulator_fn, transitions, params, process_features) + return log_prior + (-0.5 * mse / (noise_sigma**2)) # Initialize walkers in a small ball around init values. p0 = init_values * (1.0 + 0.01 * np.random.randn(num_walkers, ndim)) @@ -142,7 +139,17 @@ def log_posterior(theta: np.ndarray) -> float: logger.info("Running emcee: %d walkers, %d steps, %d burn-in.", num_walkers, num_steps, burn_in) - sampler.run_mcmc(p0, num_steps, progress=False) + + # Run with periodic progress reports. + report_interval = max(1, num_steps // 5) + for i, _result in enumerate(sampler.sample(p0, iterations=num_steps), + start=1): + if i % report_interval == 0 or i == num_steps: + best_lp = sampler.get_log_prob()[:i].max() + logger.info(" emcee step %d/%d (best log-prob: %.2f)", i, + num_steps, best_lp) + for h in logger.handlers + logging.getLogger().handlers: + h.flush() # Discard burn-in, flatten chains. samples = sampler.get_chain(discard=burn_in, flat=True) @@ -151,6 +158,7 @@ def log_posterior(theta: np.ndarray) -> float: result = FitResult(names=names, samples=samples, log_probs=log_probs) logger.info("emcee done. Posterior mean: %s", - {k: f"{v:.4f}" for k, v in result.point_estimate.items()}) + {k: f"{v:.4f}" + for k, v in result.point_estimate.items()}) return result From ec3b9f3171a136b8f3f9504ceeedfedcb00743a3 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:44:41 +0100 Subject: [PATCH 32/70] Fix mypy and pylint errors for CI compliance Type-annotate **kwargs on PyBullet env __init__ overrides so mypy doesn't flag them. Initialize attrs used by _domain_specific_step in __init__ (pybullet_coffee, pybullet_switch) to silence defined-outside-init. Type-ignore the emcee import. Fix encoding, unused, protected-access, and redefined-outer-name warnings in the sim-learning tests and agent-SDK tooling. --- predicators/agent_sdk/tools.py | 10 +++-- .../approaches/agent_bilevel_approach.py | 2 +- predicators/code_sim_learning/training.py | 2 +- predicators/envs/pybullet_ants.py | 6 +-- predicators/envs/pybullet_balance.py | 2 +- predicators/envs/pybullet_barrier.py | 2 +- predicators/envs/pybullet_blocks.py | 6 +-- predicators/envs/pybullet_boil.py | 2 +- predicators/envs/pybullet_circuit.py | 6 +-- predicators/envs/pybullet_coffee.py | 10 +++-- predicators/envs/pybullet_cover.py | 2 +- predicators/envs/pybullet_fan.py | 2 +- predicators/envs/pybullet_float.py | 2 +- predicators/envs/pybullet_grow.py | 2 +- predicators/envs/pybullet_laser.py | 2 +- predicators/envs/pybullet_magic_bin.py | 2 +- predicators/envs/pybullet_switch.py | 6 ++- .../test_agent_sim_learning_approach.py | 10 ++--- tests/code_sim_learning/test_param_fitting.py | 37 ++++++++++--------- 19 files changed, 61 insertions(+), 52 deletions(-) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index e56812599..aeb15edff 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -1996,7 +1996,7 @@ def create_synthesis_tools( """ import io # pylint: disable=import-outside-toplevel import sys # pylint: disable=import-outside-toplevel - import traceback # pylint: disable=import-outside-toplevel + import traceback # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported from claude_agent_sdk import \ tool # pylint: disable=import-outside-toplevel @@ -2065,7 +2065,7 @@ async def run_python(args: Dict[str, Any]) -> Dict[str, Any]: "properties": {} }, ) - async def evaluate_simulator(args: Dict[str, Any]) -> Dict[str, Any]: + async def evaluate_simulator(_args: Dict[str, Any]) -> Dict[str, Any]: rules = exec_ns.get("PROCESS_RULES") specs = exec_ns.get("PARAM_SPECS") if not isinstance(rules, list) or not rules: @@ -2076,8 +2076,10 @@ async def evaluate_simulator(args: Dict[str, Any]) -> Dict[str, Any]: "run_python to define it first.") try: - fitted_params, mse = (AgentSimLearningApproach._fit_parameters( - rules, specs, step_transitions, process_features, base_env)) + fitted_params, mse = ( + AgentSimLearningApproach._fit_parameters( # pylint: disable=protected-access + rules, specs, step_transitions, process_features, + base_env)) except Exception as e: # pylint: disable=broad-except return _text(f"Error: fit_params failed:\n{e}") diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 6461bea60..1baf550a1 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -152,7 +152,7 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: """Query agent for a plan sketch and parse it.""" sketch_file = CFG.agent_bilevel_plan_sketch_file if sketch_file: - with open(sketch_file, "r") as f: + with open(sketch_file, "r", encoding="utf-8") as f: plan_text = f.read().strip() logging.info("Loaded plan sketch from file: %s", sketch_file) else: diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index 4383aa64f..a69fb2b0c 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -113,7 +113,7 @@ def fit_params( Returns: FitResult with posterior samples and log-probabilities. """ - import emcee # pylint: disable=import-outside-toplevel + import emcee # type: ignore[import-untyped] # pylint: disable=import-outside-toplevel names = [s.name for s in param_specs] init_values = np.array([s.init_value for s in param_specs]) diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index 9d68ec92a..d02063333 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -92,7 +92,7 @@ class PyBulletAntsEnv(PyBulletEnv): def __init__(self, use_gui: bool = False, debug_layout: bool = True, - **kwargs) -> None: + **kwargs: Any) -> None: # Create single robot self._robot = Object("robot", self._robot_type) @@ -228,8 +228,8 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - """Hide unused objects, set attraction points, food colors, and - ant target references.""" + """Hide unused objects, set attraction points, food colors, and ant + target references.""" oov_x, oov_y = self._out_of_view_xy block_objs = state.get_objects(self._food_type) for i in range(len(block_objs), len(self._blocks)): diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 07b1aad06..4206875c6 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -88,7 +88,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _num_blocks_train = CFG.balance_num_blocks_train _num_blocks_test = CFG.balance_num_blocks_test - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Types # bbox_features = ["bbox_left", "bbox_right", # "bbox_upper", "bbox_lower"] diff --git a/predicators/envs/pybullet_barrier.py b/predicators/envs/pybullet_barrier.py index 9a64714e5..c0e98ebe4 100644 --- a/predicators/envs/pybullet_barrier.py +++ b/predicators/envs/pybullet_barrier.py @@ -91,7 +91,7 @@ class PyBulletBarrierEnv(PyBulletEnv): _barrier_type = Type("barrier", ["x", "y", "rot", "height"], sim_features=["id", "base_z"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._switches: List[Object] = [ diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index d3ebfb1bb..d6b5f09ce 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -27,7 +27,7 @@ class PyBulletBlocksEnv(PyBulletEnv, BlocksEnv): _table_pose: ClassVar[Pose3D] = (1.35, 0.75, table_height / 2) _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: super().__init__(use_gui, **kwargs) # Store references self._table_id: int = -1 @@ -95,8 +95,8 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: blk.id = blk_id def _set_domain_specific_state(self, state: State) -> None: - """Set block positions, grasp constraints, out-of-view placement, - ID mapping, and block colors.""" + """Set block positions, grasp constraints, out-of-view placement, ID + mapping, and block colors.""" block_objs = state.get_objects(self._block_type) # Place the relevant blocks diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 3bbf2a2b9..af1a127ce 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -174,7 +174,7 @@ def water_fill_speed(self) -> float: _human_type = Type("human", ["happiness_level"], sim_features=["id", "happiness_level"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create the robot as an Object self._robot = Object("robot", self._robot_type) diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index e1fec79bb..4155c7a9d 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -104,7 +104,7 @@ class PyBulletCircuitEnv(PyBulletEnv): _c_battery_type = Type("c_battery", ["x", "y", "z", "yaw", "pitch", "roll"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) @@ -323,8 +323,8 @@ def _domain_specific_step(self) -> None: # Check basic conditions for turning on the bulb switch_on = self._SwitchedOn_holds(state, [self._battery]) basic_conditions = switch_on and ( - CFG.circuit_light_doesnt_need_battery or self._CircuitClosed_holds( - state, [self._light, self._battery])) + CFG.circuit_light_doesnt_need_battery + or self._CircuitClosed_holds(state, [self._light, self._battery])) # Additional condition: if not using battery_in_box mode, # both C batteries must be in the battery box diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 4d5c221f0..64f66f259 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -217,7 +217,7 @@ def pour_z_offset(cls) -> float: _camera_pitch: ClassVar[float] _camera_target: ClassVar[Pose3D] - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: if CFG.coffee_render_grid_world: # Camera parameters for grid world PyBulletCoffeeEnv._camera_distance = 3 @@ -254,6 +254,11 @@ def __init__(self, use_gui: bool = False, **kwargs) -> None: self._machine_plugged_in_id: Optional[int] = None self._last_jug_liquid_level: float = 0.0 + # Captured in step() before kinematics, consumed by + # _domain_specific_step() to detect twisting motions. + self._pre_step_ee_rpy: Tuple[float, float, float] = (0.0, 0.0, 0.0) + self._last_action: Action = Action(np.zeros(0, dtype=np.float32)) + @property def oracle_proposed_predicates(self) -> Set[Predicate]: """Return the predicates that the oracle can propose.""" @@ -482,8 +487,7 @@ def _domain_specific_step(self) -> None: self._check_and_apply_plug_in_constraint(state) self._handle_machine_on_and_jug_filling(state) self._handle_pouring(state) - self._handle_twisting(state, self._pre_step_ee_rpy, - self._last_action) + self._handle_twisting(state, self._pre_step_ee_rpy, self._last_action) def _update_jug_liquid_position(self) -> None: """If the jug is filled, move its liquid to match the jug's pose. diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index ec6e63501..97d288157 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -59,7 +59,7 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): float]]] = [(0, 0, 0, 1.), (1, 1, 1, 1.)] - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: super().__init__(use_gui, **kwargs) # Store block/target IDs (from initialize_pybullet) so that we can # reset their positions in _set_domain_specific_state(). diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index 5c45eed48..7876d9cdd 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -257,7 +257,7 @@ def get_configuration_dict(cls) -> Dict[str, Any]: # ------------------------------------------------------------------------- # Environment initialization # ------------------------------------------------------------------------- - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: self._robot = Object("robot", self._robot_type) # Fans - create one fan object per side instead of multiple diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index fcad5973a..3e566609e 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -120,7 +120,7 @@ class PyBulletFloatEnv(PyBulletEnv): _block_type = Type("block", ["x", "y", "z", "in_water", "is_held"], sim_features=["id", "is_light"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: self._robot = Object("robot", self._robot_type) self._vessel = Object("vessel", self._vessel_type) self._block0 = Object("block0", self._block_type) diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index 9187ac6cc..2d4f2f9ed 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -110,7 +110,7 @@ class PyBulletGrowEnv(PyBulletEnv): _jug_type = Type("jug", ["x", "y", "z", "rot", "is_held", "r", "g", "b"], sim_features=["id", "init_x", "init_y", "init_z"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create the single robot Object self._robot = Object("robot", self._robot_type) diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index a9ee740a2..0639de35a 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -121,7 +121,7 @@ class PyBulletLaserEnv(PyBulletEnv): ["x", "y", "z", "rot", "split_mirror", "is_held"]) _target_type = Type("target", ["x", "y", "z", "rot", "is_hit"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create environment objects (logic-level) self._robot = Object("robot", self._robot_type) self._station = Object("station", self._station_type) diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index dc755286c..aec2d27a0 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -86,7 +86,7 @@ class PyBulletMagicBinEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale"]) _bin_type = Type("bin", ["x", "y", "z", "rot"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._blocks: List[Object] = [ diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index 8fec02ccc..cefcaa4ef 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -89,7 +89,7 @@ class PyBulletSwitchEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale", "color_count"]) _light_type = Type("light", ["x", "y", "z", "rot", "is_on", "color_index"]) - def __init__(self, use_gui: bool = False, **kwargs) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._power_switch = Object("power_switch", self._power_switch_type) @@ -100,6 +100,7 @@ def __init__(self, use_gui: bool = False, **kwargs) -> None: # Track previous switch states for edge detection self._prev_color_switch_on: bool = False + self._pre_step_color_count: int = 0 # Predicates self._PowerOn = Predicate("PowerOn", [self._power_switch_type], @@ -237,7 +238,8 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: - """Set switch positions, tracking vars, color count, and light visual.""" + """Set switch positions, tracking vars, color count, and light + visual.""" power_on = state.get(self._power_switch, "is_on") > 0.5 self._set_switch_state(self._power_switch, power_on) diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index 31528aa69..d9d60734a 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -125,7 +125,7 @@ def _parse_sketch_from_file( objects: Sequence[Object], ) -> List[_SketchStep]: """Parse a plan sketch from a text file, same as agent_bilevel_approach.""" - with open(sketch_file, "r") as f: + with open(sketch_file, "r", encoding="utf-8") as f: plan_text = f.read().strip() # Phase 1: parse options + objects (no continuous params) @@ -165,7 +165,7 @@ def _parse_sketch_from_file( continue pred = pred_map[pred_name] try: - objs = [obj_map[n] for n in obj_names] + objs: Sequence[Object] = [obj_map[n] for n in obj_names] except KeyError: continue if len(objs) != len(pred.types): @@ -309,7 +309,7 @@ def validate_fn(idx, _pre, _opt, post_state, _n_acts): @pytest.mark.parametrize("model_type", ["oracle", "combined"]) def test_boil_sketch_refinement(model_type): """Test that backtracking refinement solves the first test task.""" - env, task, options_dict, objects_dict = _setup_env() + env, task, _options_dict, _objects_dict = _setup_env() predicates = env.predicates options = get_gt_options(env.get_name()) @@ -372,5 +372,5 @@ def fwd_validate_fn(i, _s, _o, post, _n): if __name__ == "__main__": import sys - model = sys.argv[1] if len(sys.argv) > 1 else "oracle" - test_boil_sketch_refinement(model) + _model = sys.argv[1] if len(sys.argv) > 1 else "oracle" + test_boil_sketch_refinement(_model) diff --git a/tests/code_sim_learning/test_param_fitting.py b/tests/code_sim_learning/test_param_fitting.py index 82853b9ce..742f795d9 100644 --- a/tests/code_sim_learning/test_param_fitting.py +++ b/tests/code_sim_learning/test_param_fitting.py @@ -9,20 +9,20 @@ import re from typing import Dict, List, Optional, Sequence, Set, Tuple -import predicators.approaches # noqa: F401 (bootstrap circular import) import numpy as np +import predicators.approaches # noqa: F401 # pylint: disable=unused-import from predicators import utils from predicators.approaches.agent_bilevel_approach import _SketchStep from predicators.code_sim_learning.training import ParamSpec, fit_params from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_options -from predicators.ground_truth_models.boil.gt_simulator import ( - BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features) +from predicators.ground_truth_models.boil.gt_simulator import \ + BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features from predicators.option_model import _OracleOptionModel from predicators.planning import run_backtracking_refinement -from predicators.structs import Action, GroundAtom, Object, \ - ParameterizedOption, Predicate, State +from predicators.structs import Action, GroundAtom, LowLevelTrajectory, \ + Object, ParameterizedOption, Predicate, State logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def _build_oracle_model(env): options = get_gt_options(env.get_name()) oracle = _OracleOptionModel(options, env.simulate) preds = env.predicates - oracle._abstract_function = lambda s: utils.abstract(s, preds) + oracle._abstract_function = lambda s: utils.abstract(s, preds) # pylint: disable=protected-access return oracle @@ -72,7 +72,7 @@ def _parse_sketch_from_file( objects: Sequence[Object], ) -> List[_SketchStep]: """Parse a plan sketch from a text file.""" - with open(sketch_file, "r") as f: + with open(sketch_file, "r", encoding="utf-8") as f: plan_text = f.read().strip() parsed = utils.parse_model_output_into_option_plan( @@ -110,7 +110,7 @@ def _parse_sketch_from_file( continue pred = pred_map[pred_name] try: - objs = [obj_map[n] for n in obj_names] + objs: Sequence[Object] = [obj_map[n] for n in obj_names] except KeyError: continue if len(objs) != len(pred.types): @@ -186,7 +186,10 @@ def _informed_place_params(pre_state, sketch, step_idx, rng, n): def _generate_oracle_transitions( - env, task, options, oracle, + env, + task, + options, + oracle, ) -> List[Tuple[State, Action, State]]: """Generate (s, a, s') triples by running the oracle on the boil task. @@ -200,8 +203,7 @@ def _generate_oracle_transitions( n = len(sketch) rng = np.random.default_rng(0) max_tries = [ - 500 if step.option.params_space.shape[0] > 0 else 1 - for step in sketch + 500 if step.option.params_space.shape[0] > 0 else 1 for step in sketch ] def sample_fn(idx, state, rng_): @@ -231,7 +233,7 @@ def validate_fn(idx, _pre, _opt, post_state, _n_acts): # Collect trajectories during refinement (not replay, since # PyBullet state reconstruction is imperfect). - step_trajectories: Dict[int, object] = {} + step_trajectories: Dict[int, LowLevelTrajectory] = {} orig_validate = validate_fn @@ -241,7 +243,7 @@ def collecting_validate_fn(idx, pre, opt, post_state, n_acts): step_trajectories[idx] = oracle.last_trajectory return ok, reason - plan, success, _ = run_backtracking_refinement( + _plan, success, _ = run_backtracking_refinement( init_state=task.init, option_model=oracle, n_steps=n, @@ -276,7 +278,7 @@ def test_emcee_recovers_rate_params(): logger.info("Generated %d oracle transitions.", len(transitions)) - def simulator_fn(state, action, params): + def simulator_fn(state, _action, params): updates = {} for rule in PROCESS_RULES: updates = rule(state, updates, params) @@ -285,8 +287,7 @@ def simulator_fn(state, action, params): # Perturb rate params (50%), keep others at true. param_specs = [] for s in BOIL_PARAM_SPECS: - if s.name in ("water_fill_speed", "heating_speed", - "happiness_speed"): + if s.name in ("water_fill_speed", "heating_speed", "happiness_speed"): param_specs.append(ParamSpec(s.name, s.init_value * 0.5)) else: param_specs.append(s) @@ -307,8 +308,8 @@ def simulator_fn(state, action, params): for name, val in fitted.items(): true_val = GT_PARAMS[name] rel_err = abs(val - true_val) / max(true_val, 1e-8) - logger.info(" %s: fitted=%.4f, true=%.4f, rel_err=%.1f%%", - name, val, true_val, rel_err * 100) + logger.info(" %s: fitted=%.4f, true=%.4f, rel_err=%.1f%%", name, val, + true_val, rel_err * 100) for name in ["water_fill_speed", "heating_speed", "happiness_speed"]: true_val = GT_PARAMS[name] From e8e3675080cb292db6d039b71e0106661751d305 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Fri, 17 Apr 2026 12:44:45 +0100 Subject: [PATCH 33/70] Apply yapf, isort, and docformatter across the codebase --- predicators/agent_sdk/bilevel_sketch.py | 4 ++-- .../agent_abstraction_learning_approach.py | 2 +- predicators/approaches/agent_planner_approach.py | 2 +- predicators/code_sim_learning/__init__.py | 2 +- predicators/envs/__init__.py | 1 - predicators/envs/pybullet_domino/composed_env.py | 12 +++++++----- predicators/explorers/agent_plan_explorer.py | 9 +++++---- predicators/ground_truth_models/__init__.py | 6 +++--- predicators/option_model.py | 15 ++++++++------- 9 files changed, 28 insertions(+), 25 deletions(-) diff --git a/predicators/agent_sdk/bilevel_sketch.py b/predicators/agent_sdk/bilevel_sketch.py index 672f1bbd7..25135af86 100644 --- a/predicators/agent_sdk/bilevel_sketch.py +++ b/predicators/agent_sdk/bilevel_sketch.py @@ -177,8 +177,8 @@ def parse_subgoal_annotations( """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. Returns a list parallel to the option lines in ``text``. Each entry - is ``None`` for a line with no annotation, or - ``(positive_atoms, negative_atoms)`` otherwise. + is ``None`` for a line with no annotation, or ``(positive_atoms, + negative_atoms)`` otherwise. """ pred_map = {p.name: p for p in predicates} obj_map = {o.name: o for o in objects} diff --git a/predicators/approaches/agent_abstraction_learning_approach.py b/predicators/approaches/agent_abstraction_learning_approach.py index 96e4ab11f..bf24a5def 100644 --- a/predicators/approaches/agent_abstraction_learning_approach.py +++ b/predicators/approaches/agent_abstraction_learning_approach.py @@ -13,10 +13,10 @@ from gym.spaces import Box from predicators import utils +from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.agent_sdk.proposal_parser import ProposalBundle, \ build_exec_context, exec_code_safely from predicators.approaches.agent_planner_approach import AgentPlannerApproach -from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.approaches.pp_online_process_learning_approach import \ OnlineProcessLearningAndPlanningApproach from predicators.approaches.pp_predicate_invention_approach import \ diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 88d4a4698..5797f6276 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -22,8 +22,8 @@ from gym.spaces import Box from predicators import utils -from predicators.approaches import ApproachFailure from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin +from predicators.approaches import ApproachFailure from predicators.approaches.base_approach import BaseApproach from predicators.explorers import create_explorer from predicators.explorers.base_explorer import BaseExplorer diff --git a/predicators/code_sim_learning/__init__.py b/predicators/code_sim_learning/__init__.py index 685d11353..5fba924ac 100644 --- a/predicators/code_sim_learning/__init__.py +++ b/predicators/code_sim_learning/__init__.py @@ -1 +1 @@ -"""Compositional world modeling via code""" +"""Compositional world modeling via code.""" diff --git a/predicators/envs/__init__.py b/predicators/envs/__init__.py index a986a0628..2510edd60 100644 --- a/predicators/envs/__init__.py +++ b/predicators/envs/__init__.py @@ -1,7 +1,6 @@ """Handle creation of environments.""" import logging - from typing import Any from predicators import utils diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index 04f0de983..34aa3da41 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -427,8 +427,7 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: num_pivots_max=max_pivots, workspace_bounds=workspace_bounds) - super().__init__(components=[domino_comp], use_gui=use_gui, - **kwargs) + super().__init__(components=[domino_comp], use_gui=use_gui, **kwargs) @classmethod def get_name(cls) -> str: @@ -468,7 +467,8 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: table_height=self.table_height) super().__init__(components=[domino_comp, fan_comp, ball_comp], - use_gui=use_gui, **kwargs) + use_gui=use_gui, + **kwargs) @classmethod def get_name(cls) -> str: @@ -529,7 +529,8 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: super().__init__( components=[domino_comp, fan_comp, ball_comp, ramp_comp], - use_gui=use_gui, **kwargs) + use_gui=use_gui, + **kwargs) @classmethod def get_name(cls) -> str: @@ -597,7 +598,8 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: super().__init__(components=[ domino_comp, fan_comp, ball_comp, ramp_comp, stairs_comp ], - use_gui=use_gui, **kwargs) + use_gui=use_gui, + **kwargs) # Store reference to stairs component self._stairs_component = stairs_comp diff --git a/predicators/explorers/agent_plan_explorer.py b/predicators/explorers/agent_plan_explorer.py index f693c273f..46fb2f98b 100644 --- a/predicators/explorers/agent_plan_explorer.py +++ b/predicators/explorers/agent_plan_explorer.py @@ -1,9 +1,10 @@ """Agent plan explorer: Claude agent generates grounded option plans. -Produces fully-grounded option plans (including continuous parameters) and -rolls them out in the real environment. Unlike ``AgentBilevelExplorer``, it -does not run backtracking refinement against a learned option model — the -agent is expected to provide complete parameters itself. +Produces fully-grounded option plans (including continuous parameters) +and rolls them out in the real environment. Unlike +``AgentBilevelExplorer``, it does not run backtracking refinement +against a learned option model — the agent is expected to provide +complete parameters itself. """ import logging diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index 3ca0f04ca..e1084954b 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -276,9 +276,9 @@ def get_gt_processes(env_name: str, def get_gt_simulator(env_name: str) -> tuple: """Load ground-truth process rules and param specs for an env. - Returns ``(rules, param_specs)`` where *rules* is a list of - process rule functions and *param_specs* is a list of - ``ParamSpec`` objects whose ``init_value`` is the GT value. + Returns ``(rules, param_specs)`` where *rules* is a list of process + rule functions and *param_specs* is a list of ``ParamSpec`` objects + whose ``init_value`` is the GT value. """ gt_name = _normalize_env_name_for_gt(env_name) for cls in utils.get_all_subclasses(GroundTruthSimulatorFactory): diff --git a/predicators/option_model.py b/predicators/option_model.py index 1a3826efb..1ca608393 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -20,11 +20,13 @@ ParameterizedOption, State, _Option -def _check_wait_termination(option: _Option, state: State, - last_state: State, +def _check_wait_termination(option: _Option, state: State, last_state: State, abstract_fn: Callable[[State], Set]) -> bool: """Check if a Wait option should terminate based on target atoms or atom - change. Returns True if it should terminate.""" + change. + + Returns True if it should terminate. + """ result = utils.check_wait_target_atoms(option, state, abstract_fn) if result is True: logging.info("Wait terminating: target atoms satisfied") @@ -33,10 +35,9 @@ def _check_wait_termination(option: _Option, state: State, cur_atoms = abstract_fn(state) prev_atoms = abstract_fn(last_state) if cur_atoms != prev_atoms: - logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") + logging.info(f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") return True return False From 328b4d7eaa1710c001f722a2a0569ad8811a9ef6 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 09:49:09 -0300 Subject: [PATCH 34/70] Inline approach configs into parent files in predicatorv3 --- scripts/configs/predicatorv3/agents.yaml | 55 ++++++++++++++++++- .../predicatorv3/approaches/agents.yaml | 54 ------------------ .../predicatorv3/approaches/oracle.yaml | 15 ----- scripts/configs/predicatorv3/oracle.yaml | 16 +++++- 4 files changed, 69 insertions(+), 71 deletions(-) delete mode 100644 scripts/configs/predicatorv3/approaches/agents.yaml delete mode 100644 scripts/configs/predicatorv3/approaches/oracle.yaml diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index d31968051..291d64160 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -3,5 +3,58 @@ --- includes: - common.yaml - - approaches/agents.yaml - envs/all.yaml +APPROACHES: + # agent_planner: + # NAME: "agent_planner" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_sdk_max_agent_turns_per_iteration: 50 + # agent_planner_use_scratchpad: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: True + # agent_bilevel: + # NAME: "agent_bilevel" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_sdk_max_agent_turns_per_iteration: 50 + # agent_planner_use_scratchpad: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: True + # agent_bilevel_log_state: False + # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + agent_sim_learning: + NAME: "agent_sim_learning" + FLAGS: + explorer: "agent_bilevel" + demonstrator: "oracle_process_planning" + terminate_on_goal_reached_and_option_terminated: True + agent_sdk_use_local_sandbox: True + option_model_terminate_on_repeat: False + agent_sdk_max_agent_turns_per_iteration: 50 + agent_planner_use_scratchpad: False + agent_planner_use_visualize_state: True + agent_planner_use_annotate_scene: True + option_model_use_gui: True + agent_bilevel_log_state: False + agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + skip_test_until_last_ite_or_early_stopping: True + # agent_option_learning: + # NAME: "agent_option_learning" + # FLAGS: + # explorer: "agent_plan" + # option_learner: "agent" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # agent_sdk_max_agent_turns_per_iteration: 50 diff --git a/scripts/configs/predicatorv3/approaches/agents.yaml b/scripts/configs/predicatorv3/approaches/agents.yaml deleted file mode 100644 index 52e0f3958..000000000 --- a/scripts/configs/predicatorv3/approaches/agents.yaml +++ /dev/null @@ -1,54 +0,0 @@ -APPROACHES: - # agent_planner: - # NAME: "agent_planner" - # FLAGS: - # explorer: "agent_plan" - # demonstrator: "oracle_process_planning" - # terminate_on_goal_reached_and_option_terminated: True - # agent_sdk_use_local_sandbox: True - # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 - # agent_planner_use_scratchpad: False - # agent_planner_use_visualize_state: True - # agent_planner_use_annotate_scene: True - # option_model_use_gui: True - # agent_bilevel: - # NAME: "agent_bilevel" - # FLAGS: - # explorer: "agent_plan" - # demonstrator: "oracle_process_planning" - # terminate_on_goal_reached_and_option_terminated: True - # agent_sdk_use_local_sandbox: True - # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 - # agent_planner_use_scratchpad: False - # agent_planner_use_visualize_state: True - # agent_planner_use_annotate_scene: True - # option_model_use_gui: True - # agent_bilevel_log_state: False - # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" - agent_sim_learning: - NAME: "agent_sim_learning" - FLAGS: - explorer: "agent_bilevel" - demonstrator: "oracle_process_planning" - terminate_on_goal_reached_and_option_terminated: True - agent_sdk_use_local_sandbox: True - option_model_terminate_on_repeat: False - agent_sdk_max_agent_turns_per_iteration: 50 - agent_planner_use_scratchpad: False - agent_planner_use_visualize_state: True - agent_planner_use_annotate_scene: True - option_model_use_gui: True - agent_bilevel_log_state: False - agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" - skip_test_until_last_ite_or_early_stopping: True - # agent_option_learning: - # NAME: "agent_option_learning" - # FLAGS: - # explorer: "agent_plan" - # option_learner: "agent" - # demonstrator: "oracle_process_planning" - # terminate_on_goal_reached_and_option_terminated: True - # agent_sdk_use_local_sandbox: True - # agent_sdk_max_agent_turns_per_iteration: 50 diff --git a/scripts/configs/predicatorv3/approaches/oracle.yaml b/scripts/configs/predicatorv3/approaches/oracle.yaml deleted file mode 100644 index 7501a44b3..000000000 --- a/scripts/configs/predicatorv3/approaches/oracle.yaml +++ /dev/null @@ -1,15 +0,0 @@ -APPROACHES: - oracle: - NAME: "oracle_process_planning" - FLAGS: - demonstrator: "oracle_process_planning" - terminate_on_goal_reached_and_option_terminated: True - bilevel_plan_without_sim: True - # human_interaction: - # NAME: "human_interaction" - # FLAGS: - # human_interaction_approach_use_scripted_option: True - # human_interaction_approach_use_all_options: True - # scripted_option_dir: "scripted_option_policies" - # skill_phase_use_motion_planning: True - # terminate_on_goal_reached_and_option_terminated: True diff --git a/scripts/configs/predicatorv3/oracle.yaml b/scripts/configs/predicatorv3/oracle.yaml index 1253eb4c1..45abe8371 100644 --- a/scripts/configs/predicatorv3/oracle.yaml +++ b/scripts/configs/predicatorv3/oracle.yaml @@ -3,5 +3,19 @@ --- includes: - common.yaml - - approaches/oracle.yaml - envs/all.yaml +APPROACHES: + oracle: + NAME: "oracle_process_planning" + FLAGS: + demonstrator: "oracle_process_planning" + terminate_on_goal_reached_and_option_terminated: True + bilevel_plan_without_sim: True + # human_interaction: + # NAME: "human_interaction" + # FLAGS: + # human_interaction_approach_use_scripted_option: True + # human_interaction_approach_use_all_options: True + # scripted_option_dir: "scripted_option_policies" + # skill_phase_use_motion_planning: True + # terminate_on_goal_reached_and_option_terminated: True From 6735ac835d557ca38edbbd16ef187ba8babdf8cf Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 17:10:42 -0300 Subject: [PATCH 35/70] Preserve robot joint config across PyBullet state save/restore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a held object's grasp constraint is recreated via _set_state, the gripper frame must match the original world pose exactly — otherwise the recorded base_link->object offset is rotated and the object lands at the wrong world position when the gripper next moves. The State representation only carries (x, y, z, tilt, wrist), so IK during reset can pick a different wrist-roll solution and corrupt the constraint. Thread joint_positions from PyBulletState.simulator_state through reset_state so we skip IK and restore the exact arm configuration. Falls back to IK when joints aren't available (plain State). Also wire wait-termination so refinement and execution can stop Wait when expected atoms hold instead of running to max_num_steps_option_rollout: set _abstract_function on the option model in BilevelPlanningApproach (mirrors AgentPlannerApproach), pass abstract_function into option_plan_to_policy in BilevelProcessPlanningApproach, and inject wait_target_atoms per sample in run_low_level_search. --- .../approaches/bilevel_planning_approach.py | 13 ++++++- .../approaches/process_planning_approach.py | 4 ++- predicators/envs/pybullet_env.py | 36 +++++++++++++++++-- predicators/planning.py | 6 ++++ .../pybullet_helpers/robots/single_arm.py | 32 +++++++++++------ 5 files changed, 77 insertions(+), 14 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 33ba29167..a0c288bdd 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -5,7 +5,7 @@ """ import abc import logging -from typing import Any, Callable, List, Optional, Set, Tuple +from typing import Any, Callable, List, Optional, Set, Tuple, cast from gym.spaces import Box @@ -47,6 +47,17 @@ def __init__(self, if option_model is None: option_model = create_option_model(CFG.option_model_name) self._option_model = option_model + # Let the option model terminate Wait on atom change. Without + # this, Wait runs to max_num_steps_option_rollout during + # refinement and the step is rejected for "exceeded individual + # horizon", even when the expected atoms have already become + # true. Mirrors AgentPlannerApproach.__init__. + if CFG.wait_option_terminate_on_atom_change: + preds = self._get_current_predicates() + cast( # pylint: disable=protected-access + Any, self._option_model + )._abstract_function = \ + lambda s, _p=preds: utils.abstract(s, _p) self._num_calls = 0 self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py index 40a8e644f..65770ce06 100644 --- a/predicators/approaches/process_planning_approach.py +++ b/predicators/approaches/process_planning_approach.py @@ -119,7 +119,9 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._last_option_plan = option_plan self._last_process_plan = process_plan # pylint: enable=attribute-defined-outside-init - policy = utils.option_plan_to_policy(option_plan) + policy = utils.option_plan_to_policy( + option_plan, + abstract_function=lambda s: utils.abstract(s, preds)) self._save_metrics(metrics, processes, preds) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 910fb3f68..31787baf8 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -47,6 +47,7 @@ from predicators.envs import BaseEnv from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion +from predicators.pybullet_helpers.joint import JointPositions from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.objects import update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot, \ @@ -433,8 +434,14 @@ def _set_state(self, state: State) -> None: self._held_obj_to_base_link = None self._held_obj_id = None - # 2) Reset robot pose - self._pybullet_robot.reset_state(self._extract_robot_state(state)) + # 2) Reset robot pose. Prefer exact joint positions when the + # State carries them in simulator_state — IK from (x, y, z, + # tilt, wrist) drops wrist roll, which corrupts the held- + # object offset that _create_grasp_constraint records below. + joint_positions = self._extract_robot_joint_positions(state) + self._pybullet_robot.reset_state( + self._extract_robot_state(state), + joint_positions=joint_positions) # 3) Reset all known objects (position, orientation, etc.) for obj in self._objects: @@ -570,6 +577,31 @@ def get_pos_feature( return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) + def _extract_robot_joint_positions( + self, state: State) -> Optional[JointPositions]: + """Pull arm joint positions out of a State's simulator_state. + + Returns None when the State doesn't carry them (plain State, or + a PyBulletState whose simulator_state has a different shape than + this robot's arm). Callers fall back to IK in that case. + """ + sim_state = getattr(state, "simulator_state", None) + jp: Any + if isinstance(sim_state, dict): + jp = sim_state.get("joint_positions") + else: + # Legacy: simulator_state is the joint_positions list itself. + jp = sim_state + if jp is None: + return None + try: + jp_list = list(jp) + except TypeError: + return None + if len(jp_list) != len(self._pybullet_robot.arm_joints): + return None + return cast(JointPositions, jp_list) + @classmethod def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, finger_state: float) -> float: diff --git a/predicators/planning.py b/predicators/planning.py index 14f3889a6..4aaf9fc80 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -660,6 +660,12 @@ def sample_fn(idx: int, state: State, discovered_failures[idx] = None metrics["num_samples"] += 1 option = skeleton[idx].sample_option(state, task.goal, rng_) + # Inject Wait target atoms so Wait terminates as soon as the + # expected atoms hold rather than running to + # max_num_steps_option_rollout. Without this, refinement keeps + # hitting "exceeded individual horizon" even when heating / + # filling / etc. has already completed. + utils.inject_wait_targets_for_option(option, idx, atoms_sequence) logging.info(f"Running option {option}") return option diff --git a/predicators/pybullet_helpers/robots/single_arm.py b/predicators/pybullet_helpers/robots/single_arm.py index a0ae333c4..f965d479d 100644 --- a/predicators/pybullet_helpers/robots/single_arm.py +++ b/predicators/pybullet_helpers/robots/single_arm.py @@ -239,11 +239,20 @@ def initial_joint_positions(self) -> JointPositions: joint_positions[self.right_finger_joint_idx] = self.open_fingers return joint_positions - def reset_state(self, robot_state: Array) -> None: + def reset_state( + self, + robot_state: Array, + joint_positions: Optional[JointPositions] = None, + ) -> None: """Reset the robot state to match the input state. The robot_state corresponds to the State vector for the robot - object. + object. If joint_positions is provided, the arm joints are set + directly from it; otherwise IK is run from the EE pose, which + loses information not encoded in (x, y, z, tilt, wrist) — most + importantly wrist roll. Preserving exact joints is required for + held-object grasps to round-trip through state save/restore + without geometric drift. """ rx, ry, rz, qx, qy, qz, qw, rf = robot_state p.resetBasePositionAndOrientation( @@ -252,14 +261,17 @@ def reset_state(self, robot_state: Array) -> None: self._base_pose.orientation, physicsClientId=self.physics_client_id, ) - # First, reset the joint values to initial joint positions, - # so that IK is consistent (less sensitive to initialization). - self.set_joints(self.initial_joint_positions) - - # Now run IK to get to the actual starting rx, ry, rz. We use - # validate=True to ensure that this initialization works. - pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) - self.inverse_kinematics(pose, validate=True) + if joint_positions is not None: + self.set_joints(list(joint_positions)) + else: + # First, reset the joint values to initial joint positions, + # so that IK is consistent (less sensitive to initialization). + self.set_joints(self.initial_joint_positions) + + # Now run IK to get to the actual starting rx, ry, rz. We use + # validate=True to ensure that this initialization works. + pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) + self.inverse_kinematics(pose, validate=True) # Handle setting the robot finger joints. for finger_id in [self.left_finger_id, self.right_finger_id]: From ebb33046584922e746015cfa651fe4f84a967ded Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 17:31:57 -0300 Subject: [PATCH 36/70] Add 'emcee' to the list of install_requires in setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 859d05e63..502446850 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "claude-agent-sdk", "nest_asyncio", "mara_robosim@git+https://github.com/yichao-liang/mara-robosim.git", + "emcee", ], include_package_data=True, extras_require={ From 0bc523483485d84995264bd79378440bd3c10968 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 22:27:11 -0300 Subject: [PATCH 37/70] Force PyBullet FK refresh and skip redundant finger snap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After resetJointState, PyBullet's getLinkState returns a stale link pose from the previous FK cache, producing 50-500μm drift in the EE pose readback. Pass computeForwardKinematics=1 so world poses are recomputed from current joints on every call. Also skip the explicit finger reset in reset_state when joint_positions are provided: arm_joints already includes the finger joints, so set_joints has restored them to their exact continuous values, and the subsequent loop was overwriting them with the discrete-snapped value from _fingers_state_to_joint. The finger reset still runs on the IK path where set_joints leaves fingers untouched. Together these eliminate the "Could not reconstruct state exactly in reset" warning noise (24 -> 0 on the boil-oracle run). --- predicators/pybullet_helpers/link.py | 19 +++++++++++-------- .../pybullet_helpers/robots/single_arm.py | 15 +++++++++------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/predicators/pybullet_helpers/link.py b/predicators/pybullet_helpers/link.py index b8680c408..c9f90adcd 100644 --- a/predicators/pybullet_helpers/link.py +++ b/predicators/pybullet_helpers/link.py @@ -41,15 +41,18 @@ def get_link_state( ) -> LinkState: """Get the state of a link in a given body. - Note: it is unclear what the computeForwardKinematics flag does as we - could not reproduce any difference in the resulting Cartesian world - position or orientation of the link after setting joint positions - with both the flag set to False or True. - - The default PyBullet flag is computeForwardKinematics=False, so we - will stick to that. + With ``computeForwardKinematics=False`` (PyBullet's default), + getLinkState returns the link's Cartesian pose from the last + physics-step / FK cache, which is stale immediately after + ``resetJointState``. After a state save/restore round-trip this + showed up as ~50-500μm drift in the reported EE pose. We pass + ``computeForwardKinematics=1`` so the world pose is recomputed + from current joint positions on every call. """ - link_state = p.getLinkState(body, link, physicsClientId=physics_client_id) + link_state = p.getLinkState(body, + link, + computeForwardKinematics=1, + physicsClientId=physics_client_id) return LinkState(*link_state) diff --git a/predicators/pybullet_helpers/robots/single_arm.py b/predicators/pybullet_helpers/robots/single_arm.py index f965d479d..454b1f7be 100644 --- a/predicators/pybullet_helpers/robots/single_arm.py +++ b/predicators/pybullet_helpers/robots/single_arm.py @@ -262,6 +262,9 @@ def reset_state( physicsClientId=self.physics_client_id, ) if joint_positions is not None: + # arm_joints includes fingers, so set_joints already + # restored both — skip the snapped-finger overwrite below + # so continuous finger values round-trip cleanly. self.set_joints(list(joint_positions)) else: # First, reset the joint values to initial joint positions, @@ -273,12 +276,12 @@ def reset_state( pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) self.inverse_kinematics(pose, validate=True) - # Handle setting the robot finger joints. - for finger_id in [self.left_finger_id, self.right_finger_id]: - p.resetJointState(self.robot_id, - finger_id, - rf, - physicsClientId=self.physics_client_id) + # IK does not touch fingers, so snap them from the EE state. + for finger_id in [self.left_finger_id, self.right_finger_id]: + p.resetJointState(self.robot_id, + finger_id, + rf, + physicsClientId=self.physics_client_id) def get_state(self) -> Array: """Get the robot state vector based on the current PyBullet state. From e84d7885c69a8a96ee5cce145ada27e54c5d5aed Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 22:36:20 -0300 Subject: [PATCH 38/70] Apply yapf/docformatter to satisfy CI autoformat check --- predicators/envs/pybullet_env.py | 5 ++--- predicators/pybullet_helpers/link.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 31787baf8..62dc75f68 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -439,9 +439,8 @@ def _set_state(self, state: State) -> None: # tilt, wrist) drops wrist roll, which corrupts the held- # object offset that _create_grasp_constraint records below. joint_positions = self._extract_robot_joint_positions(state) - self._pybullet_robot.reset_state( - self._extract_robot_state(state), - joint_positions=joint_positions) + self._pybullet_robot.reset_state(self._extract_robot_state(state), + joint_positions=joint_positions) # 3) Reset all known objects (position, orientation, etc.) for obj in self._objects: diff --git a/predicators/pybullet_helpers/link.py b/predicators/pybullet_helpers/link.py index c9f90adcd..b29d327da 100644 --- a/predicators/pybullet_helpers/link.py +++ b/predicators/pybullet_helpers/link.py @@ -46,8 +46,8 @@ def get_link_state( physics-step / FK cache, which is stale immediately after ``resetJointState``. After a state save/restore round-trip this showed up as ~50-500μm drift in the reported EE pose. We pass - ``computeForwardKinematics=1`` so the world pose is recomputed - from current joint positions on every call. + ``computeForwardKinematics=1`` so the world pose is recomputed from + current joint positions on every call. """ link_state = p.getLinkState(body, link, From 8333b0fb6fa90b93061a949fc7f6e520e2a86174 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Tue, 28 Apr 2026 22:49:24 -0300 Subject: [PATCH 39/70] Configure predicatorv3 demos for offline-only sim-learning runs common.yaml: switch to one demonstration per task with no online learning cycle so launch_simp.py exercises only the offline pipeline. agents.yaml (agent_sim_learning): turn on oracle_sim_program with oracle_sim_params disabled so synthesis fits parameters but starts from the ground-truth program structure. --- scripts/configs/predicatorv3/agents.yaml | 2 ++ scripts/configs/predicatorv3/common.yaml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index 291d64160..a55df02c0 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -49,6 +49,8 @@ APPROACHES: agent_bilevel_log_state: False agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" skip_test_until_last_ite_or_early_stopping: True + agent_sim_learn_oracle_sim_program: True + agent_sim_learn_oracle_sim_params: False # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: diff --git a/scripts/configs/predicatorv3/common.yaml b/scripts/configs/predicatorv3/common.yaml index cbb09dc4c..581e5dd43 100644 --- a/scripts/configs/predicatorv3/common.yaml +++ b/scripts/configs/predicatorv3/common.yaml @@ -9,8 +9,8 @@ ARGS: # - "make_test_images" # query images # - "save_atoms" FLAGS: - max_initial_demos: 0 - num_online_learning_cycles: 1 + max_initial_demos: 1 + num_online_learning_cycles: 0 online_nsrt_learning_requests_per_cycle: 1 skill_phase_use_motion_planning: True max_num_steps_interaction_request: 300 From 0b6a4b0972596c8aba1e019da1cf5e012f9f9331 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Wed, 29 Apr 2026 21:06:40 +0100 Subject: [PATCH 40/70] Add jug orientation handling in PyBulletBoilEnv --- predicators/envs/pybullet_boil.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index af1a127ce..77cb1f805 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -1362,6 +1362,8 @@ def _create_liquid_for_jug( cx = state.get(jug, "x") cy = state.get(jug, "y") cz = self.z_lb + liquid_height / 2 + 0.02 # sits on table + jug_rot = state.get(jug, "rot") + orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) color = self.water_color return create_pybullet_block(color=color, @@ -1369,6 +1371,7 @@ def _create_liquid_for_jug( mass=0.01, friction=0.5, position=(cx, cy, cz), + orientation=orientation, physics_client_id=self._physics_client_id) From 1b6c5102db53800c26c19d92626a1c364a35cf58 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Thu, 30 Apr 2026 11:18:43 +0100 Subject: [PATCH 41/70] Revert getLinkState to PyBullet default (no computeForwardKinematics flag) Investigation found no measurable difference in reported Cartesian world position or orientation whether the flag is True or False, so the override introduced earlier was not needed. --- predicators/pybullet_helpers/link.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/predicators/pybullet_helpers/link.py b/predicators/pybullet_helpers/link.py index b29d327da..b8680c408 100644 --- a/predicators/pybullet_helpers/link.py +++ b/predicators/pybullet_helpers/link.py @@ -41,18 +41,15 @@ def get_link_state( ) -> LinkState: """Get the state of a link in a given body. - With ``computeForwardKinematics=False`` (PyBullet's default), - getLinkState returns the link's Cartesian pose from the last - physics-step / FK cache, which is stale immediately after - ``resetJointState``. After a state save/restore round-trip this - showed up as ~50-500μm drift in the reported EE pose. We pass - ``computeForwardKinematics=1`` so the world pose is recomputed from - current joint positions on every call. + Note: it is unclear what the computeForwardKinematics flag does as we + could not reproduce any difference in the resulting Cartesian world + position or orientation of the link after setting joint positions + with both the flag set to False or True. + + The default PyBullet flag is computeForwardKinematics=False, so we + will stick to that. """ - link_state = p.getLinkState(body, - link, - computeForwardKinematics=1, - physicsClientId=physics_client_id) + link_state = p.getLinkState(body, link, physicsClientId=physics_client_id) return LinkState(*link_state) From f0b4692ecbb9168a1d8713dd0d57e3010c1418b7 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Thu, 30 Apr 2026 11:18:55 +0100 Subject: [PATCH 42/70] Add lo/hi bounds to ParamSpec and skip-MCMC support in fit_params ParamSpec gains optional lo/hi fields for clamping sampled values. fit_params now reads num_steps from CFG.code_sim_learning_num_mcmc_steps by default; passing 0 (or setting the flag to 0) skips emcee entirely and returns the initial parameter values as the fit result. burn_in is also clamped to num_steps-1 to avoid emcee errors on very short runs. Adds a test covering the skip-MCMC path via CFG. --- predicators/code_sim_learning/training.py | 25 ++++++++++++++++++----- predicators/settings.py | 6 ++++++ tests/code_sim_learning/test_training.py | 23 +++++++++++++++++++++ 3 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 tests/code_sim_learning/test_training.py diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index a69fb2b0c..8ff469890 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -7,10 +7,11 @@ import logging from dataclasses import dataclass -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Optional, Tuple import numpy as np +from predicators.settings import CFG from predicators.structs import Action, State logger = logging.getLogger(__name__) @@ -25,6 +26,8 @@ class ParamSpec: name: str init_value: float + lo: Optional[float] = None + hi: Optional[float] = None @dataclass @@ -88,7 +91,7 @@ def fit_params( param_specs: List[ParamSpec], process_features: Dict[str, List[str]], num_walkers: int = 32, - num_steps: int = 500, + num_steps: Optional[int] = None, burn_in: int = 200, noise_sigma: float = 0.05, prior_sigma_scale: float = 2.0, @@ -105,7 +108,9 @@ def fit_params( param_specs: Parameter specifications (name, init_value). process_features: {type_name: [feat_names]} to fit. num_walkers: Number of ensemble walkers (>= 2*ndim). - num_steps: Total MCMC steps per walker. + num_steps: Total MCMC steps per walker. If None, defaults to + CFG.code_sim_learning_num_mcmc_steps. If 0, skip training and + use initial parameter values directly. burn_in: Steps to discard as burn-in. noise_sigma: Observation noise std dev for likelihood. prior_sigma_scale: Prior width as multiple of init_value. @@ -113,13 +118,23 @@ def fit_params( Returns: FitResult with posterior samples and log-probabilities. """ - import emcee # type: ignore[import-untyped] # pylint: disable=import-outside-toplevel - names = [s.name for s in param_specs] init_values = np.array([s.init_value for s in param_specs]) + if num_steps is None: + num_steps = CFG.code_sim_learning_num_mcmc_steps + if num_steps < 0: + raise ValueError("code_sim_learning_num_mcmc_steps must be " + "non-negative.") + if num_steps == 0: + logger.info("Skipping emcee; using initial parameter values.") + return FitResult(names, init_values[None, :], np.zeros(1)) + + import emcee # type: ignore[import-untyped] # pylint: disable=import-outside-toplevel + ndim = len(param_specs) num_walkers = max(num_walkers, 2 * ndim + 2) prior_sigma = init_values * prior_sigma_scale + burn_in = min(burn_in, max(num_steps - 1, 0)) def log_posterior(theta: np.ndarray) -> float: # Reject negative values diff --git a/predicators/settings.py b/predicators/settings.py index ef898e028..1a292fb9e 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1022,10 +1022,16 @@ class GlobalSettings: # upstream step multiplies the cost. agent_bilevel_explorer_max_samples_per_step = 50 + # Code sim-learning parameter fitting settings. + # Set to 0 to skip MCMC and use initial parameter values directly. + code_sim_learning_num_mcmc_steps = 500 + # Sim-learning oracle flags (for ablation / debugging). # When True, load GT process rules instead of running agent synthesis. # Parameters init_values are perturbed so MCMC still has work to do. agent_sim_learn_oracle_sim_program = False + # Relative scale for perturbing oracle parameter init_values before MCMC. + agent_sim_learn_oracle_sim_param_noise_scale = 0.2 # When True, use GT parameter values directly, skipping MCMC fitting. agent_sim_learn_oracle_sim_params = False diff --git a/tests/code_sim_learning/test_training.py b/tests/code_sim_learning/test_training.py new file mode 100644 index 000000000..4f294c3a3 --- /dev/null +++ b/tests/code_sim_learning/test_training.py @@ -0,0 +1,23 @@ +"""Tests for code sim-learning training utilities.""" + +import numpy as np + +from predicators import utils +from predicators.code_sim_learning.training import ParamSpec, fit_params + + +def test_fit_params_can_skip_training_with_cfg(): + """Test that CFG can disable parameter fitting.""" + utils.reset_config({"code_sim_learning_num_mcmc_steps": 0}) + param_specs = [ParamSpec("rate", 2.5), ParamSpec("threshold", 0.7)] + + result = fit_params( + simulator_fn=lambda _s, _a, _p: {}, + transitions=[], + param_specs=param_specs, + process_features={}, + ) + + assert result.point_estimate == {"rate": 2.5, "threshold": 0.7} + np.testing.assert_allclose(result.samples, np.array([[2.5, 0.7]])) + np.testing.assert_allclose(result.log_probs, np.array([0.0])) From 9c61f3e2520e27a18e41a1de3d67814fe828aa5f Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Thu, 30 Apr 2026 11:19:44 +0100 Subject: [PATCH 43/70] Build boil param specs dynamically from CFG with lo/hi bounds Replace the module-level BOIL_PARAM_SPECS list with _build_param_specs() so water_fill_speed is derived from CFG.boil_water_fill_speed at call time rather than import time. All specs now carry lo=0.0 to prevent MCMC from sampling physically invalid negative values. get_param_specs() is updated to call _build_param_specs() so per-run CFG values are always reflected. --- .../ground_truth_models/boil/gt_simulator.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index 03daa230b..ac6092de5 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -13,10 +13,12 @@ from predicators.code_sim_learning.training import ParamSpec from predicators.code_sim_learning.utils import ProcessUpdate from predicators.ground_truth_models import GroundTruthSimulatorFactory +from predicators.settings import CFG from predicators.structs import Object, State -# Constants matching pybullet_boil.py exactly. -WATER_FILL_SPEED = 0.02 # 0.002 * water_height_to_level_ratio(10) +# Constants matching pybullet_boil.py exactly. Note: water_fill_speed is +# derived from CFG at spec-build time (env uses +# CFG.boil_water_fill_speed * water_height_to_level_ratio). HEATING_SPEED = 0.03 HAPPINESS_SPEED = 0.05 MAX_JUG_WATER_CAPACITY = 1.3 @@ -25,19 +27,28 @@ FAUCET_ALIGN_THRESHOLD = 0.1 BURNER_ALIGN_THRESHOLD = 0.05 FAUCET_X_LEN = 0.15 +_WATER_HEIGHT_TO_LEVEL_RATIO = 10 -# Parameter specs for fitting. -BOIL_PARAM_SPECS: List[ParamSpec] = [ - ParamSpec("water_fill_speed", WATER_FILL_SPEED), - ParamSpec("heating_speed", HEATING_SPEED), - ParamSpec("happiness_speed", HAPPINESS_SPEED), - ParamSpec("max_jug_water_capacity", MAX_JUG_WATER_CAPACITY), - ParamSpec("water_filled_height", WATER_FILLED_HEIGHT), - ParamSpec("max_water_spill_width", MAX_WATER_SPILL_WIDTH), - ParamSpec("faucet_x_len", FAUCET_X_LEN), - ParamSpec("faucet_align_threshold", FAUCET_ALIGN_THRESHOLD), - ParamSpec("burner_align_threshold", BURNER_ALIGN_THRESHOLD), -] + +def _build_param_specs() -> List[ParamSpec]: + """Build at call time so CFG-driven values match the current run.""" + water_fill_speed = (CFG.boil_water_fill_speed * + _WATER_HEIGHT_TO_LEVEL_RATIO) + return [ + ParamSpec("water_fill_speed", water_fill_speed, lo=0.0), + ParamSpec("heating_speed", HEATING_SPEED, lo=0.0), + ParamSpec("happiness_speed", HAPPINESS_SPEED, lo=0.0), + ParamSpec("max_jug_water_capacity", MAX_JUG_WATER_CAPACITY, lo=0.0), + ParamSpec("water_filled_height", WATER_FILLED_HEIGHT, lo=0.0), + ParamSpec("max_water_spill_width", MAX_WATER_SPILL_WIDTH, lo=0.0), + ParamSpec("faucet_x_len", FAUCET_X_LEN, lo=0.0), + ParamSpec("faucet_align_threshold", FAUCET_ALIGN_THRESHOLD, lo=0.0), + ParamSpec("burner_align_threshold", BURNER_ALIGN_THRESHOLD, lo=0.0), + ] + + +# Static specs for tests / introspection (uses CFG defaults at import time). +BOIL_PARAM_SPECS: List[ParamSpec] = _build_param_specs() Params = Dict[str, float] @@ -169,7 +180,7 @@ def get_rules(cls) -> list: @classmethod def get_param_specs(cls) -> list: - return list(BOIL_PARAM_SPECS) + return _build_param_specs() def get_gt_process_features() -> Dict[str, List[str]]: From e08df545bc668036f5bc9e9d5b3a139f43880f4d Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Thu, 30 Apr 2026 11:19:48 +0100 Subject: [PATCH 44/70] Apply lo/hi clamping and configurable noise scale to oracle perturbation Oracle parameter perturbation now uses the relative scale from CFG.agent_sim_learn_oracle_sim_param_noise_scale (default 0.2) instead of a hard-coded 20 % figure, and clamps perturbed values to the lo/hi bounds declared in each ParamSpec. Also improves the log message when MCMC is skipped (num_mcmc_steps == 0) so it is clear no fitting occurred. --- .../approaches/agent_sim_learning_approach.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index c415cc4b2..d00e1bafa 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -191,11 +191,11 @@ def _synthesize_with_agent( and ``PARAM_SPECS``. Each ``run_python`` call appends code to a saved file; after the session we reload from that file. - Behaviour is modified by two CFG flags: - - ``agent_sim_learn_oracle_sim_program``: skip agent synthesis and load GT rules/specs instead (init_values perturbed so MCMC has non-trivial work). + - ``agent_sim_learn_oracle_sim_param_noise_scale``: adjust the + magnitude of the perturbation applied to oracle init_values. - ``agent_sim_learn_oracle_sim_params``: skip MCMC fitting and use the GT parameter values directly. """ @@ -206,12 +206,22 @@ def _synthesize_with_agent( rules, specs = get_gt_simulator(CFG.env) if not CFG.agent_sim_learn_oracle_sim_params: rng = np.random.default_rng(CFG.seed) - specs = [ - ParamSpec( - s.name, s.init_value + - rng.normal(0, max(abs(s.init_value) * 0.2, 1e-4))) - for s in specs - ] + noise_scale = CFG.agent_sim_learn_oracle_sim_param_noise_scale + if noise_scale < 0.0: + raise ValueError( + "agent_sim_learn_oracle_sim_param_noise_scale must " + "be non-negative.") + perturbed = [] + for s in specs: + val = s.init_value * ( + 1.0 + float(rng.normal(0, noise_scale))) + if s.lo is not None: + val = max(s.lo, val) + if s.hi is not None: + val = min(s.hi, val) + perturbed.append( + ParamSpec(s.name, val, lo=s.lo, hi=s.hi)) + specs = perturbed logger.info("Loaded oracle sim program (%d rules, %d params).", len(rules), len(specs)) else: @@ -290,8 +300,12 @@ def _synthesize_with_agent( else: self._fitted_params, self._fit_mse = self._fit_parameters( rules, specs, step_transitions, process_features, fit_env) - logger.info("Fitted %d params (MSE: %.6f).", len(specs), - self._fit_mse) + if CFG.code_sim_learning_num_mcmc_steps == 0: + logger.info("Skipped fitting; using %d initial params " + "(MSE: %.6f).", len(specs), self._fit_mse) + else: + logger.info("Fitted %d params (MSE: %.6f).", len(specs), + self._fit_mse) # ── Parameter fitting ──────────────────────────────────────── From e44a850d4a8e9e11fe3eaf0ba14929c2e46eb887 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Thu, 30 Apr 2026 18:36:04 +0100 Subject: [PATCH 45/70] Update installation instructions and add macOS setup script for PyBullet --- README.md | 3 ++- setup.py | 2 +- setup.sh | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) create mode 100755 setup.sh diff --git a/README.md b/README.md index 3819738dd..4d51fad4b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,8 @@ A simple implementation of search-then-sample bilevel planning is provided in `p ## Installation * This repository uses Python versions 3.10-3.11. We recommend 3.10.14. -* Run `pip install -e .` to install dependencies. +* Run `./setup.sh` to install dependencies (handles macOS PyBullet source build automatically). +* Alternatively, run `pip install -e .` directly if not on macOS. ## Instructions For Running Code diff --git a/setup.py b/setup.py index 502446850..728343803 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "pillow==10.3.0", "requests", "slack_bolt", - "pybullet>=3.2.0", + "pybullet==3.2.5", "scikit-learn>=1.1.3", "graphlib-backport", "openai==1.19.0", diff --git a/setup.sh b/setup.sh new file mode 100755 index 000000000..50d172160 --- /dev/null +++ b/setup.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e +git submodule update --init --recursive + +if [[ "$OSTYPE" == "darwin"* ]]; then + echo "macOS detected: building PyBullet from source (workaround for macOS compatibility)..." + + # Initialize the virtual environment first so we can use its Python + uv venv + + VENV_PYTHON="$(pwd)/.venv/bin/python" + BULLET_TMP=$(mktemp -d) + trap 'rm -rf "$BULLET_TMP"' EXIT + + git clone https://github.com/bulletphysics/bullet3 "$BULLET_TMP/bullet3" + git -C "$BULLET_TMP/bullet3" checkout 3.25 + + # Comment out the line that causes build failure on recent macOS + sed -i '' \ + 's|^#define fdopen(fd, mode) NULL|// #define fdopen(fd, mode) NULL|' \ + "$BULLET_TMP/bullet3/examples/ThirdPartyLibs/zlib/zutil.h" + + uv pip install setuptools + pushd "$BULLET_TMP/bullet3" + "$VENV_PYTHON" setup.py build + "$VENV_PYTHON" setup.py install + popd + + # Install everything else; pybullet 3.2.5 is already installed from source + # above so pip will skip it + uv pip install -e . +else + uv pip install -e . +fi From b8df145659a580f6a884e46d20492d0e1999f54c Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 11:51:39 +0100 Subject: [PATCH 46/70] Update PyBullet version to 3.2.7 and simplify macOS setup script --- setup.py | 2 +- setup.sh | 17 ++++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 728343803..8cfdfa2e0 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "pillow==10.3.0", "requests", "slack_bolt", - "pybullet==3.2.5", + "pybullet==3.2.7", "scikit-learn>=1.1.3", "graphlib-backport", "openai==1.19.0", diff --git a/setup.sh b/setup.sh index 50d172160..d1cb6b8c8 100755 --- a/setup.sh +++ b/setup.sh @@ -5,30 +5,25 @@ git submodule update --init --recursive if [[ "$OSTYPE" == "darwin"* ]]; then echo "macOS detected: building PyBullet from source (workaround for macOS compatibility)..." - # Initialize the virtual environment first so we can use its Python - uv venv - - VENV_PYTHON="$(pwd)/.venv/bin/python" BULLET_TMP=$(mktemp -d) trap 'rm -rf "$BULLET_TMP"' EXIT git clone https://github.com/bulletphysics/bullet3 "$BULLET_TMP/bullet3" - git -C "$BULLET_TMP/bullet3" checkout 3.25 # Comment out the line that causes build failure on recent macOS sed -i '' \ 's|^#define fdopen(fd, mode) NULL|// #define fdopen(fd, mode) NULL|' \ "$BULLET_TMP/bullet3/examples/ThirdPartyLibs/zlib/zutil.h" - uv pip install setuptools + pip install setuptools pushd "$BULLET_TMP/bullet3" - "$VENV_PYTHON" setup.py build - "$VENV_PYTHON" setup.py install + python setup.py build + python setup.py install popd - # Install everything else; pybullet 3.2.5 is already installed from source + # Install everything else; pybullet 3.2.7 is already installed from source # above so pip will skip it - uv pip install -e . + pip install -e . else - uv pip install -e . + pip install -e . fi From c033f9c46ae3e2bc88710bfecb0bb40ebe66ab27 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 11:51:49 +0100 Subject: [PATCH 47/70] Refactor liquid color update logic and rename related methods for clarity --- predicators/envs/pybullet_boil.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 77cb1f805..7fa429b50 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -584,6 +584,8 @@ def _set_domain_specific_state(self, state: State) -> None: jug.heat_level = state.get(jug, "heat_level") liquid_id = self._create_liquid_for_jug(jug, state) self._jug_to_liquid_id[jug] = liquid_id + + self._update_liquid_colors(state) # Update jug body colors from state for jug in jugs: @@ -646,7 +648,7 @@ def _domain_specific_step(self) -> None: state = self._get_state() self._handle_faucet_logic(state) self._handle_heating_logic(state) - self._update_jug_colors(state) + self._update_liquid_colors(state) self._update_burner_colors(state) self._update_human_happiness(state) self._update_prev_on_states(state) @@ -764,7 +766,7 @@ def _handle_heating_logic(self, state: State) -> None: new_heat = min(1.0, old_heat + self.heating_speed) jug_obj.heat_level = new_heat - def _update_jug_colors(self, state: State) -> None: + def _update_liquid_colors(self, state: State) -> None: """Simple linear interpolation from blue (0.0) to red (1.0) based on jug.heat.""" jugs = state.get_objects(self._jug_type) @@ -1362,8 +1364,8 @@ def _create_liquid_for_jug( cx = state.get(jug, "x") cy = state.get(jug, "y") cz = self.z_lb + liquid_height / 2 + 0.02 # sits on table - jug_rot = state.get(jug, "rot") - orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) + # jug_rot = state.get(jug, "rot") + # orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) color = self.water_color return create_pybullet_block(color=color, @@ -1371,7 +1373,7 @@ def _create_liquid_for_jug( mass=0.01, friction=0.5, position=(cx, cy, cz), - orientation=orientation, + # orientation=orientation, physics_client_id=self._physics_client_id) From 20a310edca8d7890a6b6fa84b183ef7d40519d31 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 12:31:11 +0100 Subject: [PATCH 48/70] Add more debug logging for CogMan and option execution flow --- predicators/cogman.py | 20 +++++++++++++- predicators/utils.py | 62 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/predicators/cogman.py b/predicators/cogman.py index e35e27eb6..ebb8f8119 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -78,6 +78,7 @@ def step(self, observation: Observation) -> Optional[Action]: self._episode_state_history.append(state) if self._termination_fn is not None and self._termination_fn(state): logging.info("[CogMan] Termination triggered.") + logging.debug("[CogMan] step returning None: termination_fn fired") return None # Check if we should replan. if self._exec_monitor.step(state): @@ -227,8 +228,9 @@ def run_episode_and_get_observations( metrics["policy_call_time"] = 0.0 metrics["num_options_executed"] = 0.0 exception_raised_in_step = False + step_num = -1 if not (terminate_on_goal_reached and env.goal_reached()): - for _ in range(max_num_steps): + for step_num in range(max_num_steps): monitor_observed = False exception_raised_in_step = False try: @@ -236,6 +238,7 @@ def run_episode_and_get_observations( act = cogman.step(obs) metrics["policy_call_time"] += time.perf_counter() - start_time if act is None: + logging.debug("[CogMan] loop break: act is None") break if act.has_option() and act.get_option() != curr_option: curr_option = act.get_option() @@ -264,9 +267,14 @@ def run_episode_and_get_observations( any(issubclass(type(e), c) for c in exceptions_to_break_on): if monitor_observed: exception_raised_in_step = True + logging.debug( + f"[CogMan] loop break: exception in break_on set: {e}") break if CFG.terminate_on_goal_reached_and_option_terminated and \ env.goal_reached(): + logging.debug( + f"[CogMan] loop break: goal_reached+option_terminated " + f"(exception: {e})") break if monitor is not None and not monitor_observed: monitor.observe(obs, None) @@ -277,7 +285,17 @@ def run_episode_and_get_observations( return traj, solved, metrics raise e if terminate_on_goal_reached and env.goal_reached(): + logging.debug("[CogMan] loop break: terminate_on_goal_reached") break + else: + option_str = (None if curr_option is None else + curr_option.simple_str()) + logging.info("[CogMan] Reached max_num_steps=%d while executing " + "option %s.", max_num_steps, option_str) + logging.debug("[CogMan] Final loop step index before horizon: %d", + step_num) + logging.debug("[CogMan] Atoms at horizon: %s", + sorted(utils.abstract(obs, env.predicates))) if monitor is not None and not exception_raised_in_step: monitor.observe(obs, None) cogman.finish_episode(obs) diff --git a/predicators/utils.py b/predicators/utils.py index 7181522b0..cbe628f34 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1684,6 +1684,39 @@ def strip_wait_annotations(text: str) -> str: return re.sub(r'\s*->\s*\{[^}]*\}', '', text) +def _format_wait_target_debug( + state: State, target_atoms: Set[GroundAtom], + abstract_function: Callable[[State], Set[GroundAtom]]) -> str: + """Format state details for debugging why Wait has not terminated.""" + cur_atoms = abstract_function(state) + missing_targets = target_atoms - cur_atoms + target_objects = sorted({ + ent + for atom in target_atoms for ent in atom.entities + if isinstance(ent, Object) + }, + key=lambda o: o.name) + object_details = [] + for obj in target_objects: + feature_values = [] + for feature_name in obj.type.feature_names: + value = state.get(obj, feature_name) + if isinstance(value, float): + value_str = f"{value:.4f}" + else: + value_str = str(value) + feature_values.append(f"{feature_name}={value_str}") + object_details.append(f"{obj}: " + ", ".join(feature_values)) + details = [ + f"Targets: {sorted(target_atoms)}", + f"Missing: {sorted(missing_targets)}", + f"cur_atoms: {sorted(cur_atoms)}", + ] + if object_details: + details.append(f"target_objects: {'; '.join(object_details)}") + return "; ".join(details) + + def option_policy_to_policy( option_policy: Callable[[State], _Option], max_option_steps: Optional[int] = None, @@ -1728,11 +1761,26 @@ def _policy(state: State) -> Action: and cur_option.name == "Wait": assert abstract_function is not None assert last_state is not None + target_atoms = cur_option.memory.get("wait_target_atoms") result = check_wait_target_atoms(cur_option, state, abstract_function) if result is True: - logging.debug("Wait terminating: target atoms satisfied") + cur_atoms = abstract_function(state) + logging.debug( + "Wait terminating: target atoms satisfied. " + f"Targets: {target_atoms}, " + f"cur_atoms: {sorted(cur_atoms)}, " + f"num_option_steps={num_cur_option_steps}") wait_terminate = True + elif result is False: + assert target_atoms is not None + if num_cur_option_steps <= 1 or num_cur_option_steps % 25 == 0: + wait_debug = _format_wait_target_debug( + state, target_atoms, abstract_function) + logging.debug( + "Wait continuing: target atoms not yet satisfied. " + "%s, num_option_steps=%d", wait_debug, + num_cur_option_steps) elif result is None: # No targets specified: fall back to any-atom-change cur_atoms = abstract_function(state) @@ -1766,6 +1814,9 @@ def _policy(state: State) -> Action: raise OptionExecutionFailure( "Unsound option policy.", info={"last_failed_option": last_option}) + logging.debug( + f"[option_policy] Started option {cur_option.name}, " + f"initiable=True") num_cur_option_steps = 0 num_cur_option_steps += 1 @@ -1783,13 +1834,20 @@ def option_plan_to_policy( ) -> Callable[[State], Action]: """Create a policy that executes a sequence of options in order.""" queue = list(plan) # don't modify plan, just in case + total_options = len(queue) def _option_policy(state: State) -> _Option: del state # not used if not queue: + logging.info("Option plan exhausted after %d options.", + total_options) raise OptionExecutionFailure("Option plan exhausted!") option = queue.pop(0) - logging.info(f"Executing option {option.simple_str()}") + option_num = total_options - len(queue) + next_option = None if not queue else queue[0].simple_str() + logging.info("Executing option %d/%d: %s (remaining=%d, next=%s)", + option_num, total_options, option.simple_str(), + len(queue), next_option) return option return option_policy_to_policy( From 9d6b9e37860c2bf851339e037c55c16a6e1dc23d Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 13:20:16 +0100 Subject: [PATCH 49/70] Handle PyBullet physics server crashes with env recreation and retry Converts _build_combined_simulator to an instance method so it can capture self, recreate the base env on pybullet.error, and retry once. Also catches pybullet.error in the oracle option model alongside OptionExecutionFailure. Updates agents.yaml config for testing. --- .../approaches/agent_sim_learning_approach.py | 41 ++++++++++++++++--- predicators/option_model.py | 7 ++-- scripts/configs/predicatorv3/agents.yaml | 3 +- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index d00e1bafa..682478509 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -23,6 +23,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple import numpy as np +import pybullet from gym.spaces import Box from predicators import utils @@ -146,8 +147,7 @@ def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: return # Build combined simulator. - combined_sim = self._build_combined_simulator(self._base_env, - self._simulator, + combined_sim = self._build_combined_simulator(self._simulator, self._process_features) # Build learned option model @@ -439,16 +439,45 @@ def _extract_step_transitions( (traj.states[i], traj.actions[i], traj.states[i + 1])) return triples - @staticmethod + def _recreate_base_env(self) -> None: + """Reconnect after a PyBullet physics-server crash. + + Disconnects the dead client (best-effort), then spins up a fresh + env with the same settings so subsequent simulate() calls work. + """ + try: + pybullet.disconnect(self._base_env._physics_client_id) + except Exception: # client may already be dead + pass + logging.warning( + "PyBullet physics client crashed; recreating base env " + "(use_gui=%s).", CFG.option_model_use_gui) + self._base_env = create_new_env(CFG.env, + do_cache=False, + use_gui=CFG.option_model_use_gui, + skip_process_dynamics=True) + def _build_combined_simulator( - base_env: Any, + self, simulator: LearnedSimulator, process_features: Dict[str, List[str]], ) -> Callable[[State, Action], State]: - """Compose kinematics-only env with learned step-level dynamics.""" + """Compose kinematics-only env with learned step-level dynamics. + + Captures ``self`` so that if the PyBullet physics server crashes + (common on macOS Metal with GUI mode after many simulation steps), + the closure can recreate ``self._base_env`` and retry once. + """ def combined_simulate(state: State, action: Action) -> State: - kin_state = base_env.simulate(state, action) + try: + kin_state = self._base_env.simulate(state, action) + except pybullet.error as e: + logging.warning( + "PyBullet error in combined_simulate (%s); " + "recreating base env and retrying.", e) + self._recreate_base_env() + kin_state = self._base_env.simulate(state, action) updates = simulator.predict_step(kin_state) if not updates: return kin_state diff --git a/predicators/option_model.py b/predicators/option_model.py index 1ca608393..788f85b4e 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, Set, Tuple import numpy as np +import pybullet from predicators import utils from predicators.envs import create_new_env @@ -173,9 +174,9 @@ def _terminal(s: State) -> bool: state, _terminal, max_num_steps=CFG.max_num_steps_option_rollout) - except utils.OptionExecutionFailure as e: - # If there is a failure during the execution of the option, treat - # this as a noop. + except (utils.OptionExecutionFailure, pybullet.error) as e: + # Treat PyBullet physics engine errors the same as planned + # execution failures (e.g. GUI/Metal crash on macOS). self.last_execution_failure = str(e) return state, 0 # Note that in the case of using a PyBullet environment, the diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index a55df02c0..952045126 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -48,9 +48,10 @@ APPROACHES: option_model_use_gui: True agent_bilevel_log_state: False agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" - skip_test_until_last_ite_or_early_stopping: True + skip_test_until_last_ite_or_early_stopping: False agent_sim_learn_oracle_sim_program: True agent_sim_learn_oracle_sim_params: False + agent_sim_learn_oracle_sim_param_noise_scale: 1.0 # 1.0 allows successful planning but insatisficing plan; 0.8 produces satisficing plan # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: From 99b38b12b626d8ce881c11c3e6a5afd834914830 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 18:21:39 +0100 Subject: [PATCH 50/70] Fix jug orientation handling in PyBulletBoilEnv by restoring rotation logic --- predicators/envs/pybullet_boil.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 7fa429b50..f1ebb9164 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -1364,8 +1364,8 @@ def _create_liquid_for_jug( cx = state.get(jug, "x") cy = state.get(jug, "y") cz = self.z_lb + liquid_height / 2 + 0.02 # sits on table - # jug_rot = state.get(jug, "rot") - # orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) + jug_rot = state.get(jug, "rot") + orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) color = self.water_color return create_pybullet_block(color=color, @@ -1373,7 +1373,7 @@ def _create_liquid_for_jug( mass=0.01, friction=0.5, position=(cx, cy, cz), - # orientation=orientation, + orientation=orientation, physics_client_id=self._physics_client_id) From 8521882b55d2c75b7de8b0f15e85dbc565b56915 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 18:26:52 +0100 Subject: [PATCH 51/70] Update installation instructions and dependencies; remove macOS setup script --- README.md | 3 +-- setup.py | 2 +- setup.sh | 29 ----------------------------- 3 files changed, 2 insertions(+), 32 deletions(-) delete mode 100755 setup.sh diff --git a/README.md b/README.md index 4d51fad4b..3819738dd 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,7 @@ A simple implementation of search-then-sample bilevel planning is provided in `p ## Installation * This repository uses Python versions 3.10-3.11. We recommend 3.10.14. -* Run `./setup.sh` to install dependencies (handles macOS PyBullet source build automatically). -* Alternatively, run `pip install -e .` directly if not on macOS. +* Run `pip install -e .` to install dependencies. ## Instructions For Running Code diff --git a/setup.py b/setup.py index 8cfdfa2e0..c60c43852 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "pillow==10.3.0", "requests", "slack_bolt", - "pybullet==3.2.7", + "pybullet-arm64>=3.2.8", "scikit-learn>=1.1.3", "graphlib-backport", "openai==1.19.0", diff --git a/setup.sh b/setup.sh deleted file mode 100755 index d1cb6b8c8..000000000 --- a/setup.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -e -git submodule update --init --recursive - -if [[ "$OSTYPE" == "darwin"* ]]; then - echo "macOS detected: building PyBullet from source (workaround for macOS compatibility)..." - - BULLET_TMP=$(mktemp -d) - trap 'rm -rf "$BULLET_TMP"' EXIT - - git clone https://github.com/bulletphysics/bullet3 "$BULLET_TMP/bullet3" - - # Comment out the line that causes build failure on recent macOS - sed -i '' \ - 's|^#define fdopen(fd, mode) NULL|// #define fdopen(fd, mode) NULL|' \ - "$BULLET_TMP/bullet3/examples/ThirdPartyLibs/zlib/zutil.h" - - pip install setuptools - pushd "$BULLET_TMP/bullet3" - python setup.py build - python setup.py install - popd - - # Install everything else; pybullet 3.2.7 is already installed from source - # above so pip will skip it - pip install -e . -else - pip install -e . -fi From f998254d952b8bccdb49962b0b18af218f104816 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Fri, 1 May 2026 18:40:20 +0100 Subject: [PATCH 52/70] Remove mara_robosim dependency from setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index c60c43852..5ce859b2a 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ "psutil", "claude-agent-sdk", "nest_asyncio", - "mara_robosim@git+https://github.com/yichao-liang/mara-robosim.git", "emcee", ], include_package_data=True, From 4e12a17ccd9d6f77a77c7465fe39285ec713c56b Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Sat, 2 May 2026 15:20:59 +0100 Subject: [PATCH 53/70] Fix get_gt_simulator to use env_name instead of normalized name --- predicators/ground_truth_models/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index e5fde702d..8359d1d18 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -270,9 +270,8 @@ def get_gt_simulator(env_name: str) -> tuple: rule functions and *param_specs* is a list of ``ParamSpec`` objects whose ``init_value`` is the GT value. """ - gt_name = _normalize_env_name_for_gt(env_name) for cls in utils.get_all_subclasses(GroundTruthSimulatorFactory): - if not cls.__abstractmethods__ and gt_name in cls.get_env_names(): + if not cls.__abstractmethods__ and env_name in cls.get_env_names(): return cls.get_rules(), cls.get_param_specs() raise NotImplementedError("Ground-truth simulator not implemented for " f"env: {env_name}") From a8105cf6b1b3f559cac65755c39c5a8359b101da Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Sat, 2 May 2026 16:52:16 +0100 Subject: [PATCH 54/70] Add before/after MSE, likelihood, and param-delta logging for parameter fitting --- .../approaches/agent_sim_learning_approach.py | 83 +++++++++++++------ 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index 682478509..ea7686c47 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -4,7 +4,7 @@ agent-synthesized step-level simulator with parameterized process rules. Parameters are fitted via emcee ensemble MCMC (training.py). -The approach creates a kinematics-only oracle (PyBullet with process +The approach creates a base oracle (PyBullet with process dynamics disabled) and composes it with the learned step-level dynamics into a single simulator function, plugged into a standard _OracleOptionModel for true per-step interleaving. @@ -53,7 +53,7 @@ class AgentSimLearningApproach(AgentBilevelApproach): 2. Segment into option-level transitions 3. Synthesize parameterized process rules via Claude agent 4. Fit rule parameters via emcee ensemble MCMC - 5. Compose with kinematics-only oracle into a combined simulator + 5. Compose with base oracle into a combined simulator 6. Build _OracleOptionModel with the combined simulator During solving: @@ -70,7 +70,7 @@ def __init__(self, *args: Any, option_model: Optional[_OptionModelBase] = None, **kwargs: Any) -> None: - # Build the kinematics-only env BEFORE super().__init__ and pass + # Build the base env BEFORE super().__init__ and pass # the resulting option model in via option_model=. This stops # AgentPlannerApproach.__init__ from spinning up its own full- # process env (which would conflict with this one over PyBullet @@ -162,7 +162,7 @@ def _build_option_model( Plumbs ``_abstract_function`` for Wait-target atom-change termination so the model behaves identically whether it's - wrapping the bare kin-only simulator (init) or the learned + wrapping the bare base simulator (init) or the learned kin+process combined simulator (post learn_from_interaction). Uses ``self._get_all_options()`` rather than ``get_gt_options(CFG.env)`` to avoid spawning a second cached @@ -282,12 +282,12 @@ def _synthesize_with_agent( self._process_rules = rules # ── Obtain fitted parameters ──────────────────────────── - # Use a headless env for fitting so the GUI env isn't - # thrashed by thousands of _set_state calls during MCMC. + # Use a headless env for fitting. fit_env = create_new_env(CFG.env, do_cache=False, use_gui=False, skip_process_dynamics=True) + _noise_sigma = 0.05 # matches fit_params default if CFG.agent_sim_learn_oracle_sim_params: self._fitted_params = {s.name: s.init_value for s in specs} self._fit_mse = compute_mse( @@ -296,16 +296,19 @@ def _synthesize_with_agent( step_transitions, self._fitted_params, process_features) - logger.info("Using oracle params (MSE: %.6f).", self._fit_mse) + fit_ll = -0.5 * self._fit_mse / (_noise_sigma**2) + logger.info("Oracle params — MSE: %.6f log-likelihood: %.2f", + self._fit_mse, fit_ll) + for name, val in sorted(self._fitted_params.items()): + logger.info(" %-30s %.4f", name, val) else: self._fitted_params, self._fit_mse = self._fit_parameters( rules, specs, step_transitions, process_features, fit_env) if CFG.code_sim_learning_num_mcmc_steps == 0: - logger.info("Skipped fitting; using %d initial params " - "(MSE: %.6f).", len(specs), self._fit_mse) + logger.info("Skipped MCMC; using %d initial params.", + len(specs)) else: - logger.info("Fitted %d params (MSE: %.6f).", len(specs), - self._fit_mse) + logger.info("Fitted %d params.", len(specs)) # ── Parameter fitting ──────────────────────────────────────── @@ -315,35 +318,65 @@ def _fit_parameters( specs: List[ParamSpec], step_transitions: List[Tuple[State, Action, State]], process_features: Dict[str, List[str]], - base_env: Any = None, + base_env: Any, ) -> Tuple[Dict[str, float], float]: """Fit parameters for the synthesized rules via MCMC. Args: - base_env: Kinematics-only environment. When provided the - simulator runs kinematics first so learned rules see - the post-kinematics state (consistent with inference). + base_env: Base environment. base_env.simulate(s, a) handles the + first half of each transition, leaving only the learned + process-rule updates for the MCMC loop to evaluate. Returns: (fitted_params, mse) tuple. """ - - def sim_fn(state: State, action: Action, params: Dict[str, - float]) -> Dict: - if base_env is not None: - state = base_env.simulate(state, action) + assert base_env is not None, "base_env required" + # base_env.simulate(s, a) is param-independent, so pre-compute it + # once here rather than inside every MCMC log-posterior call + # (num_walkers × num_steps × len(transitions) invocations). + # The MCMC loop then only evaluates the cheap apply_rules step. + logger.info("Pre-computing base states for %d transitions.", + len(step_transitions)) + base_transitions: List[Tuple[State, Action, State]] = [ + (base_env.simulate(s, a), a, s_next) + for s, a, s_next in step_transitions + ] + + def sim_fn(state: State, action: Action, + params: Dict[str, float]) -> Dict: return apply_rules(state, rules, params) + noise_sigma = 0.05 # matches fit_params default + init_params = {s.name: s.init_value for s in specs} + pre_mse = compute_mse(sim_fn, base_transitions, init_params, + process_features) + pre_ll = -0.5 * pre_mse / (noise_sigma**2) + logger.info("Before fitting — MSE: %.6f log-likelihood: %.2f", + pre_mse, pre_ll) + result = fit_params( simulator_fn=sim_fn, - transitions=step_transitions, + transitions=base_transitions, param_specs=specs, process_features=process_features, ) - mse = compute_mse(sim_fn, step_transitions, result.point_estimate, - process_features) - return result.point_estimate, mse + fitted_params = result.point_estimate + post_mse = compute_mse(sim_fn, base_transitions, fitted_params, + process_features) + post_ll = -0.5 * post_mse / (noise_sigma**2) + logger.info("After fitting — MSE: %.6f log-likelihood: %.2f", + post_mse, post_ll) + + for name in sorted(fitted_params): + init_val = init_params[name] + fit_val = fitted_params[name] + delta = fit_val - init_val + pct = (delta / init_val * 100) if init_val != 0 else float("nan") + logger.info(" %-30s %.4f -> %.4f (Δ=%.4f, %+.1f%%)", name, + init_val, fit_val, delta, pct) + + return fitted_params, post_mse @staticmethod def _load_simulator_from_file( @@ -462,7 +495,7 @@ def _build_combined_simulator( simulator: LearnedSimulator, process_features: Dict[str, List[str]], ) -> Callable[[State, Action], State]: - """Compose kinematics-only env with learned step-level dynamics. + """Compose base env with learned step-level dynamics. Captures ``self`` so that if the PyBullet physics server crashes (common on macOS Metal with GUI mode after many simulation steps), From 2f97798a0e6b863589f75f7146124dd62fddb26e Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Sun, 3 May 2026 15:45:49 +0100 Subject: [PATCH 55/70] Use SSE loss and wider walker init so MCMC parameter fitting actually moves Switch the fitting loss from per-feature MSE to total SSE (drop the /count in compute_sse) so the Gaussian log-likelihood -0.5*SSE/sigma^2 is in its correct iid form. The previous MSE form silently rescaled per-observation noise by sqrt(count), making walker proposals indistinguishable from each other. Pair this with a wider walker initialization (0.5 * prior_sigma instead of 1% of init_value) so the swarm covers the prior support and emcee stretch moves can actually explore. --- predicators/agent_sdk/tools.py | 6 +-- .../approaches/agent_sim_learning_approach.py | 40 +++++++++--------- predicators/code_sim_learning/training.py | 41 +++++++++++-------- scripts/configs/predicatorv3/agents.yaml | 1 + 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index aeb15edff..08418c5ab 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -2058,7 +2058,7 @@ async def run_python(args: Dict[str, Any]) -> Dict[str, Any]: @tool( "evaluate_simulator", "Fit parameters using PROCESS_RULES and PARAM_SPECS " - "from the run_python namespace. Reports MSE and fitted " + "from the run_python namespace. Reports SSE and fitted " "parameter values.", { "type": "object", @@ -2076,7 +2076,7 @@ async def evaluate_simulator(_args: Dict[str, Any]) -> Dict[str, Any]: "run_python to define it first.") try: - fitted_params, mse = ( + fitted_params, sse = ( AgentSimLearningApproach._fit_parameters( # pylint: disable=protected-access rules, specs, step_transitions, process_features, base_env)) @@ -2084,7 +2084,7 @@ async def evaluate_simulator(_args: Dict[str, Any]) -> Dict[str, Any]: return _text(f"Error: fit_params failed:\n{e}") lines = [ - f"MSE: {mse:.6f} on " + f"SSE: {sse:.6f} on " f"{len(step_transitions)} step transitions.", "", "Fitted parameters:", diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index ea7686c47..74874ce22 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -29,7 +29,7 @@ from predicators import utils from predicators.agent_sdk.tools import create_synthesis_tools from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach -from predicators.code_sim_learning.training import ParamSpec, compute_mse, \ +from predicators.code_sim_learning.training import ParamSpec, compute_sse, \ fit_params from predicators.code_sim_learning.utils import LearnedSimulator, \ apply_rules, merge_updates @@ -100,7 +100,7 @@ def __init__(self, # Persistent state across learning cycles. self._process_rules: Optional[List] = None self._fitted_params: Optional[Dict[str, float]] = None - self._fit_mse: float = float("inf") + self._fit_sse: float = float("inf") # True during simulator synthesis (learning); False during # plan generation (decision-making). self._learning_mode: bool = False @@ -152,7 +152,7 @@ def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: # Build learned option model self._option_model = self._build_option_model(combined_sim) - logger.info("Built learned option model (MSE: %.6f).", self._fit_mse) + logger.info("Built learned option model (SSE: %.6f).", self._fit_sse) def _build_option_model( self, @@ -290,19 +290,19 @@ def _synthesize_with_agent( _noise_sigma = 0.05 # matches fit_params default if CFG.agent_sim_learn_oracle_sim_params: self._fitted_params = {s.name: s.init_value for s in specs} - self._fit_mse = compute_mse( + self._fit_sse = compute_sse( lambda s, a, p: apply_rules( # type: ignore[misc] fit_env.simulate(s, a), rules, p), step_transitions, self._fitted_params, process_features) - fit_ll = -0.5 * self._fit_mse / (_noise_sigma**2) - logger.info("Oracle params — MSE: %.6f log-likelihood: %.2f", - self._fit_mse, fit_ll) + fit_ll = -0.5 * self._fit_sse / (_noise_sigma**2) + logger.info("Oracle params — SSE: %.6f log-likelihood: %.2f", + self._fit_sse, fit_ll) for name, val in sorted(self._fitted_params.items()): logger.info(" %-30s %.4f", name, val) else: - self._fitted_params, self._fit_mse = self._fit_parameters( + self._fitted_params, self._fit_sse = self._fit_parameters( rules, specs, step_transitions, process_features, fit_env) if CFG.code_sim_learning_num_mcmc_steps == 0: logger.info("Skipped MCMC; using %d initial params.", @@ -328,7 +328,7 @@ def _fit_parameters( process-rule updates for the MCMC loop to evaluate. Returns: - (fitted_params, mse) tuple. + (fitted_params, sse) tuple. """ assert base_env is not None, "base_env required" # base_env.simulate(s, a) is param-independent, so pre-compute it @@ -348,11 +348,11 @@ def sim_fn(state: State, action: Action, noise_sigma = 0.05 # matches fit_params default init_params = {s.name: s.init_value for s in specs} - pre_mse = compute_mse(sim_fn, base_transitions, init_params, + pre_sse = compute_sse(sim_fn, base_transitions, init_params, process_features) - pre_ll = -0.5 * pre_mse / (noise_sigma**2) - logger.info("Before fitting — MSE: %.6f log-likelihood: %.2f", - pre_mse, pre_ll) + pre_ll = -0.5 * pre_sse / (noise_sigma**2) + logger.info("Before fitting — SSE: %.6f log-likelihood: %.2f", + pre_sse, pre_ll) result = fit_params( simulator_fn=sim_fn, @@ -362,11 +362,11 @@ def sim_fn(state: State, action: Action, ) fitted_params = result.point_estimate - post_mse = compute_mse(sim_fn, base_transitions, fitted_params, + post_sse = compute_sse(sim_fn, base_transitions, fitted_params, process_features) - post_ll = -0.5 * post_mse / (noise_sigma**2) - logger.info("After fitting — MSE: %.6f log-likelihood: %.2f", - post_mse, post_ll) + post_ll = -0.5 * post_sse / (noise_sigma**2) + logger.info("After fitting — SSE: %.6f log-likelihood: %.2f", + post_sse, post_ll) for name in sorted(fitted_params): init_val = init_params[name] @@ -376,7 +376,7 @@ def sim_fn(state: State, action: Action, logger.info(" %-30s %.4f -> %.4f (Δ=%.4f, %+.1f%%)", name, init_val, fit_val, delta, pct) - return fitted_params, post_mse + return fitted_params, post_sse @staticmethod def _load_simulator_from_file( @@ -534,7 +534,7 @@ def _build_synthesis_system_prompt() -> str: - `run_python(code)` — execute Python in a persistent namespace. `print()` \ output is returned. The namespace persists across calls. - `evaluate_simulator` — fit parameters using PROCESS_RULES and PARAM_SPECS \ -from the namespace. Reports MSE. +from the namespace. Reports SSE. - `test_simulator` — test predictions vs observations on step transitions. \ Shows mismatches. @@ -586,7 +586,7 @@ def rule(state, updates, params): state changes over time 2. Identify which features change due to process dynamics (not kinematics) 3. Define `PROCESS_RULES` and `PARAM_SPECS` in the namespace via `run_python` -4. Call `evaluate_simulator` to fit parameters and check MSE +4. Call `evaluate_simulator` to fit parameters and check SSE 5. Call `test_simulator` to see prediction mismatches 6. Iterate if needed diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index 8ff469890..532fb0bbd 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -40,20 +40,27 @@ class FitResult: @property def point_estimate(self) -> Dict[str, float]: - """Posterior mean.""" - mean = self.samples.mean(axis=0) - return {n: float(mean[i]) for i, n in enumerate(self.names)} + """MAP (sample with highest log-probability).""" + best_idx = int(np.argmax(self.log_probs)) + return {n: float(self.samples[best_idx, i]) + for i, n in enumerate(self.names)} -def compute_mse( +def compute_sse( simulator_fn: StepSimulatorFn, transitions: List[Tuple[State, Action, State]], params: Dict[str, float], process_features: Dict[str, List[str]], ) -> float: - """Compute MSE between predicted and observed process features.""" + """Sum of squared errors between predicted and observed process features. + + Returns the total (un-normalized) SSE so that the Gaussian + log-likelihood ``-0.5 * SSE / noise_sigma**2`` is the correct + iid-observation form. Dividing by count would silently rescale the + per-observation noise by sqrt(count), making the chain insensitive + to parameter changes. + """ total_se = 0.0 - count = 0 for s_t, action, s_next_obs in transitions: updates = simulator_fn(s_t, action, params) @@ -67,7 +74,6 @@ def compute_mse( v = pred_val.item() if hasattr(pred_val, 'item') else pred_val obs_val = float(s_next_obs.get(obj, feat_name)) total_se += (v - obs_val)**2 - count += 1 # Penalize unpredicted features (model predicts no change). for obj in s_t: @@ -78,11 +84,8 @@ def compute_mse( pred_val = float(s_t.get(obj, feat_name)) obs_val = float(s_next_obs.get(obj, feat_name)) total_se += (pred_val - obs_val)**2 - count += 1 - if count == 0: - return 0.0 - return total_se / count + return total_se def fit_params( @@ -94,7 +97,7 @@ def fit_params( num_steps: Optional[int] = None, burn_in: int = 200, noise_sigma: float = 0.05, - prior_sigma_scale: float = 2.0, + prior_sigma_scale: float = 1.0, ) -> FitResult: """Fit simulator parameters via emcee ensemble MCMC. @@ -144,11 +147,15 @@ def log_posterior(theta: np.ndarray) -> float: # Broad Gaussian prior centered on init values log_prior = -0.5 * np.sum(((theta - init_values) / prior_sigma)**2) # Likelihood - mse = compute_mse(simulator_fn, transitions, params, process_features) - return log_prior + (-0.5 * mse / (noise_sigma**2)) - - # Initialize walkers in a small ball around init values. - p0 = init_values * (1.0 + 0.01 * np.random.randn(num_walkers, ndim)) + sse = compute_sse(simulator_fn, transitions, params, process_features) + return log_prior + (-0.5 * sse / (noise_sigma**2)) + + # Initialize walkers across the prior support (sigma = half the prior + # width). A tight ball around init traps the chain on flat plateaus + # of the likelihood (e.g., when threshold-based rules don't fire), + # because emcee stretch moves scale with the swarm's spread. + p0 = init_values + 0.5 * prior_sigma * np.random.randn(num_walkers, ndim) + p0 = np.clip(p0, 1e-6, None) sampler = emcee.EnsembleSampler(num_walkers, ndim, log_posterior) diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index 952045126..cc6eb545f 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -52,6 +52,7 @@ APPROACHES: agent_sim_learn_oracle_sim_program: True agent_sim_learn_oracle_sim_params: False agent_sim_learn_oracle_sim_param_noise_scale: 1.0 # 1.0 allows successful planning but insatisficing plan; 0.8 produces satisficing plan + code_sim_learning_num_mcmc_steps: 500 # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: From 9f09ff9952da74fc61f437fdbf29351295c4ab1a Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:07:15 +0100 Subject: [PATCH 56/70] Move GT simulator components onto module-globals contract Unifies oracle and agent-synthesized simulators behind one loader: read_simulator_components pulls PROCESS_RULES, PARAM_SPECS, and PROCESS_FEATURES out of any namespace (module dict for oracle, exec_ns for agent), and get_gt_simulator now returns the triple including features. merge_updates no longer takes process_features since the rule producer owns that scope. --- predicators/code_sim_learning/utils.py | 78 ++++++++++++------- predicators/ground_truth_models/__init__.py | 49 ++++++++---- .../ground_truth_models/boil/gt_simulator.py | 46 ++++++----- .../test_agent_sim_learning_approach.py | 13 +--- 4 files changed, 115 insertions(+), 71 deletions(-) diff --git a/predicators/code_sim_learning/utils.py b/predicators/code_sim_learning/utils.py index 5436a36e8..830a1e1ed 100644 --- a/predicators/code_sim_learning/utils.py +++ b/predicators/code_sim_learning/utils.py @@ -4,15 +4,18 @@ * ``apply_rules`` — run a list of rule functions on a state, return feature updates (``ProcessUpdate``). -* ``merge_updates`` — overwrite process features in a ``State`` with - values from a ``ProcessUpdate``. -* ``simulate_step`` — full pipeline: kinematics → rules → merge. +* ``merge_updates`` — overwrite features in a ``State`` with values + from a ``ProcessUpdate``. +* ``simulate_step`` — full pipeline: base → rules → merge. +* ``read_simulator_components`` — pull the ``PROCESS_RULES``, + ``PARAM_SPECS``, ``PROCESS_FEATURES`` triple out of a namespace + (oracle module globals or agent-synthesized exec namespace). """ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple from predicators.structs import Action, Object, State @@ -45,28 +48,18 @@ def apply_rules(state: State, rules: List, def merge_updates( base_state: State, updates: ProcessUpdate, - process_features: Dict[str, List[str]], ) -> State: - """Overwrite process features in *base_state* with *updates*. - - Only features listed in ``process_features[type_name]`` are - overwritten; all other features are preserved from *base_state*. - """ + """Overwrite features in *base_state* with values from *updates*.""" if not updates: return base_state new_data = {} for obj in base_state: arr = base_state[obj].copy() - type_name = obj.type.name - process_feats = set(process_features.get(type_name, [])) - if obj in updates: for feat_name, new_val in updates[obj].items(): - if feat_name in process_feats: - idx = obj.type.feature_names.index(feat_name) - arr[idx] = new_val - + idx = obj.type.feature_names.index(feat_name) + arr[idx] = new_val new_data[obj] = arr merged = base_state.copy() @@ -80,18 +73,51 @@ def simulate_step( base_env: Any, rules: List, params: Dict[str, float], - process_features: Dict[str, List[str]], ) -> State: - """Full simulation pipeline: kinematics → rules → merge. + """Full simulation pipeline: base → rules → merge.""" + base_state = base_env.simulate(state, action) + updates = apply_rules(base_state, rules, params) + if not updates: + return base_state + return merge_updates(base_state, updates) + + +# ── Module-namespace loader ─────────────────────────────────────── - Runs ``base_env.simulate`` for kinematics, ``apply_rules`` for - process dynamics, and ``merge_updates`` to combine them. + +def read_simulator_components( + ns: Mapping[str, Any], +) -> Tuple[Optional[List], Optional[List], Optional[Dict[str, List[str]]]]: + """Pull the simulator triple from a namespace (module or exec dict). + + Looks for three names by convention: + + * ``PROCESS_RULES`` — non-empty list of rule functions. + * ``PARAM_SPECS`` — list of ``ParamSpec``, **or** a zero-arg + callable returning such a list. The callable form lets oracle + modules defer CFG-dependent values until consumption time, so the + module can be imported before CFG is finalized; the agent's + saved-file form normally just uses a list. + * ``PROCESS_FEATURES`` — ``{type_name: [feature_names]}`` dict. + + Returns ``(rules, specs, features)`` with ``None`` for any + missing-or-malformed component; callers decide how to react. """ - kin_state = base_env.simulate(state, action) - updates = apply_rules(kin_state, rules, params) - if not updates: - return kin_state - return merge_updates(kin_state, updates, process_features) + rules = ns.get("PROCESS_RULES") + if not isinstance(rules, list) or not rules: + rules = None + + specs = ns.get("PARAM_SPECS") + if callable(specs): + specs = specs() + if not isinstance(specs, list) or not specs: + specs = None + + features = ns.get("PROCESS_FEATURES") + if features is not None and not isinstance(features, dict): + features = None + + return rules, specs, features # ── LearnedSimulator ────────────────────────────────────────────── diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index 8359d1d18..54b6155d9 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -1,5 +1,6 @@ """Implements ground-truth NSRTs and options.""" import abc +import sys from pathlib import Path from typing import Dict, List, Sequence, Set @@ -69,7 +70,15 @@ def get_processes( class GroundTruthSimulatorFactory(abc.ABC): - """Parent class for ground-truth process-dynamics simulator programs.""" + """Parent class for ground-truth process-dynamics simulator programs. + + The factory itself only pins an env-name binding. The actual + simulator components (``PROCESS_RULES``, ``PARAM_SPECS``, + ``PROCESS_FEATURES``) live as module-level globals on the same file + as the subclass, matching the contract used by agent-synthesized + simulators. ``get_gt_simulator`` reads them via + ``read_simulator_components``. + """ @classmethod @abc.abstractmethod @@ -77,18 +86,6 @@ def get_env_names(cls) -> Set[str]: """Get the env names that this factory builds simulators for.""" raise NotImplementedError("Override me!") - @classmethod - @abc.abstractmethod - def get_rules(cls) -> list: - """Return the list of process rule functions.""" - raise NotImplementedError("Override me!") - - @classmethod - @abc.abstractmethod - def get_param_specs(cls) -> list: - """Return the list of ParamSpec objects.""" - raise NotImplementedError("Override me!") - class GroundTruthLDLBridgePolicyFactory(abc.ABC): """Ground-truth policies implemented with LDLs saved in text files.""" @@ -266,13 +263,31 @@ def get_gt_processes(env_name: str, def get_gt_simulator(env_name: str) -> tuple: """Load ground-truth process rules and param specs for an env. - Returns ``(rules, param_specs)`` where *rules* is a list of process - rule functions and *param_specs* is a list of ``ParamSpec`` objects - whose ``init_value`` is the GT value. + Returns ``(rules, param_specs, process_features)``: *rules* is the + list of process rule functions, *param_specs* is the list of + ``ParamSpec`` objects whose ``init_value`` is the GT value, and + *process_features* is the ``{type_name: [feat_names]}`` mapping that + scopes which features the rules predict. + + Locates the right module via the ``GroundTruthSimulatorFactory`` + registry (env-name binding) and reads the three components from + that module's globals via ``read_simulator_components``. This + mirrors the loader used for agent-synthesized simulators. """ + # Local import to avoid pulling code_sim_learning into ground_truth_models + # at import time. + # pylint: disable=import-outside-toplevel + from predicators.code_sim_learning.utils import read_simulator_components + for cls in utils.get_all_subclasses(GroundTruthSimulatorFactory): if not cls.__abstractmethods__ and env_name in cls.get_env_names(): - return cls.get_rules(), cls.get_param_specs() + module = sys.modules[cls.__module__] + rules, specs, features = read_simulator_components(vars(module)) + if rules is None or specs is None or features is None: + raise RuntimeError( + f"GT simulator module {cls.__module__} is missing one " + "of PROCESS_RULES / PARAM_SPECS / PROCESS_FEATURES.") + return rules, specs, features raise NotImplementedError("Ground-truth simulator not implemented for " f"env: {env_name}") diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index ac6092de5..129afa5c8 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -47,7 +47,21 @@ def _build_param_specs() -> List[ParamSpec]: ] -# Static specs for tests / introspection (uses CFG defaults at import time). +# Module-level globals consumed by ``read_simulator_components`` (the +# same contract used by agent-synthesized simulator files). +# ``PARAM_SPECS`` is bound to the *callable* rather than its result so +# CFG-dependent defaults are evaluated when the loader pulls the value, +# after CFG has been finalized. +PARAM_SPECS = _build_param_specs + +PROCESS_FEATURES: Dict[str, List[str]] = { + "jug": ["water_volume", "heat_level"], + "faucet": ["spilled_level"], + "human": ["happiness_level"], +} + +# Backward-compat alias for tests that import a static, eagerly-built +# spec list (uses CFG defaults at import time). BOIL_PARAM_SPECS: List[ParamSpec] = _build_param_specs() Params = Dict[str, float] @@ -167,26 +181,20 @@ def _get_val(obj: Object, feat: str) -> float: PROCESS_RULES = [_water_filling, _heating, _happiness] +def get_gt_process_features() -> Dict[str, List[str]]: + """Backward-compat accessor; prefer the ``PROCESS_FEATURES`` global.""" + return dict(PROCESS_FEATURES) + + class PyBulletBoilGroundTruthSimulatorFactory(GroundTruthSimulatorFactory): - """GT process-dynamics simulator for pybullet_boil.""" + """GT process-dynamics simulator for pybullet_boil. + + The actual simulator components (``PROCESS_RULES``, ``PARAM_SPECS``, + ``PROCESS_FEATURES``) live as module globals above; this class only + pins the env-name binding so ``get_gt_simulator`` can locate the + right module via the factory registry. + """ @classmethod def get_env_names(cls) -> set: return {"pybullet_boil"} - - @classmethod - def get_rules(cls) -> list: - return list(PROCESS_RULES) - - @classmethod - def get_param_specs(cls) -> list: - return _build_param_specs() - - -def get_gt_process_features() -> Dict[str, List[str]]: - """Process features handled by the simulator (not PyBullet).""" - return { - "jug": ["water_volume", "heat_level"], - "faucet": ["spilled_level"], - "human": ["happiness_level"], - } diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index d9d60734a..d0fb5eb7b 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -83,19 +83,14 @@ def _build_kinematics_only_oracle(env): def _build_combined_model(env): """Build a combined model: kinematics-only env + GT step-level dynamics. - Uses the same construction as AgentSimLearningApproach: wraps GT - rules in a LearnedSimulator via apply_rules, composes with a - kinematics-only env, and derives process_features from env.types - (all features, not just GT process features). + Mirrors AgentSimLearningApproach: wraps GT rules in a + LearnedSimulator via apply_rules and composes with a + kinematics-only base env. """ base_env = create_new_env("pybullet_boil", do_cache=False, use_gui=False, skip_process_dynamics=True) - process_features = { - t.name: list(t.feature_names) - for t in env.types if t.feature_names - } gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} rules = PROCESS_RULES @@ -108,7 +103,7 @@ def combined_simulate(state, action): updates = simulator.predict_step(kin_state) if not updates: return kin_state - return merge_updates(kin_state, updates, process_features) + return merge_updates(kin_state, updates) options = get_gt_options(env.get_name()) model = _OracleOptionModel(options, combined_simulate) From c5d45c2eb9dde19728ae369bfe54bc174c8c462e Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:07:26 +0100 Subject: [PATCH 57/70] Soften boil parameter-dependent gates with sigmoid weights Replaces hard ``dist < threshold`` indicators in the boil rules with sigmoid-smoothed gates of width ``_SOFT_EPS``. Without smoothing, the LM finite-difference Jacobian is ~zero almost everywhere, and the Hessian identifiability diagnostic is uninformative; emcee also gets a non-flat likelihood as a side effect. State-dependent gates (faucet on/off, jug held) stay hard since they don't enter the parameter likelihood. --- .../ground_truth_models/boil/gt_simulator.py | 126 +++++++++++++----- 1 file changed, 91 insertions(+), 35 deletions(-) diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index 129afa5c8..3ffc82089 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -2,6 +2,16 @@ Reproduces the custom step logic from pybullet_boil.py as composable process rules using plain numpy/float arithmetic. + +Parameter-dependent gates (alignment thresholds, capacity caps, fill +height) are softened with sigmoid weights so the residual is +differentiable in those parameters. The primary consumer is the +Levenberg-Marquardt fit (and its Hessian identifiability diagnostic), +which builds a finite-difference Jacobian and would see J ~ 0 almost +everywhere with hard indicators. Smoothing also keeps MCMC walkers +from stalling on flat-likelihood plateaus, but emcee is gradient-free +and benefits less directly. State-dependent gates (faucet on/off, jug +held) remain hard since they don't enter the parameter likelihood. """ from __future__ import annotations @@ -29,6 +39,21 @@ FAUCET_X_LEN = 0.15 _WATER_HEIGHT_TO_LEVEL_RATIO = 10 +# Smoothing scale for parameter-dependent gates. Small enough that gates +# are ~99% saturated when the operand is one threshold-width into the +# active region, large enough to give MCMC a usable gradient near the +# cliff. 0.02 is in the right ballpark for both spatial thresholds +# (~0.05–0.15 m) and water-level thresholds (~0.3–1.3). +_SOFT_EPS = 0.02 + + +def _sigmoid(z: float) -> float: + """Numerically-stable scalar sigmoid.""" + if z >= 0: + return 1.0 / (1.0 + np.exp(-z)) + ez = np.exp(z) + return ez / (1.0 + ez) + def _build_param_specs() -> List[ParamSpec]: """Build at call time so CFG-driven values match the current run.""" @@ -77,7 +102,13 @@ def _objs_by_type(state: State) -> Dict[str, List[Object]]: def _water_filling(state: State, updates: ProcessUpdate, params: Params) -> ProcessUpdate: - """Faucet on + jug aligned → fill jug; otherwise spill.""" + """Faucet on + jug aligned → fill jug; otherwise spill. + + Alignment and capacity gates are soft (sigmoid-weighted) so the + residual is differentiable in ``faucet_align_threshold``, + ``faucet_x_len``, and ``max_jug_water_capacity`` — needed for the + LM Jacobian (and downstream Hessian diagnostic) to be informative. + """ objs = _objs_by_type(state) for faucet in objs.get("faucet", []): if state.get(faucet, "is_on") <= 0.5: @@ -89,40 +120,49 @@ def _water_filling(state: State, updates: ProcessUpdate, out_x = fx + params["faucet_x_len"] * np.cos(frot) out_y = fy - params["faucet_x_len"] * np.sin(frot) - jug_catching = False + # Closest non-held jug picks up the catch (matches the + # original "first aligned wins" semantics for single-jug tasks). + best_jug, best_dist = None, float("inf") for jug in objs.get("jug", []): if state.get(jug, "is_held") > 0.5: continue jx = float(state.get(jug, "x")) jy = float(state.get(jug, "y")) - dist = float(np.hypot(out_x - jx, out_y - jy)) - - if dist < params["faucet_align_threshold"]: - water = float(state.get(jug, "water_volume")) - if water < params["max_jug_water_capacity"]: - new_water = min(params["max_jug_water_capacity"], - water + params["water_fill_speed"]) - updates.setdefault(jug, {})["water_volume"] = new_water - jug_catching = True - else: - spill = float(state.get(faucet, "spilled_level")) - new_spill = min(params["max_water_spill_width"], - spill + params["water_fill_speed"]) - updates.setdefault(faucet, {})["spilled_level"] = new_spill - break - - if not jug_catching: - spill = float(state.get(faucet, "spilled_level")) - new_spill = min(params["max_water_spill_width"], - spill + params["water_fill_speed"]) - updates.setdefault(faucet, {})["spilled_level"] = new_spill + d = float(np.hypot(out_x - jx, out_y - jy)) + if d < best_dist: + best_jug, best_dist = jug, d + + catch_w = 0.0 + if best_jug is not None: + water = float(state.get(best_jug, "water_volume")) + align_w = _sigmoid( + (params["faucet_align_threshold"] - best_dist) / _SOFT_EPS) + cap_w = _sigmoid( + (params["max_jug_water_capacity"] - water) / _SOFT_EPS) + catch_w = align_w * cap_w + new_water = water + catch_w * params["water_fill_speed"] + updates.setdefault(best_jug, {})["water_volume"] = new_water + + # Uncaught water spills (clamped at max_water_spill_width). + spill = float(state.get(faucet, "spilled_level")) + new_spill = min( + params["max_water_spill_width"], + spill + (1.0 - catch_w) * params["water_fill_speed"]) + updates.setdefault(faucet, {})["spilled_level"] = new_spill return updates def _heating(state: State, updates: ProcessUpdate, params: Params) -> ProcessUpdate: - """Burner on + jug with water aligned → heat jug.""" + """Burner on + jug with water aligned → heat jug. + + Alignment gate is soft so the residual is differentiable in + ``burner_align_threshold`` (LM's finite-difference Jacobian needs + this; MCMC also avoids flat-likelihood plateaus as a side effect). + The heat cap at 1.0 stays hard since 1.0 is a constant boundary, + not a learned parameter. + """ objs = _objs_by_type(state) for burner in objs.get("burner", []): if state.get(burner, "is_on") <= 0.5: @@ -139,17 +179,25 @@ def _heating(state: State, updates: ProcessUpdate, jy = float(state.get(jug, "y")) dist = float(np.hypot(bx - jx, by - jy)) - if dist < params["burner_align_threshold"]: - heat = float(state.get(jug, "heat_level")) - new_heat = min(1.0, heat + params["heating_speed"]) - updates.setdefault(jug, {})["heat_level"] = new_heat + align_w = _sigmoid( + (params["burner_align_threshold"] - dist) / _SOFT_EPS) + heat = float(state.get(jug, "heat_level")) + new_heat = min(1.0, heat + align_w * params["heating_speed"]) + updates.setdefault(jug, {})["heat_level"] = new_heat return updates def _happiness(state: State, updates: ProcessUpdate, params: Params) -> ProcessUpdate: - """Jug filled + boiled + no spill + burner off → human happy.""" + """Jug filled + boiled + no spill + burner off → human happy. + + The water-filled gate is soft on ``water_filled_height`` so the + residual is differentiable in that parameter for LM (and emcee + gets a non-flat likelihood as a side effect). The heat>=1.0 gate + stays hard (1.0 is a constant cap, not a learned parameter). + Spill / burner-on gates are state-dependent. + """ objs = _objs_by_type(state) faucets = objs.get("faucet", []) burners = objs.get("burner", []) @@ -160,7 +208,12 @@ def _get_val(obj: Object, feat: str) -> float: return float(val) if hasattr(val, 'item') else val return float(state.get(obj, feat)) - any_spill = any(_get_val(f, "spilled_level") > 0 for f in faucets) + # Spilled-level prediction can be a tiny positive number under soft + # semantics even when the env reports zero, so treat anything below + # the smoothing scale as "no spill" to avoid spuriously gating + # happiness off. + any_spill = any( + _get_val(f, "spilled_level") > _SOFT_EPS for f in faucets) any_burner_on = any(state.get(b, "is_on") > 0.5 for b in burners) if any_spill or any_burner_on: @@ -169,11 +222,14 @@ def _get_val(obj: Object, feat: str) -> float: for jug in objs.get("jug", []): water = _get_val(jug, "water_volume") heat = _get_val(jug, "heat_level") - if water >= params["water_filled_height"] and heat >= 1.0: - for human in objs.get("human", []): - h = float(state.get(human, "happiness_level")) - new_h = min(1.0, h + params["happiness_speed"]) - updates.setdefault(human, {})["happiness_level"] = new_h + if heat < 1.0: + continue + filled_w = _sigmoid( + (water - params["water_filled_height"]) / _SOFT_EPS) + for human in objs.get("human", []): + h = float(state.get(human, "happiness_level")) + new_h = min(1.0, h + filled_w * params["happiness_speed"]) + updates.setdefault(human, {})["happiness_level"] = new_h return updates From 95e384fd87ef0f2c9fc205d4458050ae7e2d87bc Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:07:38 +0100 Subject: [PATCH 58/70] Add LM warm-start and Hessian identifiability diagnostic Adds fit_map_lm (Levenberg-Marquardt MAP estimate via SciPy TRF) and log_hessian_identifiability (eigendecompose J^T J/sigma^2 + prior precision to flag sloppy parameter directions). Both run as a single LM pass before MCMC; fit_params now centers walkers on theta_map when code_sim_learning_warm_start_with_lm is set, and short-circuits to it directly when num_mcmc_steps == 0. Also adds compute_residuals (per-feature residual vector LM consumes) and log_sse_breakdown (per-(type, feature) SSE so we can see which features dominate the loss). Two CFG flags gate the new behavior: warm_start_with_lm (default True), log_hessian_identifiability (default False). --- predicators/code_sim_learning/training.py | 300 +++++++++++++++++++++- predicators/settings.py | 6 + 2 files changed, 302 insertions(+), 4 deletions(-) diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index 532fb0bbd..494e274b2 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -88,6 +88,269 @@ def compute_sse( return total_se +def compute_residuals( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], +) -> np.ndarray: + """Per-feature residuals (predicted - observed) as a flat vector. + + Used by Levenberg-Marquardt, which needs the residual *vector* + rather than scalar SSE so it can build J = dr/dtheta. Iteration + order is deterministic so the same theta produces the same vector + across calls (required for finite-difference Jacobians). + """ + residuals: List[float] = [] + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + raw = updates[obj][feat_name] + pred = raw.item() if hasattr(raw, 'item') else float(raw) + else: + pred = float(s_t.get(obj, feat_name)) + obs = float(s_next_obs.get(obj, feat_name)) + residuals.append(pred - obs) + return np.asarray(residuals, dtype=float) + + +def log_sse_breakdown( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], + label: str = "", +) -> None: + """Log per-(type, feature) SSE so we can see which features dominate. + + Splits each feature's residual into two buckets: + * ``pred`` — transitions where the rule produced an update + (residual is sim's prediction error) + * ``no_pred`` — transitions where no rule fired + (residual is whatever the env changed on its own; + large values here mean the model is missing a + process for this feature) + """ + bucket: Dict[Tuple[str, str], Dict[str, float]] = {} + + def _slot(key: Tuple[str, str]) -> Dict[str, float]: + if key not in bucket: + bucket[key] = { + "sse_pred": 0.0, + "n_pred": 0, + "sse_no_pred": 0.0, + "n_no_pred": 0, + "max_abs_err": 0.0, + } + return bucket[key] + + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + + for obj, feat_dict in updates.items(): + type_name = obj.type.name + allowed_feats = process_features.get(type_name, []) + for feat_name, pred_val in feat_dict.items(): + if feat_name not in allowed_feats: + continue + v = pred_val.item() if hasattr(pred_val, 'item') else pred_val + obs_val = float(s_next_obs.get(obj, feat_name)) + err = float(v) - obs_val + slot = _slot((type_name, feat_name)) + slot["sse_pred"] += err * err + slot["n_pred"] += 1 + slot["max_abs_err"] = max(slot["max_abs_err"], abs(err)) + + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + continue + pred_val = float(s_t.get(obj, feat_name)) + obs_val = float(s_next_obs.get(obj, feat_name)) + err = pred_val - obs_val + slot = _slot((type_name, feat_name)) + slot["sse_no_pred"] += err * err + slot["n_no_pred"] += 1 + slot["max_abs_err"] = max(slot["max_abs_err"], abs(err)) + + if not bucket: + return + + total = sum(s["sse_pred"] + s["sse_no_pred"] for s in bucket.values()) + header = f"SSE breakdown{(' — ' + label) if label else ''} " \ + f"(total {total:.4f}):" + logger.info(header) + logger.info(" %-22s %10s %6s %10s %6s %10s", "type.feature", + "sse_pred", "n_pred", "sse_no_pred", "n_nop", "max|err|") + rows = sorted( + bucket.items(), + key=lambda kv: -(kv[1]["sse_pred"] + kv[1]["sse_no_pred"]), + ) + for (type_name, feat_name), s in rows: + logger.info( + " %-22s %10.4f %6d %10.4f %6d %10.4f", + f"{type_name}.{feat_name}", + s["sse_pred"], + int(s["n_pred"]), + s["sse_no_pred"], + int(s["n_no_pred"]), + s["max_abs_err"], + ) + + +def fit_map_lm( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + param_specs: List[ParamSpec], + process_features: Dict[str, List[str]], + max_nfev: int = 200, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Find a MAP estimate via Levenberg-Marquardt (trust-region reflective). + + Returns ``(theta_map, jacobian_at_optimum)``. Jacobian is ``None`` + only if the residual vector is empty or LM raises; in those cases + callers should treat the diagnostic as unavailable. + + How LM finds the MAP here: + * ``compute_residuals`` returns r(theta) = (s_{t+1}_obs - sim(s_t, a; + theta)) flattened over transitions and the features named in + ``process_features``. Minimizing 0.5 * ||r||^2 is exactly MLE + under iid Gaussian observation noise; with the broad Gaussian + prior used elsewhere in this module being effectively flat near + init, the least-squares minimizer coincides with the MAP. + * ``scipy.optimize.least_squares(method='trf')`` runs a + Levenberg-Marquardt step inside a trust region with box + constraints (``lo``/``hi`` from ``param_specs``). At each step + it numerically estimates the Jacobian J = dr/dtheta, solves the + damped normal equations (J^T J + lambda I) dtheta = -J^T r, and + adapts lambda based on whether the step reduces SSE. + * On exit, ``result.x`` is theta_map and ``result.jac`` is J at + the optimum. J^T J / sigma^2 is the Gauss-Newton approximation + to the negative log-likelihood Hessian — the input + ``log_hessian_identifiability`` eigendecomposes to flag flat + directions. + + Two callers (see ``fit_simulator_params``): + * Hessian identifiability diagnostic — eigendecompose J^T J. + * MCMC warm start — center emcee walkers on theta_map (and short- + circuit to it directly when ``num_mcmc_steps == 0``). + """ + from scipy.optimize import least_squares # pylint: disable=import-outside-toplevel + + names = [s.name for s in param_specs] + init = np.array([s.init_value for s in param_specs], dtype=float) + lo = np.array([s.lo if s.lo is not None else 1e-6 for s in param_specs]) + hi = np.array( + [s.hi if s.hi is not None else np.inf for s in param_specs]) + # Nudge init strictly into the interior so trf doesn't reject it. + init = np.maximum(init, lo + 1e-9) + safe_hi = np.where(np.isfinite(hi), hi - 1e-9, np.inf) + init = np.minimum(init, safe_hi) + + def residuals_fn(theta: np.ndarray) -> np.ndarray: + params = {n: float(theta[i]) for i, n in enumerate(names)} + return compute_residuals(simulator_fn, transitions, params, + process_features) + + init_residuals = residuals_fn(init) + if init_residuals.size == 0: + logger.warning("No residuals to fit (empty process_features); " + "skipping LM diagnostic.") + return init, None + + sse_init = float(np.sum(init_residuals**2)) + + try: + result = least_squares(residuals_fn, + init, + method='trf', + bounds=(lo, hi), + max_nfev=max_nfev) + except Exception as exc: # pylint: disable=broad-except + logger.warning("LM diagnostic raised %s; skipping Hessian log.", exc) + return init, None + + sse_lm = float(2.0 * result.cost) + delta = {names[i]: float(result.x[i] - init[i]) + for i in range(len(names))} + logger.info( + "LM diagnostic fit: SSE %.4f -> %.4f in %d fn-evals (status=%d, %s).", + sse_init, sse_lm, result.nfev, result.status, + "converged" if result.success else "max-evals") + logger.info("LM theta_map - init: %s", + {k: f"{v:+.4f}" for k, v in delta.items()}) + + jac = np.asarray(result.jac, dtype=float) + if jac.size == 0: + return np.asarray(result.x, dtype=float), None + return np.asarray(result.x, dtype=float), jac + + +def log_hessian_identifiability( + jacobian: np.ndarray, + param_names: List[str], + noise_sigma: float, + prior_sigma: np.ndarray, + top_k: int = 3, +) -> None: + """Eigendecompose the Hessian at the MAP and log identifiability. + + Under a Laplace approximation, the Hessian of the negative + log-posterior is the inverse posterior covariance. Its eigenvectors + are *combinations* of parameters (not individual params), and the + eigenvalues say how tightly the data constrains each combination: + + * Large eigenvalue -> stiff direction: data pins this down. + * Small eigenvalue -> sloppy direction: data is silent here. + + Sloppy directions point to parameter combinations no optimizer can + recover from the current data — typically structural rule-pair + degeneracy or under-excited input trajectories. The Gauss-Newton + approximation H ~= J^T J / sigma^2 + diag(1/prior_sigma^2) reuses + the LM Jacobian, so this analysis costs effectively nothing once + LM has run. + """ + H_data = jacobian.T @ jacobian / (noise_sigma**2) + H_prior = np.diag(1.0 / prior_sigma**2) + H = H_data + H_prior + + eigvals, eigvecs = np.linalg.eigh(H) # ascending + + cond = float(eigvals[-1] / max(eigvals[0], 1e-30)) + logger.info("Hessian eigenanalysis (cond %.2e, %d params):", + cond, len(param_names)) + + def _format(vec: np.ndarray) -> str: + order = np.argsort(-np.abs(vec)) + parts = [] + for j in order[:4]: + if abs(vec[j]) < 0.05: + break + parts.append(f"{vec[j]:+.2f} {param_names[j]}") + return " ".join(parts) if parts else "(uniform)" + + n = len(eigvals) + k = min(top_k, n) + stiff_idx = list(range(n - 1, n - 1 - k, -1)) + stiff_set = set(stiff_idx) + sloppy_idx = [i for i in range(k) if i not in stiff_set] + + logger.info(" Stiff (well-constrained):") + for i in stiff_idx: + logger.info(" lambda = %10.3e : %s", + eigvals[i], _format(eigvecs[:, i])) + + if sloppy_idx: + logger.info(" Sloppy (under-constrained):") + for i in sloppy_idx: + logger.info(" lambda = %10.3e : %s", + eigvals[i], _format(eigvecs[:, i])) + + def fit_params( simulator_fn: StepSimulatorFn, transitions: List[Tuple[State, Action, State]], @@ -128,15 +391,43 @@ def fit_params( if num_steps < 0: raise ValueError("code_sim_learning_num_mcmc_steps must be " "non-negative.") + prior_sigma = init_values * prior_sigma_scale + + # Optional one-shot LM fit. Two independent uses: + # * Hessian diagnostic — eigendecompose J^T J at the MAP. + # * Warm start — center MCMC walkers on theta_map (and short-circuit + # to it directly when num_steps == 0). + walker_center = init_values + if (CFG.code_sim_learning_log_hessian_identifiability + or CFG.code_sim_learning_warm_start_with_lm): + theta_map, jac = fit_map_lm(simulator_fn, transitions, param_specs, + process_features) + if (CFG.code_sim_learning_log_hessian_identifiability + and jac is not None and jac.size > 0): + log_hessian_identifiability(jac, names, noise_sigma, prior_sigma) + if CFG.code_sim_learning_warm_start_with_lm: + walker_center = np.asarray(theta_map, dtype=float) + logger.info("Warm-starting MCMC walkers from LM MAP estimate.") + lm_params = {n: float(walker_center[i]) for i, n in enumerate(names)} + lm_sse = compute_sse(simulator_fn, transitions, lm_params, + process_features) + lm_ll = -0.5 * lm_sse / (noise_sigma**2) + logger.info("After LM warm start — SSE: %.6f log-likelihood: %.2f", + lm_sse, lm_ll) + log_sse_breakdown(simulator_fn, transitions, lm_params, + process_features, label="lm-warm-start") + if num_steps == 0: - logger.info("Skipping emcee; using initial parameter values.") - return FitResult(names, init_values[None, :], np.zeros(1)) + if CFG.code_sim_learning_warm_start_with_lm: + logger.info("Skipping emcee; using LM warm-start parameters.") + else: + logger.info("Skipping emcee; using initial parameter values.") + return FitResult(names, walker_center[None, :], np.zeros(1)) import emcee # type: ignore[import-untyped] # pylint: disable=import-outside-toplevel ndim = len(param_specs) num_walkers = max(num_walkers, 2 * ndim + 2) - prior_sigma = init_values * prior_sigma_scale burn_in = min(burn_in, max(num_steps - 1, 0)) def log_posterior(theta: np.ndarray) -> float: @@ -154,7 +445,7 @@ def log_posterior(theta: np.ndarray) -> float: # width). A tight ball around init traps the chain on flat plateaus # of the likelihood (e.g., when threshold-based rules don't fire), # because emcee stretch moves scale with the swarm's spread. - p0 = init_values + 0.5 * prior_sigma * np.random.randn(num_walkers, ndim) + p0 = walker_center + 0.5 * prior_sigma * np.random.randn(num_walkers, ndim) p0 = np.clip(p0, 1e-6, None) sampler = emcee.EnsembleSampler(num_walkers, ndim, log_posterior) @@ -164,6 +455,7 @@ def log_posterior(theta: np.ndarray) -> float: # Run with periodic progress reports. report_interval = max(1, num_steps // 5) + report_interval = 100 for i, _result in enumerate(sampler.sample(p0, iterations=num_steps), start=1): if i % report_interval == 0 or i == num_steps: diff --git a/predicators/settings.py b/predicators/settings.py index 1a292fb9e..248b8c63e 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1025,6 +1025,12 @@ class GlobalSettings: # Code sim-learning parameter fitting settings. # Set to 0 to skip MCMC and use initial parameter values directly. code_sim_learning_num_mcmc_steps = 500 + # Diagnostic: log the Hessian eigendecomposition at the MAP to + # spot unidentifiable parameter combinations. Adds ~5-15s per fit. + code_sim_learning_log_hessian_identifiability = False + # If True, run an LM fit and center MCMC walkers on its MAP estimate + # instead of init_values. Adds ~5-15s per fit. + code_sim_learning_warm_start_with_lm = True # Sim-learning oracle flags (for ablation / debugging). # When True, load GT process rules instead of running agent synthesis. From 195e889656a4ca4b550e10d95993b02ae9048cbf Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:07:49 +0100 Subject: [PATCH 59/70] Infer process-feature scope from base-sim residuals The agent now declares its own PROCESS_FEATURES alongside PROCESS_RULES and PARAM_SPECS, and the loss is scoped to that declaration (instead of every feature on every type). Before synthesis, the approach runs the base sim on each transition and flags (type, feat) pairs whose prediction diverges from the observation on at least min_hits triples; this set is sent to the agent as a starting hint and used as the eval/test scope until the agent overrides it. The base-sim prediction is precomputed once into base_pred_triples so MCMC's inner loop only evaluates the cheap apply_rules step. create_synthesis_tools now takes the precomputed triples plus the inferred hint, drops the live base_env, and reads PROCESS_FEATURES from exec_ns each call (falling back to the hint when undeclared). --- predicators/agent_sdk/tools.py | 49 ++- .../approaches/agent_sim_learning_approach.py | 345 ++++++++++-------- 2 files changed, 226 insertions(+), 168 deletions(-) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 08418c5ab..685e73202 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -1968,9 +1968,8 @@ async def visualize_state(args: Dict[str, Any]) -> Dict[str, Any]: def create_synthesis_tools( exec_ns: Dict[str, Any], - step_transitions: list, - process_features: Dict[str, List[str]], - base_env: Any = None, + base_pred_triples: list, + inferred_process_features: Dict[str, List[str]], save_dir: Optional[str] = None, ) -> list: """Create MCP tools for the sim-learning synthesis agent. @@ -1983,13 +1982,18 @@ def create_synthesis_tools( ``PROCESS_RULES`` / ``PARAM_SPECS`` defined in the namespace. * ``test_simulator`` — tests predictions vs observations. + Both eval/test read ``PROCESS_FEATURES`` from ``exec_ns`` on each + call, falling back to ``inferred_process_features`` if the agent + hasn't declared it yet. + Args: - exec_ns: Persistent namespace for ``run_python``. Should + exec_ns: Persistent namespace for ``run_python``. Should contain ``trajectories``, ``np``, ``ParamSpec``. - step_transitions: ``(State, Action, State)`` triples. - process_features: ``{type_name: [feat_names]}`` for MSE. - base_env: Kinematics-only environment. When provided, - evaluate/test tools run kinematics before learned rules. + base_pred_triples: ``(s_base, action, s_next_obs)`` triples + with the base step already advanced — eval/test consume + ``s_base`` directly so no live env is needed. + inferred_process_features: Data-driven default scope used + until the agent defines ``PROCESS_FEATURES`` in exec_ns. save_dir: Directory to save simulator source code to. Each ``run_python`` call appends code to ``save_dir/simulator_code.py``. @@ -2075,17 +2079,23 @@ async def evaluate_simulator(_args: Dict[str, Any]) -> Dict[str, Any]: return _text("Error: PARAM_SPECS not defined. Use " "run_python to define it first.") + declared = exec_ns.get("PROCESS_FEATURES") + process_features = (declared if isinstance(declared, dict) else + inferred_process_features) + scope_note = ("PROCESS_FEATURES" if isinstance(declared, dict) else + "inferred (PROCESS_FEATURES not declared)") + try: fitted_params, sse = ( AgentSimLearningApproach._fit_parameters( # pylint: disable=protected-access - rules, specs, step_transitions, process_features, - base_env)) + rules, specs, base_pred_triples, process_features)) except Exception as e: # pylint: disable=broad-except return _text(f"Error: fit_params failed:\n{e}") lines = [ f"SSE: {sse:.6f} on " - f"{len(step_transitions)} step transitions.", + f"{len(base_pred_triples)} step transitions " + f"(scope: {scope_note}).", "", "Fitted parameters:", ] @@ -2123,9 +2133,13 @@ async def test_simulator(args: Dict[str, Any]) -> Dict[str, Any]: if not isinstance(rules, list) or not rules: return _text("Error: PROCESS_RULES not defined.") + declared = exec_ns.get("PROCESS_FEATURES") + process_features = (declared if isinstance(declared, dict) else + inferred_process_features) + max_n = args.get("max_transitions", 100) tol = args.get("tolerance", 1e-4) - pairs = step_transitions[:max_n] + pairs = base_pred_triples[:max_n] # Use init params if not yet fitted. if specs: @@ -2137,16 +2151,13 @@ async def test_simulator(args: Dict[str, Any]) -> Dict[str, Any]: n_tested = 0 n_mismatch = 0 - for s_t, action, s_next_obs in pairs: - # Run kinematics first so rules see post-kin state. - kin_state = (base_env.simulate(s_t, action) - if base_env is not None else s_t) + for base_state, _action, s_next_obs in pairs: updates: Dict = {} for rule in rules: - updates = rule(kin_state, updates, t_params) + updates = rule(base_state, updates, t_params) entry: list = [] - for obj in s_t: + for obj in base_state: type_name = obj.type.name for feat in process_features.get(type_name, []): if obj in updates and feat in updates[obj]: @@ -2154,7 +2165,7 @@ async def test_simulator(args: Dict[str, Any]) -> Dict[str, Any]: pred = (pred.item() if hasattr(pred, "item") else float(pred)) else: - pred = s_t.get(obj, feat) + pred = base_state.get(obj, feat) obs = s_next_obs.get(obj, feat) err = abs(pred - obs) if err > tol: diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index 74874ce22..0f29c8c19 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -30,9 +30,9 @@ from predicators.agent_sdk.tools import create_synthesis_tools from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach from predicators.code_sim_learning.training import ParamSpec, compute_sse, \ - fit_params + fit_params, log_sse_breakdown from predicators.code_sim_learning.utils import LearnedSimulator, \ - apply_rules, merge_updates + apply_rules, merge_updates, read_simulator_components from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_simulator from predicators.option_model import _OptionModelBase, _OracleOptionModel @@ -70,13 +70,9 @@ def __init__(self, *args: Any, option_model: Optional[_OptionModelBase] = None, **kwargs: Any) -> None: - # Build the base env BEFORE super().__init__ and pass - # the resulting option model in via option_model=. This stops - # AgentPlannerApproach.__init__ from spinning up its own full- - # process env (which would conflict with this one over PyBullet - # GUI connections) and is the only env this approach holds. - # learn_from_interaction_results later wraps a kin+learned - # combined simulator around the same env. + # Build the base env and pass the option model in so the parent + # __init__ doesn't spin up its own full-process env, which + # would fight this one for the PyBullet GUI client. self._base_env = create_new_env(CFG.env, do_cache=False, use_gui=CFG.option_model_use_gui, @@ -92,17 +88,12 @@ def __init__(self, *args, option_model=option_model, **kwargs) - self._simulator: Optional[LearnedSimulator] = None - self._process_features: Dict[str, List[str]] = { - t.name: list(t.feature_names) - for t in types if t.feature_names - } - # Persistent state across learning cycles. + self._learned_simulator: Optional[LearnedSimulator] = None + # Loss-scope mask for parameter fitting (compute_sse). + self._process_features: Dict[str, List[str]] = {} self._process_rules: Optional[List] = None self._fitted_params: Optional[Dict[str, float]] = None self._fit_sse: float = float("inf") - # True during simulator synthesis (learning); False during - # plan generation (decision-making). self._learning_mode: bool = False @classmethod @@ -128,29 +119,44 @@ def learn_from_interaction_results( self._learn_simulator(self._online_trajectories) def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: - """Synthesize rules, fit parameters, and build the option model. + """Synthesize rules, fit parameters, and build the option model.""" + # Two parallel triple lists drive the rest of this method: + # * obs_triples — raw (s_t, a, s_{t+1}) from the data. + # * base_pred_triples — same triples but s_t replaced by the + # base sim's one-step prediction. The rules run on top of + # that prediction; SSE compares against s_{t+1}. + obs_triples = self._extract_obs_triples(trajectories) + if not obs_triples: + logger.warning("No step transitions; skipping simulator learning.") + return + # Headless env for the pre-compute: reusing the GUI base_env + # corrupts its visual-shape state after a few hundred steps. + fit_env = create_new_env(CFG.env, + do_cache=False, + use_gui=False, + skip_process_dynamics=True) + logger.info("Pre-computing base states for %d transitions.", + len(obs_triples)) + base_pred_triples = self._compute_base_pred_triples( + obs_triples, fit_env) + inferred_hint = self._infer_process_features_from_residuals( + obs_triples, base_pred_triples) + logger.info("Process features (data-driven hint): %s", inferred_hint) - Shared by ``learn_from_offline_dataset`` and - ``learn_from_interaction_results``. - """ - self._synthesize_with_agent(self._process_features, trajectories) + self._synthesize_with_agent(trajectories, obs_triples, + base_pred_triples, inferred_hint) - # Build learned simulator. if self._process_rules is not None and self._fitted_params is not None: rules, params = self._process_rules, self._fitted_params - self._simulator = LearnedSimulator( + self._learned_simulator = LearnedSimulator( step_fn=lambda s, _r=rules, _p=params: # type: ignore[misc] apply_rules(s, _r, _p), name="agent_synthesized") - elif self._simulator is None: + elif self._learned_simulator is None: logger.warning("Synthesis produced no simulator, skipping.") return - # Build combined simulator. - combined_sim = self._build_combined_simulator(self._simulator, - self._process_features) - - # Build learned option model + combined_sim = self._build_combined_simulator(self._learned_simulator) self._option_model = self._build_option_model(combined_sim) logger.info("Built learned option model (SSE: %.6f).", self._fit_sse) @@ -160,10 +166,6 @@ def _build_option_model( ) -> _OracleOptionModel: """Wrap a simulator function in an OracleOptionModel. - Plumbs ``_abstract_function`` for Wait-target atom-change - termination so the model behaves identically whether it's - wrapping the bare base simulator (init) or the learned - kin+process combined simulator (post learn_from_interaction). Uses ``self._get_all_options()`` rather than ``get_gt_options(CFG.env)`` to avoid spawning a second cached PyBullet env via ``get_or_create_env``. @@ -179,31 +181,25 @@ def _build_option_model( def _synthesize_with_agent( self, - process_features: Dict[str, List[str]], trajectories: List[LowLevelTrajectory], + obs_triples: List[Tuple[State, Action, State]], + base_pred_triples: List[Tuple[State, Action, State]], + inferred_hint: Dict[str, List[str]], ) -> None: - """Synthesize parameterized process rules via a Claude agent. - - Provides ``run_python``, ``evaluate_simulator``, and - ``test_simulator`` tools. The agent explores trajectory data - via ``run_python`` (which has a persistent namespace with - ``trajectories`` pre-loaded), then defines ``PROCESS_RULES`` - and ``PARAM_SPECS``. Each ``run_python`` call appends code - to a saved file; after the session we reload from that file. - - - ``agent_sim_learn_oracle_sim_program``: skip agent synthesis - and load GT rules/specs instead (init_values perturbed so - MCMC has non-trivial work). - - ``agent_sim_learn_oracle_sim_param_noise_scale``: adjust the - magnitude of the perturbation applied to oracle init_values. - - ``agent_sim_learn_oracle_sim_params``: skip MCMC fitting and - use the GT parameter values directly. + """Synthesize PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES via agent. + + ``inferred_hint`` is passed to the agent as a starting point + and used as the eval/test scope until it declares its own + ``PROCESS_FEATURES``. CFG flags + ``agent_sim_learn_oracle_sim_program`` and + ``agent_sim_learn_oracle_sim_params`` short-circuit the agent + and/or MCMC by loading the GT simulator instead. """ - step_transitions = self._extract_step_transitions(trajectories) - # ── Obtain rules + specs ──────────────────────────────── if CFG.agent_sim_learn_oracle_sim_program: - rules, specs = get_gt_simulator(CFG.env) + rules, specs, process_features = get_gt_simulator(CFG.env) + self._log_feature_set_diff(inferred_hint, process_features, + "inferred", "oracle") if not CFG.agent_sim_learn_oracle_sim_params: rng = np.random.default_rng(CFG.seed) noise_scale = CFG.agent_sim_learn_oracle_sim_param_noise_scale @@ -225,43 +221,44 @@ def _synthesize_with_agent( logger.info("Loaded oracle sim program (%d rules, %d params).", len(rules), len(specs)) else: - # Directory for saving simulator source code. base = self._tool_context.sandbox_dir or self._get_log_dir() save_dir = os.path.join(base, "simulator_code") - # Persistent exec namespace — the agent's "scratch-pad". exec_ns: Dict[str, Any] = { "trajectories": trajectories, "np": np, "ParamSpec": ParamSpec, } - # Build synthesis tools (run_python, evaluate, test). tools = create_synthesis_tools(exec_ns, - step_transitions, - process_features, - self._base_env, + base_pred_triples, + inferred_hint, save_dir=save_dir) self._tool_context.extra_mcp_tools = tools self._learning_mode = True - # Force a fresh session so the synthesis system prompt and - # tool set take effect. + # Fresh session so the synthesis prompt + tools take effect. self._close_agent_session() self._ensure_agent_session() - # Write data-structure reference for the agent to Read. structs_ref = self._write_structs_reference() n_trajs = len(trajectories) message = f"""\ Synthesize a process dynamics simulator for this environment. \ -There are {n_trajs} trajectories ({len(step_transitions)} step \ +There are {n_trajs} trajectories ({len(obs_triples)} step \ transitions) available. Data-structure source code is at: {structs_ref} -Read that file first, then explore the trajectory data with \ -`run_python` and define PROCESS_RULES and PARAM_SPECS.""" + +A residual scan between the base simulator's prediction and the \ +observed next state suggests these features carry process dynamics \ +(starting hint, may include base-sim jitter — refine as you go): +{inferred_hint} + +Read the data-structures file first, then explore the trajectory \ +data with `run_python` and define PROCESS_RULES, PARAM_SPECS, and \ +PROCESS_FEATURES.""" try: self._query_agent_sync(message) @@ -270,40 +267,41 @@ def _synthesize_with_agent( self._learning_mode = False self._close_agent_session() - # Load results from saved versioned files. - rules, specs = self._load_simulator_from_file( + rules, specs, declared = self._load_simulator_from_file( save_dir, trajectories) if rules is None or specs is None: return - + assert declared is not None, ( + "Agent did not declare PROCESS_FEATURES; " + "synthesis output is incomplete.") + process_features = declared + self._log_feature_set_diff(inferred_hint, process_features, + "inferred", "declared") logger.info("Agent synthesized %d rules, %d params.", len(rules), len(specs)) self._process_rules = rules + self._process_features = process_features - # ── Obtain fitted parameters ──────────────────────────── - # Use a headless env for fitting. - fit_env = create_new_env(CFG.env, - do_cache=False, - use_gui=False, - skip_process_dynamics=True) _noise_sigma = 0.05 # matches fit_params default if CFG.agent_sim_learn_oracle_sim_params: self._fitted_params = {s.name: s.init_value for s in specs} - self._fit_sse = compute_sse( - lambda s, a, p: apply_rules( # type: ignore[misc] - fit_env.simulate(s, a), rules, p), - step_transitions, - self._fitted_params, - process_features) + oracle_sim_fn = lambda s, a, p: apply_rules( # noqa: E731 + s, rules, p) + self._fit_sse = compute_sse(oracle_sim_fn, base_pred_triples, + self._fitted_params, + process_features) fit_ll = -0.5 * self._fit_sse / (_noise_sigma**2) logger.info("Oracle params — SSE: %.6f log-likelihood: %.2f", self._fit_sse, fit_ll) for name, val in sorted(self._fitted_params.items()): logger.info(" %-30s %.4f", name, val) + log_sse_breakdown(oracle_sim_fn, base_pred_triples, + self._fitted_params, process_features, + label="oracle") else: self._fitted_params, self._fit_sse = self._fit_parameters( - rules, specs, step_transitions, process_features, fit_env) + rules, specs, base_pred_triples, process_features) if CFG.code_sim_learning_num_mcmc_steps == 0: logger.info("Skipped MCMC; using %d initial params.", len(specs)) @@ -316,31 +314,14 @@ def _synthesize_with_agent( def _fit_parameters( rules: List, specs: List[ParamSpec], - step_transitions: List[Tuple[State, Action, State]], + base_pred_triples: List[Tuple[State, Action, State]], process_features: Dict[str, List[str]], - base_env: Any, ) -> Tuple[Dict[str, float], float]: """Fit parameters for the synthesized rules via MCMC. - Args: - base_env: Base environment. base_env.simulate(s, a) handles the - first half of each transition, leaving only the learned - process-rule updates for the MCMC loop to evaluate. - - Returns: - (fitted_params, sse) tuple. + ``base_pred_triples`` must already have the base step applied; + precomputing avoids re-running it inside the MCMC inner loop. """ - assert base_env is not None, "base_env required" - # base_env.simulate(s, a) is param-independent, so pre-compute it - # once here rather than inside every MCMC log-posterior call - # (num_walkers × num_steps × len(transitions) invocations). - # The MCMC loop then only evaluates the cheap apply_rules step. - logger.info("Pre-computing base states for %d transitions.", - len(step_transitions)) - base_transitions: List[Tuple[State, Action, State]] = [ - (base_env.simulate(s, a), a, s_next) - for s, a, s_next in step_transitions - ] def sim_fn(state: State, action: Action, params: Dict[str, float]) -> Dict: @@ -348,25 +329,29 @@ def sim_fn(state: State, action: Action, noise_sigma = 0.05 # matches fit_params default init_params = {s.name: s.init_value for s in specs} - pre_sse = compute_sse(sim_fn, base_transitions, init_params, + pre_sse = compute_sse(sim_fn, base_pred_triples, init_params, process_features) pre_ll = -0.5 * pre_sse / (noise_sigma**2) logger.info("Before fitting — SSE: %.6f log-likelihood: %.2f", pre_sse, pre_ll) + log_sse_breakdown(sim_fn, base_pred_triples, init_params, + process_features, label="before") result = fit_params( simulator_fn=sim_fn, - transitions=base_transitions, + transitions=base_pred_triples, param_specs=specs, process_features=process_features, ) fitted_params = result.point_estimate - post_sse = compute_sse(sim_fn, base_transitions, fitted_params, + post_sse = compute_sse(sim_fn, base_pred_triples, fitted_params, process_features) post_ll = -0.5 * post_sse / (noise_sigma**2) logger.info("After fitting — SSE: %.6f log-likelihood: %.2f", post_sse, post_ll) + log_sse_breakdown(sim_fn, base_pred_triples, fitted_params, + process_features, label="after") for name in sorted(fitted_params): init_val = init_params[name] @@ -378,27 +363,92 @@ def sim_fn(state: State, action: Action, return fitted_params, post_sse + # ── Process-feature inference ──────────────────────────────── + + @staticmethod + def _compute_base_pred_triples( + obs_triples: List[Tuple[State, Action, State]], + base_env: Any, + ) -> List[Tuple[State, Action, State]]: + """Replace each ``s_t`` with the base sim's one-step prediction.""" + return [(base_env.simulate(s, a), a, s_next) + for s, a, s_next in obs_triples] + + @staticmethod + def _infer_process_features_from_residuals( + obs_triples: List[Tuple[State, Action, State]], + base_pred_triples: List[Tuple[State, Action, State]], + abs_tol: float = 1e-4, + rel_tol: float = 1e-3, + min_hits: int = 3, + ) -> Dict[str, List[str]]: + """Features whose base-sim prediction diverges from observation. + + Flags ``(type, feat)`` if ``|pred - obs| > rel_tol*|obs| + abs_tol`` + on at least ``min_hits`` triples. The ``min_hits`` floor keeps + one-off PyBullet jitter from leaking base-handled features into the set. + """ + hits: Dict[Tuple[str, str], int] = {} + for (s_t, _, _), (s_base, _, s_obs) in zip(obs_triples, + base_pred_triples): + for obj in s_t: + for feat in obj.type.feature_names: + pred = float(s_base.get(obj, feat)) + obs = float(s_obs.get(obj, feat)) + if abs(pred - obs) > rel_tol * abs(obs) + abs_tol: + key = (obj.type.name, feat) + hits[key] = hits.get(key, 0) + 1 + out: Dict[str, List[str]] = {} + for (t, f), n in hits.items(): + if n >= min_hits: + out.setdefault(t, []).append(f) + return {t: sorted(fs) for t, fs in out.items()} + + @staticmethod + def _log_feature_set_diff( + a: Dict[str, List[str]], + b: Dict[str, List[str]], + a_label: str, + b_label: str, + ) -> None: + """Log set-difference between two {type: [feats]} maps.""" + a_pairs = {(t, f) for t, fs in a.items() for f in fs} + b_pairs = {(t, f) for t, fs in b.items() for f in fs} + only_a = sorted(a_pairs - b_pairs) + only_b = sorted(b_pairs - a_pairs) + common = a_pairs & b_pairs + logger.info( + "Feature-set diff: %s vs %s (%d common, %d only-%s, %d only-%s)", + a_label, b_label, len(common), len(only_a), a_label, len(only_b), + b_label) + if only_a: + logger.info(" only in %s: %s", a_label, only_a) + if only_b: + logger.info(" only in %s: %s", b_label, only_b) + @staticmethod def _load_simulator_from_file( save_dir: str, trajectories: Optional[List[LowLevelTrajectory]] = None, - ) -> Tuple[Optional[List], Optional[List[ParamSpec]]]: - """Load PROCESS_RULES and PARAM_SPECS from versioned code files. - - Executes all ``NNN_run_python.py`` files in ``save_dir`` in - order, accumulating into a single namespace. - - Returns (rules, specs), either of which may be None on failure. + ) -> Tuple[Optional[List], Optional[List[ParamSpec]], + Optional[Dict[str, List[str]]]]: + """Load PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES from saved files. + + Execs all ``NNN_run_python.py`` files in ``save_dir`` in order + into one namespace. Returns ``(None, None, None)`` if rules or + specs are missing; ``features`` may be ``None`` independently, + in which case the caller asserts (PROCESS_FEATURES is required + from the agent). """ if not os.path.isdir(save_dir): logger.warning("No simulator code dir at %s.", save_dir) - return None, None + return None, None, None files = sorted(f for f in os.listdir(save_dir) if f.endswith(".py") and f[0].isdigit()) if not files: logger.warning("No code files in %s.", save_dir) - return None, None + return None, None, None ns: Dict[str, Any] = { "np": np, @@ -416,26 +466,22 @@ def _load_simulator_from_file( fpath, exc_info=True) - rules = ns.get("PROCESS_RULES") - specs = ns.get("PARAM_SPECS") - if not isinstance(rules, list) or not rules: + rules, specs, features = read_simulator_components(ns) + if rules is None: logger.warning("Saved code did not define PROCESS_RULES.") - return None, None - if not isinstance(specs, list) or not specs: + return None, None, None + if specs is None: logger.warning("Saved code did not define PARAM_SPECS.") - return None, None + return None, None, None logger.info("Loaded %d rules, %d param specs from %d files in %s.", len(rules), len(specs), len(files), save_dir) - return rules, specs + return rules, specs, features # ── Static helpers ─────────────────────────────────────────── def _write_structs_reference(self) -> str: - """Write extracted source of key structs to the sandbox. - - Returns the path the agent should Read. - """ + """Write key struct sources to the sandbox; return the agent-visible path.""" # pylint: disable=import-outside-toplevel,reimported from predicators.structs import Action as _Action from predicators.structs import LowLevelTrajectory as _LLT @@ -447,7 +493,6 @@ def _write_structs_reference(self) -> str: inspect.getsource(cls) for cls in [_Type, _Object, _State, _Action, _LLT]) - # Write into sandbox reference dir if available, else log dir. base = self._tool_context.sandbox_dir or self._get_log_dir() ref_dir = os.path.join(base, "reference") os.makedirs(ref_dir, exist_ok=True) @@ -455,16 +500,16 @@ def _write_structs_reference(self) -> str: with open(ref_path, "w", encoding="utf-8") as f: f.write(source) - # In Docker sandbox the agent sees /sandbox/reference/structs.py. + # Agent sees the sandbox-mounted path, not the host path. if self._tool_context.sandbox_dir: return "/sandbox/reference/structs.py" return ref_path @staticmethod - def _extract_step_transitions( + def _extract_obs_triples( trajectories: List[LowLevelTrajectory], ) -> List[Tuple[State, Action, State]]: - """Extract consecutive (s_t, action_t, s_{t+1}) triples.""" + """Extract observed (s_t, action_t, s_{t+1}) triples.""" triples: List[Tuple[State, Action, State]] = [] for traj in trajectories: for i in range(len(traj.actions)): @@ -473,11 +518,7 @@ def _extract_step_transitions( return triples def _recreate_base_env(self) -> None: - """Reconnect after a PyBullet physics-server crash. - - Disconnects the dead client (best-effort), then spins up a fresh - env with the same settings so subsequent simulate() calls work. - """ + """Reconnect after a PyBullet physics-server crash.""" try: pybullet.disconnect(self._base_env._physics_client_id) except Exception: # client may already be dead @@ -492,29 +533,27 @@ def _recreate_base_env(self) -> None: def _build_combined_simulator( self, - simulator: LearnedSimulator, - process_features: Dict[str, List[str]], + learned_simulator: LearnedSimulator, ) -> Callable[[State, Action], State]: """Compose base env with learned step-level dynamics. - Captures ``self`` so that if the PyBullet physics server crashes - (common on macOS Metal with GUI mode after many simulation steps), - the closure can recreate ``self._base_env`` and retry once. + Captures ``self`` so the closure can recreate ``_base_env`` and + retry once on a PyBullet crash (common on macOS Metal + GUI). """ def combined_simulate(state: State, action: Action) -> State: try: - kin_state = self._base_env.simulate(state, action) + base_state = self._base_env.simulate(state, action) except pybullet.error as e: logging.warning( "PyBullet error in combined_simulate (%s); " "recreating base env and retrying.", e) self._recreate_base_env() - kin_state = self._base_env.simulate(state, action) - updates = simulator.predict_step(kin_state) + base_state = self._base_env.simulate(state, action) + updates = learned_simulator.predict_step(base_state) if not updates: - return kin_state - return merge_updates(kin_state, updates, process_features) + return base_state + return merge_updates(base_state, updates) return combined_simulate @@ -525,9 +564,10 @@ def _build_synthesis_system_prompt() -> str: You are synthesizing a parameterized process dynamics simulator for a \ robotic manipulation environment. -A separate physics engine (PyBullet) handles kinematics (robot movement, \ -grasping, rigid body physics). Your simulator handles **process dynamics**: \ -non-kinematic features that change due to ongoing physical or causal processes. +A separate base physics engine (PyBullet) handles robot movement, grasping, \ +and rigid body physics. Your simulator handles **process dynamics**: features \ +that change due to ongoing physical or causal processes (e.g., water filling, \ +heat transfer) that the base sim doesn't model. ## Tools @@ -551,10 +591,17 @@ def _build_synthesis_system_prompt() -> str: ## Goal -Define two variables in the `run_python` namespace: +Define three variables in the `run_python` namespace: - `PROCESS_RULES`: list of rule functions - `PARAM_SPECS`: list of ParamSpec objects +- `PROCESS_FEATURES`: `Dict[str, List[str]]` — for each object type, \ +the feature names your rules predict. This is treated as the truth: \ +the loss only penalises mismatches on these features, and at test \ +time the learned simulator only overwrites these features on top of \ +the base sim's prediction. Be honest — listing features your rules \ +don't actually update will inflate the loss without giving MCMC \ +anything to optimise. Parameters are fitted automatically after the session ends. @@ -584,7 +631,7 @@ def rule(state, updates, params): 1. Explore the trajectory data with `run_python`: types, features, \ state changes over time -2. Identify which features change due to process dynamics (not kinematics) +2. Identify which features change due to process dynamics (not the base sim) 3. Define `PROCESS_RULES` and `PARAM_SPECS` in the namespace via `run_python` 4. Call `evaluate_simulator` to fit parameters and check SSE 5. Call `test_simulator` to see prediction mismatches From 124dd94dd8e6770fa93b65b3e16a47a9389b891f Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:07:56 +0100 Subject: [PATCH 60/70] Skip MCMC and use LM warm-start in boil agent config LM warm start alone matches the parameter fit for the current boil oracle program; emcee's MAP-of-walkers cannot improve on it in the time budgeted for 500 steps and routinely lands at higher SSE. Setting num_mcmc_steps to 0 and enabling warm_start_with_lm returns the LM theta_map directly. --- scripts/configs/predicatorv3/agents.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index cc6eb545f..6fd77ef5c 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -52,7 +52,8 @@ APPROACHES: agent_sim_learn_oracle_sim_program: True agent_sim_learn_oracle_sim_params: False agent_sim_learn_oracle_sim_param_noise_scale: 1.0 # 1.0 allows successful planning but insatisficing plan; 0.8 produces satisficing plan - code_sim_learning_num_mcmc_steps: 500 + code_sim_learning_num_mcmc_steps: 0 + code_sim_learning_warm_start_with_lm: True # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: From cc11084c5c19fe1cf2e0103ccf20aa2f2aa9109e Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:11:15 +0100 Subject: [PATCH 61/70] Apply yapf and docformatter formatting Cleans up line-wrap and docstring drift across the sim-learning branch so the autoformat CI check is satisfied. Bundles the formatting-only changes for cogman, pybullet_boil, and utils that earlier branch commits left behind, plus minor wraps across the new sim-learning code. --- .../approaches/agent_sim_learning_approach.py | 45 ++++++++++-------- predicators/code_sim_learning/training.py | 47 +++++++++++-------- predicators/cogman.py | 9 ++-- predicators/envs/pybullet_boil.py | 2 +- .../ground_truth_models/boil/gt_simulator.py | 24 +++++----- predicators/utils.py | 27 +++++------ .../test_agent_sim_learning_approach.py | 4 +- 7 files changed, 86 insertions(+), 72 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index 0f29c8c19..ed9d31e03 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -188,8 +188,8 @@ def _synthesize_with_agent( ) -> None: """Synthesize PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES via agent. - ``inferred_hint`` is passed to the agent as a starting point - and used as the eval/test scope until it declares its own + ``inferred_hint`` is passed to the agent as a starting point and + used as the eval/test scope until it declares its own ``PROCESS_FEATURES``. CFG flags ``agent_sim_learn_oracle_sim_program`` and ``agent_sim_learn_oracle_sim_params`` short-circuit the agent @@ -209,14 +209,13 @@ def _synthesize_with_agent( "be non-negative.") perturbed = [] for s in specs: - val = s.init_value * ( - 1.0 + float(rng.normal(0, noise_scale))) + val = s.init_value * (1.0 + + float(rng.normal(0, noise_scale))) if s.lo is not None: val = max(s.lo, val) if s.hi is not None: val = min(s.hi, val) - perturbed.append( - ParamSpec(s.name, val, lo=s.lo, hi=s.hi)) + perturbed.append(ParamSpec(s.name, val, lo=s.lo, hi=s.hi)) specs = perturbed logger.info("Loaded oracle sim program (%d rules, %d params).", len(rules), len(specs)) @@ -289,15 +288,16 @@ def _synthesize_with_agent( oracle_sim_fn = lambda s, a, p: apply_rules( # noqa: E731 s, rules, p) self._fit_sse = compute_sse(oracle_sim_fn, base_pred_triples, - self._fitted_params, - process_features) + self._fitted_params, process_features) fit_ll = -0.5 * self._fit_sse / (_noise_sigma**2) logger.info("Oracle params — SSE: %.6f log-likelihood: %.2f", self._fit_sse, fit_ll) for name, val in sorted(self._fitted_params.items()): logger.info(" %-30s %.4f", name, val) - log_sse_breakdown(oracle_sim_fn, base_pred_triples, - self._fitted_params, process_features, + log_sse_breakdown(oracle_sim_fn, + base_pred_triples, + self._fitted_params, + process_features, label="oracle") else: self._fitted_params, self._fit_sse = self._fit_parameters( @@ -323,8 +323,8 @@ def _fit_parameters( precomputing avoids re-running it inside the MCMC inner loop. """ - def sim_fn(state: State, action: Action, - params: Dict[str, float]) -> Dict: + def sim_fn(state: State, action: Action, params: Dict[str, + float]) -> Dict: return apply_rules(state, rules, params) noise_sigma = 0.05 # matches fit_params default @@ -334,8 +334,11 @@ def sim_fn(state: State, action: Action, pre_ll = -0.5 * pre_sse / (noise_sigma**2) logger.info("Before fitting — SSE: %.6f log-likelihood: %.2f", pre_sse, pre_ll) - log_sse_breakdown(sim_fn, base_pred_triples, init_params, - process_features, label="before") + log_sse_breakdown(sim_fn, + base_pred_triples, + init_params, + process_features, + label="before") result = fit_params( simulator_fn=sim_fn, @@ -350,8 +353,11 @@ def sim_fn(state: State, action: Action, post_ll = -0.5 * post_sse / (noise_sigma**2) logger.info("After fitting — SSE: %.6f log-likelihood: %.2f", post_sse, post_ll) - log_sse_breakdown(sim_fn, base_pred_triples, fitted_params, - process_features, label="after") + log_sse_breakdown(sim_fn, + base_pred_triples, + fitted_params, + process_features, + label="after") for name in sorted(fitted_params): init_val = init_params[name] @@ -430,8 +436,8 @@ def _log_feature_set_diff( def _load_simulator_from_file( save_dir: str, trajectories: Optional[List[LowLevelTrajectory]] = None, - ) -> Tuple[Optional[List], Optional[List[ParamSpec]], - Optional[Dict[str, List[str]]]]: + ) -> Tuple[Optional[List], Optional[List[ParamSpec]], Optional[Dict[ + str, List[str]]]]: """Load PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES from saved files. Execs all ``NNN_run_python.py`` files in ``save_dir`` in order @@ -481,7 +487,8 @@ def _load_simulator_from_file( # ── Static helpers ─────────────────────────────────────────── def _write_structs_reference(self) -> str: - """Write key struct sources to the sandbox; return the agent-visible path.""" + """Write key struct sources to the sandbox; return the agent-visible + path.""" # pylint: disable=import-outside-toplevel,reimported from predicators.structs import Action as _Action from predicators.structs import LowLevelTrajectory as _LLT diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index 494e274b2..ff85923ab 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -42,8 +42,10 @@ class FitResult: def point_estimate(self) -> Dict[str, float]: """MAP (sample with highest log-probability).""" best_idx = int(np.argmax(self.log_probs)) - return {n: float(self.samples[best_idx, i]) - for i, n in enumerate(self.names)} + return { + n: float(self.samples[best_idx, i]) + for i, n in enumerate(self.names) + } def compute_sse( @@ -239,13 +241,13 @@ def fit_map_lm( * MCMC warm start — center emcee walkers on theta_map (and short- circuit to it directly when ``num_mcmc_steps == 0``). """ - from scipy.optimize import least_squares # pylint: disable=import-outside-toplevel + from scipy.optimize import \ + least_squares # pylint: disable=import-outside-toplevel names = [s.name for s in param_specs] init = np.array([s.init_value for s in param_specs], dtype=float) lo = np.array([s.lo if s.lo is not None else 1e-6 for s in param_specs]) - hi = np.array( - [s.hi if s.hi is not None else np.inf for s in param_specs]) + hi = np.array([s.hi if s.hi is not None else np.inf for s in param_specs]) # Nudge init strictly into the interior so trf doesn't reject it. init = np.maximum(init, lo + 1e-9) safe_hi = np.where(np.isfinite(hi), hi - 1e-9, np.inf) @@ -275,14 +277,14 @@ def residuals_fn(theta: np.ndarray) -> np.ndarray: return init, None sse_lm = float(2.0 * result.cost) - delta = {names[i]: float(result.x[i] - init[i]) - for i in range(len(names))} + delta = {names[i]: float(result.x[i] - init[i]) for i in range(len(names))} logger.info( "LM diagnostic fit: SSE %.4f -> %.4f in %d fn-evals (status=%d, %s).", sse_init, sse_lm, result.nfev, result.status, "converged" if result.success else "max-evals") logger.info("LM theta_map - init: %s", - {k: f"{v:+.4f}" for k, v in delta.items()}) + {k: f"{v:+.4f}" + for k, v in delta.items()}) jac = np.asarray(result.jac, dtype=float) if jac.size == 0: @@ -321,8 +323,8 @@ def log_hessian_identifiability( eigvals, eigvecs = np.linalg.eigh(H) # ascending cond = float(eigvals[-1] / max(eigvals[0], 1e-30)) - logger.info("Hessian eigenanalysis (cond %.2e, %d params):", - cond, len(param_names)) + logger.info("Hessian eigenanalysis (cond %.2e, %d params):", cond, + len(param_names)) def _format(vec: np.ndarray) -> str: order = np.argsort(-np.abs(vec)) @@ -341,14 +343,14 @@ def _format(vec: np.ndarray) -> str: logger.info(" Stiff (well-constrained):") for i in stiff_idx: - logger.info(" lambda = %10.3e : %s", - eigvals[i], _format(eigvecs[:, i])) + logger.info(" lambda = %10.3e : %s", eigvals[i], + _format(eigvecs[:, i])) if sloppy_idx: logger.info(" Sloppy (under-constrained):") for i in sloppy_idx: - logger.info(" lambda = %10.3e : %s", - eigvals[i], _format(eigvecs[:, i])) + logger.info(" lambda = %10.3e : %s", eigvals[i], + _format(eigvecs[:, i])) def fit_params( @@ -408,14 +410,21 @@ def fit_params( if CFG.code_sim_learning_warm_start_with_lm: walker_center = np.asarray(theta_map, dtype=float) logger.info("Warm-starting MCMC walkers from LM MAP estimate.") - lm_params = {n: float(walker_center[i]) for i, n in enumerate(names)} + lm_params = { + n: float(walker_center[i]) + for i, n in enumerate(names) + } lm_sse = compute_sse(simulator_fn, transitions, lm_params, process_features) lm_ll = -0.5 * lm_sse / (noise_sigma**2) - logger.info("After LM warm start — SSE: %.6f log-likelihood: %.2f", - lm_sse, lm_ll) - log_sse_breakdown(simulator_fn, transitions, lm_params, - process_features, label="lm-warm-start") + logger.info( + "After LM warm start — SSE: %.6f log-likelihood: %.2f", + lm_sse, lm_ll) + log_sse_breakdown(simulator_fn, + transitions, + lm_params, + process_features, + label="lm-warm-start") if num_steps == 0: if CFG.code_sim_learning_warm_start_with_lm: diff --git a/predicators/cogman.py b/predicators/cogman.py index ebb8f8119..d573d2ad8 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -288,10 +288,11 @@ def run_episode_and_get_observations( logging.debug("[CogMan] loop break: terminate_on_goal_reached") break else: - option_str = (None if curr_option is None else - curr_option.simple_str()) - logging.info("[CogMan] Reached max_num_steps=%d while executing " - "option %s.", max_num_steps, option_str) + option_str = (None + if curr_option is None else curr_option.simple_str()) + logging.info( + "[CogMan] Reached max_num_steps=%d while executing " + "option %s.", max_num_steps, option_str) logging.debug("[CogMan] Final loop step index before horizon: %d", step_num) logging.debug("[CogMan] Atoms at horizon: %s", diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index f1ebb9164..1731ac0d1 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -584,7 +584,7 @@ def _set_domain_specific_state(self, state: State) -> None: jug.heat_level = state.get(jug, "heat_level") liquid_id = self._create_liquid_for_jug(jug, state) self._jug_to_liquid_id[jug] = liquid_id - + self._update_liquid_colors(state) # Update jug body colors from state diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py index 3ffc82089..b971d9992 100644 --- a/predicators/ground_truth_models/boil/gt_simulator.py +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -106,8 +106,8 @@ def _water_filling(state: State, updates: ProcessUpdate, Alignment and capacity gates are soft (sigmoid-weighted) so the residual is differentiable in ``faucet_align_threshold``, - ``faucet_x_len``, and ``max_jug_water_capacity`` — needed for the - LM Jacobian (and downstream Hessian diagnostic) to be informative. + ``faucet_x_len``, and ``max_jug_water_capacity`` — needed for the LM + Jacobian (and downstream Hessian diagnostic) to be informative. """ objs = _objs_by_type(state) for faucet in objs.get("faucet", []): @@ -145,9 +145,8 @@ def _water_filling(state: State, updates: ProcessUpdate, # Uncaught water spills (clamped at max_water_spill_width). spill = float(state.get(faucet, "spilled_level")) - new_spill = min( - params["max_water_spill_width"], - spill + (1.0 - catch_w) * params["water_fill_speed"]) + new_spill = min(params["max_water_spill_width"], + spill + (1.0 - catch_w) * params["water_fill_speed"]) updates.setdefault(faucet, {})["spilled_level"] = new_spill return updates @@ -160,8 +159,8 @@ def _heating(state: State, updates: ProcessUpdate, Alignment gate is soft so the residual is differentiable in ``burner_align_threshold`` (LM's finite-difference Jacobian needs this; MCMC also avoids flat-likelihood plateaus as a side effect). - The heat cap at 1.0 stays hard since 1.0 is a constant boundary, - not a learned parameter. + The heat cap at 1.0 stays hard since 1.0 is a constant boundary, not + a learned parameter. """ objs = _objs_by_type(state) for burner in objs.get("burner", []): @@ -193,10 +192,10 @@ def _happiness(state: State, updates: ProcessUpdate, """Jug filled + boiled + no spill + burner off → human happy. The water-filled gate is soft on ``water_filled_height`` so the - residual is differentiable in that parameter for LM (and emcee - gets a non-flat likelihood as a side effect). The heat>=1.0 gate - stays hard (1.0 is a constant cap, not a learned parameter). - Spill / burner-on gates are state-dependent. + residual is differentiable in that parameter for LM (and emcee gets + a non-flat likelihood as a side effect). The heat>=1.0 gate stays + hard (1.0 is a constant cap, not a learned parameter). Spill / + burner-on gates are state-dependent. """ objs = _objs_by_type(state) faucets = objs.get("faucet", []) @@ -212,8 +211,7 @@ def _get_val(obj: Object, feat: str) -> float: # semantics even when the env reports zero, so treat anything below # the smoothing scale as "no spill" to avoid spuriously gating # happiness off. - any_spill = any( - _get_val(f, "spilled_level") > _SOFT_EPS for f in faucets) + any_spill = any(_get_val(f, "spilled_level") > _SOFT_EPS for f in faucets) any_burner_on = any(state.get(b, "is_on") > 0.5 for b in burners) if any_spill or any_burner_on: diff --git a/predicators/utils.py b/predicators/utils.py index cbe628f34..48b8590bb 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1690,12 +1690,13 @@ def _format_wait_target_debug( """Format state details for debugging why Wait has not terminated.""" cur_atoms = abstract_function(state) missing_targets = target_atoms - cur_atoms - target_objects = sorted({ - ent - for atom in target_atoms for ent in atom.entities - if isinstance(ent, Object) - }, - key=lambda o: o.name) + target_objects = sorted( + { + ent + for atom in target_atoms + for ent in atom.entities if isinstance(ent, Object) + }, + key=lambda o: o.name) object_details = [] for obj in target_objects: feature_values = [] @@ -1766,11 +1767,10 @@ def _policy(state: State) -> Action: abstract_function) if result is True: cur_atoms = abstract_function(state) - logging.debug( - "Wait terminating: target atoms satisfied. " - f"Targets: {target_atoms}, " - f"cur_atoms: {sorted(cur_atoms)}, " - f"num_option_steps={num_cur_option_steps}") + logging.debug("Wait terminating: target atoms satisfied. " + f"Targets: {target_atoms}, " + f"cur_atoms: {sorted(cur_atoms)}, " + f"num_option_steps={num_cur_option_steps}") wait_terminate = True elif result is False: assert target_atoms is not None @@ -1814,9 +1814,8 @@ def _policy(state: State) -> Action: raise OptionExecutionFailure( "Unsound option policy.", info={"last_failed_option": last_option}) - logging.debug( - f"[option_policy] Started option {cur_option.name}, " - f"initiable=True") + logging.debug(f"[option_policy] Started option {cur_option.name}, " + f"initiable=True") num_cur_option_steps = 0 num_cur_option_steps += 1 diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index d0fb5eb7b..4ee5ea8a2 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -84,8 +84,8 @@ def _build_combined_model(env): """Build a combined model: kinematics-only env + GT step-level dynamics. Mirrors AgentSimLearningApproach: wraps GT rules in a - LearnedSimulator via apply_rules and composes with a - kinematics-only base env. + LearnedSimulator via apply_rules and composes with a kinematics-only + base env. """ base_env = create_new_env("pybullet_boil", do_cache=False, From 465177a972fc9000fab576ba3b74d1e8fa7d2f67 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:14:07 +0100 Subject: [PATCH 62/70] Silence mypy on PyBullet client-id attribute access ``BaseEnv`` doesn't declare ``_physics_client_id`` (only PyBullet subclasses do), and ``_recreate_base_env`` reads it best-effort inside a try block. Bind to a local with type:ignore so mypy stops flagging the access without affecting runtime. --- predicators/approaches/agent_sim_learning_approach.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index ed9d31e03..c7d5da49b 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -527,8 +527,9 @@ def _extract_obs_triples( def _recreate_base_env(self) -> None: """Reconnect after a PyBullet physics-server crash.""" try: - pybullet.disconnect(self._base_env._physics_client_id) - except Exception: # client may already be dead + client_id = self._base_env._physics_client_id # type: ignore[attr-defined] # pylint: disable=protected-access + pybullet.disconnect(client_id) + except Exception: # pylint: disable=broad-except # client may already be dead pass logging.warning( "PyBullet physics client crashed; recreating base env " From 6e76660e5c5d55aeba5022183615756459d733e0 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 13:18:16 +0100 Subject: [PATCH 63/70] Mark unused action arg in sim_fn to satisfy pylint The simulator callback signature must match StepSimulatorFn's (state, action, params) shape even though apply_rules doesn't use the action. Renaming to _action signals intent and silences pylint's unused-argument check. --- predicators/approaches/agent_sim_learning_approach.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index c7d5da49b..f1607e91a 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -323,8 +323,8 @@ def _fit_parameters( precomputing avoids re-running it inside the MCMC inner loop. """ - def sim_fn(state: State, action: Action, params: Dict[str, - float]) -> Dict: + def sim_fn(state: State, _action: Action, params: Dict[str, + float]) -> Dict: return apply_rules(state, rules, params) noise_sigma = 0.05 # matches fit_params default From 9415d12231a3eb74e9a268cbac2313862c9e14a2 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 20:34:43 +0100 Subject: [PATCH 64/70] Use per-component diff in _set_state to eliminate robot jitter Replace the all-or-nothing kinematic-match gate with a per-component diff: robot pose, each object pose, and held-object identity are each compared against the live PyBullet world and only re-written when they actually differ. _robot_matches_state now compares at the joint level (the prior EE-quaternion path hard-coded roll=0, which spuriously mismatched whenever the wrist had any roll and forced a full reset on every simulate() call). reset_state honors caller-provided joint_positions only when they reconstruct the requested EE pose, falling back to IK otherwise. --- predicators/envs/pybullet_env.py | 232 +++++++++++++----- .../pybullet_helpers/robots/single_arm.py | 39 +-- 2 files changed, 194 insertions(+), 77 deletions(-) diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 62dc75f68..1e78b9825 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -140,7 +140,7 @@ def __init__(self, self._held_obj_id: Optional[int] = None # When True, _domain_specific_step() is skipped in step(). - # Used by sim-learning to create kinematics-only envs. + # Used by sim-learning to create base-sim-only envs. self._skip_domain_specific_dynamics: bool = skip_process_dynamics # Set up all the static PyBullet content. @@ -323,9 +323,9 @@ def simulate(self, state: State, action: Action) -> State: def step(self, action: Action, render_obs: bool = False) -> Observation: """Execute one environment step with the given action. - Flow: kinematics → domain-specific dynamics → observation. + Flow: base sim → domain-specific dynamics → observation. Subclasses override ``_domain_specific_step`` (not this method) - to add post-kinematics dynamics (water filling, heating, etc.). + to add post-base-sim dynamics (water filling, heating, etc.). """ self._step_base(action) if not self._skip_domain_specific_dynamics: @@ -387,9 +387,9 @@ def _step_base(self, action: Action) -> None: self._held_obj_id = None def _domain_specific_step(self) -> None: - """Apply domain-specific dynamics after kinematics. + """Apply domain-specific dynamics after the base sim. - Override in subclasses to add post-kinematics effects (water + Override in subclasses to add post-base-sim effects (water filling, heating, balance beam physics, etc.). Skipped when ``skip_process_dynamics=True`` is passed to the constructor. """ @@ -397,64 +397,108 @@ def _domain_specific_step(self) -> None: # ── State Write (State → PyBullet) ────────────────────────── def _set_state(self, state: State) -> None: - """State -> PyBullet: set the simulator to match a State. - - Converts the agent-facing State representation (feature dicts - keyed by Object) into the corresponding PyBullet scene (joint - positions, body poses, grasp constraints, etc.). - - When robot and object poses already match (e.g. sequential - simulate calls where only process features changed), the - kinematic reset is skipped to avoid discontinuous joint resets - and grasp constraint teardown/recreation that cause visible - jitter. + """State -> PyBullet: write the requested State into the simulator. + + Per-component diff: each piece of the State (robot pose, each + object pose, held-object identity) is compared against the live + PyBullet world and only re-written when it actually differs. + This lets sequential rollouts (option model, learned process + simulators) advance without snapping the arm or rebuilding the + grasp constraint when only a subset of features changed — which + is what eliminates the visible robot jitter during combined + base+learned simulator calls. It also lets a learned rule move + an *unheld* object without disturbing the arm or any other body. Call sites: - reset() / _add_pybullet_state_to_tasks(): initialization - simulate(): option-model / bilevel-planning rollouts - external callers (skill factories, agent tools, tests) """ - # Check if kinematics already match before overwriting - # _current_observation. When only process features differ - # (e.g. combined kin+learned simulator), we can skip the - # expensive kinematic reset that causes robot arm jitter. - skip_kin = self._kinematics_match(state) + # Cohort change or the very first call forces a full reset: + # per-component compares assume the same set of bodies. + full_reset = (self._current_observation is None + or set(self._objects) != set(state.data)) - # Keep _current_observation in sync so that step() can read it + # Keep _current_observation in sync so step() can read it # (e.g. for finger-delta computation). self._current_observation = state self._objects = list(state.data) - if not skip_kin: - # 1) Clear old constraint if we had a held object + wrote_anything = False + + # 1) Robot pose diff. Skipping this branch when the live joints + # already match the requested pose is what eliminates arm + # jitter: resetJointState would otherwise hard-snap the arm + # on every simulate() call in a sequential rollout. + robot_changed = full_reset or not self._robot_matches_state(state) + + # 2) Object pose diff. Identify which non-virtual object bodies + # have moved relative to PyBullet. + objects_to_reset: List[Object] = [] + for obj in self._objects: + if obj.type.name == "robot" or \ + obj.type.name in self._VIRTUAL_OBJECT_TYPES or \ + obj.id is None: + continue + if full_reset or not self._object_pose_matches_state(obj, state): + objects_to_reset.append(obj) + + # 3) Held-object identity diff. The grasp constraint must be + # torn down and rebuilt whenever: + # - the held identity changes (including held → unheld and + # unheld → held), + # - the held object's recorded pose changes (the offset to + # the gripper moves), or + # - the gripper itself moves (resetJointState bypasses the + # constraint, so a kept constraint would leave the held + # body behind). + new_held_id = self._held_obj_id_in_state(state) + held_obj_moved = (self._held_obj_id is not None and any( + o.id == self._held_obj_id for o in objects_to_reset)) + rebuild_constraint = (full_reset + or new_held_id != self._held_obj_id + or (self._held_obj_id is not None and + (robot_changed or held_obj_moved))) + + # Tear down before robot/object resets so the held body is free + # while we move things around. + if rebuild_constraint: if self._held_constraint_id is not None: p.removeConstraint(self._held_constraint_id, physicsClientId=self._physics_client_id) - self._held_constraint_id = None + wrote_anything = True + self._held_constraint_id = None self._held_obj_to_base_link = None self._held_obj_id = None - # 2) Reset robot pose. Prefer exact joint positions when the - # State carries them in simulator_state — IK from (x, y, z, - # tilt, wrist) drops wrist roll, which corrupts the held- - # object offset that _create_grasp_constraint records below. + if robot_changed: + # Prefer exact joint positions when the State carries them in + # simulator_state — IK from (x, y, z, tilt, wrist) drops + # wrist roll, which corrupts the held-object offset that + # _create_grasp_constraint records below. joint_positions = self._extract_robot_joint_positions(state) self._pybullet_robot.reset_state(self._extract_robot_state(state), joint_positions=joint_positions) + wrote_anything = True - # 3) Reset all known objects (position, orientation, etc.) - for obj in self._objects: - if obj.type.name == "robot" or \ - obj.type.name in self._VIRTUAL_OBJECT_TYPES: - continue - self._reset_single_object(obj, state) + for obj in objects_to_reset: + self._reset_single_object(obj, state) + wrote_anything = True - # 4) Let the subclass do any domain-specific state setup + # Recreate the constraint after objects are repositioned so the + # recorded base_link → object offset matches the new pose. + if rebuild_constraint and new_held_id is not None: + self._held_obj_id = new_held_id + self._create_grasp_constraint() + wrote_anything = True + + # 4) Subclass-specific state always runs (idempotent and cheap). self._set_domain_specific_state(state) - # 5) Check for reconstruction mismatch. - # Only raise for envs that override _get_state(). - if not skip_kin: + # 5) Reconstruction check — only when we actually wrote + # something kinematic. Only raise for envs that override + # _get_state(). + if wrote_anything: reconstructed = self._get_state() if not reconstructed.allclose(state): if type(self)._get_state is not PyBulletEnv._get_state: @@ -462,26 +506,97 @@ def _set_state(self, state: State) -> None: logging.warning( "Could not reconstruct state exactly in reset.") - def _kinematics_match(self, state: State) -> bool: - """Check if robot pose in *state* matches the current PyBullet state. - - Used by ``_set_state`` to skip the kinematic reset when only - non-kinematic features (process dynamics) have changed. + def _robot_matches_state(self, + state: State, + atol: float = 1e-2) -> bool: + """True if PyBullet's live robot pose already equals state's. + + Compares at the joint level. The EE-quaternion path that + ``_extract_robot_state`` builds always uses ``roll=0``, so any + non-zero wrist roll in the live PyBullet pose would spuriously + fail an EE-pose comparison and trigger a full robot reset on + every simulate() call (visible jitter). + + Returns False when ``state`` has no joint_positions — the only + live caller in that situation is + ``_add_pybullet_state_to_tasks``, where forcing a reset is + exactly the desired behavior. """ - if self._current_observation is None: + jp = self._extract_robot_joint_positions(state) + if jp is None: return False try: - new_robot = self._extract_robot_state(state) - cur_robot = self._extract_robot_state(self._current_observation) - return bool(np.allclose(new_robot, cur_robot, atol=1e-3)) + cur_jp = self._pybullet_robot.get_joints() except (KeyError, ValueError): return False + return bool(np.allclose(jp, cur_jp, atol=atol)) + + def _object_pose_matches_state(self, + obj: Object, + state: State, + atol: float = 1e-2) -> bool: + """True if PyBullet's live pose for ``obj`` equals state[obj].""" + if obj.id is None: + return True + try: + features = obj.type.feature_names + (px, py, pz), orn = p.getBasePositionAndOrientation( + obj.id, physicsClientId=self._physics_client_id) + if "x" in features and \ + not np.isclose(state.get(obj, "x"), px, atol=atol): + return False + if "y" in features and \ + not np.isclose(state.get(obj, "y"), py, atol=atol): + return False + if "z" in features and \ + not np.isclose(state.get(obj, "z"), pz, atol=atol): + return False + if {"rot", "yaw", "roll", "pitch"} & set(features): + roll, pitch, yaw = p.getEulerFromQuaternion(orn) + if "rot" in features and not np.isclose( + state.get(obj, "rot"), yaw, atol=atol): + return False + if "yaw" in features and not np.isclose( + state.get(obj, "yaw"), yaw, atol=atol): + return False + if "roll" in features and not np.isclose( + state.get(obj, "roll"), roll, atol=atol): + return False + if "pitch" in features and not np.isclose( + state.get(obj, "pitch"), pitch, atol=atol): + return False + return True + except (KeyError, ValueError): + return False + + def _held_obj_id_in_state(self, state: State) -> Optional[int]: + """Which PyBullet body id is marked is_held > 0.5 in ``state``. + + Returns None if no object is held in ``state``. Mirrors the + per-object logic in _reset_single_object before constraint + management was hoisted out into _set_state. + """ + for obj in state.data: + if obj.id is None: + continue + if "is_held" not in obj.type.feature_names: + continue + try: + if state.get(obj, "is_held") > 0.5: + return obj.id + except (KeyError, ValueError): + continue + return None def _reset_single_object(self, obj: Object, state: State) -> None: - """Set a single physical object's pose and grasp constraint in PyBullet - to match the given State. + """Teleport a single physical object to match the given State. + + Pose only — grasp-constraint management is centralized in + _set_state so teardown/rebuild stays in one place. - Called by _set_state() for every non-robot, non-virtual object. + Called by _set_state() for every non-robot, non-virtual object + whose pose differs from PyBullet (or for all such objects on a + full reset). """ # Skip objects without pybullet IDs (handled by subclass). if obj.id is None: @@ -511,15 +626,6 @@ def _reset_single_object(self, obj: Object, state: State) -> None: orn, physics_client_id=self._physics_client_id) - # 3) If there's an is_held feature, reattach constraints if needed - if "is_held" in features: - if state.get(obj, "is_held") > 0.5: - # attach constraint - self._held_obj_id = obj.id - self._create_grasp_constraint() - # _create_grasp_constraint already correctly computes - # and stores _held_obj_to_base_link. - @abc.abstractmethod def _set_domain_specific_state(self, state: State) -> None: """Set simulator state for features that the base class doesn't handle. @@ -588,8 +694,12 @@ def _extract_robot_joint_positions( jp: Any if isinstance(sim_state, dict): jp = sim_state.get("joint_positions") + elif sim_state is None: + return None else: - # Legacy: simulator_state is the joint_positions list itself. + # PyBulletState also accepts simulator_state passed as a raw + # joint-positions sequence (see PyBulletState.joint_positions + # and tests/envs/test_pybullet_blocks.py:69-70). jp = sim_state if jp is None: return None diff --git a/predicators/pybullet_helpers/robots/single_arm.py b/predicators/pybullet_helpers/robots/single_arm.py index 454b1f7be..5e32c7812 100644 --- a/predicators/pybullet_helpers/robots/single_arm.py +++ b/predicators/pybullet_helpers/robots/single_arm.py @@ -261,27 +261,34 @@ def reset_state( self._base_pose.orientation, physicsClientId=self.physics_client_id, ) + target = np.array([rx, ry, rz, qx, qy, qz, qw, rf], dtype=np.float32) if joint_positions is not None: # arm_joints includes fingers, so set_joints already # restored both — skip the snapped-finger overwrite below # so continuous finger values round-trip cleanly. self.set_joints(list(joint_positions)) - else: - # First, reset the joint values to initial joint positions, - # so that IK is consistent (less sensitive to initialization). - self.set_joints(self.initial_joint_positions) - - # Now run IK to get to the actual starting rx, ry, rz. We use - # validate=True to ensure that this initialization works. - pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) - self.inverse_kinematics(pose, validate=True) - - # IK does not touch fingers, so snap them from the EE state. - for finger_id in [self.left_finger_id, self.right_finger_id]: - p.resetJointState(self.robot_id, - finger_id, - rf, - physicsClientId=self.physics_client_id) + # Some callers attach nominal joints to plain states as a reset + # hint. Preserve exact joints only when they really reconstruct the + # requested EE pose; otherwise fall back to IK, matching the legacy + # reset behavior. + if np.allclose(self.get_state()[:7], target[:7], atol=1e-3): + return + + # First, reset the joint values to initial joint positions, + # so that IK is consistent (less sensitive to initialization). + self.set_joints(self.initial_joint_positions) + + # Now run IK to get to the actual starting rx, ry, rz. We use + # validate=True to ensure that this initialization works. + pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) + self.inverse_kinematics(pose, validate=True) + + # IK does not touch fingers, so snap them from the EE state. + for finger_id in [self.left_finger_id, self.right_finger_id]: + p.resetJointState(self.robot_id, + finger_id, + rf, + physicsClientId=self.physics_client_id) def get_state(self) -> Array: """Get the robot state vector based on the current PyBullet state. From 418fd3038ba52b77cbca537ca0f7a1727b1bdd72 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 20:34:49 +0100 Subject: [PATCH 65/70] Reposition recreated cups and plugs in coffee _set_domain_specific_state _remake_cups creates fresh PyBullet bodies that need to be teleported to their state-specified poses; the per-component diff in _set_state now skips objects whose pose already matches PyBullet, so the explicit _reset_single_object calls ensure freshly-recreated bodies land in the right place. Same treatment for plugs when coffee_machine_has_plug. --- predicators/envs/pybullet_coffee.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 64f66f259..364318200 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -403,9 +403,14 @@ def _remake_cord(self) -> None: def _set_domain_specific_state(self, state: State) -> None: """Reset liquid visuals, cup geometry, cord, and button colors.""" self._remake_jug_liquid(state) - self._remake_cup_liquids(state) self._remake_cups(state) + for cup in state.get_objects(self._cup_type): + self._reset_single_object(cup, state) + self._remake_cup_liquids(state) self._remake_cord() + if CFG.coffee_machine_has_plug: + for plug in state.get_objects(self._plug_type): + self._reset_single_object(plug, state) # Machine button color if self._MachineOn_holds(state, [self._machine]) and \ From e82df9ddbc9db80350dfaed0492a1d4bb80fa7d9 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 20:34:56 +0100 Subject: [PATCH 66/70] Look up predicates lazily in option-model _abstract_function The lambda used to capture predicates at __init__ time, which missed predicates invented later (grammar search) and broke subclasses whose _get_current_predicates depends on attributes not yet set during super().__init__(). --- predicators/approaches/agent_planner_approach.py | 9 +++++---- predicators/approaches/agent_sim_learning_approach.py | 3 +-- predicators/approaches/bilevel_planning_approach.py | 9 ++++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 5797f6276..64bc2e350 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -60,13 +60,14 @@ def __init__(self, else: self._option_model = create_option_model(CFG.option_model_name) # Let the option model terminate Wait on atom change using the - # approach's predicates (which may include invented ones). + # approach's predicates (which may include invented ones). Looked + # up lazily so the lambda picks up predicates invented after + # __init__. if CFG.wait_option_terminate_on_atom_change: - preds = self._get_all_predicates() cast( # pylint: disable=protected-access Any, self._option_model - )._abstract_function = \ - lambda s, _p=preds: utils.abstract(s, _p) + )._abstract_function = ( + lambda s: utils.abstract(s, self._get_all_predicates())) self._online_learning_cycle = 0 self._requests_train_task_idxs: Optional[List[int]] = None self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index f1607e91a..f840e2781 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -172,9 +172,8 @@ def _build_option_model( """ model = _OracleOptionModel(self._get_all_options(), simulator_fn) if CFG.wait_option_terminate_on_atom_change: - preds = self._get_all_predicates() model._abstract_function = ( # pylint: disable=protected-access - lambda s, _p=preds: utils.abstract(s, _p)) + lambda s: utils.abstract(s, self._get_all_predicates())) return model # ── Agent-based synthesis ──────────────────────────────────── diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index a0c288bdd..31fcac44c 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -52,12 +52,15 @@ def __init__(self, # refinement and the step is rejected for "exceeded individual # horizon", even when the expected atoms have already become # true. Mirrors AgentPlannerApproach.__init__. + # Looked up lazily so subclasses whose _get_current_predicates + # depends on attributes set after super().__init__() (e.g. + # GrammarSearchInventionApproach._learned_predicates) don't break, + # and so predicates invented later are reflected at call time. if CFG.wait_option_terminate_on_atom_change: - preds = self._get_current_predicates() cast( # pylint: disable=protected-access Any, self._option_model - )._abstract_function = \ - lambda s, _p=preds: utils.abstract(s, _p) + )._abstract_function = ( + lambda s: utils.abstract(s, self._get_current_predicates())) self._num_calls = 0 self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim From 8b6d709943bd3fcf1597de450774811fd5ff914c Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 20:35:01 +0100 Subject: [PATCH 67/70] Rename 'kinematics-only' to 'base-sim-only' in docs and test names Terminology cleanup to match how skip_process_dynamics is described elsewhere; the env wraps the full base sim, not just kinematics. --- predicators/code_sim_learning/training.py | 2 +- tests/approaches/test_agent_sim_learning_approach.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py index ff85923ab..92ac98217 100644 --- a/predicators/code_sim_learning/training.py +++ b/predicators/code_sim_learning/training.py @@ -371,7 +371,7 @@ def fit_params( Args: simulator_fn: Simulator(state, action, params_dict) -> updates. - Should run kinematics internally if needed. + Should run the base sim internally if needed. transitions: List of (s_t, action, s_{t+1}_obs) triples. param_specs: Parameter specifications (name, init_value). process_features: {type_name: [feat_names]} to fit. diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py index 4ee5ea8a2..f5e808700 100644 --- a/tests/approaches/test_agent_sim_learning_approach.py +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -81,11 +81,11 @@ def _build_kinematics_only_oracle(env): def _build_combined_model(env): - """Build a combined model: kinematics-only env + GT step-level dynamics. + """Build a combined model: base-sim-only env + GT step-level dynamics. Mirrors AgentSimLearningApproach: wraps GT rules in a - LearnedSimulator via apply_rules and composes with a kinematics-only - base env. + LearnedSimulator via apply_rules and composes with a base-sim-only + env. """ base_env = create_new_env("pybullet_boil", do_cache=False, From abc448f240ef0d0b4b1587f436b4d12964757f10 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 21:51:01 +0100 Subject: [PATCH 68/70] Tighten _robot_matches_state atol so set_state hint forces reset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fast-path joint-match check used atol=1e-2, which let a caller's initial_joint_positions hint be silently treated as "already there" when live joints were within 1e-2 of initial — leaving the EE pose up to ~3e-3 off the requested state. State.allclose compares features at 1e-3, so the test then failed reconstruction. Match the State.allclose tolerance. Also pick up trailing yapf reformatting in two approach files. --- .../approaches/agent_planner_approach.py | 5 ++-- .../approaches/bilevel_planning_approach.py | 6 ++--- predicators/envs/pybullet_env.py | 23 +++++++++++-------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 64bc2e350..cfa164737 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -65,9 +65,8 @@ def __init__(self, # __init__. if CFG.wait_option_terminate_on_atom_change: cast( # pylint: disable=protected-access - Any, self._option_model - )._abstract_function = ( - lambda s: utils.abstract(s, self._get_all_predicates())) + Any, self._option_model)._abstract_function = ( + lambda s: utils.abstract(s, self._get_all_predicates())) self._online_learning_cycle = 0 self._requests_train_task_idxs: Optional[List[int]] = None self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 31fcac44c..cc9d7ce36 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,9 +58,9 @@ def __init__(self, # and so predicates invented later are reflected at call time. if CFG.wait_option_terminate_on_atom_change: cast( # pylint: disable=protected-access - Any, self._option_model - )._abstract_function = ( - lambda s: utils.abstract(s, self._get_current_predicates())) + Any, self._option_model)._abstract_function = ( + lambda s: utils.abstract(s, self._get_current_predicates()) + ) self._num_calls = 0 self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 1e78b9825..c788bedb0 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -453,10 +453,10 @@ def _set_state(self, state: State) -> None: # constraint, so a kept constraint would leave the held # body behind). new_held_id = self._held_obj_id_in_state(state) - held_obj_moved = (self._held_obj_id is not None and any( - o.id == self._held_obj_id for o in objects_to_reset)) - rebuild_constraint = (full_reset - or new_held_id != self._held_obj_id + held_obj_moved = (self._held_obj_id is not None + and any(o.id == self._held_obj_id + for o in objects_to_reset)) + rebuild_constraint = (full_reset or new_held_id != self._held_obj_id or (self._held_obj_id is not None and (robot_changed or held_obj_moved))) @@ -506,9 +506,7 @@ def _set_state(self, state: State) -> None: logging.warning( "Could not reconstruct state exactly in reset.") - def _robot_matches_state(self, - state: State, - atol: float = 1e-2) -> bool: + def _robot_matches_state(self, state: State, atol: float = 1e-3) -> bool: """True if PyBullet's live robot pose already equals state's. Compares at the joint level. The EE-quaternion path that @@ -517,6 +515,13 @@ def _robot_matches_state(self, fail an EE-pose comparison and trigger a full robot reset on every simulate() call (visible jitter). + ``atol`` matches ``State.allclose``'s feature tolerance: a looser + check would let the fast-path skip a reset even when the live EE + pose differs from the requested state by more than allclose + accepts (e.g. when a caller hands us + ``initial_joint_positions`` as a hint and the live joints are + only 1e-2 close). + Returns False when ``state`` has no joint_positions — the only live caller in that situation is ``_add_pybullet_state_to_tasks``, where forcing a reset is @@ -572,8 +577,8 @@ def _object_pose_matches_state(self, def _held_obj_id_in_state(self, state: State) -> Optional[int]: """Which PyBullet body id is marked is_held > 0.5 in ``state``. - Returns None if no object is held in ``state``. Mirrors the - per-object logic in _reset_single_object before constraint + Returns None if no object is held in ``state``. Mirrors the per- + object logic in _reset_single_object before constraint management was hoisted out into _set_state. """ for obj in state.data: From 58f44f63174640f1576f5ab2153e6b372076ffb7 Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 22:48:33 +0100 Subject: [PATCH 69/70] Fix flaky test_glib_explorer and test_demo_dataset_loading under pytest-split Both tests pass on master and in isolation but fail on shards 6/8 of CI on this branch. The branch's new tests shifted pytest-split's least_duration distribution so existing tests landed in different shards than on master, exposing pre-existing fragility: - test_glib_explorer[Holding]: score_fn returned 0 (not -inf) for non-target goals, so they weren't filtered. With cover's 7-atom dynamic universe and 10 babbles, ~3.5% of seeds sample no Holding goal and the explorer falls through to a Covers goal, leaving the final state without Holding. Bumped glib_num_babbles to 100 and switched the test's score_fn to return -inf for non-target so the explorer never plans toward an off-target predicate. - test_demo_dataset_loading[10-True-oracle-...]: _ensure_cover_demo_ data_exists only checked file existence. test_demo_dataset's max_initial_demos block writes a 3-trajectory dataset under the cover__demo__oracle__7__... name; the [10-...] case then loaded 3 + generated 3 = 6, expected 10. Added a trajectory-count check so the helper regenerates partial files. --- tests/datasets/test_datasets.py | 50 ++++++++++++++++++++++++++- tests/explorers/test_glib_explorer.py | 12 +++++-- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 11d03ca22..fdf922884 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,5 +1,6 @@ """Test cases for dataset generation.""" import os +import pickle as pkl import shutil from contextlib import nullcontext as does_not_raise @@ -304,6 +305,11 @@ def _ensure_cover_demo_data_exists(): this data file existing (for truncation and extension). When pytest- split distributes parametrized cases across groups, the generating case may not run first, so we ensure it here. + + Earlier tests (test_demo_dataset's max_initial_demos / impossible- + goal blocks) write a partial dataset under this same filename, so a + bare ``os.path.exists`` check is not enough — we also have to verify + the file actually carries 7 trajectories before trusting it. """ saved_cfg = { "env": CFG.env, @@ -323,7 +329,12 @@ def _ensure_cover_demo_data_exists(): }) dataset_fname, _ = utils.create_dataset_filename_str( saving_ground_atoms=False) - if not os.path.exists(dataset_fname): + has_full_dataset = False + if os.path.exists(dataset_fname): + with open(dataset_fname, "rb") as f: + existing = pkl.load(f) + has_full_dataset = len(existing.trajectories) == 7 + if not has_full_dataset: env = CoverEnv() train_tasks = [t.task for t in env.get_train_tasks()] predicates, _ = utils.parse_config_excluded_predicates(env) @@ -385,6 +396,43 @@ def test_demo_dataset_loading(num_train_tasks, load_data, demonstrator, assert "Cannot load data" in str(e) +def test_ensure_cover_demo_data_regenerates_partial_file(): + """A partial cover demo file under the 7-task name must be regenerated. + + Earlier tests in test_demo_dataset can write a 3-trajectory dataset + under ``cover__demo__oracle__7__...`` (e.g. the max_initial_demos + block). When pytest-split lands a downstream test that depends on a + 7-trajectory file (test_demo_dataset_loading[10-True-oracle-...]) in + a different shard, that downstream test loads the truncated file and + the load+extend path produces the wrong total. Lock in the helper's + "validate count, not just existence" contract. + """ + # Compute the 7-task filename in the default data_dir, since the + # helper resets data_dir during its reset_config call. + utils.reset_config({ + "env": "cover", + "approach": "random_actions", + "offline_data_method": "demo", + "offline_data_planning_timeout": 500, + "option_learner": "no_learning", + "num_train_tasks": 7, + "load_data": False, + "demonstrator": "oracle", + }) + dataset_fname, _ = utils.create_dataset_filename_str( + saving_ground_atoms=False) + os.makedirs(os.path.dirname(dataset_fname) or ".", exist_ok=True) + # Stage a stale empty dataset under the 7-task filename to simulate + # the leftover from earlier tests' partial writes. + stub = Dataset([]) + with open(dataset_fname, "wb") as f: + pkl.dump(stub, f) + _ensure_cover_demo_data_exists() + with open(dataset_fname, "rb") as f: + regenerated = pkl.load(f) + assert len(regenerated.trajectories) == 7 + + def _ensure_blocks_demo_data_exists(): """Generate the 10-task blocks demo dataset if it doesn't exist. diff --git a/tests/explorers/test_glib_explorer.py b/tests/explorers/test_glib_explorer.py index 89a70d507..5c9af5376 100644 --- a/tests/explorers/test_glib_explorer.py +++ b/tests/explorers/test_glib_explorer.py @@ -11,18 +11,26 @@ @pytest.mark.parametrize("target_predicate", ["Covers", "Holding"]) def test_glib_explorer(target_predicate): """Tests for GLIBExplorer class.""" + # Bump glib_num_babbles so we reliably sample at least one goal + # containing the target predicate. Default 10 babbles from cover's + # 7-atom dynamic universe gives a ~3.5% chance of zero Holding + # samples, which surfaces as a flake when test ordering shifts the + # shared explorer-RNG counter (predicators/explorers/base_explorer.py:15). utils.reset_config({ "env": "cover", "explorer": "glib", "cover_initial_holding_prob": 0.0, + "glib_num_babbles": 100, }) env = CoverEnv() options = get_gt_options(env.get_name()) nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] - # For testing purposes, score everything except target predicate low. - score_fn = lambda atoms: target_predicate in str(atoms) + # Filter out non-target goals so the explorer never falls through to + # plan toward a different predicate when target goals fail. + score_fn = lambda atoms: 1.0 if target_predicate in str(atoms) \ + else -float("inf") explorer = create_explorer("glib", env.predicates, get_gt_options(env.get_name()), From 7bc444396b94a17a160a3fcb46379a81936d339f Mon Sep 17 00:00:00 2001 From: Yichao Liang Date: Mon, 4 May 2026 22:48:47 +0100 Subject: [PATCH 70/70] Add unit tests for _robot_matches_state atol and pybullet_helpers.objects - test_robot_matches_state_atol_forces_reset_on_small_drift: locks in the 1e-3 atol regression. A ~5e-3 joint drift (within the previous 1e-2 tolerance, outside the new 1e-3) must NOT be treated as "already there" by the fast-path; _set_state must move the robot back to the requested EE pose at State.allclose precision. - tests/pybullet_helpers/test_objects.py (new): coverage for sample_collision_free_2d_positions, used by 3 PyBullet envs but previously without direct tests. Covers no-overlap (circles and rectangles), bounds, reproducibility across seeds, RuntimeError on impossible packing, and ValueError on unknown shape_type. --- tests/envs/test_pybullet_blocks.py | 40 +++++++++ tests/pybullet_helpers/test_objects.py | 112 +++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 tests/pybullet_helpers/test_objects.py diff --git a/tests/envs/test_pybullet_blocks.py b/tests/envs/test_pybullet_blocks.py index 39512c703..739334493 100644 --- a/tests/envs/test_pybullet_blocks.py +++ b/tests/envs/test_pybullet_blocks.py @@ -405,6 +405,46 @@ def test_pybullet_blocks_putontable_corners(env): assert abs(state.get(block, "pose_y") - by) < 1e-2 +def test_robot_matches_state_atol_forces_reset_on_small_drift(env): + """A small joint drift (~5e-3) must NOT be treated as "already there". + + Locks in the _robot_matches_state atol regression: with the prior + 1e-2 tolerance, a caller-supplied initial_joint_positions hint was + silently accepted whenever the live joints were within 1e-2 of + initial, leaving the EE pose ~3e-3 off the requested state — past + the 1e-3 State.allclose threshold. The fast-path must agree with + State.allclose precision. + """ + robot = env.robot + block = Object("block0", env.block_type) + bx = (env.x_lb + env.x_ub) / 2 + by = (env.y_lb + env.y_ub) / 2 + bz = env.table_height + 0.5 * env.block_size + rx, ry, rz = env.robot_init_x, env.robot_init_y, env.robot_init_z + rf = env.open_fingers + init_state = State({ + robot: np.array([rx, ry, rz, rf]), + block: np.array([bx, by, bz, 0.0, 1.0, 0.0, 0.0]), + }) + # First, get the env into the requested init pose. + env.set_state(init_state) + initial_joints = list(env._pybullet_robot.initial_joint_positions) # pylint: disable=protected-access + # Nudge the live joints by ~5e-3 (within old 1e-2 atol, outside new + # 1e-3 atol) so the fast-path *would* incorrectly accept under the + # old tolerance. + drifted_joints = [j + 5e-3 for j in initial_joints] + env._pybullet_robot.set_joints(drifted_joints) # pylint: disable=protected-access + # State carries the original initial joints as a "should be here" hint. + hint_state = utils.PyBulletState(init_state.data, + simulator_state=initial_joints) + # The fast-path comparison must reject the drift. + assert not env._robot_matches_state(hint_state) # pylint: disable=protected-access + # And calling _set_state must actually move the robot back to the + # requested EE pose at State.allclose precision (atol=1e-3). + env._set_state(hint_state) # pylint: disable=protected-access + assert env.get_state().allclose(init_state) + + def test_pybullet_blocks_close_pick_place(env): """Test a tricky case where we attempt to pick and place immediately next to a pile of blocks. diff --git a/tests/pybullet_helpers/test_objects.py b/tests/pybullet_helpers/test_objects.py new file mode 100644 index 000000000..fd743c3d1 --- /dev/null +++ b/tests/pybullet_helpers/test_objects.py @@ -0,0 +1,112 @@ +"""Unit tests for predicators.pybullet_helpers.objects.""" +import numpy as np +import pytest + +from predicators.pybullet_helpers.objects import \ + sample_collision_free_2d_positions +from predicators.utils import Circle, Rectangle + + +def test_sample_collision_free_2d_positions_circles_no_overlap(): + """Sampled circles never overlap with each other.""" + rng = np.random.default_rng(0) + radius = 0.05 + positions = sample_collision_free_2d_positions( + num_samples=8, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[radius], + rng=rng, + ) + assert len(positions) == 8 + circles = [Circle(x, y, radius) for x, y in positions] + for i, c1 in enumerate(circles): + for c2 in circles[i + 1:]: + assert not c1.intersects(c2) + + +def test_sample_collision_free_2d_positions_within_bounds(): + """Sampled positions stay inside the requested x/y range.""" + rng = np.random.default_rng(0) + positions = sample_collision_free_2d_positions( + num_samples=5, + x_range=(-0.5, 0.5), + y_range=(2.0, 3.0), + shape_type="circle", + shape_params=[0.05], + rng=rng, + ) + for x, y in positions: + assert -0.5 <= x <= 0.5 + assert 2.0 <= y <= 3.0 + + +def test_sample_collision_free_2d_positions_rectangles_no_overlap(): + """Sampled rectangles never overlap with each other.""" + rng = np.random.default_rng(1) + w, h, theta = 0.05, 0.05, 0.0 + positions = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="rectangle", + shape_params=[w, h, theta], + rng=rng, + ) + assert len(positions) == 4 + rects = [Rectangle(x, y, w, h, theta) for x, y in positions] + for i, r1 in enumerate(rects): + for r2 in rects[i + 1:]: + assert not r1.intersects(r2) + + +def test_sample_collision_free_2d_positions_reproducible(): + """Same seed produces the same positions.""" + pos_a = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.05], + rng=np.random.default_rng(123), + ) + pos_b = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.05], + rng=np.random.default_rng(123), + ) + assert pos_a == pos_b + + +def test_sample_collision_free_2d_positions_impossible_raises(): + """Asking for more shapes than fit raises RuntimeError.""" + # 4 disks of radius 0.5 cannot fit non-overlapping in [0,1]^2. + rng = np.random.default_rng(0) + with pytest.raises(RuntimeError, match="Max tries exceeded"): + sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.5], + rng=rng, + max_tries_total=200, + ) + + +def test_sample_collision_free_2d_positions_invalid_shape_raises(): + """An unknown shape_type raises ValueError.""" + rng = np.random.default_rng(0) + with pytest.raises(ValueError, match="Unsupported shape_type"): + sample_collision_free_2d_positions( + num_samples=1, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="triangle", + shape_params=[0.05], + rng=rng, + )