Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9a163bb
added steps=n logic
alessandrofasse May 2, 2025
9873e5c
Adjust save-every-step logic.
patrick-kidger Jun 7, 2025
6739e19
fixed a comparison test between ::n and skip save at
alessandrofasse Jun 10, 2025
0730717
fixing linting
alessandrofasse Jun 13, 2025
5f89ab4
Added saveat-steps+event test
patrick-kidger Jun 18, 2025
0129941
Introduction of bidirectional vs. unidirectional triggering of events
LuggiStruggi Jun 18, 2025
1b3daf0
Extend event crossing tests
patrick-kidger Jun 18, 2025
49752b8
Fixes for pytree-valued condition functions.
patrick-kidger Jun 18, 2025
97b656e
Improve ConstantStepSize incrementation
jpbrodrick89 Jul 16, 2025
a1305ea
Tweaked layout of ConstantStepSize code
patrick-kidger Jul 30, 2025
5d9f6b9
Use 100 ULP's to clip timesteps close to t1 (#660)
jpbrodrick89 Aug 2, 2025
1ac005d
Tests that a jump at t1 is saved.
patrick-kidger Jul 13, 2025
b7dc392
adapt
lockwo Aug 3, 2025
8a72ee5
Fixes 681
patrick-kidger Aug 31, 2025
2fd3ef3
Added benchmarking FAQ
patrick-kidger Sep 12, 2025
43f82dc
Standardised infra
patrick-kidger Oct 3, 2025
6694c86
Fix failing test
patrick-kidger Oct 3, 2025
02d6b8a
fix float error in prev_dt step calculation that led to an infinite loop
philipwijesinghe Nov 5, 2025
b91138f
avoids accumulation of float precision errors in dt
philipwijesinghe Nov 6, 2025
62bf876
Fixed case in which t0 is prevbefore a jump time
patrick-kidger Dec 22, 2025
fdfecc7
Fix #720; bool event + root find + terminate on first step
patrick-kidger Jan 30, 2026
550c202
Added Advanced SDE example and a table of SRKs
andyElking Jul 29, 2024
c4f6c7d
remove spurious type in union
patrick-kidger Jan 31, 2026
e7d78bf
0.7.1 version bump
patrick-kidger Jan 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v6
uses: patrick-kidger/action_update_python_project@v8
with:
python-version: "3.11"
test-script: |
cp -r ${{ github.workspace }}/test ./test
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
python -m pip install -r ./test/requirements.txt
python -m test
uv sync --extra tests --no-install-project --inexact
uv run --no-sync pytest
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
13 changes: 8 additions & 5 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ "3.10", "3.12" ]
python-version: [ "3.11", "3.13" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand All @@ -23,13 +23,16 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r ./test/requirements.txt

python -m pip install '.[dev,docs,tests]'

- name: Checks with pre-commit
uses: pre-commit/action@v3.0.1
run: |
pre-commit run --all-files

- name: Test with pytest
run: |
python -m pip install .
python -m test

- name: Check that documentation can be built.
run: |
mkdocs build
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ site/
.pymon
.idea/
.venv/
uv.lock
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ repos:
files: ^pyproject\.toml$
additional_dependencies: ["toml-sort==0.23.1"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.13.0
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
types_or: [ python, pyi, jupyter, toml ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
types_or: [ python, pyi, jupyter, toml ]
args: [ --fix ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.405
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]
18 changes: 5 additions & 13 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started.

First fork the library on GitHub.

Then clone and install the library in development mode:
Then clone and install the library:

```bash
git clone https://github.com/your-username-here/diffrax.git
cd diffrax
pip install -e .
pip install -e '.[dev]'
pre-commit install # `pre-commit` is installed by `pip` on the previous line
```

Then install the pre-commit hook:

```bash
pip install pre-commit
pre-commit install
```

These hooks use ruff to lint and format the code, and pyright to type-check it.

---

**If you're making changes to the code:**
Expand All @@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary.
Next verify the tests all pass:

```bash
pip install -r test/requirements.txt
pytest
pip install -e '.[tests]'
pytest # `pytest` is installed by `pip` on the previous line.
```

Then push your changes back to your fork of the repository:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/against_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def speedtest(fn, name):
# INTEGRATE WITH scan


@jax.checkpoint # pyright: ignore
@jax.checkpoint
def body(carry, t):
u, v, dt = carry
u = u + du(t, v, None) * dt
Expand Down
3 changes: 1 addition & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def loop(
if is_unsafe_sde(terms):
kind = "lax"
msg = (
"Cannot reverse-mode autodifferentiate when using "
"`UnsafeBrownianPath`."
"Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`."
)
elif max_steps is None:
kind = "lax"
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
] = eqx.field(static=True)
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
eqx.field(static=True)
)
key: PRNGKeyArray

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ class VirtualBrownianTree(AbstractBrownianPath):
t1: RealScalarLike
tol: RealScalarLike
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
] = eqx.field(static=True)
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
eqx.field(static=True)
)
key: PyTree[PRNGKeyArray]
_spline: _Spline = eqx.field(static=True)

Expand Down
2 changes: 1 addition & 1 deletion diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
BoolScalarLike = bool | Array | np.ndarray
FloatScalarLike = float | Array | np.ndarray
IntScalarLike = int | Array | np.ndarray
RealScalarLike = bool | int | float | Array | np.ndarray
RealScalarLike = int | float | Array | np.ndarray
else:
BoolScalarLike = Bool[ArrayLike, ""]
FloatScalarLike = Float[ArrayLike, ""]
Expand Down
35 changes: 34 additions & 1 deletion diffrax/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable

import equinox as eqx
import jax.tree_util as jtu
import optimistix as optx
from jaxtyping import Array, PyTree

Expand All @@ -20,7 +21,33 @@ class Event(eqx.Module):
"""

cond_fn: PyTree[Callable[..., BoolScalarLike | RealScalarLike]]
root_finder: optx.AbstractRootFinder | None = None
direction: PyTree[None | bool]
root_finder: optx.AbstractRootFinder | None

def __init__(
self,
cond_fn,
root_finder: optx.AbstractRootFinder | None = None,
direction: None | bool | PyTree[None | bool] = None,
):
if direction in (None, False, True):
direction = jtu.tree_map(lambda _: direction, cond_fn, is_leaf=callable)

direction_leaves, direction_structure = jtu.tree_flatten(
direction, is_leaf=lambda x: x is None
)
if direction_structure != jtu.tree_structure(cond_fn, is_leaf=callable):
raise ValueError("Missmatch in the structure of `cond_fn` and `direction`.")

if any(x not in (None, False, True) for x in direction_leaves):
raise ValueError(
"`trig_dir` must be a `None`, `bool`, or a PyTree of `None | bool`s "
"with the same structure as `cond_fn`."
)

self.cond_fn = cond_fn
self.root_finder = root_finder
self.direction = direction


Event.__init__.__doc__ = """**Arguments:**
Expand All @@ -39,6 +66,12 @@ class Event(eqx.Module):
[`optimistix.Newton`](https://docs.kidger.site/optimistix/api/root_find/#optimistix.Newton)
would be a typical choice here.

- `direction`: `None` or `bool` or PyTree of `None | bool` of the same shape as
`cond_fn`, that decides for each `cond_fn` if it triggers an event from a
zero-cossing in both directions (`None`), from an upcrossing (`True`) or from a
downcrossing (`False`). Only needed for those `cond_fn` which return floating point
numbers; ignored for those `cond_fn` which return booleans.

!!! Example

Consider a bouncing ball dropped from some intial height $x_0$. We can model
Expand Down
Loading