Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions scripts/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,23 @@ def train_loop(config: _config.TrainConfig):
batch_size = next(iter(sample_batch["image"].values())).shape[0]
for i in range(min(5, batch_size)):
# Concatenate all camera views horizontally for this batch item
# Convert from NCHW to NHWC format for wandb
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
frames = []
for img in sample_batch["image"].values():
frame = img[i]
# Detect format: NCHW has channels (3) as first dim, NHWC has channels as last dim
if frame.shape[0] == 3:
# NCHW format -> convert to NHWC
frame = frame.permute(1, 2, 0)
# else: already NHWC format, no conversion needed
frames.append(frame)
# Concatenate along width dimension (dim=1 for NHWC)
img_concatenated = torch.cat(frames, dim=1)
# Convert to numpy and ensure values are in [0, 1] range for wandb
img_concatenated = img_concatenated.cpu().numpy()
if img_concatenated.min() < 0:
# Values in [-1, 1] range, rescale to [0, 1]
img_concatenated = (img_concatenated + 1) / 2
img_concatenated = np.clip(img_concatenated, 0, 1)
images_to_log.append(wandb.Image(img_concatenated))

wandb.log({"camera_views": images_to_log}, step=0)
Expand Down