diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 291be9f28..04588e191 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,94 +13,91 @@ env: jobs: prek: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - name: prek check uses: j178/prek-action@v1 with: extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy - lint: - runs-on: ubuntu-latest strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y pandoc gsfonts - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install '.[doc,test]' - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -r docs/requirements.txt - pip freeze + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade '.[doc,test]' + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -r docs/requirements.txt + uv pip freeze - name: Lint with mypy and ruff run: | - make lint + uv run make lint - name: Build documentation run: | - make docs + uv run make docs - name: Test documentation run: | - make doctest - python -m doctest -v README.md - + uv run make doctest + uv run python -m doctest -v README.md test-modeling: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + update-path: true + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y graphviz - python -m pip install --upgrade pip # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,test]' + uv pip freeze - name: Test with pytest run: | - CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ + CI=1 uv run pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum" + JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum" - name: Test tracer leak if: matrix.python-version == '3.13' env: JAX_CHECK_TRACER_LEAKS: 1 run: | - pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit - pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke - pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run - pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths - pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - + uv run pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit + uv run pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke + uv run pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run + uv run pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths + uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -109,51 +106,50 @@ jobs: parallel: true flag-name: test-modeling - test-inference: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,test]' + uv pip freeze - name: Test with pytest run: | - pytest -vs --durations=20 test/infer/test_mcmc.py - pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py - pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py + uv run pytest -vs --durations=20 test/infer/test_mcmc.py + uv run pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py + uv run pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 + JAX_ENABLE_X64=1 uv run pytest -vs test/infer/test_mcmc.py -k x64 - name: Test chains run: | - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/stochastic_support/test_dcc.py + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_hmc_gibbs.py -k "chain" - name: Test custom prng run: | - JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py + JAX_ENABLE_CUSTOM_PRNG=1 uv run pytest -vs test/infer/test_mcmc.py - name: Test nested sampling run: | - JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py + JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -162,32 +158,32 @@ jobs: parallel: true flag-name: test-inference - examples: - runs-on: ubuntu-latest needs: [lint, prek] strategy: matrix: python-version: ["3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,examples,test]' - pip freeze + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,examples,test]' + uv pip freeze - name: Test with pytest run: | - CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example + CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -196,9 +192,7 @@ jobs: parallel: true flag-name: examples - finish: - needs: [test-modeling, test-inference, examples] runs-on: ubuntu-latest steps: @@ -208,4 +202,3 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true carryforward: "test-modeling,test-inference,examples" -