diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c7ddd2b595..5c97a2381a 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -373,10 +373,23 @@ def train_loop(config: _config.TrainConfig): # Get batch size from the first image tensor batch_size = next(iter(sample_batch["image"].values())).shape[0] for i in range(min(5, batch_size)): - # Concatenate all camera views horizontally for this batch item - # Convert from NCHW to NHWC format for wandb - img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) - img_concatenated = img_concatenated.cpu().numpy() + # Concatenate all camera views horizontally for this batch item. + # Images may be NCHW [C,H,W] or NHWC [H,W,C] depending on the + # dataloader backend; detect the layout and normalise to NHWC + # before handing off to wandb (see #877). + frames = [] + for img in sample_batch["image"].values(): + frame = img[i] + if frame.ndim == 3 and frame.shape[0] in (1, 3): + # NCHW → NHWC + frame = frame.permute(1, 2, 0) + frames.append(frame) + # Concatenate along the width axis (dim=1 in NHWC) + img_concatenated = torch.cat(frames, dim=1) + # Ensure pixel values are in [0, 1] for wandb + if img_concatenated.min() < 0: + img_concatenated = (img_concatenated + 1.0) / 2.0 + img_concatenated = img_concatenated.clamp(0, 1).cpu().numpy() images_to_log.append(wandb.Image(img_concatenated)) wandb.log({"camera_views": images_to_log}, step=0)