diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index eac52dafc17..628cefb6c11 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -561,6 +561,25 @@ def test_convert_image_dtype_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", [torch.uint8, torch.int8, torch.int16, torch.int32]) +def test_convert_image_dtype_half_precision(device, in_dtype, out_dtype): + image = torch.tensor([0.0, 0.5, 1.0], dtype=in_dtype, device=device).reshape(1, 1, 3) + result = F.convert_image_dtype(image, out_dtype) + max_val = torch.iinfo(out_dtype).max + assert result[0, 0, 0] == 0, f"0.0 should map to 0, got {result[0, 0, 0]}" + assert result[0, 0, 2] == max_val, f"1.0 should map to {max_val}, got {result[0, 0, 2]}" + + +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_convert_image_dtype_int_to_float16_raises(device): + for in_dtype in (torch.int32, torch.int64): + image = torch.tensor([0, 1, 2], dtype=in_dtype, device=device).reshape(1, 1, 3) + with pytest.raises(RuntimeError, match=r"cannot be performed safely"): + F.convert_image_dtype(image, torch.float16) + + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy]) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 1a9830450d5..ab7240d59b5 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -85,6 +85,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # when float is exactly 1.0. # `max + 1 - epsilon` provides more evenly distributed mapping of # ranges of floats to ints. + # float16/bfloat16 lack precision: 255.999 rounds to 256.0, overflowing to 0 on .to(uint8). + if image.dtype == torch.float16 or image.dtype == torch.bfloat16: + image = image.to(torch.float32) eps = 1e-3 max_val = float(_max_value(dtype)) result = image.mul(max_val + 1.0 - eps) @@ -95,6 +98,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # int to float # TODO: replace with dtype.is_floating_point when torchscript supports it if torch.tensor(0, dtype=dtype).is_floating_point(): + if dtype == torch.float16 and image.dtype in (torch.int32, torch.int64): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) image = image.to(dtype) return image / input_max