From a5d1fcf4d5f28db2403978e37ee3fde0ecb12bb8 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 30 Oct 2025 11:36:02 -0700 Subject: [PATCH 1/6] update randomcrop --- torchvision/transforms/transforms.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e33b3e28194..9f2e3155732 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -631,12 +631,13 @@ class RandomCrop(torch.nn.Module): """ @staticmethod - def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int, int]: + def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[torch.Generator] = None) -> tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. + generator (torch.Generator, optional): Random number generator. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. @@ -650,11 +651,11 @@ def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,)).item() - j = torch.randint(0, w - tw + 1, size=(1,)).item() + i = torch.randint(0, h - th + 1, size=(1,), generator=generator).item() + j = torch.randint(0, w - tw + 1, size=(1,), generator=generator).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant", generator=None): super().__init__() _log_api_usage_once(self) @@ -664,6 +665,7 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode + self.generator = generator def forward(self, img): """ @@ -686,7 +688,7 @@ def forward(self, img): padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) - i, j, h, w = self.get_params(img, self.size) + i, j, h, w = self.get_params(img, self.size, self.generator) return F.crop(img, i, j, h, w) From 547b0c18a0310a45c14e789dd5119680a5df6749 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 15 Jan 2026 16:25:25 -0800 Subject: [PATCH 2/6] Use torch.thread_safe_generator --- torchvision/transforms/transforms.py | 12 ++-- torchvision/transforms/v2/_geometry.py | 83 ++++++++++++++++--------- torchvision/transforms/v2/_transform.py | 3 +- 3 files changed, 59 insertions(+), 39 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9f2e3155732..e33b3e28194 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -631,13 +631,12 @@ class RandomCrop(torch.nn.Module): """ @staticmethod - def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[torch.Generator] = None) -> tuple[int, int, int, int]: + def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. - generator (torch.Generator, optional): Random number generator. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. @@ -651,11 +650,11 @@ def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[to if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,), generator=generator).item() - j = torch.randint(0, w - tw + 1, size=(1,), generator=generator).item() + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant", generator=None): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() _log_api_usage_once(self) @@ -665,7 +664,6 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode - self.generator = generator def forward(self, img): """ @@ -688,7 +686,7 @@ def forward(self, img): padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) - i, j, h, w = self.get_params(img, self.size, self.generator) + i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index fd156b80fbe..76a2f5f1eb4 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -275,13 +275,16 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) area = height * width + g = torch.thread_safe_generator() + log_ratio = self._log_ratio for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -289,8 +292,8 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() + i = torch.randint(0, height - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, width - w + 1, size=(1,), generator=g).item() break else: # Fallback to central crop @@ -541,11 +544,13 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + g = torch.thread_safe_generator() + + r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) @@ -629,7 +634,8 @@ def __init__( self.center = center def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() return dict(angle=angle) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -729,27 +735,29 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) - tx = int(round(float(torch.empty(1).uniform_(-max_dx, max_dx).item()))) - ty = int(round(float(torch.empty(1).uniform_(-max_dy, max_dy).item()))) + tx = int(round(float(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item()))) + ty = int(round(float(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item()))) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) + shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item()) if len(self.shear) == 4: - shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) + shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item()) shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -887,13 +895,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: padding = [pad_left, pad_top, pad_right, pad_bottom] needs_pad = any(padding) + g = torch.thread_safe_generator() + needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g))) if padded_width > cropped_width else (False, 0) ) @@ -972,21 +982,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: half_width = width // 2 bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 + + g = torch.thread_safe_generator() + topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -1067,7 +1080,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - dx = torch.rand(1, 1, height, width) * 2 - 1 + g = torch.thread_safe_generator() + + dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd @@ -1076,7 +1091,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / width - dy = torch.rand(1, 1, height, width) * 2 - 1 + dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd @@ -1159,16 +1174,18 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) + g = torch.thread_safe_generator() + while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g)) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h @@ -1176,7 +1193,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue # check for 0 area crops - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w @@ -1208,7 +1225,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: - if len(params) < 1: return inpt @@ -1279,7 +1295,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + g = torch.thread_safe_generator() + + scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) @@ -1345,7 +1363,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + g = torch.thread_safe_generator() + + min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) @@ -1423,7 +1443,8 @@ def __init__( self.antialias = antialias def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) + g = torch.thread_safe_generator() + size = int(torch.randint(self.min_size, self.max_size, (), generator=g)) return dict(size=[size]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..ae02e736f05 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any: self.check_inputs(flat_inputs) - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs) From fa2b373eef73dc2ac7b07a41e6e6e48e84f2a7bb Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Wed, 25 Feb 2026 14:19:48 -0800 Subject: [PATCH 3/6] add unit test mocking torch.thread_safe_generator --- test/test_transforms_v2.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a87e601e1a6..8a0c8e369fb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -8019,3 +8019,61 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + +class TestThreadSafeGenerator: + """Test that transforms correctly use torch.thread_safe_generator(). + + For multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default process global RNG, + i.e. for a multiprocessing worker the RNG of that process. + For thread workers, it returns a thread-local torch.Generator. + """ + + TRANSFORMS = [ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomResizedCrop(size=(24, 24)), + transforms.RandomRotation(degrees=10), + transforms.RandomAffine(degrees=10), + transforms.RandomCrop(size=(24, 24), pad_if_needed=True), + transforms.RandomPerspective(p=1.0), + transforms.RandomErasing(p=1.0), + transforms.ScaleJitter(target_size=(24, 24)), + ] + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_worker_uses_global_rng(self, transform): + """In multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default global (per-process) RNG. Mimic two + workers with different seeds and verify they produce different results.""" + image = make_image((32, 32)) + + with mock.patch("torch.thread_safe_generator", return_value=None): + torch.manual_seed(0) + result_worker0 = transform(image) + + with mock.patch("torch.thread_safe_generator", return_value=None): + torch.manual_seed(1) + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_worker_uses_thread_local_generator(self, transform): + """In thread workers, thread_safe_generator() returns a thread-local + Generator. Mimic two workers with differently seeded generators + and verify they produce different results.""" + image = make_image((32, 32)) + + g0 = torch.Generator() + g0.manual_seed(0) + with mock.patch("torch.thread_safe_generator", return_value=g0): + result_worker0 = transform(image) + + g1 = torch.Generator() + g1.manual_seed(1) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) From 2ff6b5047fb1ee7a01b5d7cb9ee44b99cb8eb7c6 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Wed, 25 Feb 2026 14:46:41 -0800 Subject: [PATCH 4/6] run ufmt, update mp unit test --- test/test_transforms_v2.py | 38 ++++++++++++-------------- torchvision/transforms/v2/_geometry.py | 1 - 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8a0c8e369fb..d55d0a60437 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -15,11 +15,9 @@ import numpy as np import PIL.Image import pytest - import torch import torchvision.ops import torchvision.transforms.v2 as transforms - from common_utils import ( assert_equal, cache, @@ -38,14 +36,12 @@ needs_cuda, set_rng_seed, ) - from torch import nn from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors from torchvision.ops.boxes import box_iou - from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F @@ -61,7 +57,6 @@ ) from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -8031,8 +8026,6 @@ class TestThreadSafeGenerator: """ TRANSFORMS = [ - transforms.RandomHorizontalFlip(p=0.5), - transforms.RandomVerticalFlip(p=0.5), transforms.RandomResizedCrop(size=(24, 24)), transforms.RandomRotation(degrees=10), transforms.RandomAffine(degrees=10), @@ -8042,22 +8035,27 @@ class TestThreadSafeGenerator: transforms.ScaleJitter(target_size=(24, 24)), ] - @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) - def test_multiprocessing_worker_uses_global_rng(self, transform): - """In multiprocessing workers, thread_safe_generator() returns None, - so transforms use the default global (per-process) RNG. Mimic two - workers with different seeds and verify they produce different results.""" - image = make_image((32, 32)) + class TransformDataset(torch.utils.data.Dataset): + def __init__(self, size, transform): + self.size = size + self.transform = transform + self.image = make_image((32, 32)) - with mock.patch("torch.thread_safe_generator", return_value=None): - torch.manual_seed(0) - result_worker0 = transform(image) + def __getitem__(self, idx): + return self.transform(self.image) - with mock.patch("torch.thread_safe_generator", return_value=None): - torch.manual_seed(1) - result_worker1 = transform(image) + def __len__(self): + return self.size - assert not torch.equal(result_worker0, result_worker1) + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_workers(self, transform): + """With multiprocessing DataLoader workers, thread_safe_generator() + returns None and transforms use the per-process global RNG. + Each worker gets a different seed, so results should differ.""" + dataset = self.TransformDataset(size=2, transform=transform) + dl = DataLoader(dataset, batch_size=1, num_workers=2) + batch0, batch1 = list(dl) + assert not torch.equal(batch0, batch1) @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) def test_thread_worker_uses_thread_local_generator(self, transform): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 76a2f5f1eb4..2f921a65e6d 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -6,7 +6,6 @@ import PIL.Image import torch - from torchvision import transforms as _transforms, tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs From d935a9930fffef7f630ee1b8104317353034d46d Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Mon, 16 Mar 2026 15:21:08 -0700 Subject: [PATCH 5/6] cover more random transforms --- test/test_transforms_v2.py | 62 +++++++++++++++++++++- torchvision/transforms/v2/_augment.py | 18 ++++--- torchvision/transforms/v2/_auto_augment.py | 48 ++++++++++------- torchvision/transforms/v2/_color.py | 27 ++++++---- torchvision/transforms/v2/_container.py | 9 ++-- torchvision/transforms/v2/_misc.py | 3 +- 6 files changed, 126 insertions(+), 41 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index d55d0a60437..34e68c20461 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -8033,6 +8033,19 @@ class TestThreadSafeGenerator: transforms.RandomPerspective(p=1.0), transforms.RandomErasing(p=1.0), transforms.ScaleJitter(target_size=(24, 24)), + transforms.RandomZoomOut(), + transforms.ElasticTransform(), + transforms.RandomShortestSize(min_size=(20, 24)), + transforms.RandomResize(min_size=20, max_size=28), + transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), + transforms.RandomChannelPermutation(), + transforms.RandomPhotometricDistort(), + transforms.AutoAugment(), + transforms.RandAugment(), + transforms.TrivialAugmentWide(), + transforms.AugMix(), + transforms.JPEG(quality=(1, 100)), + transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), ] class TransformDataset(torch.utils.data.Dataset): @@ -8070,8 +8083,55 @@ def test_thread_worker_uses_thread_local_generator(self, transform): result_worker0 = transform(image) g1 = torch.Generator() - g1.manual_seed(1) + g1.manual_seed(5) with mock.patch("torch.thread_safe_generator", return_value=g1): result_worker1 = transform(image) assert not torch.equal(result_worker0, result_worker1) + + def test_thread_generator_random_iou_crop(self): + """RandomIoUCrop requires bounding boxes, so test it separately.""" + image = make_image((32, 32)) + bboxes = make_bounding_boxes(canvas_size=(32, 32), format="XYXY", num_boxes=3) + + transform = transforms.RandomIoUCrop() + + results = [] + for seed in (0, 1): + g = torch.Generator() + g.manual_seed(seed) + with mock.patch("torch.thread_safe_generator", return_value=g): + result = transform(image, bboxes) + results.append(result) + + # The image output should differ between different seeds + assert not torch.equal(results[0][0], results[1][0]) + + # Reproducibility test list: includes flips which are excluded from + # the divergence tests above. + ALL_TRANSFORMS = TRANSFORMS + [ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + ] + + @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_generator_reproducibility(self, transform): + """Verify transforms use the provided generator, not the global RNG. + Same seeded generator should produce identical results even when + the global RNG state changes between calls.""" + image = make_image((32, 32)) + + g1 = torch.Generator() + g1.manual_seed(42) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result1 = transform(image) + + # Advance global RNG so it's in a different state + torch.rand(100) + + g2 = torch.Generator() + g2.manual_seed(42) + with mock.patch("torch.thread_safe_generator", return_value=g2): + result2 = transform(image) + + torch.testing.assert_close(result1, result2) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 4b1d15a91e1..b258d3cc845 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -106,12 +106,14 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: area = img_h * img_w log_ratio = self._log_ratio + g = torch.thread_safe_generator() for _ in range(10): - erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -121,12 +123,12 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue if self.value is None: - v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_(generator=g) else: v = torch.tensor(self.value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = torch.randint(0, img_h - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, img_w - w + 1, size=(1,), generator=g).item() break else: i, j, h, w, v = 0, 0, img_h, img_w, None @@ -298,8 +300,9 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: H, W = query_size(flat_inputs) - r_x = torch.randint(W, size=(1,)) - r_y = torch.randint(H, size=(1,)) + g = torch.thread_safe_generator() + r_x = torch.randint(W, size=(1,), generator=g) + r_y = torch.randint(H, size=(1,), generator=g) r = 0.5 * math.sqrt(1.0 - lam) r_w_half = int(r * W) @@ -365,7 +368,8 @@ def __init__(self, quality: Union[int, Sequence[int]]): self.quality = quality def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() + g = torch.thread_safe_generator() + quality = torch.randint(self.quality[0], self.quality[1] + 1, (), generator=g).item() return dict(quality=quality) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 52707af1f2e..714a18fc823 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -38,9 +38,11 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]: return params - def _get_random_item(self, dct: dict[str, tuple[Callable, bool]]) -> tuple[str, tuple[Callable, bool]]: + def _get_random_item( + self, dct: dict[str, tuple[Callable, bool]], generator: torch.Generator = None + ) -> tuple[str, tuple[Callable, bool]]: keys = tuple(dct.keys()) - key = keys[int(torch.randint(len(keys), ()))] + key = keys[int(torch.randint(len(keys), (), generator=generator))] return key, dct[key] def _flatten_and_extract_image_or_video( @@ -327,10 +329,11 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] - policy = self._policies[int(torch.randint(len(self._policies), ()))] + g = torch.thread_safe_generator() + policy = self._policies[int(torch.randint(len(self._policies), (), generator=g))] for transform_id, probability, magnitude_idx in policy: - if not torch.rand(()) <= probability: + if not torch.rand((), generator=g) <= probability: continue magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] @@ -338,7 +341,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -419,12 +422,13 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] + g = torch.thread_safe_generator() for _ in range(self.num_ops): - transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[self.magnitude]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -488,12 +492,13 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] - transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + g = torch.thread_safe_generator() + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=g))]) + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -572,9 +577,9 @@ def __init__( self.alpha = alpha self.all_ops = all_ops - def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + def _sample_dirichlet(self, params: torch.Tensor, generator: torch.Generator = None) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. - return torch._sample_dirichlet(params) + return torch._sample_dirichlet(params, generator) def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) @@ -595,26 +600,33 @@ def forward(self, *inputs: Any) -> Any: # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of # augmented image or video. + g = torch.thread_safe_generator() m = self._sample_dirichlet( - torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1), + generator=g, ) # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. combined_weights = self._sample_dirichlet( - torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1), + generator=g, ) * m[:, 1].reshape([batch_dims[0], -1]) mix = m[:, 0].reshape(batch_dims) * batch for i in range(self.mixture_width): aug = batch - depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + depth = ( + self.chain_depth + if self.chain_depth > 0 + else int(torch.randint(low=1, high=4, size=(1,), generator=g).item()) + ) for _ in range(depth): - transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space, generator=g) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=g))]) + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index e249dbc0b1f..eff19da3a9c 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -140,16 +140,17 @@ def _check_input( return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod - def _generate_value(left: float, right: float) -> float: - return float(torch.empty(1).uniform_(left, right).item()) + def _generate_value(left: float, right: float, generator: torch.Generator = None) -> float: + return float(torch.empty(1).uniform_(left, right, generator=generator).item()) def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - fn_idx = torch.randperm(4) + g = torch.thread_safe_generator() + fn_idx = torch.randperm(4, generator=g) - b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) - c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) - s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) - h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1], g) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1], g) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1], g) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], g) return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) @@ -176,7 +177,8 @@ class RandomChannelPermutation(Transform): def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict(permutation=torch.randperm(num_channels)) + g = torch.thread_safe_generator() + return dict(permutation=torch.randperm(num_channels, generator=g)) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.permute_channels, inpt, params["permutation"]) @@ -223,8 +225,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) + g = torch.thread_safe_generator() params: dict[str, Any] = { - key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + key: ColorJitter._generate_value(range[0], range[1], g) if torch.rand(1, generator=g) < self.p else None for key, range in [ ("brightness_factor", self.brightness), ("contrast_factor", self.contrast), @@ -232,8 +235,10 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: ("hue_factor", self.hue), ] } - params["contrast_before"] = bool(torch.rand(()) < 0.5) - params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + params["contrast_before"] = bool(torch.rand((), generator=g) < 0.5) + params["channel_permutation"] = ( + torch.randperm(num_channels, generator=g) if torch.rand(1, generator=g) < self.p else None + ) return params def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 95ec25a22f8..dee01f4802e 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -101,7 +101,8 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]: def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs if needs_unpacking else inputs[0] for transform in self.transforms: @@ -149,7 +150,8 @@ def __init__( self.p = [prob / total for prob in p] def forward(self, *inputs: Any) -> Any: - idx = int(torch.multinomial(torch.tensor(self.p), 1)) + g = torch.thread_safe_generator() + idx = int(torch.multinomial(torch.tensor(self.p), 1, generator=g)) transform = self.transforms[idx] return transform(*inputs) @@ -173,7 +175,8 @@ def __init__(self, transforms: Sequence[Callable]) -> None: def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - for idx in torch.randperm(len(self.transforms)): + g = torch.thread_safe_generator() + for idx in torch.randperm(len(self.transforms), generator=g): transform = self.transforms[idx] outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..66f4f8e18cf 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -207,7 +207,8 @@ def __init__( raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + g = torch.thread_safe_generator() + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1], generator=g).item() return dict(sigma=[sigma, sigma]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: From 4b142d52241eeb9729f8447a6397e99143f18f3a Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Fri, 24 Apr 2026 16:15:05 -0700 Subject: [PATCH 6/6] linter --- torchvision/transforms/v2/_auto_augment.py | 4 ++-- torchvision/transforms/v2/_color.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 714a18fc823..cff64f515d5 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -39,7 +39,7 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]: return params def _get_random_item( - self, dct: dict[str, tuple[Callable, bool]], generator: torch.Generator = None + self, dct: dict[str, tuple[Callable, bool]], generator: Optional[torch.Generator] = None ) -> tuple[str, tuple[Callable, bool]]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), (), generator=generator))] @@ -577,7 +577,7 @@ def __init__( self.alpha = alpha self.all_ops = all_ops - def _sample_dirichlet(self, params: torch.Tensor, generator: torch.Generator = None) -> torch.Tensor: + def _sample_dirichlet(self, params: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. return torch._sample_dirichlet(params, generator) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index eff19da3a9c..607199e441e 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -140,7 +140,7 @@ def _check_input( return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod - def _generate_value(left: float, right: float, generator: torch.Generator = None) -> float: + def _generate_value(left: float, right: float, generator: Optional[torch.Generator] = None) -> float: return float(torch.empty(1).uniform_(left, right, generator=generator).item()) def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: