diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 29618b4945..6bb9b95f61 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -168,20 +168,23 @@ def preprocess_observation( if train: # Convert from [-1, 1] to [0, 1] for augmax. image = image / 2.0 + 0.5 + sub_rngs = jax.random.split(rng, image.shape[0]) - transforms = [] + # Stage 1: Apply spatial transforms (base cameras only) if "wrist" not in key: height, width = image.shape[1:3] - transforms += [ + spatial_rngs = jax.vmap(lambda k: jax.random.fold_in(k, 0))(sub_rngs) + image = jax.vmap(augmax.Chain( augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), augmax.Resize(width, height), augmax.Rotate((-5, 5)), - ] - transforms += [ + ))(spatial_rngs, image) + + # Stage 2: Apply color transforms (all cameras, consistent RNG) + color_rngs = jax.vmap(lambda k: jax.random.fold_in(k, 1))(sub_rngs) + image = jax.vmap(augmax.Chain( augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), - ] - sub_rngs = jax.random.split(rng, image.shape[0]) - image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) + ))(color_rngs, image) # Back to [-1, 1]. image = image * 2.0 - 1.0