diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 22c96af1779e..f8d971317614 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -19,7 +19,7 @@ import isaaclab.utils.sensors as sensor_utils from isaaclab.app.settings_manager import get_settings_manager from isaaclab.renderers import BaseRenderer, Renderer -from isaaclab.sim.views import FrameView +from isaaclab.sim.views import UsdFrameView from isaaclab.utils import has_kit, to_camel_case from isaaclab.utils.math import ( convert_camera_frame_orientation_convention, @@ -405,9 +405,11 @@ def _initialize_impl(self): # references to prims located in the stage. self._renderer.prepare_stage(self.stage, self._num_envs) - # Create a view for the sensor with Fabric enabled for fast pose queries. - # TODO: remove sync_usd_on_fabric_write=True once the GPU Fabric sync bug is fixed. - self._view = FrameView(self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True) + # Camera uses UsdFrameView directly (not FrameView/FabricFrameView) because + # the RTX renderer / Replicator reads camera poses from USD prim paths, not + # from Fabric. Writing to Fabric + sync_usd_on_fabric_write was wasteful — + # this bypasses Fabric entirely for camera transforms. + self._view = UsdFrameView(self.cfg.prim_path, device=self._device, stage=self.stage) # Check that sizes are correct if self._view.count != self._num_envs: raise RuntimeError( diff --git a/source/isaaclab/isaaclab/sim/views/usd_frame_view.py b/source/isaaclab/isaaclab/sim/views/usd_frame_view.py index 4421fa5391ea..7e02e0b64d4a 100644 --- a/source/isaaclab/isaaclab/sim/views/usd_frame_view.py +++ b/source/isaaclab/isaaclab/sim/views/usd_frame_view.py @@ -71,8 +71,7 @@ def __init__( stage: USD stage to search for prims. Defaults to None, in which case the current active stage from the simulation context is used. **kwargs: Additional keyword arguments (ignored). Allows forward-compatible - construction when callers pass backend-specific options like - ``sync_usd_on_fabric_write``. + construction when callers pass backend-specific options. Raises: ValueError: If any matched prim is not Xformable or doesn't have standardized diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index 4ce62cd5336f..2c5487a53f64 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -195,3 +195,79 @@ def _populate_scene(): sim_utils.define_rigid_body_properties(prim_path, sim_utils.RigidBodyPropertiesCfg()) sim_utils.define_mass_properties(prim_path, sim_utils.MassPropertiesCfg(mass=5.0)) sim_utils.define_collision_properties(prim_path, sim_utils.CollisionPropertiesCfg()) + + +# ------------------------------------------------------------------ +# Camera pose → render validation (PrepareForReuse / Fabric path) +# ------------------------------------------------------------------ + + +@pytest.mark.parametrize( + "device, camera_cls", + [ + pytest.param("cpu", TiledCamera, id="cpu-tiled"), + pytest.param("cpu", Camera, id="cpu-non_tiled"), + pytest.param("cuda:0", TiledCamera, id="cuda:0-tiled"), + pytest.param("cuda:0", Camera, id="cuda:0-non_tiled"), + ], +) +def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls): + """Camera pose changes via FrameView should be visible in rendered depth. + + Moves camera close then far, renders depth, and verifies that the mean + valid depth from the far position is significantly larger (>1.5×) than + the close position. This validates that Fabric-side pose writes + (via PrepareForReuse) or USD writes are correctly propagated to the + RTX renderer. + """ + sim, _unused_cam_cfg, dt = setup_camera + + cam_cfg = CameraCfg( + prim_path="/World/PoseTestCam", + height=128, + width=256, + update_period=0, + update_latest_camera_pose=True, + data_types=["distance_to_camera"], + spawn=sim_utils.PinholeCameraCfg( + focal_length=24.0, + focus_distance=400.0, + horizontal_aperture=20.955, + clipping_range=(0.1, 1.0e5), + ), + ) + camera = camera_cls(cam_cfg) + sim.reset() + + target = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32, device=camera.device) + max_range = cam_cfg.spawn.clipping_range[1] + + # -- close position -- + eyes_close = torch.tensor([[2.0, 2.0, 2.0]], dtype=torch.float32, device=camera.device) + camera.set_world_poses_from_view(eyes_close, target) + sim.step() + camera.update(dt) + depth_close = camera.data.output["distance_to_camera"].clone() + + # -- far position -- + eyes_far = torch.tensor([[8.0, 8.0, 8.0]], dtype=torch.float32, device=camera.device) + camera.set_world_poses_from_view(eyes_far, target) + sim.step() + camera.update(dt) + depth_far = camera.data.output["distance_to_camera"].clone() + + # -- validate -- + valid_close = depth_close[depth_close < max_range] + valid_far = depth_far[depth_far < max_range] + + assert valid_close.numel() > 0, "No valid close-range depth pixels" + assert valid_far.numel() > 0, "No valid far-range depth pixels" + + mean_close = valid_close.mean().item() + mean_far = valid_far.mean().item() + + assert mean_far > mean_close * 1.5, ( + f"Far depth ({mean_far:.2f}) should be > 1.5× close depth ({mean_close:.2f}). " + "Camera pose change may not be reaching the renderer." + ) + del camera diff --git a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py index 87adad2238c4..20f70d8a3fde 100644 --- a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py +++ b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py @@ -12,7 +12,7 @@ import torch import warp as wp -from pxr import Usd +from pxr import Usd, UsdGeom import isaaclab.sim as sim_utils from isaaclab.app.settings_manager import SettingsManager @@ -50,6 +50,10 @@ class FabricFrameView(BaseFrameView): Warp kernels operating on ``omni:fabric:worldMatrix``. All other operations delegate to the internal USD view. + After every Fabric write, :meth:`PrepareForReuse` is called on the + ``PrimSelection`` to notify the renderer (FSD/Storm) that Fabric data + has changed. + All getters return ``wp.array``. Setters accept ``wp.array``. """ @@ -58,12 +62,11 @@ def __init__( prim_path: str, device: str = "cpu", validate_xform_ops: bool = True, - sync_usd_on_fabric_write: bool = False, stage: Usd.Stage | None = None, + **kwargs, ): self._usd_view = UsdFrameView(prim_path, device=device, validate_xform_ops=validate_xform_ops, stage=stage) self._device = device - self._sync_usd_on_fabric_write = sync_usd_on_fabric_write settings = SettingsManager.instance() self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) @@ -134,6 +137,8 @@ def set_world_poses(self, positions=None, orientations=None, indices=None): if not self._fabric_initialized: self._initialize_fabric() + self._prepare_for_reuse() + indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -165,8 +170,6 @@ def set_world_poses(self, positions=None, orientations=None, indices=None): self._fabric_hierarchy.update_world_xforms() self._fabric_usd_sync_done = True - if self._sync_usd_on_fabric_write: - self._usd_view.set_world_poses(positions, orientations, indices) def get_world_poses(self, indices=None): if not self._use_fabric: @@ -177,6 +180,8 @@ def get_world_poses(self, indices=None): if not self._fabric_usd_sync_done: self._sync_fabric_from_usd_once() + self._prepare_for_reuse() + indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -207,14 +212,66 @@ def get_world_poses(self, indices=None): return positions_wp, orientations_wp # ------------------------------------------------------------------ - # Local poses — USD fallback (Fabric only accelerates world poses) + # Local poses — computed from Fabric world poses when Fabric is active # ------------------------------------------------------------------ def set_local_poses(self, translations=None, orientations=None, indices=None): - self._usd_view.set_local_poses(translations, orientations, indices) + if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done: + self._usd_view.set_local_poses(translations, orientations, indices) + if self._use_fabric and self._fabric_initialized: + # After writing local to USD, recompute Fabric world matrices + self._fabric_hierarchy.update_world_xforms() + self._prepare_for_reuse() + return + + # Fabric path: compute child world = parent_world * local, then write to Fabric + import torch + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + indices_list = wp.to_torch(indices_wp).long().tolist() + + parent_pos, parent_ori = self._get_parent_world_poses(indices_list) + + if translations is not None: + local_pos = wp.to_torch(_to_float32_2d(translations)) + else: + local_pos = torch.zeros((count, 3), dtype=torch.float32, device=self._device) + + if orientations is not None: + local_ori = wp.to_torch(_to_float32_2d(orientations)) + else: + local_ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]] * count, dtype=torch.float32, device=self._device) + + child_pos, child_ori = self._compose_parent_local(parent_pos, parent_ori, local_pos, local_ori) + + self.set_world_poses( + wp.from_torch(child_pos.contiguous()), + wp.from_torch(child_ori.contiguous()), + indices, + ) def get_local_poses(self, indices=None): - return self._usd_view.get_local_poses(indices) + if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done: + return self._usd_view.get_local_poses(indices) + + # Fabric path: local = inv(parent_world) * child_world + + indices_wp = self._resolve_indices_wp(indices) + indices_list = wp.to_torch(indices_wp).long().tolist() + + child_pos_wp, child_ori_wp = self.get_world_poses(indices) + child_pos = wp.to_torch(child_pos_wp) + child_ori = wp.to_torch(child_ori_wp) + + parent_pos, parent_ori = self._get_parent_world_poses(indices_list) + + local_pos, local_ori = self._invert_parent_compose(parent_pos, parent_ori, child_pos, child_ori) + + return ( + wp.from_torch(local_pos.contiguous()), + wp.from_torch(local_ori.contiguous()), + ) # ------------------------------------------------------------------ # Scales — Fabric-accelerated or USD fallback @@ -228,6 +285,8 @@ def set_scales(self, scales, indices=None): if not self._fabric_initialized: self._initialize_fabric() + self._prepare_for_reuse() + indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -255,8 +314,6 @@ def set_scales(self, scales, indices=None): self._fabric_hierarchy.update_world_xforms() self._fabric_usd_sync_done = True - if self._sync_usd_on_fabric_write: - self._usd_view.set_scales(scales, indices) def get_scales(self, indices=None): if not self._use_fabric: @@ -267,6 +324,8 @@ def get_scales(self, indices=None): if not self._fabric_usd_sync_done: self._sync_fabric_from_usd_once() + self._prepare_for_reuse() + indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -294,6 +353,153 @@ def get_scales(self, indices=None): wp.synchronize() return scales_wp + # ------------------------------------------------------------------ + # Internal — PrepareForReuse (renderer notification + topology tracking) + # ------------------------------------------------------------------ + + def _prepare_for_reuse(self) -> None: + """Call PrepareForReuse on the PrimSelection to notify the renderer. + + PrepareForReuse serves two purposes: + + 1. **Renderer notification**: Tells FSD/Storm that Fabric data has + been (or will be) modified, so the next rendered frame reflects + the updated transforms. + 2. **Topology change detection**: Returns True when Fabric's + internal memory layout changed (e.g., prims added/removed). + In that case, view-to-fabric index mappings and fabricarrays + must be rebuilt. + """ + if self._fabric_selection is None: + return + + topology_changed = self._fabric_selection.PrepareForReuse() + if topology_changed: + logger.info("Fabric topology changed — rebuilding view-to-fabric index mapping.") + self._rebuild_fabric_arrays() + + def _rebuild_fabric_arrays(self) -> None: + """Rebuild fabricarray and view↔fabric mappings after a topology change.""" + self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._fabric_device) + self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) + + wp.launch( + kernel=fabric_utils.set_view_to_fabric_array, + dim=self._fabric_to_view.shape[0], + inputs=[self._fabric_to_view, self._view_to_fabric], + device=self._fabric_device, + ) + wp.synchronize() + + self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") + + # ------------------------------------------------------------------ + # Internal — Local/world pose helpers + # ------------------------------------------------------------------ + + def _get_parent_world_poses(self, indices_list: list[int]) -> tuple: + """Read parent world poses from USD for given child indices. + + Parents are not tracked in Fabric, so we read from USD XformCache. + Returns torch tensors ``(parent_pos[N,3], parent_ori[N,4])`` on self._device. + Orientation is ``(x, y, z, w)`` to match the convention used by FabricFrameView. + """ + import torch + + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + stage = self._usd_view._prims[0].GetStage() + + parent_positions = [] + parent_orientations = [] + for idx in indices_list: + child_path = self.prim_paths[idx] + parent_path = child_path.rsplit("/", 1)[0] + parent_prim = stage.GetPrimAtPath(parent_path) + if parent_prim and parent_prim.IsValid(): + parent_tf = xform_cache.GetLocalToWorldTransform(parent_prim) + parent_tf.Orthonormalize() + t = parent_tf.ExtractTranslation() + q = parent_tf.ExtractRotationQuat() + img = q.GetImaginary() + real = q.GetReal() + parent_positions.append([float(t[0]), float(t[1]), float(t[2])]) + # (x, y, z, w) convention + parent_orientations.append([float(img[0]), float(img[1]), float(img[2]), float(real)]) + else: + # No parent — identity + parent_positions.append([0.0, 0.0, 0.0]) + parent_orientations.append([0.0, 0.0, 0.0, 1.0]) + + return ( + torch.tensor(parent_positions, dtype=torch.float32, device=self._device), + torch.tensor(parent_orientations, dtype=torch.float32, device=self._device), + ) + + @staticmethod + def _compose_parent_local( + parent_pos: torch.Tensor, + parent_ori: torch.Tensor, + local_pos: torch.Tensor, + local_ori: torch.Tensor, + ) -> tuple: + """Compute child_world = parent_world * local. + + Orientations are ``(x, y, z, w)``. + Returns ``(child_world_pos, child_world_ori)``. + """ + child_pos = parent_pos + FabricFrameView._quat_rotate(parent_ori, local_pos) + child_ori = FabricFrameView._quat_mul(parent_ori, local_ori) + return child_pos, child_ori + + @staticmethod + def _invert_parent_compose( + parent_pos: torch.Tensor, + parent_ori: torch.Tensor, + child_pos: torch.Tensor, + child_ori: torch.Tensor, + ) -> tuple: + """Compute local = inv(parent_world) * child_world. + + Orientations are ``(x, y, z, w)``. + Returns ``(local_pos, local_ori)``. + """ + parent_ori_inv = FabricFrameView._quat_conjugate(parent_ori) + local_pos = FabricFrameView._quat_rotate(parent_ori_inv, child_pos - parent_pos) + local_ori = FabricFrameView._quat_mul(parent_ori_inv, child_ori) + return local_pos, local_ori + + @staticmethod + def _quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """Quaternion multiply (x,y,z,w) convention.""" + x1, y1, z1, w1 = q1[..., 0:1], q1[..., 1:2], q1[..., 2:3], q1[..., 3:4] + x2, y2, z2, w2 = q2[..., 0:1], q2[..., 1:2], q2[..., 2:3], q2[..., 3:4] + import torch + + return torch.cat( + [ + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + ], + dim=-1, + ) + + @staticmethod + def _quat_conjugate(q: torch.Tensor) -> torch.Tensor: + """Quaternion conjugate (x,y,z,w) convention.""" + return q * q.new_tensor([-1, -1, -1, 1]) + + @staticmethod + def _quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Rotate vector v by quaternion q. (x,y,z,w) convention.""" + import torch + + q_xyz = q[..., :3] + q_w = q[..., 3:4] + t = 2.0 * torch.linalg.cross(q_xyz, v) + return v + q_w * t + torch.linalg.cross(q_xyz, t) + # ------------------------------------------------------------------ # Internal — Fabric initialization # ------------------------------------------------------------------ @@ -384,11 +590,8 @@ def _sync_fabric_from_usd_once(self) -> None: positions_usd, orientations_usd = self._usd_view.get_world_poses() scales_usd = self._usd_view.get_scales() - prev_sync = self._sync_usd_on_fabric_write - self._sync_usd_on_fabric_write = False self.set_world_poses(positions_usd, orientations_usd) self.set_scales(scales_usd) - self._sync_usd_on_fabric_write = prev_sync self._fabric_usd_sync_done = True diff --git a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py index 0bc77ccf7223..ebe08478c567 100644 --- a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py +++ b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py @@ -21,6 +21,7 @@ import pytest # noqa: E402 import torch # noqa: E402 +import warp as wp # noqa: E402 from frame_view_contract_utils import * # noqa: F401, F403, E402 from frame_view_contract_utils import CHILD_OFFSET, ViewBundle # noqa: E402 from isaaclab_physx.sim.views import FabricFrameView as FrameView # noqa: E402 @@ -94,7 +95,7 @@ def factory(num_envs: int, device: str) -> ViewBundle: sim_utils.create_prim(f"/World/Parent_{i}/Child", "Camera", translation=CHILD_OFFSET, stage=stage) sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) - view = FrameView("/World/Parent_.*/Child", device=device, sync_usd_on_fabric_write=True) + view = FrameView("/World/Parent_.*/Child", device=device) return ViewBundle( view=view, get_parent_pos=_get_parent_positions, @@ -103,3 +104,133 @@ def factory(num_envs: int, device: str) -> ViewBundle: ) return factory + + +# ------------------------------------------------------------------ +# Override: ensure the shared contract test runs without xfail now that +# get_local_poses computes local from Fabric world matrices. +# ------------------------------------------------------------------ +# (No override needed — the shared test_set_world_updates_local from +# frame_view_contract_utils is imported via wildcard and will run as-is.) + + +# ------------------------------------------------------------------ +# Fabric-specific tests (not in shared contract) +# ------------------------------------------------------------------ + + +@wp.kernel +def _fill_position(out: wp.array(dtype=wp.float32, ndim=2), x: float, y: float, z: float): + i = wp.tid() + out[i, 0] = wp.float32(x) + out[i, 1] = wp.float32(y) + out[i, 2] = wp.float32(z) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_fabric_set_world_does_not_write_back_to_usd(device, view_factory): + """Verify that set_world_poses in Fabric mode does NOT sync back to USD. + + This confirms the removal of sync_usd_on_fabric_write. After calling + set_world_poses, the USD prim's xformOps should still contain the + original (stale) values. + """ + bundle = view_factory(1, device) + view = bundle.view + + # Capture the original USD world position BEFORE any Fabric write + stage = sim_utils.get_current_stage() + prim = stage.GetPrimAtPath(view.prim_paths[0]) + xform_cache = UsdGeom.XformCache() + usd_tf_before = xform_cache.GetLocalToWorldTransform(prim) + usd_t_before = usd_tf_before.ExtractTranslation() + orig_usd_pos = torch.tensor([float(usd_t_before[0]), float(usd_t_before[1]), float(usd_t_before[2])]) + + # Write to Fabric — move to (99, 99, 99) + new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 99.0, 99.0, 99.0], device=device) + view.set_world_poses(positions=new_pos) + + # Verify Fabric has the new position + fab_pos, _ = view.get_world_poses() + pos_torch = wp.to_torch(fab_pos) + assert torch.allclose(pos_torch, torch.tensor([[99.0, 99.0, 99.0]], device=device), atol=0.1), ( + f"Fabric should have new position, got {pos_torch}" + ) + + # Verify USD still has the ORIGINAL position (no writeback) + xform_cache_after = UsdGeom.XformCache() + usd_tf_after = xform_cache_after.GetLocalToWorldTransform(prim) + usd_t_after = usd_tf_after.ExtractTranslation() + usd_pos_after = torch.tensor([float(usd_t_after[0]), float(usd_t_after[1]), float(usd_t_after[2])]) + assert torch.allclose(usd_pos_after, orig_usd_pos, atol=0.1), ( + f"USD should still have original position {orig_usd_pos}, but got {usd_pos_after}. " + f"sync_usd_on_fabric_write may not have been fully removed." + ) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_prepare_for_reuse_detects_topology_change(device, view_factory): + """Verify PrepareForReuse() is callable and returns a bool. + + When no topology change has occurred, it should return False. + """ + bundle = view_factory(1, device) + view = bundle.view + view.get_world_poses() # trigger Fabric init + + assert view._fabric_selection is not None, "Fabric selection not initialized" + result = view._fabric_selection.PrepareForReuse() + assert isinstance(result, bool), f"PrepareForReuse should return bool, got {type(result)}" + assert not result, "PrepareForReuse should return False when no topology change" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_set_local_via_fabric_path(device, view_factory): + """Exercise the Fabric-native set_local_poses path. + + Ensures set_local_poses computes child_world = parent_world * local + entirely within Fabric (not falling back to USD) by first triggering + the Fabric sync via get_world_poses. + """ + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger Fabric init and sync (sets _fabric_usd_sync_done = True) + view.get_world_poses() + + # Now set_local_poses should take the Fabric path + new_local_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_pos, 1.0, 2.0, 3.0], device=device) + ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device) + new_local_ori = wp.from_torch(ori) + + view.set_local_poses(translations=new_local_pos, orientations=new_local_ori) + + # Verify: world = parent(0,0,1) + local(1,2,3) = (1,2,4) + world_pos, _ = view.get_world_poses() + pos_t = wp.to_torch(world_pos) + expected = torch.tensor([[1.0, 2.0, 4.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(pos_t, expected, atol=1e-4, rtol=0) + + # Verify get_local_poses returns the local offset + local_pos, _ = view.get_local_poses() + local_t = wp.to_torch(local_pos) + expected_local = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(local_t, expected_local, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_get_scales_fabric_path(device, view_factory): + """Exercise the Fabric-native get_scales path.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger Fabric init + view.get_world_poses() + + scales = view.get_scales() + scales_t = wp.to_torch(scales) + # Default scale should be (1, 1, 1) + expected = torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(scales_t, expected, atol=1e-4, rtol=0)