From 7d78ae8b710abb2efca860ed7f736adc1de31a7f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Mar 2026 01:43:07 -0700 Subject: [PATCH] [torchvision] Move video utilities to internal fb/io location (#9433) (#9433) Summary: This diff moves the torchvision.io video utilities (read_video, write_video, read_video_timestamps) from the OSS location to the internal fb/io location. Changes: - Created fb/io/video.py with the full implementation of video functions - Updated fb/io/__init__.py to export video functions and internal helpers - Updated torchvision/io/__init__.py to import video functions from fb/io - Made torchvision/io/video.py a stub that re-exports from fb/io/video.py - Removed video functions from torchvision.io.__all__ (for OSS) - The (rare) TorchVision stuff that depended on video decoding utils have been migrated to TorchCodec (that's for TV only, not for the rest of the internal users). This change ensures: - Internal users can continue using 'from torchvision.io import read_video' - Internal users importing from torchvision.io.video directly still work - No BUCK changes are required - OSS/GitHub users will no longer have access to these deprecated video APIs Eventually we'll want to migrate all these internal users to TorchCodec, but that's for later. Note: in D95081771 I bluntly removed all the torchvision decoder stuff from both GH and fbcode, but that ended up leading to tons of failed tests (that weren't triggered on the diff!). This new diff is a softer version of that. Differential Revision: D95933713 Pulled By: NicolasHug --- .github/scripts/setup-env.sh | 5 - .github/workflows/docs.yml | 10 +- gallery/others/plot_optical_flow.py | 19 +- test/common_utils.py | 11 +- test/datasets_utils.py | 34 +- test/test_datasets_samplers.py | 15 +- test/test_datasets_video_utils.py | 19 +- test/test_io.py | 247 --------------- torchvision/datasets/video_utils.py | 47 ++- torchvision/io/__init__.py | 13 +- torchvision/io/video.py | 476 ++-------------------------- 11 files changed, 141 insertions(+), 755 deletions(-) diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index 929869caadd..0ad0f0cc286 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -34,11 +34,6 @@ conda activate ci conda install --quiet --yes libjpeg-turbo -c pytorch pip install --progress-bar=off --upgrade setuptools==72.1.0 -# See https://github.com/pytorch/vision/issues/6790 -if [[ "${PYTHON_VERSION}" != "3.11" ]]; then - pip install --progress-bar=off av!=10.0.0 -fi - echo '::endgroup::' if [[ "${OS_TYPE}" == windows && "${GPU_ARCH_TYPE}" == cuda ]]; then diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8b341622181..c030b2f7493 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -34,12 +34,12 @@ jobs: CONDA_PATH=$(which conda) eval "$(${CONDA_PATH} shell.bash hook)" conda activate ci - # FIXME: not sure why we need this. `ldd torchvision/video_reader.so` shows that it - # already links against the one pulled from conda. However, at runtime it pulls from - # /lib64 - # Should we maybe always do this in `./.github/scripts/setup-env.sh` so that we don't - # have to pay attention in all other workflows? + + echo '::group::Install TorchCodec and ffmpeg' + conda install --quiet --yes ffmpeg + pip install --progress-bar=off --pre torchcodec --index-url="https://download.pytorch.org/whl/nightly/cpu" export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" + echo '::endgroup::' cd docs diff --git a/gallery/others/plot_optical_flow.py b/gallery/others/plot_optical_flow.py index 6296c8e667e..a80804e6db5 100644 --- a/gallery/others/plot_optical_flow.py +++ b/gallery/others/plot_optical_flow.py @@ -47,11 +47,10 @@ def plot(imgs, **imshow_kwargs): plt.tight_layout() # %% -# Reading Videos Using Torchvision +# Reading Videos Using TorchCodec # -------------------------------- -# We will first read a video using :func:`~torchvision.io.read_video`. -# Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if -# torchvision is built from source). +# We will first read a video using +# `TorchCodec `_. # The video we will use here is free of use from `pexels.com # `_, # credits go to `Pavel Danilyuk `_. @@ -67,16 +66,16 @@ def plot(imgs, **imshow_kwargs): _ = urlretrieve(video_url, video_path) # %% -# :func:`~torchvision.io.read_video` returns the video frames, audio frames and -# the metadata associated with the video. In our case, we only need the video -# frames. +# We use :class:`~torchcodec.decoders.VideoDecoder` to decode the video frames. +# TorchCodec returns frames in NCHW format by default. # # Here we will just make 2 predictions between 2 pre-selected pairs of frames, # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a # single model input. -from torchvision.io import read_video -frames, _, _ = read_video(str(video_path), output_format="TCHW") +from torchcodec.decoders import VideoDecoder +decoder = VideoDecoder(str(video_path)) +frames = decoder[:] img1_batch = torch.stack([frames[100], frames[150]]) img2_batch = torch.stack([frames[101], frames[151]]) @@ -85,7 +84,7 @@ def plot(imgs, **imshow_kwargs): # %% # The RAFT model accepts RGB images. We first get the frames from -# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions +# the decoder and resize them to ensure their dimensions # are divisible by 8. Note that we explicitly use ``antialias=False``, because # this is how those models were trained. Then we use the transforms bundled into # the weights in order to preprocess the input and rescale its values to the diff --git a/test/common_utils.py b/test/common_utils.py index 24ebb1376c3..1459f52cbbe 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -18,7 +18,7 @@ import torch.testing from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair -from torchvision import io, tv_tensors +from torchvision import tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor @@ -166,6 +166,8 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): + from datasets_utils import create_video_file + names = [] for i in range(num_videos): if sizes is None: @@ -176,10 +178,9 @@ def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): f = 5 else: f = fps[i] - data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) - name = os.path.join(tmpdir, f"{i}.mp4") - names.append(name) - io.write_video(name, data, fps=f) + name = f"{i}.mp4" + create_video_file(tmpdir, name, size=(size, 3, 300, 400), fps=f) + names.append(os.path.join(tmpdir, name)) return names diff --git a/test/datasets_utils.py b/test/datasets_utils.py index cbfb26b6c6b..46d82f5e784 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -66,7 +66,7 @@ class LazyImporter: """ MODULES = ( - "av", + "torchcodec", "lmdb", "pycocotools", "requests", @@ -669,17 +669,24 @@ class VideoDatasetTestCase(DatasetTestCase): - Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as well as an integer label. - - Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``). + - Overwrites the 'REQUIRED_PACKAGES' class attribute to require TorchCodec (``torchcodec``). + - Skips on non-Linux platforms and CUDA-only environments. - Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()' and it is the last parameter without a default value in the dataset constructor, the value of the 'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output. """ FEATURE_TYPES = (torch.Tensor, torch.Tensor, int) - REQUIRED_PACKAGES = ("av",) + REQUIRED_PACKAGES = ("torchcodec",) FRAMES_PER_CLIP = 1 + @classmethod + def setUpClass(cls): + if platform.system() != "Linux": + raise unittest.SkipTest("Video dataset tests are only supported on Linux.") + super().setUpClass() + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) @@ -864,13 +871,12 @@ def shape_test_for_stereo( assert dw == mw -@requires_lazy_imports("av") +@requires_lazy_imports("torchcodec") def create_video_file( root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = (1, 3, 10, 10), fps: float = 25, - **kwargs: Any, ) -> pathlib.Path: """Create a video file from random data. @@ -881,14 +887,15 @@ def create_video_file( ``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width. If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed. fps (float): Frame rate in frames per second. - kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`. Returns: - pathlib.Path: Path to the created image file. + pathlib.Path: Path to the created video file. Raises: - UsageError: If PyAV is not available. + UsageError: If TorchCodec is not available. """ + from torchcodec.encoders import VideoEncoder + if isinstance(size, int): size = (size, size) if len(size) == 2: @@ -902,11 +909,14 @@ def create_video_file( video = create_image_or_video_tensor(size) file = pathlib.Path(root) / name - torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs) + + encoder = VideoEncoder(video, frame_rate=fps) + encoder.to_file(str(file)) + return file -@requires_lazy_imports("av") +@requires_lazy_imports("torchcodec") def create_video_folder( root: Union[str, pathlib.Path], name: Union[str, pathlib.Path], @@ -933,7 +943,7 @@ def create_video_folder( List[pathlib.Path]: Paths to all created video files. Raises: - UsageError: If PyAV is not available. + UsageError: If TorchCodec is not available. .. seealso:: @@ -944,7 +954,7 @@ def create_video_folder( def size(idx): num_frames = 1 num_channels = 3 - # The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and + # The 'libx264' video codec requires the height and # width of the video to be divisible by 2. height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist() return (num_frames, num_channels, height, width) diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 9e3826b2c13..222890da20c 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,12 +1,23 @@ +import sys + import pytest import torch from common_utils import assert_equal, get_list_of_videos -from torchvision import io from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.datasets.video_utils import VideoClips +try: + import torchcodec # noqa: F401 + + _torchcodec_available = True +except ImportError: + _torchcodec_available = False + -@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") +@pytest.mark.skipif( + not (_torchcodec_available and sys.platform == "linux"), + reason="this test requires torchcodec (linux only)", +) class TestDatasetsSamplers: def test_random_clip_sampler(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 51330911e50..6d066a382b3 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,9 +1,22 @@ +import sys + import pytest import torch from common_utils import assert_equal, get_list_of_videos -from torchvision import io from torchvision.datasets.video_utils import unfold, VideoClips +try: + import torchcodec # noqa: F401 + + _torchcodec_available = True +except ImportError: + _torchcodec_available = False + +_requires_torchcodec = pytest.mark.skipif( + not (_torchcodec_available and sys.platform == "linux"), + reason="this test requires torchcodec (linux only)", +) + class TestVideo: def test_unfold(self): @@ -31,7 +44,7 @@ def test_unfold(self): ) assert_equal(r, expected) - @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + @_requires_torchcodec def test_video_clips(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3) video_clips = VideoClips(video_list, 5, 5, num_workers=2) @@ -55,7 +68,7 @@ def test_video_clips(self, tmpdir): assert video_idx == v_idx assert clip_idx == c_idx - @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + @_requires_torchcodec def test_video_clips_custom_fps(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) num_frames = 4 diff --git a/test/test_io.py b/test/test_io.py index 84d30ee3297..5194421105a 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,251 +1,4 @@ -import contextlib -import os -import sys -import tempfile - import pytest -import torch -import torchvision.io as io -from common_utils import assert_equal, cpu_and_cuda - - -try: - import av - - # Do a version test too - io.video._check_av_available() -except ImportError: - av = None - - -VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") - - -def _create_video_frames(num_frames, height, width): - y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij") - data = [] - for i in range(num_frames): - xc = float(i) / num_frames - yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 - data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) - - return torch.stack(data, 0) - - -@contextlib.contextmanager -def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): - if lossless: - if video_codec is not None: - raise ValueError("video_codec can't be specified together with lossless") - if options is not None: - raise ValueError("options can't be specified together with lossless") - video_codec = "libx264rgb" - options = {"crf": "0"} - - if video_codec is None: - video_codec = "libx264" - if options is None: - options = {} - - data = _create_video_frames(num_frames, height, width) - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.close() - io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) - yield f.name, data - os.unlink(f.name) - - -@pytest.mark.skipif(av is None, reason="PyAV unavailable") -class TestVideo: - # compression adds artifacts, thus we add a tolerance of - # 6 in 0-255 range - TOLERANCE = 6 - - def test_write_read_video(self): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name) - assert_equal(data, lv) - assert info["video_fps"] == 5 - - def test_read_timestamps(self): - with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - # note: not all formats/codecs provide accurate information for computing the - # timestamps. For the format that we use here, this information is available, - # so we use it as a baseline - with av.open(f_name) as container: - stream = container.streams[0] - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] - - assert pts == expected_pts - - @pytest.mark.parametrize("start", range(5)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video(self, start, offset): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv) - - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) - assert len(lv) == 4 - assert_equal(data[4:8], lv) - - @pytest.mark.parametrize("start", range(0, 80, 20)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video_bframes(self, start, offset): - # do not use lossless encoding, to test the presence of B-frames - options = {"bframes": "16", "keyint": "10", "min-keyint": "4"} - with temp_video(100, 300, 300, 5, options=options) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE) - - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) - assert len(lv) == 4 - assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE) - - def test_read_packed_b_frames_divx_file(self): - name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" - f_name = os.path.join(VIDEO_DIR, name) - pts, fps = io.read_video_timestamps(f_name) - - assert pts == sorted(pts) - assert fps == 30 - - def test_read_timestamps_from_packet(self): - with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - # note: not all formats/codecs provide accurate information for computing the - # timestamps. For the format that we use here, this information is available, - # so we use it as a baseline - with av.open(f_name) as container: - stream = container.streams[0] - # make sure we went through the optimized codepath - assert b"Lavc" in stream.codec_context.extradata - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] - - assert pts == expected_pts - - def test_read_video_pts_unit_sec(self): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name, pts_unit="sec") - - assert_equal(data, lv) - assert info["video_fps"] == 5 - assert info == {"video_fps": 5} - - def test_read_timestamps_pts_unit_sec(self): - with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - - with av.open(f_name) as container: - stream = container.streams[0] - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)] - - assert pts == expected_pts - - @pytest.mark.parametrize("start", range(5)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video_pts_unit_sec(self, start, offset): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec") - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv) - - with av.open(f_name) as container: - stream = container.streams[0] - lv, _, _ = io.read_video( - f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec" - ) - assert len(lv) == 4 - assert_equal(data[4:8], lv) - - def test_read_video_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.write(b"This is not an mpg4 file") - video, audio, info = io.read_video(f.name) - assert isinstance(video, torch.Tensor) - assert isinstance(audio, torch.Tensor) - assert video.numel() == 0 - assert audio.numel() == 0 - assert info == {} - - def test_read_video_timestamps_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.write(b"This is not an mpg4 file") - video_pts, video_fps = io.read_video_timestamps(f.name) - assert video_pts == [] - assert video_fps is None - - @pytest.mark.skip(reason="Temporarily disabled due to new pyav") - def test_read_video_partially_corrupted_file(self): - with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): - with open(f_name, "r+b") as f: - size = os.path.getsize(f_name) - bytes_to_overwrite = size // 10 - # seek to the middle of the file - f.seek(5 * bytes_to_overwrite) - # corrupt 10% of the file from the middle - f.write(b"\xff" * bytes_to_overwrite) - # this exercises the container.decode assertion check - video, audio, info = io.read_video(f.name, pts_unit="sec") - # check that size is not equal to 5, but 3 - assert len(video) == 3 - # but the valid decoded content is still correct - assert_equal(video[:3], data[:3]) - # and the last few frames are wrong - with pytest.raises(AssertionError): - assert_equal(video, data) - - @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_write_video_with_audio(self, device, tmpdir): - f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") - video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") - - out_f_name = os.path.join(tmpdir, "testing.mp4") - io.video.write_video( - out_f_name, - video_tensor.to(device), - round(info["video_fps"]), - video_codec="libx264rgb", - options={"crf": "0"}, - audio_array=audio_tensor.to(device), - audio_fps=info["audio_fps"], - audio_codec="aac", - ) - - out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec") - - assert info["video_fps"] == out_info["video_fps"] - assert_equal(video_tensor, out_video_tensor) - - audio_stream = av.open(f_name).streams.audio[0] - out_audio_stream = av.open(out_f_name).streams.audio[0] - - assert info["audio_fps"] == out_info["audio_fps"] - assert audio_stream.rate == out_audio_stream.rate - assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames - assert audio_stream.frame_size == out_audio_stream.frame_size - - # TODO add tests for audio if __name__ == "__main__": diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 93257a1c482..a95737a571d 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -4,13 +4,24 @@ from typing import Any, Optional, TypeVar, Union import torch -from torchvision.io import read_video, read_video_timestamps from .utils import tqdm T = TypeVar("T") +def _get_torchcodec(): + try: + import torchcodec # type: ignore[import-not-found] + except ImportError: + raise ImportError( + "Video decoding capabilities were removed from torchvision and migrated " + "to TorchCodec. Please install TorchCodec following instructions at " + "https://github.com/pytorch/torchcodec#installing-torchcodec" + ) + return torchcodec + + def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: """ similar to tensor.unfold, but with the dilation @@ -47,7 +58,11 @@ def __len__(self) -> int: return len(self.video_paths) def __getitem__(self, idx: int) -> tuple[list[int], Optional[float]]: - return read_video_timestamps(self.video_paths[idx]) + torchcodec = _get_torchcodec() + decoder = torchcodec.decoders.VideoDecoder(self.video_paths[idx]) + num_frames = decoder.metadata.num_frames + fps = decoder.metadata.average_fps + return list(range(num_frames)), fps def _collate_fn(x: T) -> T: @@ -292,9 +307,27 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] - start_pts = int(clip_pts[0].item()) - end_pts = int(clip_pts[-1].item()) - video, audio, info = read_video(video_path, start_pts, end_pts) + start_idx = int(clip_pts[0].item()) + end_idx = int(clip_pts[-1].item()) + + torchcodec = _get_torchcodec() + + dimension_order = "NHWC" if self.output_format == "THWC" else "NCHW" + decoder = torchcodec.decoders.VideoDecoder(video_path, dimension_order=dimension_order) + video = decoder.get_frames_at(indices=list(range(start_idx, end_idx + 1))).data + + # Audio via TorchCodec + fps = decoder.metadata.average_fps + start_sec = start_idx / fps + end_sec = (end_idx + 1) / fps + try: + audio_decoder = torchcodec.decoders.AudioDecoder(video_path) + audio_samples = audio_decoder.get_samples_played_in_range(start_seconds=start_sec, stop_seconds=end_sec) + audio = audio_samples.data + except Exception: + audio = torch.empty((1, 0), dtype=torch.float32) + + info = {"video_fps": fps} if self.frame_rate is not None: resampling_idx = self.resampling_idxs[video_idx][clip_idx] @@ -304,10 +337,6 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] info["video_fps"] = self.frame_rate assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" - if self.output_format == "TCHW": - # [T,H,W,C] --> [T,C,H,W] - video = video.permute(0, 3, 1, 2) - return video, audio, info, video_idx def __getstate__(self) -> dict[str, Any]: diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index a486b0275e1..02e28e107c6 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -18,6 +18,15 @@ except ImportError: pass +try: + from pytorch.vision.fb.io.video import ( # type: ignore[import-not-found] + read_video, + read_video_timestamps, + write_video, + ) +except ImportError: + pass + from .image import ( decode_avif, decode_gif, @@ -35,13 +44,9 @@ write_jpeg, write_png, ) -from .video import read_video, read_video_timestamps, write_video __all__ = [ - "write_video", - "read_video", - "read_video_timestamps", "ImageReadMode", "decode_image", "decode_jpeg", diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 5331b764d27..87fe36f2caa 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,453 +1,23 @@ -import gc -import math -import re -import warnings -from fractions import Fraction -from typing import Any, Optional, Union - -import numpy as np -import torch - -from ..utils import _log_api_usage_once -from ._video_deprecation_warning import _raise_video_deprecation_warning - -try: - import av - - av.logging.set_level(av.logging.ERROR) - if not hasattr(av.video.frame.VideoFrame, "pict_type"): - av = ImportError( - """\ -Your version of PyAV is too old for the necessary video operations in torchvision. -If you are on Python 3.5, you will have to build from source (the conda-forge -packages are not up-to-date). See -https://github.com/mikeboers/PyAV#installation for instructions on how to -install PyAV on your system. -""" - ) - try: - FFmpegError = av.FFmpegError # from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rst - except AttributeError: - FFmpegError = av.AVError -except ImportError: - av = ImportError( - """\ -PyAV is not installed, and is necessary for the video operations in torchvision. -See https://github.com/mikeboers/PyAV#installation for instructions on how to -install PyAV on your system. -""" - ) - - -def _check_av_available() -> None: - if isinstance(av, Exception): - raise av - - -def _av_available() -> bool: - return not isinstance(av, Exception) - - -# PyAV has some reference cycles -_CALLED_TIMES = 0 -_GC_COLLECTION_INTERVAL = 10 - - -def write_video( - filename: str, - video_array: torch.Tensor, - fps: float, - video_codec: str = "libx264", - options: Optional[dict[str, Any]] = None, - audio_array: Optional[torch.Tensor] = None, - audio_fps: Optional[float] = None, - audio_codec: Optional[str] = None, - audio_options: Optional[dict[str, Any]] = None, -) -> None: - """ - [DEPRECATED] Writes a 4d tensor in [T, H, W, C] format in a video file. - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - This function relies on PyAV (therefore, ultimately FFmpeg) to encode - videos, you can get more fine-grained control by referring to the other - options at your disposal within `the FFMpeg wiki - `_. - - Args: - filename (str): path where the video will be saved - video_array (Tensor[T, H, W, C]): tensor containing the individual frames, - as a uint8 tensor in [T, H, W, C] format - fps (Number): video frames per second - video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. - options (Dict): dictionary containing options to be passed into the PyAV video stream. - The list of options is codec-dependent and can all - be found from `the FFMpeg wiki `_. - audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels - and N is the number of samples - audio_fps (Number): audio sample rate, typically 44100 or 48000 - audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. - audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream. - The list of options is codec-dependent and can all - be found from `the FFMpeg wiki `_. - - Examples:: - >>> # Creating libx264 video with CRF 17, for visually lossless footage: - >>> - >>> from torchvision.io import write_video - >>> # 1000 frames of 100x100, 3-channel image. - >>> vid = torch.randn(1000, 100, 100, 3, dtype = torch.uint8) - >>> write_video("video.mp4", options = {"crf": "17"}) - - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_video) - _check_av_available() - video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True) - - # PyAV does not support floating point numbers with decimal point - # and will throw OverflowException in case this is not the case - if isinstance(fps, float): - fps = int(np.round(fps)) - - with av.open(filename, mode="w") as container: - stream = container.add_stream(video_codec, rate=fps) - stream.width = video_array.shape[2] - stream.height = video_array.shape[1] - stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" - stream.options = options or {} - - if audio_array is not None: - audio_format_dtypes = { - "dbl": " 1 else "mono" - audio_sample_fmt = container.streams.audio[0].format.name - - format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) - audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype) - - frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout) - - frame.sample_rate = audio_fps - - for packet in a_stream.encode(frame): - container.mux(packet) - - for packet in a_stream.encode(): - container.mux(packet) - - for img in video_array: - frame = av.VideoFrame.from_ndarray(img, format="rgb24") - try: - frame.pict_type = "NONE" - except TypeError: - from av.video.frame import PictureType # noqa - - frame.pict_type = PictureType.NONE - - for packet in stream.encode(frame): - container.mux(packet) - - # Flush stream - for packet in stream.encode(): - container.mux(packet) - - -def _read_from_stream( - container: "av.container.Container", - start_offset: float, - end_offset: float, - pts_unit: str, - stream: "av.stream.Stream", - stream_name: dict[str, Optional[Union[int, tuple[int, ...], list[int]]]], -) -> list["av.frame.Frame"]: - global _CALLED_TIMES, _GC_COLLECTION_INTERVAL - _CALLED_TIMES += 1 - if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: - gc.collect() - - if pts_unit == "sec": - # TODO: we should change all of this from ground up to simply take - # sec and convert to MS in C++ - start_offset = int(math.floor(start_offset * (1 / stream.time_base))) - if end_offset != float("inf"): - end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) - else: - warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") - - frames = {} - should_buffer = True - max_buffer_size = 5 - if stream.type == "video": - # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) - # so need to buffer some extra frames to sort everything - # properly - extradata = stream.codec_context.extradata - # overly complicated way of finding if `divx_packed` is set, following - # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 - if extradata and b"DivX" in extradata: - # can't use regex directly because of some weird characters sometimes... - pos = extradata.find(b"DivX") - d = extradata[pos:] - o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) - if o is None: - o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) - if o is not None: - should_buffer = o.group(3) == b"p" - seek_offset = start_offset - # some files don't seek to the right location, so better be safe here - seek_offset = max(seek_offset - 1, 0) - if should_buffer: - # FIXME this is kind of a hack, but we will jump to the previous keyframe - # so this will be safe - seek_offset = max(seek_offset - max_buffer_size, 0) - try: - # TODO check if stream needs to always be the video stream here or not - container.seek(seek_offset, any_frame=False, backward=True, stream=stream) - except FFmpegError: - # TODO add some warnings in this case - # print("Corrupted file?", container.name) - return [] - buffer_count = 0 - try: - for _idx, frame in enumerate(container.decode(**stream_name)): - frames[frame.pts] = frame - if frame.pts >= end_offset: - if should_buffer and buffer_count < max_buffer_size: - buffer_count += 1 - continue - break - except FFmpegError: - # TODO add a warning - pass - # ensure that the results are sorted wrt the pts - result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] - if len(frames) > 0 and start_offset > 0 and start_offset not in frames: - # if there is no frame that exactly matches the pts of start_offset - # add the last frame smaller than start_offset, to guarantee that - # we will have all the necessary data. This is most useful for audio - preceding_frames = [i for i in frames if i < start_offset] - if len(preceding_frames) > 0: - first_frame_pts = max(preceding_frames) - result.insert(0, frames[first_frame_pts]) - return result - - -def _align_audio_frames( - aframes: torch.Tensor, audio_frames: list["av.frame.Frame"], ref_start: int, ref_end: float -) -> torch.Tensor: - start, end = audio_frames[0].pts, audio_frames[-1].pts - total_aframes = aframes.shape[1] - step_per_aframe = (end - start + 1) / total_aframes - s_idx = 0 - e_idx = total_aframes - if start < ref_start: - s_idx = int((ref_start - start) / step_per_aframe) - if end > ref_end: - e_idx = int((ref_end - end) / step_per_aframe) - return aframes[:, s_idx:e_idx] - - -def read_video( - filename: str, - start_pts: Union[float, Fraction] = 0, - end_pts: Optional[Union[float, Fraction]] = None, - pts_unit: str = "pts", - output_format: str = "THWC", -) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - """[DEPRECATED] Reads a video from a file, returning both the video frames and the audio frames - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - Args: - filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts. - start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): - The start presentation time of the video - end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): - The end presentation time - pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, - either 'pts' or 'sec'. Defaults to 'pts'. - output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". - - Returns: - vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames - aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points - info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video) - - output_format = output_format.upper() - if output_format not in ("THWC", "TCHW"): - raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") - - _check_av_available() - - if end_pts is None: - end_pts = float("inf") - - if end_pts < start_pts: - raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = Fraction(0, 1) - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - - except FFmpegError: - # TODO raise a warning? - pass - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) - else: - aframes = torch.empty((1, 0), dtype=torch.float32) - - if output_format == "TCHW": - # [T,H,W,C] --> [T,C,H,W] - vframes = vframes.permute(0, 3, 1, 2) - - return vframes, aframes, info - - -def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: - extradata = container.streams[0].codec_context.extradata - if extradata is None: - return False - if b"Lavc" in extradata: - return True - return False - - -def _decode_video_timestamps(container: "av.container.Container") -> list[int]: - if _can_read_timestamps_from_packets(container): - # fast path - return [x.pts for x in container.demux(video=0) if x.pts is not None] - else: - return [x.pts for x in container.decode(video=0) if x.pts is not None] - - -def read_video_timestamps(filename: str, pts_unit: str = "pts") -> tuple[list[int], Optional[float]]: - """[DEPREACTED] List the video frames timestamps. - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - Note that the function decodes the whole video frame-by-frame. - - Args: - filename (str): path to the video file - pts_unit (str, optional): unit in which timestamp values will be returned - either 'pts' or 'sec'. Defaults to 'pts'. - - Returns: - pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'): - presentation timestamps for each one of the frames in the video. - video_fps (float, optional): the frame rate for the video - - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video_timestamps) - - _check_av_available() - - video_fps = None - pts = [] - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.video: - video_stream = container.streams.video[0] - video_time_base = video_stream.time_base - try: - pts = _decode_video_timestamps(container) - except FFmpegError: - warnings.warn(f"Failed decoding frames for file {filename}") - video_fps = float(video_stream.average_rate) - except FFmpegError as e: - msg = f"Failed to open container for {filename}; Caught error: {e}" - warnings.warn(msg, RuntimeWarning) - - pts.sort() - - if pts_unit == "sec": - pts = [x * video_time_base for x in pts] - - return pts, video_fps +# This module re-exports video utilities from the internal fb location. +# The actual implementation lives in pytorch.vision.fb.io.video +from pytorch.vision.fb.io.video import ( # type: ignore[import-not-found] + _align_audio_frames, + _av_available, + _check_av_available, + _read_from_stream, + av, + read_video, + read_video_timestamps, + write_video, +) + +__all__ = [ + "read_video", + "read_video_timestamps", + "write_video", + "_read_from_stream", + "_align_audio_frames", + "_check_av_available", + "_av_available", + "av", +]