Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 62 additions & 57 deletions predicators/approaches/human_low_level_control_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,15 @@ def _get_action_from_keyboard(self, state: State) -> Action:
"""
self._step_count += 1

# Get physics client ID from state
# Some envs (e.g. pybullet_blocks) pass simulator_state as a raw
# joint-positions list rather than the dict that PyBulletEnv
# builds. Only the mobile-base branch below needs
# physics_client_id/robot_id, so look them up lazily.
assert isinstance(state, utils.PyBulletState)
assert state.simulator_state is not None
physics_client_id = state.simulator_state.get("physics_client_id")
if physics_client_id is None:
raise ValueError(
"physics_client_id not found in state.simulator_state. "
"Make sure the environment adds it to PyBulletState.")
sim_state_dict = (state.simulator_state if isinstance(
state.simulator_state, dict) else {})
physics_client_id = sim_state_dict.get("physics_client_id")

# Get the most recent pressed key (only one key per frame)
key = self._get_pressed_key()
Expand Down Expand Up @@ -257,47 +258,38 @@ def _get_action_from_keyboard(self, state: State) -> Action:
# Get robot for IK
robot = self._get_robot()
if hasattr(robot, "base_action_dim") and robot.base_action_dim > 0:
robot_id = state.simulator_state.get("robot_id")
robot_id = sim_state_dict.get("robot_id")
if robot_id is not None:
if physics_client_id is None:
raise ValueError(
"physics_client_id not found in simulator_state; "
"mobile-base envs must populate it.")
base_pos, base_orn = p.getBasePositionAndOrientation(
robot_id, physicsClientId=physics_client_id)
robot.set_base_pose( # type: ignore[attr-defined]
Pose(base_pos, base_orn))

# Find robot object in state
robot_obj = None
for obj in state.data.keys():
if obj.type.name == "robot":
robot_obj = obj
break
# Get current EE pose from forward kinematics on the shadow
# robot. This avoids relying on env-specific state feature
# names (some envs use x/y/z, others pose_x/pose_z, cover has
# no y at all) and gives the true world-frame pose.
current_pose = robot.forward_kinematics(current_joint_positions)
current_x, current_y, current_z = current_pose.position

if robot_obj is None:
raise ValueError("No robot object found in state")

# Get current pose
current_x = state.get(robot_obj, "x")
current_y = state.get(robot_obj, "y")
current_z = state.get(robot_obj, "z")
current_tilt = state.get(
robot_obj,
"tilt") if "tilt" in robot_obj.type.feature_names else 0.0
current_wrist = state.get(
robot_obj,
"wrist") if "wrist" in robot_obj.type.feature_names else 0.0

# Compute target pose
target_x = current_x + dx
target_y = current_y + dy
target_z = current_z + dz
target_tilt = current_tilt + d_tilt
target_wrist = current_wrist + d_wrist

# Create poses
current_orn = p.getQuaternionFromEuler(
[0, current_tilt, current_wrist])
target_orn = p.getQuaternionFromEuler([0, target_tilt, target_wrist])
# Apply tilt/wrist deltas to the current orientation. tilt = pitch,
# wrist = yaw. Roll is preserved.
if d_tilt or d_wrist:
roll, pitch, yaw = p.getEulerFromQuaternion(
current_pose.orientation)
target_orn = p.getQuaternionFromEuler(
[roll, pitch + d_tilt, yaw + d_wrist])
else:
target_orn = current_pose.orientation

current_pose = Pose((current_x, current_y, current_z), current_orn)
target_pose = Pose((target_x, target_y, target_z), target_orn)

# Finger status
Expand Down Expand Up @@ -352,28 +344,41 @@ def _pad_base_action(self, action_arr: np.ndarray) -> Action:


def _get_shadow_robot_for_env() -> SingleArmPyBulletRobot:
"""Create a shadow robot for IK calculations based on current
environment."""
env_name = CFG.env

"""Create a shadow robot for IK calculations.

IK is base-pose-dependent (each env may translate/rotate the fetch's
base), so we instantiate a fetch at the active env's base pose. We
deliberately bypass each subclass's ``initialize_pybullet`` override
to avoid loading env-specific bodies (tables/blocks/fans/etc.) that
the IK does not need; we just connect a fresh DIRECT client, drop a
ground plane, and ask the env class for a robot.
"""
# pylint: disable=import-outside-toplevel
# Map environment names to their classes
if env_name.startswith("pybullet_circuit"):
from predicators.envs.pybullet_circuit import PyBulletCircuitEnv
_, robot, _ = PyBulletCircuitEnv.initialize_pybullet(using_gui=False)
elif env_name.startswith("pybullet_fan"):
from predicators.envs.pybullet_fan import PyBulletFanEnv
_, robot, _ = PyBulletFanEnv.initialize_pybullet(using_gui=False)
elif env_name.startswith("pybullet_blocks"):
from predicators.envs.pybullet_blocks import PyBulletBlocksEnv
_, robot, _ = PyBulletBlocksEnv.initialize_pybullet(using_gui=False)
elif env_name.startswith("pybullet_coffee"):
from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv
_, robot, _ = PyBulletCoffeeEnv.initialize_pybullet(using_gui=False)
else:
# Default fallback - use PyBulletFanEnv (most common)
from predicators.envs.pybullet_fan import PyBulletFanEnv
_, robot, _ = PyBulletFanEnv.initialize_pybullet(using_gui=False)
from predicators.envs.base_env import BaseEnv
from predicators.envs.pybullet_env import PyBulletEnv

# pylint: enable=import-outside-toplevel

return robot
env_name = CFG.env
env_cls = None
for cls in utils.get_all_subclasses(BaseEnv):
if cls.__abstractmethods__:
continue
if not issubclass(cls, PyBulletEnv):
continue
if cls.get_name() == env_name:
env_cls = cls
break
if env_cls is None:
raise NotImplementedError(
f"human_low_level_control: no PyBulletEnv subclass registered "
f"for env name {env_name!r}.")

physics_client_id = p.connect(p.DIRECT)
p.resetSimulation(physicsClientId=physics_client_id)
p.loadURDF(utils.get_env_asset_path("urdf/plane.urdf"), [0, 0, 0],
useFixedBase=True,
physicsClientId=physics_client_id)
p.setGravity(0., 0., -10., physicsClientId=physics_client_id)
return env_cls._create_pybullet_robot( # pylint: disable=protected-access
physics_client_id)
Loading