diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index e2ee7dd06b..c11836efd3 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -365,7 +365,20 @@ def create_rlds_data_loader( If not provided, will iterate over the dataset indefinitely. """ if framework == "pytorch": - raise NotImplementedError("PyTorch RLDS data loader is not supported yet") + # Only rank 0 builds the heavy TF-based RLDS pipeline; other ranks + # receive batches via broadcast to avoid 4x resource contention. + is_rank0 = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + if is_rank0: + dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) + dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) + else: + dataset = None + data_loader = PyTorchRLDSDataLoader( + dataset, + num_batches=num_batches, + ) + return DataLoaderImpl(data_config, data_loader) + dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) @@ -527,8 +540,98 @@ def __iter__(self): yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) +class PyTorchRLDSDataLoader: + """RLDS data loader for PyTorch DDP training. + + Only rank 0 builds and iterates the heavy TF-based RLDS pipeline. + Each batch is broadcast to all other ranks, then sliced so every rank + gets its own shard (batch_size // world_size samples). + """ + + def __init__( + self, + dataset: DroidRldsDataset | None, + *, + num_batches: int | None = None, + ): + self._dataset = dataset # Only non-None on rank 0 + self._num_batches = num_batches + + if torch.distributed.is_initialized(): + self._rank = torch.distributed.get_rank() + self._world_size = torch.distributed.get_world_size() + else: + self._rank = 0 + self._world_size = 1 + + def _broadcast_batch(self, batch: dict | None) -> dict: + """Broadcast a batch from rank 0 to all ranks using torch.distributed.""" + import pickle + + if self._rank == 0: + data = pickle.dumps(batch) + size_tensor = torch.tensor([len(data)], dtype=torch.long, device="cuda") + else: + size_tensor = torch.tensor([0], dtype=torch.long, device="cuda") + + torch.distributed.broadcast(size_tensor, src=0) + size = size_tensor.item() + + if self._rank == 0: + data_tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8).to("cuda") + else: + data_tensor = torch.empty(size, dtype=torch.uint8, device="cuda") + + torch.distributed.broadcast(data_tensor, src=0) + + if self._rank != 0: + batch = pickle.loads(data_tensor.cpu().numpy().tobytes()) + + return batch + + def __iter__(self): + num_items = 0 + while True: + if self._rank == 0: + data_iter = iter(self._dataset) + + while True: + if self._num_batches is not None and num_items >= self._num_batches: + return + + # Rank 0 produces the batch; others will receive via broadcast. + batch = None + exhausted = False + if self._rank == 0: + try: + batch = next(data_iter) + except StopIteration: + exhausted = True + + if self._world_size > 1: + # Sync the exhausted flag across all ranks. + flag = torch.tensor([int(exhausted) if self._rank == 0 else 0], dtype=torch.long, device="cuda") + torch.distributed.broadcast(flag, src=0) + exhausted = flag.item() == 1 + + if exhausted: + break + + if self._world_size > 1: + batch = self._broadcast_batch(batch) + + num_items += 1 + # Shard the pre-batched data across DDP ranks. + if self._world_size > 1: + shard_size = batch["actions"].shape[0] // self._world_size + start = self._rank * shard_size + end = start + shard_size + batch = jax.tree.map(lambda x: x[start:end], batch) + yield jax.tree.map(torch.as_tensor, batch) + + class DataLoaderImpl(DataLoader): - def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): + def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader | PyTorchRLDSDataLoader): self._data_config = data_config self._data_loader = data_loader