diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..ce67e34d5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +# EditorConfig +# https://editorconfig.org/ + +root = true + +[*] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{py,sh,ipynb}] +indent_size = 4 + +[*{.md,rst}] +indent_size = 3 diff --git a/.github/workflows/paddle_wheel.yaml b/.github/workflows/paddle_wheel.yaml new file mode 100644 index 000000000..7cf150825 --- /dev/null +++ b/.github/workflows/paddle_wheel.yaml @@ -0,0 +1,185 @@ +name: Build Wheels for Paddle + +on: + push: + branches: [main] + tags: ["v*"] + pull_request: + merge_group: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: write + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-paddlecodec-wheel: + runs-on: ubuntu-latest + container: + image: pytorch/manylinux2_28-builder:cpu + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + permissions: + id-token: write + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Setup conda environment + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + miniforge-version: latest + activate-environment: build + python-version: ${{ matrix.python-version }} + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build wheel setuptools + + - name: Install PaddlePaddle nightly + run: | + pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ + + - name: Run pre-build script + run: | + bash packaging/pre_build_script.sh + + - name: Build wheel + run: | + # Use pre-built FFmpeg from PyTorch S3 + export BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 + python -m build --wheel -vvv --no-isolation + + - name: Upload wheel artifact + uses: actions/upload-artifact@v5 + with: + name: paddlecodec-wheel-linux-py${{ matrix.python-version }} + path: dist/*.whl + + - name: Run post-build script + run: | + bash packaging/post_build_script.sh + + - name: List wheel contents + run: | + wheel_path=$(find dist -type f -name "*.whl") + echo "Wheel path: $wheel_path" + unzip -l $wheel_path + + test-paddlecodec-wheel: + needs: build-paddlecodec-wheel + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + # Python 3.13 needs waiting PaddlePaddle self-hosted index upload setuptools + python-version: ["3.9", "3.10", "3.11", "3.12"] + # FFmpeg 8.0 depends on libopenvino.so.2520, PaddlePaddle CPU depends on libopenvino.so.2500 + # There has some conflict causing test failures, but it works with PaddlePaddle GPU. + # We skip FFmpeg 8.0 tests for PaddlePaddle CPU builds for now. + ffmpeg-version: ["4.4.2", "5.1.2", "6.1.1", "7.0.1"] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Download wheel artifact + uses: actions/download-artifact@v4 + with: + name: paddlecodec-wheel-linux-py${{ matrix.python-version }} + path: dist/ + + - name: Install FFmpeg via conda + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + miniforge-version: latest + activate-environment: test + python-version: ${{ matrix.python-version }} + + - name: Install FFmpeg from conda-forge + run: | + conda install "ffmpeg=${{ matrix.ffmpeg-version }}" -c conda-forge -y + ffmpeg -version + + - name: Install PaddlePaddle nightly in conda env + run: | + pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ + + - name: Install paddlecodec from wheel + run: | + wheel_path=$(find dist -type f -name "*.whl") + echo "Installing $wheel_path" + pip install $wheel_path -vvv + + - name: Install test dependencies + run: | + pip install numpy pytest pillow + + - name: Delete src folder + run: | + # Delete src/ to ensure we're testing the installed wheel, not source code + rm -rf src/ + ls -la + + - name: Run tests + run: | + pytest --override-ini="addopts=-v" -s test_paddle + + publish-pypi: + runs-on: ubuntu-latest + name: Publish to PyPI + if: "startsWith(github.ref, 'refs/tags/')" + needs: + - test-paddlecodec-wheel + permissions: + id-token: write + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v6 + with: + name: artifacts + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + publish-release: + runs-on: ubuntu-latest + name: Publish to GitHub + if: "startsWith(github.ref, 'refs/tags/')" + needs: + - test-paddlecodec-wheel + permissions: + contents: write + steps: + - uses: actions/download-artifact@v6 + with: + name: artifacts + path: dist/ + - name: Get tag name + run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + - name: Publish to GitHub + uses: softprops/action-gh-release@v2 + with: + draft: true + files: dist/* + tag_name: ${{ env.RELEASE_VERSION }} diff --git a/examples/decoding/sampling.py b/examples/decoding/sampling.py index 5b0f87819..f2a9c9a8b 100644 --- a/examples/decoding/sampling.py +++ b/examples/decoding/sampling.py @@ -20,7 +20,7 @@ # :ref:`sampling_tuto_start`. import paddle -paddle.compat.enable_torch_proxy() +paddle.compat.enable_torch_proxy(scope={"torchcodec"}) from typing import Optional import torch diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 4df8d1b6d..f6a02596a 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -12,9 +12,8 @@ set( TORCH_LIBRARIES "${PADDLE_PATH}/base/libpaddle.so" "${PADDLE_PATH}/libs/libcommon.so" - "${PADDLE_PATH}/libs/libphi.so" + # "${PADDLE_PATH}/libs/libphi.so" # currently libphi.so is static linked, we need remove it when it's shared linked "${PADDLE_PATH}/libs/libphi_core.so" - "${PADDLE_PATH}/libs/libphi_gpu.so" ) set( TORCH_INSTALL_PREFIX diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 804d536fd..9ea928890 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -91,6 +91,10 @@ def disallow_in_graph(self, fn): return fn torch._dynamo = FakeDynamo("torch._dynamo") torch._C._log_api_usage_once = lambda *args, **kwargs: None +# TODO: torch.__setattr__ should trigger paddle.__setattr__ +import paddle +paddle._dynamo = FakeDynamo("torch._dynamo") +paddle._C._log_api_usage_once = lambda *args, **kwargs: None # Note: We use disallow_in_graph because PyTorch does constant propagation of # factory functions. create_from_file = torch._dynamo.disallow_in_graph( diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 730bfa0e5..6b524f119 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -102,7 +102,7 @@ def __init__( stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, - device: Optional[Union[str, torch_device]] = "cpu", + device: Optional[Union[str, "torch_device"]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", custom_frame_mappings: Optional[ Union[str, bytes, io.RawIOBase, io.BufferedReader] diff --git a/test_paddle/test_video_decode.py b/test_paddle/test_video_decode.py new file mode 100644 index 000000000..dc9a125f7 --- /dev/null +++ b/test_paddle/test_video_decode.py @@ -0,0 +1,181 @@ +import paddle +paddle.compat.enable_torch_proxy(scope={"torchcodec"}) + +import pytest +from dataclasses import dataclass, fields +from io import BytesIO +from typing import Callable, Mapping, Optional, Union + +import os +import httpx +import numpy as np + + +@dataclass +class VideoMetadata(Mapping): + total_num_frames: int + fps: Optional[float] = None + width: Optional[int] = None + height: Optional[int] = None + duration: Optional[float] = None + video_backend: Optional[str] = None + frames_indices: Optional[list[int]] = None + + def __iter__(self): + return (f.name for f in fields(self)) + + def __len__(self): + return len(fields(self)) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + @property + def timestamps(self) -> list[float]: + "Timestamps of the sampled frames in seconds." + if self.fps is None or self.frames_indices is None: + raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.") + return [frame_idx / self.fps for frame_idx in self.frames_indices] + + def update(self, dictionary): + for key, value in dictionary.items(): + if hasattr(self, key): + setattr(self, key, value) + + +def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): + total_num_frames = metadata.total_num_frames + video_fps = metadata.fps + + if num_frames is None and fps is not None: + num_frames = int(total_num_frames / video_fps * fps) + if num_frames > total_num_frames: + raise ValueError( + f"When loading the video with fps={fps}, we computed num_frames={num_frames} " + f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." + ) + + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) + else: + indices = np.arange(0, total_num_frames, dtype=int) + return indices + + +def read_video_decord( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + from decord import VideoReader, cpu + + vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu + video_fps = vr.get_avg_fps() + total_num_frames = len(vr) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="decord", + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + video = vr.get_batch(indices).asnumpy() + + metadata.update( + { + "frames_indices": indices, + "height": video.shape[1], + "width": video.shape[2], + } + ) + return video, metadata + +def read_video_torchcodec( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + from torchcodec.decoders import VideoDecoder # import torchcodec + + decoder = VideoDecoder( + video_path, + seek_mode="exact", + num_ffmpeg_threads=0, + ) + metadata = VideoMetadata( + total_num_frames=decoder.metadata.num_frames, + fps=decoder.metadata.average_fps, + duration=decoder.metadata.duration_seconds, + video_backend="torchcodec", + height=decoder.metadata.height, + width=decoder.metadata.width, + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = decoder.get_frames_at(indices=indices).data + video = video.contiguous() + metadata.frames_indices = indices + return video, metadata + + +VIDEO_DECODERS = { + "decord": read_video_decord, + "torchcodec": read_video_torchcodec, +} + + +def load_video( + video, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + backend: str = "decord", + sample_indices_fn: Optional[Callable] = None, + **kwargs, +) -> np.ndarray: + + if fps is not None and num_frames is not None and sample_indices_fn is None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + # If user didn't pass a sampling function, create one on the fly with default logic + if sample_indices_fn is None: + + def sample_indices_fn_func(metadata, **fn_kwargs): + return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) + + sample_indices_fn = sample_indices_fn_func + + # Early exit if provided an array or `PIL` frames + if not isinstance(video, str): + metadata = [None] * len(video) + return video, metadata + + if video.startswith("http://") or video.startswith("https://"): + file_obj = BytesIO(httpx.get(video, follow_redirects=True).content) + elif os.path.isfile(video): + file_obj = video + else: + raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") + + video_decoder = VIDEO_DECODERS[backend] + video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) + return video, metadata + + +def test_video_decode(): + url = "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_video/example_video.mp4" + video, metadata = load_video(url, backend="torchcodec") + assert video.to(paddle.int64).sum().item() == 247759890390 + assert metadata.total_num_frames == 263 + assert metadata.fps == pytest.approx(29.99418249715141) + assert metadata.width == 1920 + assert metadata.height == 1080 + assert metadata.duration == pytest.approx(8.768367) + for i, idx in enumerate(metadata.frames_indices): + assert idx == i