Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 70 additions & 77 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,94 +13,91 @@ env:

jobs:
prek:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v6
- name: prek check
uses: j178/prek-action@v1
with:
extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy


lint:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
update-path: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y pandoc gsfonts
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install '.[doc,test]'
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
uv pip install --system --upgrade jaxlib jax
uv pip install --system --upgrade '.[doc,test]'
uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --system --upgrade -r docs/requirements.txt
uv pip freeze
- name: Lint with mypy and ruff
run: |
make lint
uv run make lint
- name: Build documentation
run: |
make docs
uv run make docs
- name: Test documentation
run: |
make doctest
python -m doctest -v README.md

uv run make doctest
uv run python -m doctest -v README.md

test-modeling:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.11", "3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
update-path: true
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y graphviz
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,test]'
pip freeze
uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip
uv pip install --system --upgrade jaxlib jax
uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --system --upgrade -e '.[dev,test]'
uv pip freeze
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
CI=1 uv run pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum"
JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum"
- name: Test tracer leak
if: matrix.python-version == '3.13'
env:
JAX_CHECK_TRACER_LEAKS: 1
run: |
pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
pytest -vs test/test_distributions.py::test_mean_var -k Gompertz

uv run pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
uv run pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
uv run pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
uv run pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -109,51 +106,50 @@ jobs:
parallel: true
flag-name: test-modeling


test-inference:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.11", "3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
update-path: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,test]'
pip freeze
uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip
uv pip install --system --upgrade jaxlib jax
uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --system --upgrade -e '.[dev,test]'
uv pip freeze
- name: Test with pytest
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
uv run pytest -vs --durations=20 test/infer/test_mcmc.py
uv run pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
uv run pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64
JAX_ENABLE_X64=1 uv run pytest -vs test/infer/test_mcmc.py -k x64
- name: Test chains
run: |
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
JAX_ENABLE_CUSTOM_PRNG=1 uv run pytest -vs test/infer/test_mcmc.py
- name: Test nested sampling
run: |
JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py
JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -162,32 +158,32 @@ jobs:
parallel: true
flag-name: test-inference


examples:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
update-path: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,examples,test]'
pip freeze
uv pip install --system --upgrade jaxlib jax
uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --system --upgrade -e '.[dev,examples,test]'
uv pip freeze
- name: Test with pytest
run: |
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -196,9 +192,7 @@ jobs:
parallel: true
flag-name: examples


finish:

needs: [test-modeling, test-inference, examples]
runs-on: ubuntu-latest
steps:
Expand All @@ -208,4 +202,3 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel-finished: true
carryforward: "test-modeling,test-inference,examples"

Loading