diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c7ddd2b595..5d1e27b534 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -374,9 +374,23 @@ def train_loop(config: _config.TrainConfig): batch_size = next(iter(sample_batch["image"].values())).shape[0] for i in range(min(5, batch_size)): # Concatenate all camera views horizontally for this batch item - # Convert from NCHW to NHWC format for wandb - img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) + frames = [] + for img in sample_batch["image"].values(): + frame = img[i] + # Detect format: NCHW has channels (3) as first dim, NHWC has channels as last dim + if frame.shape[0] == 3: + # NCHW format -> convert to NHWC + frame = frame.permute(1, 2, 0) + # else: already NHWC format, no conversion needed + frames.append(frame) + # Concatenate along width dimension (dim=1 for NHWC) + img_concatenated = torch.cat(frames, dim=1) + # Convert to numpy and ensure values are in [0, 1] range for wandb img_concatenated = img_concatenated.cpu().numpy() + if img_concatenated.min() < 0: + # Values in [-1, 1] range, rescale to [0, 1] + img_concatenated = (img_concatenated + 1) / 2 + img_concatenated = np.clip(img_concatenated, 0, 1) images_to_log.append(wandb.Image(img_concatenated)) wandb.log({"camera_views": images_to_log}, step=0)