Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions src/openpi/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,20 @@ def create_rlds_data_loader(
If not provided, will iterate over the dataset indefinitely.
"""
if framework == "pytorch":
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
# Only rank 0 builds the heavy TF-based RLDS pipeline; other ranks
# receive batches via broadcast to avoid 4x resource contention.
is_rank0 = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
if is_rank0:
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
else:
dataset = None
data_loader = PyTorchRLDSDataLoader(
dataset,
num_batches=num_batches,
)
return DataLoaderImpl(data_config, data_loader)

dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)

Expand Down Expand Up @@ -527,8 +540,98 @@ def __iter__(self):
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)


class PyTorchRLDSDataLoader:
"""RLDS data loader for PyTorch DDP training.

Only rank 0 builds and iterates the heavy TF-based RLDS pipeline.
Each batch is broadcast to all other ranks, then sliced so every rank
gets its own shard (batch_size // world_size samples).
"""

def __init__(
self,
dataset: DroidRldsDataset | None,
*,
num_batches: int | None = None,
):
self._dataset = dataset # Only non-None on rank 0
self._num_batches = num_batches

if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
else:
self._rank = 0
self._world_size = 1

def _broadcast_batch(self, batch: dict | None) -> dict:
"""Broadcast a batch from rank 0 to all ranks using torch.distributed."""
import pickle

if self._rank == 0:
data = pickle.dumps(batch)
size_tensor = torch.tensor([len(data)], dtype=torch.long, device="cuda")
else:
size_tensor = torch.tensor([0], dtype=torch.long, device="cuda")

torch.distributed.broadcast(size_tensor, src=0)
size = size_tensor.item()

if self._rank == 0:
data_tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8).to("cuda")
else:
data_tensor = torch.empty(size, dtype=torch.uint8, device="cuda")

torch.distributed.broadcast(data_tensor, src=0)

if self._rank != 0:
batch = pickle.loads(data_tensor.cpu().numpy().tobytes())

return batch

def __iter__(self):
num_items = 0
while True:
if self._rank == 0:
data_iter = iter(self._dataset)

while True:
if self._num_batches is not None and num_items >= self._num_batches:
return

# Rank 0 produces the batch; others will receive via broadcast.
batch = None
exhausted = False
if self._rank == 0:
try:
batch = next(data_iter)
except StopIteration:
exhausted = True

if self._world_size > 1:
# Sync the exhausted flag across all ranks.
flag = torch.tensor([int(exhausted) if self._rank == 0 else 0], dtype=torch.long, device="cuda")
torch.distributed.broadcast(flag, src=0)
exhausted = flag.item() == 1

if exhausted:
break

if self._world_size > 1:
batch = self._broadcast_batch(batch)

num_items += 1
# Shard the pre-batched data across DDP ranks.
if self._world_size > 1:
shard_size = batch["actions"].shape[0] // self._world_size
start = self._rank * shard_size
end = start + shard_size
batch = jax.tree.map(lambda x: x[start:end], batch)
yield jax.tree.map(torch.as_tensor, batch)


class DataLoaderImpl(DataLoader):
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader | PyTorchRLDSDataLoader):
self._data_config = data_config
self._data_loader = data_loader

Expand Down