From 9a163bb4d8320f1842e9de436b719a4cc7c42bff Mon Sep 17 00:00:00 2001 From: Alessandro Fasse Date: Fri, 2 May 2025 14:46:33 +0200 Subject: [PATCH 01/24] added steps=n logic --- diffrax/_integrate.py | 82 ++++++++++++++++++------- diffrax/_saveat.py | 45 ++++++++------ test/test_event.py | 4 +- test/test_integrate.py | 1 + test/test_saveat_solution.py | 115 +++++++++++++++++++++++++++++++++-- 5 files changed, 203 insertions(+), 44 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 17461ba4..029c07ac 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -236,6 +236,7 @@ def _save( fn: Callable, save_state: SaveState, repeat: int, + pred=True, ) -> SaveState: ts = save_state.ts ys = save_state.ys @@ -244,12 +245,17 @@ def _save( ts = lax.dynamic_update_slice_in_dim( ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0 ) + y_to_save = lax.cond( + pred, + lambda: fn(t, y, args), + lambda: jtu.tree_map(lambda ys_: ys_[save_index], ys), + ) ys = jtu.tree_map( lambda ys_, y_: lax.dynamic_update_slice_in_dim( ys_, jnp.broadcast_to(y_, (repeat, *y_.shape)), save_index, axis=0 ), ys, - fn(t, y, args), + y_to_save, ) save_index = save_index + repeat @@ -482,14 +488,32 @@ def maybe_inplace(i, u, x): return eqxi.buffer_at_set(x, i, u, pred=keep_step) def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: - if subsaveat.steps: + if subsaveat.steps != 0: + save_step = (state.num_accepted_steps % subsaveat.steps) == 0 + should_save = keep_step & save_step + + def save_fn(tprev, y, args): + return subsaveat.fn(tprev, y, args) + # TODO: Enable this, but I am not sure if possible? How do we know + # the output shape of `.fn`? We should do a dummy call to it? + if subsaveat.steps == 1: + return subsaveat.fn(tprev, y, args) + else: + return lax.cond( + should_save, + lambda: subsaveat.fn(tprev, y, args), + lambda: jtu.tree_map( + lambda y: jnp.zeros(y.shape[1:], y.dtype), save_state.ys + ), + ) + ts = maybe_inplace(save_state.save_index, tprev, save_state.ts) ys = jtu.tree_map( ft.partial(maybe_inplace, save_state.save_index), - subsaveat.fn(tprev, y, args), + save_fn(tprev, y, args), save_state.ys, ) - save_index = save_state.save_index + jnp.where(keep_step, 1, 0) + save_index = save_state.save_index + jnp.where(should_save, 1, 0) save_state = eqx.tree_at( lambda s: [s.ts, s.ys, s.save_index], save_state, @@ -500,7 +524,6 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: save_state = jtu.tree_map( save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat ) - if saveat.dense: dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) dense_infos = jtu.tree_map( @@ -800,20 +823,33 @@ def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveSt ) def _save_t1(subsaveat, save_state): + if subsaveat.steps == 0: + # We're not saving the final value via `steps`, + # so we might need to save it via `t1`. + t1_saved_via_steps = False + elif subsaveat.steps == 1: + # We're definitely saving the final value via `steps`, + # so we can skip saving it via `t1`. + t1_saved_via_steps = True + else: + # We might be saving the final value via `steps`, + # so we might need to save it via `t1`. + t1_saved_via_steps = final_state.num_accepted_steps % subsaveat.steps == 0 if event is None or event.root_finder is None: - if subsaveat.t1 and not subsaveat.steps: - # If subsaveat.steps then the final value is already saved. - save_state = _save( - tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1 - ) + if type(t1_saved_via_steps) is bool: + t1_not_saved_via_steps = not t1_saved_via_steps + else: + t1_not_saved_via_steps = jnp.logical_not(t1_saved_via_steps) + pred = subsaveat.t1 & t1_not_saved_via_steps else: - if subsaveat.t1 or subsaveat.steps: - # In this branch we need to replace the last value with tfinal - # and yfinal returned by the root finder also if subsaveat.steps - # because we deleted the last value after the event time above. - save_state = _save( - tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1 - ) + # If we're using an event with a root finder, and are saving steps, + # then we need to write the final value here because we deleted the + # last value after the event time above. + pred = subsaveat.t1 | t1_saved_via_steps + if pred is not False: + save_state = _save( + tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1, pred=pred + ) return save_state save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat) @@ -1215,16 +1251,20 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: out_size += 1 if subsaveat.ts is not None: out_size += len(subsaveat.ts) - if subsaveat.steps: + if subsaveat.steps != 0: # We have no way of knowing how many steps we'll actually end up taking, and # XLA doesn't support dynamic shapes. So we just have to allocate the # maximum amount of steps we can possibly take. if max_steps is None: raise ValueError( - "`max_steps=None` is incompatible with saving at `steps=True`" + "`max_steps=None` is incompatible with saving at `steps=n`" ) - out_size += max_steps - if subsaveat.t1 and not subsaveat.steps: + out_size += max_steps // subsaveat.steps + if subsaveat.t1 and ( + (max_steps is None) + or (subsaveat.steps == 0) + or (max_steps % subsaveat.steps != 0) + ): out_size += 1 saveat_ts_index = 0 save_index = 0 diff --git a/diffrax/_saveat.py b/diffrax/_saveat.py index e0786593..d32e2b33 100644 --- a/diffrax/_saveat.py +++ b/diffrax/_saveat.py @@ -11,15 +11,6 @@ def save_y(t, y, args): return y -def _convert_ts( - ts: None | Sequence[RealScalarLike] | Real[Array, " times"], -) -> Real[Array, " times"] | None: - if ts is None or len(ts) == 0: - return None - else: - return jnp.asarray(ts) - - class SubSaveAt(eqx.Module): """Used for finer-grained control over what is saved. A PyTree of these should be passed to `SaveAt(subs=...)`. @@ -28,11 +19,29 @@ class SubSaveAt(eqx.Module): relatively niche feature and most users will probably not need to use `SubSaveAt`.) """ - t0: bool = False - t1: bool = False - ts: Real[Array, " times"] | None = eqx.field(default=None, converter=_convert_ts) - steps: bool = False - fn: Callable = save_y + t0: bool + t1: bool + ts: Real[Array, " times"] | None + steps: int + fn: Callable + + def __init__( + self, + *, + t0: bool = False, + t1: bool = False, + ts: None | Sequence[RealScalarLike] | Real[Array, " times"] = None, + steps: bool | int = 0, + fn: Callable = save_y, + ): + self.t0 = t0 + self.t1 = t1 + self.ts = jnp.asarray(ts) if ts is not None and len(ts) > 0 else None + if isinstance(steps, bool): + self.steps = 1 if steps else 0 + else: + self.steps = steps + self.fn = fn def __check_init__(self): if not self.t0 and not self.t1 and self.ts is None and not self.steps: @@ -44,7 +53,8 @@ def __check_init__(self): - `t0`: If `True`, save the initial input `y0`. - `t1`: If `True`, save the output at `t1`. - `ts`: Some array of times at which to save the output. -- `steps`: If `True`, save the output at every step of the numerical solver. +- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver. + `0` means no saving. - `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the evolving solution is saved. This can be useful to save only statistics of your @@ -71,7 +81,7 @@ def __init__( t0: bool = False, t1: bool = False, ts: None | Sequence[RealScalarLike] | Real[Array, " times"] = None, - steps: bool = False, + steps: bool | int = False, fn: Callable = save_y, subs: PyTree[SubSaveAt] = None, dense: bool = False, @@ -100,7 +110,8 @@ def __init__( - `t0`: If `True`, save the initial input `y0`. - `t1`: If `True`, save the output at `t1`. - `ts`: Some array of times at which to save the output. -- `steps`: If `True`, save the output at every step of the numerical solver. +- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver. + `0` means no saving. - `dense`: If `True`, save dense output, that can later be evaluated at any part of the interval $[t_0, t_1]$ via `sol = diffeqsolve(...); sol.evaluate(...)`. diff --git a/test/test_event.py b/test/test_event.py index 80f0102c..12581c53 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -564,7 +564,7 @@ def cond_fn_2(t, y, args, **kwargs): @pytest.mark.parametrize("steps", (1, 2, 3, 4, 5)) -def test_event_save_steps(steps): +def test_event_save_all_steps(steps): term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0)) solver = diffrax.Tsit5() t0 = 0 @@ -604,7 +604,7 @@ def run(saveat): num_steps = [steps, steps, steps + 1, steps] yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr)] - for saveat, n, yevent in zip(saveats, num_steps, yevents): + for saveat, n, yevent in zip(saveats, num_steps, yevents, strict=True): ts, ys = run(saveat) xs, zs = ys xevent, zevent = yevent diff --git a/test/test_integrate.py b/test/test_integrate.py index 15d83f3e..cfcaadfd 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -334,6 +334,7 @@ def get_dt_and_controller(level): diffrax.SaveAt(t1=True), diffrax.SaveAt(ts=[3.5, 0.7]), diffrax.SaveAt(steps=True), + diffrax.SaveAt(steps=2), diffrax.SaveAt(dense=True), ), ) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 8ddca38d..0f87b090 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -111,7 +111,7 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(steps=True) + saveat = diffrax.SaveAt(steps=1) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 @@ -131,6 +131,49 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful + saveat = diffrax.SaveAt(steps=2) + sol = _integrate(saveat) + assert sol.t0 == _t0 + assert sol.t1 == _t1 + n = (4096 - 1) // 2 + 1 + assert sol.ts.shape == (n,) # pyright: ignore + assert sol.ys.shape == (n, 1) # pyright: ignore + _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) + assert tree_allclose(sol.ys, _ys) + assert sol.controller_state is None + assert sol.solver_state is None + with pytest.raises(ValueError): + sol.evaluate(0.2, 0.8) + with pytest.raises(ValueError): + sol.derivative(0.2) + assert sol.stats["num_steps"] > 0 + assert sol.result == diffrax.RESULTS.successful + + saveat = diffrax.SaveAt(steps=2, t1=True) + sol = _integrate(saveat) + assert sol.t0 == _t0 + assert sol.t1 == _t1 + n = (4096 - 1) // 2 + 1 + assert sol.ts.shape == (n,) # pyright: ignore + assert sol.ys.shape == (n, 1) # pyright: ignore + _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) + print(_ys) + assert tree_allclose(sol.ys, _ys) + assert sol.controller_state is None + assert sol.solver_state is None + with pytest.raises(ValueError): + sol.evaluate(0.2, 0.8) + with pytest.raises(ValueError): + sol.derivative(0.2) + assert sol.stats["num_steps"] > 0 + assert sol.result == diffrax.RESULTS.successful + saveat = diffrax.SaveAt(dense=True) sol = _integrate(saveat) assert sol.t0 == _t0 @@ -147,6 +190,70 @@ def test_saveat_solution(): assert sol.result == diffrax.RESULTS.successful +def test_saveat_solution_skip_steps(): + def _step_integrate(saveat: diffrax.SaveAt): + term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) + ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + sol_ts = diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=jnp.array([1.0]), + dt0=None, + solver=diffrax.Euler(), + saveat=saveat, + stepsize_controller=diffrax.StepTo(ts=ts), + max_steps=10, + ).ts + assert sol_ts is not None + return sol_ts[jnp.isfinite(sol_ts)] + + saveat = diffrax.SaveAt(steps=2) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0])) + saveat = diffrax.SaveAt(steps=2, t1=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0, 6.0])) + saveat = diffrax.SaveAt(steps=2, t1=True, t0=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([0.0, 1.0, 3.0, 5.0, 6.0])) + saveat = diffrax.SaveAt(steps=3) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 4.0])) + saveat = diffrax.SaveAt(steps=3, t1=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 4.0, 6.0])) + saveat = diffrax.SaveAt(steps=3, t1=True, t0=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([0.0, 1.0, 4.0, 6.0])) + + +def test_saveat_solution_skip_vs_saveat(): + ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + n = 2 + saveat_skip = diffrax.SaveAt(steps=n) + saveat = diffrax.SaveAt(ts=ts[::n]) + term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) + + def solve(saveat): + return diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=jnp.array([1.0]), + dt0=None, + solver=diffrax.Euler(), + saveat=saveat, + stepsize_controller=diffrax.StepTo(ts=ts), + max_steps=10, + ) + + sol_skip = solve(saveat_skip) + sol = solve(saveat) + assert sol_skip.ts == sol.ts + assert sol_skip.ys == sol.ys + + @pytest.mark.parametrize("subs", [True, False]) def test_t0_eq_t1(subs): y0 = jnp.array([2.0]) @@ -164,7 +271,7 @@ def test_t0_eq_t1(subs): get2 = diffrax.SubSaveAt( t0=True, ts=ts, - steps=True, + steps=1, ) subs = (get0, get1, get2) saveat = diffrax.SaveAt(subs=subs) @@ -220,7 +327,7 @@ def _solve(tf): get2 = diffrax.SubSaveAt( t0=True, ts=ts, - steps=True, + steps=1, fn=lambda t, y, args: jnp.where(jnp.isinf(y), 3.0, 4.0), ) subs = (get0, get1, get2) @@ -294,7 +401,7 @@ def test_subsaveat(adjoint, multi_subs, with_fn, getkey): subsaveat_kwargs: dict = dict() get2 = diffrax.SubSaveAt(t0=True, ts=jnp.linspace(0.5, 1.5, 3), **subsaveat_kwargs) if multi_subs: - get0 = diffrax.SubSaveAt(steps=True, fn=lambda _, y, __: y[0]) + get0 = diffrax.SubSaveAt(steps=1, fn=lambda _, y, __: y[0]) get1 = diffrax.SubSaveAt( ts=jnp.linspace(0, 1, 5), t1=True, fn=lambda _, y, __: y[1] ) From 9873e5c57424ea099af2ca3bc4a429c8454244e3 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 7 Jun 2025 14:46:15 +0200 Subject: [PATCH 02/24] Adjust save-every-step logic. --- diffrax/_integrate.py | 51 +++++++++++++++++----------------- test/test_saveat_solution.py | 54 +++++++++++++++++++++++------------- 2 files changed, 59 insertions(+), 46 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 029c07ac..e78c0004 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -242,9 +242,8 @@ def _save( ys = save_state.ys save_index = save_state.save_index - ts = lax.dynamic_update_slice_in_dim( - ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0 - ) + t_to_save = jnp.broadcast_to(static_select(pred, t, ts[save_index]), (repeat,)) + ts = lax.dynamic_update_slice_in_dim(ts, t_to_save, save_index, axis=0) y_to_save = lax.cond( pred, lambda: fn(t, y, args), @@ -484,33 +483,29 @@ def _body_fun(_save_state): save_ts, saveat.subs, save_state, is_leaf=_is_subsaveat ) - def maybe_inplace(i, u, x): - return eqxi.buffer_at_set(x, i, u, pred=keep_step) - def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.steps != 0: - save_step = (state.num_accepted_steps % subsaveat.steps) == 0 + save_step = (num_accepted_steps % subsaveat.steps) == 0 should_save = keep_step & save_step - def save_fn(tprev, y, args): - return subsaveat.fn(tprev, y, args) - # TODO: Enable this, but I am not sure if possible? How do we know - # the output shape of `.fn`? We should do a dummy call to it? - if subsaveat.steps == 1: - return subsaveat.fn(tprev, y, args) - else: - return lax.cond( - should_save, - lambda: subsaveat.fn(tprev, y, args), - lambda: jtu.tree_map( - lambda y: jnp.zeros(y.shape[1:], y.dtype), save_state.ys - ), - ) + if subsaveat.steps == 1: + y_to_save = subsaveat.fn(tprev, y, args) + else: + struct = eqx.filter_eval_shape(subsaveat.fn, tprev, y, args) + y_to_save = lax.cond( + eqxi.unvmap_any(should_save), + lambda: subsaveat.fn(tprev, y, args), + lambda: jtu.tree_map(jnp.zeros_like, struct), + ) - ts = maybe_inplace(save_state.save_index, tprev, save_state.ts) + ts = eqxi.buffer_at_set( + save_state.ts, save_state.save_index, tprev, pred=should_save + ) ys = jtu.tree_map( - ft.partial(maybe_inplace, save_state.save_index), - save_fn(tprev, y, args), + lambda _y, _ys: eqxi.buffer_at_set( + _ys, save_state.save_index, _y, pred=should_save + ), + y_to_save, save_state.ys, ) save_index = save_state.save_index + jnp.where(should_save, 1, 0) @@ -525,9 +520,13 @@ def save_fn(tprev, y, args): save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat ) if saveat.dense: - dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) + dense_ts = eqxi.buffer_at_set( + dense_ts, dense_save_index + 1, tprev, pred=keep_step + ) dense_infos = jtu.tree_map( - ft.partial(maybe_inplace, dense_save_index), + lambda _i, _is: eqxi.buffer_at_set( + _is, dense_save_index, _i, pred=keep_step + ), dense_info, dense_infos, ) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 0f87b090..0d8ad4c1 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -191,9 +191,12 @@ def test_saveat_solution(): def test_saveat_solution_skip_steps(): - def _step_integrate(saveat: diffrax.SaveAt): + def _step_integrate(saveat: diffrax.SaveAt, with_7: bool): term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) - ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + if with_7: + ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) + else: + ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) sol_ts = diffrax.diffeqsolve( term, t0=ts[0], @@ -208,24 +211,35 @@ def _step_integrate(saveat: diffrax.SaveAt): assert sol_ts is not None return sol_ts[jnp.isfinite(sol_ts)] - saveat = diffrax.SaveAt(steps=2) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0])) - saveat = diffrax.SaveAt(steps=2, t1=True) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0, 6.0])) - saveat = diffrax.SaveAt(steps=2, t1=True, t0=True) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([0.0, 1.0, 3.0, 5.0, 6.0])) - saveat = diffrax.SaveAt(steps=3) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([1.0, 4.0])) - saveat = diffrax.SaveAt(steps=3, t1=True) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([1.0, 4.0, 6.0])) - saveat = diffrax.SaveAt(steps=3, t1=True, t0=True) - ts = _step_integrate(saveat) - assert jnp.allclose(ts, jnp.array([0.0, 1.0, 4.0, 6.0])) + ts = _step_integrate(diffrax.SaveAt(steps=2), with_7=True) + assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0])) + ts = _step_integrate(diffrax.SaveAt(steps=2), with_7=False) + assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0])) + + ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True), with_7=True) + assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0, 7.0])) + ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True), with_7=False) + assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0])) + + ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True, t0=True), with_7=True) + assert jnp.allclose(ts, jnp.array([0.0, 2.0, 4.0, 6.0, 7.0])) + ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True, t0=True), with_7=False) + assert jnp.allclose(ts, jnp.array([0.0, 2.0, 4.0, 6.0])) + + ts = _step_integrate(diffrax.SaveAt(steps=3), with_7=True) + assert jnp.allclose(ts, jnp.array([3.0, 6.0])) + ts = _step_integrate(diffrax.SaveAt(steps=3), with_7=False) + assert jnp.allclose(ts, jnp.array([3.0, 6.0])) + + ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True), with_7=True) + assert jnp.allclose(ts, jnp.array([3.0, 6.0, 7.0])) + ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True), with_7=False) + assert jnp.allclose(ts, jnp.array([3.0, 6.0])) + + ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True, t0=True), with_7=True) + assert jnp.allclose(ts, jnp.array([0.0, 3.0, 6.0, 7.0])) + ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True, t0=True), with_7=False) + assert jnp.allclose(ts, jnp.array([0.0, 3.0, 6.0])) def test_saveat_solution_skip_vs_saveat(): From 6739e19121da58d05c833d4a4810583200e86fdc Mon Sep 17 00:00:00 2001 From: Alessandro Fasse Date: Tue, 10 Jun 2025 10:51:39 +0200 Subject: [PATCH 03/24] fixed a comparison test between ::n and skip save at --- test/test_saveat_solution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 0d8ad4c1..ea6f4d64 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -245,7 +245,7 @@ def _step_integrate(saveat: diffrax.SaveAt, with_7: bool): def test_saveat_solution_skip_vs_saveat(): ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) n = 2 - saveat_skip = diffrax.SaveAt(steps=n) + saveat_skip = diffrax.SaveAt(steps=n, t0=True) saveat = diffrax.SaveAt(ts=ts[::n]) term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) @@ -259,13 +259,13 @@ def solve(saveat): solver=diffrax.Euler(), saveat=saveat, stepsize_controller=diffrax.StepTo(ts=ts), - max_steps=10, + max_steps=6, ) sol_skip = solve(saveat_skip) sol = solve(saveat) - assert sol_skip.ts == sol.ts - assert sol_skip.ys == sol.ys + assert jnp.allclose(sol_skip.ts, sol.ts) + assert jnp.allclose(sol_skip.ys, sol.ys) @pytest.mark.parametrize("subs", [True, False]) From 073071749a0f08e72aaff9a77a9bea1865d711a1 Mon Sep 17 00:00:00 2001 From: Alessandro Fasse Date: Fri, 13 Jun 2025 11:45:15 +0200 Subject: [PATCH 04/24] fixing linting --- test/test_saveat_solution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index ea6f4d64..3b3c06c2 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -264,6 +264,12 @@ def solve(saveat): sol_skip = solve(saveat_skip) sol = solve(saveat) + assert sol is not None + assert sol.ts is not None + assert sol.ys is not None + assert sol_skip is not None + assert sol_skip.ts is not None + assert sol_skip.ys is not None assert jnp.allclose(sol_skip.ts, sol.ts) assert jnp.allclose(sol_skip.ys, sol.ys) From 5f89ab4a9ecd39ea6f36bd344ea5ee9add05e574 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 18 Jun 2025 22:24:35 +0200 Subject: [PATCH 05/24] Added saveat-steps+event test --- test/test_saveat_solution.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 3b3c06c2..3db4e7ba 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -5,6 +5,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import optimistix as optx import pytest from .helpers import tree_allclose @@ -274,6 +275,47 @@ def solve(saveat): assert jnp.allclose(sol_skip.ys, sol.ys) +def test_saveat_steps_with_event(): + def solve(saveat): + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: -0.5 * y), + t0=0, + t1=5, + y0=1.0, + dt0=1, + solver=diffrax.Euler(), + saveat=saveat, + event=diffrax.Event( + cond_fn=lambda t, y, args, **k: t - 3.5, + root_finder=optx.Newton(rtol=1e-5, atol=1e-5), + ), + max_steps=6, + ) + assert sol.result == diffrax.RESULTS.event_occurred + assert sol.ts is not None + assert sol.ys is not None + return sol.ts, sol.ys + + ts1, ys1 = solve(diffrax.SaveAt(steps=2)) + assert jnp.allclose(ts1, jnp.array([2.0, 3.5, jnp.inf])) + # Computed using Euler + # y(1) = 0.5 + # y(2) = 0.25 + # y(3) = 0.125 + # y(4) = 0.0625 + # linearly interpolate => y(3.5) = 0.09375 + assert jnp.allclose(ys1, jnp.array([0.25, 0.09375, jnp.inf])) + ts2, ys2 = solve(diffrax.SaveAt(steps=2, t1=True)) + assert jnp.allclose(ts2, jnp.array([2.0, 3.5, jnp.inf])) + assert jnp.allclose(ys2, jnp.array([0.25, 0.09375, jnp.inf])) + ts3, ys3 = solve(diffrax.SaveAt(steps=3)) + assert jnp.allclose(ts3, jnp.array([3.0, jnp.inf])) + assert jnp.allclose(ys3, jnp.array([0.125, jnp.inf])) + ts4, ys4 = solve(diffrax.SaveAt(steps=3, t1=True)) + assert jnp.allclose(ts4, jnp.array([3.0, 3.5])) + assert jnp.allclose(ys4, jnp.array([0.125, 0.09375])) + + @pytest.mark.parametrize("subs", [True, False]) def test_t0_eq_t1(subs): y0 = jnp.array([2.0]) From 012994121ec95933fd4a0388205f336315ccc705 Mon Sep 17 00:00:00 2001 From: LuggiStruggi Date: Wed, 18 Jun 2025 22:42:13 +0200 Subject: [PATCH 06/24] Introduction of bidirectional vs. unidirectional triggering of events --- diffrax/_event.py | 36 ++++++++++++++++++++- diffrax/_integrate.py | 20 +++++++++--- test/test_event.py | 74 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 5 deletions(-) diff --git a/diffrax/_event.py b/diffrax/_event.py index 12d11114..eff74e24 100644 --- a/diffrax/_event.py +++ b/diffrax/_event.py @@ -3,6 +3,7 @@ import equinox as eqx import optimistix as optx +from jax.tree import flatten, unflatten from jaxtyping import Array, PyTree from ._custom_types import BoolScalarLike, FloatScalarLike, RealScalarLike @@ -20,7 +21,36 @@ class Event(eqx.Module): """ cond_fn: PyTree[Callable[..., BoolScalarLike | RealScalarLike]] - root_finder: optx.AbstractRootFinder | None = None + trig_dir: PyTree[None | bool] + root_finder: optx.AbstractRootFinder | None + + def __init__( + self, + cond_fn, + root_finder: optx.AbstractRootFinder | None = None, + trig_dir: None | bool | PyTree[None | bool] = None, + ): + vals_cond, treedef_cond = flatten(cond_fn) + + if isinstance(trig_dir, bool) or trig_dir is None: + vals_trig = [trig_dir] * len(vals_cond) + treedef_trig = treedef_cond + else: + vals_trig, treedef_trig = flatten(trig_dir, is_leaf=lambda x: x is None) + + if treedef_cond != treedef_trig: + raise ValueError("Missmatch in the structure of cond_fn and trigger_dir") + + if not all(x is None or isinstance(x, bool) for x in vals_trig): + raise ValueError( + "`trig_dir` must be a None, bool or a PyTree of None | bools" + + " with the same structure as cond_fn" + ) + + trig_tree = unflatten(treedef_cond, vals_trig) + self.cond_fn = cond_fn + self.root_finder = root_finder + self.trig_dir = trig_tree Event.__init__.__doc__ = """**Arguments:** @@ -39,6 +69,10 @@ class Event(eqx.Module): [`optimistix.Newton`](https://docs.kidger.site/optimistix/api/root_find/#optimistix.Newton) would be a typical choice here. +- `trig_dir`: None or bool or PyTree of None or bool of the same shape as cond_fn, + that decides for each cond_fn if it triggers an event from a zero-cossing in both + directions (None), from an upcrossing (True) or from a downcrossing (False). + !!! Example Consider a bouncing ball dropped from some intial height $x_0$. We can model diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index e78c0004..545e6dd7 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -543,7 +543,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: event_tnext = state.tnext event_dense_info = dense_info - def _outer_cond_fn(cond_fn_i, old_event_value_i): + def _outer_cond_fn(cond_fn_i, old_event_value_i, trig_dir_i): new_event_value_i = cond_fn_i( tprev, y, @@ -577,9 +577,19 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): f"{new_dtype}." ) if jnp.issubdtype(new_dtype, jnp.floating): - event_mask_i = jnp.sign(old_event_value_i) != jnp.sign( - new_event_value_i - ) + if trig_dir_i is None: + event_mask_i = jnp.sign(old_event_value_i) != jnp.sign( + new_event_value_i + ) + elif trig_dir_i: + event_mask_i = (jnp.sign(old_event_value_i) <= 0) & ( + jnp.sign(new_event_value_i) > 0 + ) + else: + event_mask_i = (jnp.sign(old_event_value_i) > 0) & ( + jnp.sign(new_event_value_i) <= 0 + ) + elif jnp.issubdtype(new_dtype, jnp.bool_): event_mask_i = new_event_value_i else: @@ -593,8 +603,10 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): _outer_cond_fn, event.cond_fn, state.event_values, + event.trig_dir, is_leaf=callable, ) + event_structure = jtu.tree_structure(event.cond_fn, is_leaf=callable) event_values, event_mask = jtu.tree_transpose( event_structure, diff --git a/test/test_event.py b/test/test_event.py index 12581c53..d4e37850 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -709,3 +709,77 @@ def save_fn(t, y, args): assert jnp.sum(jnp.isfinite(ts_2)) == steps assert jnp.all(jnp.isclose(ys_2[steps - 1], jnp.array([thr, 0]), atol=1e-5)) assert jnp.all(jnp.isclose(ys_1.y[ts_event - 1], last_save, atol=1e-5)) + + +def test_event_trig_dir(): + term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) + solver = diffrax.Tsit5() + t0 = 0.0 + t1 = 10.0 + dt0 = 1.0 + y0 = jnp.array([0, 1]) + + def cond_fn0(t, y, args, **kwargs): + return y[0] - 5.0 + + def cond_fn1(t, y, args, **kwargs): + return y[1] - 5.0 + + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + event = diffrax.Event((cond_fn0, cond_fn1), root_finder, (True, False)) + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) + + assert jnp.isclose(cast(Array, sol.ts)[-1], 5.0) + assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([5.0, 6.0]))) + + +def test_event_trig_dir_single_true(): + term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) + solver = diffrax.Tsit5() + t0 = 0.0 + t1 = 10.0 + dt0 = 1.0 + y0 = jnp.array([0, 1]) + + def cond_fn0(t, y, args, **kwargs): + return y[0] - 5.0 + + def cond_fn1(t, y, args, **kwargs): + return -(y[1] - 5.0) + + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + event = diffrax.Event((cond_fn0, cond_fn1), root_finder, True) + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) + + assert jnp.isclose(cast(Array, sol.ts)[-1], 5.0) + assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([5.0, 6.0]))) + + +def test_event_trig_dir_single_none(): + term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) + solver = diffrax.Tsit5() + t0 = 0.0 + t1 = 10.0 + dt0 = 1.0 + y0 = jnp.array([0, 1]) + + def cond_fn0(t, y, args, **kwargs): + return y[0] - 5.0 + + def cond_fn1(t, y, args, **kwargs): + return -(y[1] - 5.0) + + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + event = diffrax.Event((cond_fn0, cond_fn1), root_finder, None) + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) + + assert jnp.isclose(cast(Array, sol.ts)[-1], 4.0) + assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([4.0, 5.0]))) + + +def test_event_trig_dir_pytree_structure(): + f = lambda x: x + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + diffrax.Event([f, f, [f, (f)]], root_finder, [True, None, [False, (True)]]) + with pytest.raises(ValueError): + diffrax.Event([f, f, [f, f, f]], root_finder, [True, None, [False, (True)]]) From 1b3daf0fe1cf2b6c25f773a28672ca4ef658e5b0 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 18 Jun 2025 22:52:34 +0200 Subject: [PATCH 07/24] Extend event crossing tests --- test/test_event.py | 101 +++++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 54 deletions(-) diff --git a/test/test_event.py b/test/test_event.py index d4e37850..ea368141 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -719,67 +719,60 @@ def test_event_trig_dir(): dt0 = 1.0 y0 = jnp.array([0, 1]) - def cond_fn0(t, y, args, **kwargs): - return y[0] - 5.0 - - def cond_fn1(t, y, args, **kwargs): - return y[1] - 5.0 - - root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) - event = diffrax.Event((cond_fn0, cond_fn1), root_finder, (True, False)) - sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) - - assert jnp.isclose(cast(Array, sol.ts)[-1], 5.0) - assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([5.0, 6.0]))) - - -def test_event_trig_dir_single_true(): - term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) - solver = diffrax.Tsit5() - t0 = 0.0 - t1 = 10.0 - dt0 = 1.0 - y0 = jnp.array([0, 1]) - - def cond_fn0(t, y, args, **kwargs): - return y[0] - 5.0 - - def cond_fn1(t, y, args, **kwargs): - return -(y[1] - 5.0) - - root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) - event = diffrax.Event((cond_fn0, cond_fn1), root_finder, True) - sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) - - assert jnp.isclose(cast(Array, sol.ts)[-1], 5.0) - assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([5.0, 6.0]))) - - -def test_event_trig_dir_single_none(): - term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) - solver = diffrax.Tsit5() - t0 = 0.0 - t1 = 10.0 - dt0 = 1.0 - y0 = jnp.array([0, 1]) - - def cond_fn0(t, y, args, **kwargs): - return y[0] - 5.0 + def up_cond(t, y, args, **kwargs): + del t, args, kwargs + y0, _ = y + return y0 - 5.0 - def cond_fn1(t, y, args, **kwargs): - return -(y[1] - 5.0) + def down_cond(t, y, args, **kwargs): + del t, args, kwargs + _, y1 = y + return -(y1 - 5.0) root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) - event = diffrax.Event((cond_fn0, cond_fn1), root_finder, None) - sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) - assert jnp.isclose(cast(Array, sol.ts)[-1], 4.0) - assert jnp.all(jnp.isclose(cast(Array, sol.ys)[-1], jnp.array([4.0, 5.0]))) + def run(event): + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) + assert sol.ts is not None + assert sol.ys is not None + [t_final] = sol.ts + [y_final] = sol.ys + return t_final, y_final + + event = diffrax.Event((up_cond, down_cond), root_finder, True) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 5.0) + assert jnp.allclose(y_final, jnp.array([5.0, 6.0])) + + event = diffrax.Event((up_cond, down_cond), root_finder, False) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 4.0) + assert jnp.allclose(y_final, jnp.array([4.0, 5.0])) + + event = diffrax.Event((up_cond, down_cond), root_finder, (True, True)) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 5.0) + assert jnp.allclose(y_final, jnp.array([5.0, 6.0])) + + event = diffrax.Event((up_cond, down_cond), root_finder, (True, False)) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 4.0) + assert jnp.allclose(y_final, jnp.array([4.0, 5.0])) + + event = diffrax.Event((up_cond, down_cond), root_finder, (False, True)) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 10.0) + assert jnp.allclose(y_final, jnp.array([10.0, 11.0])) + + event = diffrax.Event((up_cond, down_cond), root_finder, (False, None)) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 4.0) + assert jnp.allclose(y_final, jnp.array([4.0, 5.0])) def test_event_trig_dir_pytree_structure(): f = lambda x: x root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) - diffrax.Event([f, f, [f, (f)]], root_finder, [True, None, [False, (True)]]) + diffrax.Event([f, f, [f, f]], root_finder, [True, None, [False, True]]) with pytest.raises(ValueError): - diffrax.Event([f, f, [f, f, f]], root_finder, [True, None, [False, (True)]]) + diffrax.Event([f, f, [f, f, f]], root_finder, [True, None, [False, True]]) From 49752b84cc43ac46a642f1aac037ece2d95179e6 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 18 Jun 2025 23:05:24 +0200 Subject: [PATCH 08/24] Fixes for pytree-valued condition functions. --- diffrax/_event.py | 39 +++++++++++++++-------------- diffrax/_integrate.py | 8 +++--- test/test_event.py | 57 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 24 deletions(-) diff --git a/diffrax/_event.py b/diffrax/_event.py index eff74e24..f1d1ea7e 100644 --- a/diffrax/_event.py +++ b/diffrax/_event.py @@ -2,8 +2,8 @@ from collections.abc import Callable import equinox as eqx +import jax.tree_util as jtu import optimistix as optx -from jax.tree import flatten, unflatten from jaxtyping import Array, PyTree from ._custom_types import BoolScalarLike, FloatScalarLike, RealScalarLike @@ -21,36 +21,33 @@ class Event(eqx.Module): """ cond_fn: PyTree[Callable[..., BoolScalarLike | RealScalarLike]] - trig_dir: PyTree[None | bool] + direction: PyTree[None | bool] root_finder: optx.AbstractRootFinder | None def __init__( self, cond_fn, root_finder: optx.AbstractRootFinder | None = None, - trig_dir: None | bool | PyTree[None | bool] = None, + direction: None | bool | PyTree[None | bool] = None, ): - vals_cond, treedef_cond = flatten(cond_fn) + if direction in (None, False, True): + direction = jtu.tree_map(lambda _: direction, cond_fn, is_leaf=callable) - if isinstance(trig_dir, bool) or trig_dir is None: - vals_trig = [trig_dir] * len(vals_cond) - treedef_trig = treedef_cond - else: - vals_trig, treedef_trig = flatten(trig_dir, is_leaf=lambda x: x is None) - - if treedef_cond != treedef_trig: - raise ValueError("Missmatch in the structure of cond_fn and trigger_dir") + direction_leaves, direction_structure = jtu.tree_flatten( + direction, is_leaf=lambda x: x is None + ) + if direction_structure != jtu.tree_structure(cond_fn, is_leaf=callable): + raise ValueError("Missmatch in the structure of `cond_fn` and `direction`.") - if not all(x is None or isinstance(x, bool) for x in vals_trig): + if any(x not in (None, False, True) for x in direction_leaves): raise ValueError( - "`trig_dir` must be a None, bool or a PyTree of None | bools" - + " with the same structure as cond_fn" + "`trig_dir` must be a `None`, `bool`, or a PyTree of `None | bool`s " + "with the same structure as `cond_fn`." ) - trig_tree = unflatten(treedef_cond, vals_trig) self.cond_fn = cond_fn self.root_finder = root_finder - self.trig_dir = trig_tree + self.direction = direction Event.__init__.__doc__ = """**Arguments:** @@ -69,9 +66,11 @@ def __init__( [`optimistix.Newton`](https://docs.kidger.site/optimistix/api/root_find/#optimistix.Newton) would be a typical choice here. -- `trig_dir`: None or bool or PyTree of None or bool of the same shape as cond_fn, - that decides for each cond_fn if it triggers an event from a zero-cossing in both - directions (None), from an upcrossing (True) or from a downcrossing (False). +- `direction`: `None` or `bool` or PyTree of `None | bool` of the same shape as + `cond_fn`, that decides for each `cond_fn` if it triggers an event from a + zero-cossing in both directions (`None`), from an upcrossing (`True`) or from a + downcrossing (`False`). Only needed for those `cond_fn` which return floating point + numbers; ignored for those `cond_fn` which return booleans. !!! Example diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 545e6dd7..8d1c66b4 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -543,7 +543,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: event_tnext = state.tnext event_dense_info = dense_info - def _outer_cond_fn(cond_fn_i, old_event_value_i, trig_dir_i): + def _outer_cond_fn(cond_fn_i, old_event_value_i, direction_i): new_event_value_i = cond_fn_i( tprev, y, @@ -577,11 +577,11 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i, trig_dir_i): f"{new_dtype}." ) if jnp.issubdtype(new_dtype, jnp.floating): - if trig_dir_i is None: + if direction_i is None: event_mask_i = jnp.sign(old_event_value_i) != jnp.sign( new_event_value_i ) - elif trig_dir_i: + elif direction_i: event_mask_i = (jnp.sign(old_event_value_i) <= 0) & ( jnp.sign(new_event_value_i) > 0 ) @@ -603,7 +603,7 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i, trig_dir_i): _outer_cond_fn, event.cond_fn, state.event_values, - event.trig_dir, + event.direction, is_leaf=callable, ) diff --git a/test/test_event.py b/test/test_event.py index ea368141..5de0ded6 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -776,3 +776,60 @@ def test_event_trig_dir_pytree_structure(): diffrax.Event([f, f, [f, f]], root_finder, [True, None, [False, True]]) with pytest.raises(ValueError): diffrax.Event([f, f, [f, f, f]], root_finder, [True, None, [False, True]]) + + +def test_event_with_pytree_valued_condition_function(): + term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0])) + solver = diffrax.Tsit5() + t0 = 0.0 + t1 = 10.0 + dt0 = 1.0 + y0 = jnp.array([0, 1]) + + class CondFn(eqx.Module): + crossing: tuple[tuple[float]] + downcrossing: bool + + def __call__(self, t, y, args, **kwargs): + del t, args, kwargs + y0, _ = y + [[crossing]] = self.crossing + out = y0 - crossing + if self.downcrossing: + out = -out + return out + + def another_cond_fn(t, y, args, **kwargs): + del t, y, args, kwargs + return 5.0 + + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + + def run(event): + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) + assert sol.ts is not None + assert sol.ys is not None + [t_final] = sol.ts + [y_final] = sol.ys + return t_final, y_final + + event = diffrax.Event( + (CondFn(((3.0,),), False), another_cond_fn), root_finder, (True, None) + ) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 3.0) + assert jnp.allclose(y_final, jnp.array([3.0, 4.0])) + + event = diffrax.Event( + (CondFn(((3.0,),), False), another_cond_fn), root_finder, (None, False) + ) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 3.0) + assert jnp.allclose(y_final, jnp.array([3.0, 4.0])) + + event = diffrax.Event( + (CondFn(((3.0,),), False), another_cond_fn), root_finder, (False, False) + ) + t_final, y_final = run(event) + assert jnp.allclose(t_final, 10.0) + assert jnp.allclose(y_final, jnp.array([10.0, 11.0])) From 97b656e2386152c8cdacaa30b47287cb17fd8135 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 16 Jul 2025 02:16:45 +0100 Subject: [PATCH 09/24] Improve ConstantStepSize incrementation --- diffrax/_step_size_controller/constant.py | 63 ++++++++++++++++++++--- test/test_progress_meter.py | 10 ++-- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/diffrax/_step_size_controller/constant.py b/diffrax/_step_size_controller/constant.py index e4100f83..a7e5c879 100644 --- a/diffrax/_step_size_controller/constant.py +++ b/diffrax/_step_size_controller/constant.py @@ -11,7 +11,15 @@ from .base import AbstractStepSizeController -class ConstantStepSize(AbstractStepSizeController[RealScalarLike, RealScalarLike]): +# ConstantStepSizeState = (steps_completed, num_steps, t0_sim, t1_sim_or_dt0) +_ConstantStepSizeState = tuple[ + IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike +] + + +class ConstantStepSize( + AbstractStepSizeController[_ConstantStepSizeState, RealScalarLike] +): """Use a constant step size, equal to the `dt0` argument of [`diffrax.diffeqsolve`][]. """ @@ -29,14 +37,23 @@ def init( args: Args, func: Callable[[PyTree[AbstractTerm], RealScalarLike, Y, Args], VF], error_order: RealScalarLike | None, - ) -> tuple[RealScalarLike, RealScalarLike]: - del terms, t1, y0, args, func, error_order + ) -> tuple[RealScalarLike, _ConstantStepSizeState]: + del terms, y0, args, func, error_order if dt0 is None: raise ValueError( "Constant step size solvers cannot select step size automatically; " "please pass a value for `dt0`." ) - return t0 + dt0, dt0 + steps_completed = jnp.asarray(1, dtype=jnp.int32) + # Special case for infinite t1, allow termination based on conditional tests + # Use num_steps=-1 to ensure finite int + num_steps = jnp.where( + jnp.isfinite(t1), + jnp.astype(jnp.ceil((t1 - t0) / eqxi.nextafter(dt0)), jnp.int32), + -1, + ) + t1_sim_or_dt0 = jnp.where(jnp.isfinite(t1), t1, dt0) + return t0 + dt0, (steps_completed, num_steps, t0, t1_sim_or_dt0) def adapt_step_size( self, @@ -47,15 +64,45 @@ def adapt_step_size( args: Args, y_error: Y | None, error_order: RealScalarLike | None, - controller_state: RealScalarLike, - ) -> tuple[bool, RealScalarLike, RealScalarLike, bool, RealScalarLike, RESULTS]: + controller_state: _ConstantStepSizeState, + ) -> tuple[ + bool, + RealScalarLike, + RealScalarLike, + bool, + _ConstantStepSizeState, + RESULTS, + ]: del t0, y0, y1_candidate, args, y_error, error_order + steps_already_completed, num_steps, t0_sim, t1_sim_or_dt0 = controller_state + # Number of steps that will be completed when this function returns. + steps_completed = steps_already_completed + 1 + + time_dtype = jnp.result_type(t0_sim, t1_sim_or_dt0) + + # Calculate step size by calculating fraction of `t1 - t0` to avoid compounding + # of truncation/rounding errors + t1_next = jnp.where( + num_steps >= 0, + jnp.where( + steps_completed == num_steps, + t1_sim_or_dt0, + t0_sim + + (t1_sim_or_dt0 - t0_sim) + * jnp.astype(steps_completed, time_dtype) + / jnp.astype(num_steps, time_dtype), + ), + # Special case for non-finite t1_sim + # in this t1_sim_or_dt0 is dt0 + t1 + t1_sim_or_dt0, + ) + return ( True, t1, - t1 + controller_state, + t1_next, False, - controller_state, + (steps_completed, num_steps, t0_sim, t1_sim_or_dt0), RESULTS.successful, ) diff --git a/test/test_progress_meter.py b/test/test_progress_meter.py index a9613c9e..07db1277 100644 --- a/test/test_progress_meter.py +++ b/test/test_progress_meter.py @@ -53,31 +53,31 @@ def solve(t0): t1=5, dt0=0.01, y0=1.0, - progress_meter=diffrax.TextProgressMeter(minimum_increase=0.1), + progress_meter=diffrax.TextProgressMeter(minimum_increase=0.0999), ) solve(2.0) jax.effects_barrier() captured = capfd.readouterr() - expected = "0.00%\n10.33%\n20.67%\n31.00%\n41.33%\n51.67%\n62.00%\n72.33%\n82.67%\n93.00%\n100.00%\n" # noqa: E501 + expected = "%\n".join(f"{x:.2f}" for x in jnp.linspace(0, 100, num=11)) + "%\n" assert captured.out == expected jax.vmap(solve)(jnp.arange(3.0)) jax.effects_barrier() captured = capfd.readouterr() - expected = "0.00%\n10.00%\n20.00%\n30.00%\n40.00%\n50.20%\n60.40%\n70.60%\n80.80%\n91.00%\n100.00%\n" # noqa: E501 + expected = "%\n".join(f"{x:.2f}" for x in jnp.linspace(0, 100, num=11)) + "%\n" assert captured.out == expected jax.jit(solve)(2.0) jax.effects_barrier() captured = capfd.readouterr() - expected = "0.00%\n10.33%\n20.67%\n31.00%\n41.33%\n51.67%\n62.00%\n72.33%\n82.67%\n93.00%\n100.00%\n" # noqa: E501 + expected = "%\n".join(f"{x:.2f}" for x in jnp.linspace(0, 100, num=11)) + "%\n" assert captured.out == expected jax.jit(jax.vmap(solve))(jnp.arange(3.0)) jax.effects_barrier() captured = capfd.readouterr() - expected = "0.00%\n10.00%\n20.00%\n30.00%\n40.00%\n50.20%\n60.40%\n70.60%\n80.80%\n91.00%\n100.00%\n" # noqa: E501 + expected = "%\n".join(f"{x:.2f}" for x in jnp.linspace(0, 100, num=11)) + "%\n" assert captured.out == expected From a1305ea4f689b4e8db187c118a0eb1909ba9966b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:32:23 +0200 Subject: [PATCH 10/24] Tweaked layout of ConstantStepSize code --- diffrax/_step_size_controller/constant.py | 43 +++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/diffrax/_step_size_controller/constant.py b/diffrax/_step_size_controller/constant.py index a7e5c879..02252869 100644 --- a/diffrax/_step_size_controller/constant.py +++ b/diffrax/_step_size_controller/constant.py @@ -45,13 +45,12 @@ def init( "please pass a value for `dt0`." ) steps_completed = jnp.asarray(1, dtype=jnp.int32) - # Special case for infinite t1, allow termination based on conditional tests - # Use num_steps=-1 to ensure finite int - num_steps = jnp.where( - jnp.isfinite(t1), - jnp.astype(jnp.ceil((t1 - t0) / eqxi.nextafter(dt0)), jnp.int32), - -1, - ) + # `eqxi.nextafter` to handle floating point error, see + # https://github.com/patrick-kidger/diffrax/pull/666#discussion_r2215868590 + num_steps = jnp.astype(jnp.ceil((t1 - t0) / eqxi.nextafter(dt0)), jnp.int32) + # Use `num_steps=-1` as a marker to indicate that `diffeqsolve(..., t1=...)` is + # infinite. + num_steps = jnp.where(jnp.isfinite(t1), num_steps, -1) t1_sim_or_dt0 = jnp.where(jnp.isfinite(t1), t1, dt0) return t0 + dt0, (steps_completed, num_steps, t0, t1_sim_or_dt0) @@ -78,25 +77,23 @@ def adapt_step_size( # Number of steps that will be completed when this function returns. steps_completed = steps_already_completed + 1 + # Calculate step size by calculating fraction of `t1 - t0` -- rather than just + # adding up `dt0` multiple times -- to avoid compounding of truncation/rounding + # errors. time_dtype = jnp.result_type(t0_sim, t1_sim_or_dt0) - - # Calculate step size by calculating fraction of `t1 - t0` to avoid compounding - # of truncation/rounding errors - t1_next = jnp.where( - num_steps >= 0, - jnp.where( - steps_completed == num_steps, - t1_sim_or_dt0, - t0_sim - + (t1_sim_or_dt0 - t0_sim) - * jnp.astype(steps_completed, time_dtype) - / jnp.astype(num_steps, time_dtype), - ), - # Special case for non-finite t1_sim - # in this t1_sim_or_dt0 is dt0 - t1 + t1_sim_or_dt0, + t1_next = t0_sim + (t1_sim_or_dt0 - t0_sim) * ( + jnp.astype(steps_completed, time_dtype) / jnp.astype(num_steps, time_dtype) ) + # If we're on the final step then use `t1` directly, this time to avoid + # floating-point weirdness in the above. (Not sure if necessary?) + t1_next = jnp.where(steps_completed == num_steps, t1_sim_or_dt0, t1_next) + + # If `num_steps == -1` then we use that as a marker to indicate that we have an + # infinite `diffeqsolve(..., t1=...)`. In this case then never mind everything + # above, we really do just want to keep adding on `dt0` multiple times. + t1_next = jnp.where(num_steps >= 0, t1_next, t1 + t1_sim_or_dt0) + return ( True, t1, From 5d9f6b99af175ac3c0631b81a5af4d3de71b49ad Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 2 Aug 2025 20:01:04 +0100 Subject: [PATCH 11/24] Use 100 ULP's to clip timesteps close to t1 (#660) * Use 100 ULP's to clip timesteps close to t1 * test that t1-t0 > 100 ULP's * revert testing as t1 is traced * remove unnecessary pyright ignores --- diffrax/_integrate.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 8d1c66b4..bc319d40 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -263,15 +263,11 @@ def _save( ) -def _clip_to_end(tprev, tnext, t1, keep_step): - # The tolerance means that we don't end up with too-small intervals for - # dense output, which then gives numerically unstable answers due to floating +def _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step): + # The tolerance of ~100 ULP's means that we don't end up with too-small intervals + # for dense output, which then gives numerically unstable answers due to floating # point errors. - if tnext.dtype == jnp.dtype("float64"): - tol = 1e-10 - else: - tol = 1e-6 - clip = tnext > t1 - tol + clip = tnext > t1_clip_floor tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev)) return jnp.where(clip, tclip, tnext) @@ -308,6 +304,11 @@ def loop( outer_while_loop, progress_meter, ): + # Calculate in advance t1 - 100 ULP's: the threshold at which to round tnext to t1 + t1_clip_floor = t1 + for _ in range(100): + t1_clip_floor = eqxi.prevbefore(t1_clip_floor) + if saveat.dense: dense_ts = init_state.dense_ts dense_ts = dense_ts.at[0].set(t0) @@ -397,7 +398,7 @@ def body_fun_aux(state): # tprev = jnp.minimum(tprev, t1) - tnext = _clip_to_end(tprev, tnext, t1, keep_step) + tnext = _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step) progress_meter_state = progress_meter.step( state.progress_meter_state, linear_rescale(t0, tprev, t1) From 1ac005dde8862087a05c83b82cdcb1ad608b34da Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 13 Jul 2025 12:07:23 +0200 Subject: [PATCH 12/24] Tests that a jump at t1 is saved. --- test/test_adaptive_stepsize_controller.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 8161b9d8..21c24e4a 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -336,3 +336,28 @@ def test_implicit_solver_with_clip_controller(new: bool): max_steps=16384, saveat=diffrax.SaveAt(t1=True), ) + + +# https://github.com/patrick-kidger/diffrax/issues/663 +# `jump_ts` sets the time we step to as `prevbefore` the time provided. +# Clipping at t1 saves us! We need to clip at at least 1 ULP. +def test_jump_at_t1_with_large_t1_in_float32(): + t0 = jnp.array(0.0, dtype=jnp.float32) + t1 = jnp.array(1e3, dtype=jnp.float32) + dt0 = jnp.array(0.01, dtype=jnp.float32) + y0 = jnp.array(1, dtype=jnp.float32) + saveat = diffrax.SaveAt(ts=t1[None]) + ssc = diffrax.ClipStepSizeController( + diffrax.PIDController(atol=1e-6, rtol=1e-6), jump_ts=t1[None] + ) + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: -y), + diffrax.Heun(), + t0=t0, + t1=t1, + dt0=dt0, + y0=y0, + stepsize_controller=ssc, + saveat=saveat, + ) + assert sol.ts == jnp.array([t1]) From b7dc392f8d8748cdf3fcee9b962ca597de760be5 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Sun, 3 Aug 2025 17:54:38 -0400 Subject: [PATCH 13/24] adapt --- diffrax/_solver/align.py | 5 ++++- diffrax/_solver/spark.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index 806a8a8f..750469b6 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -15,6 +15,7 @@ UnderdampedLangevinTuple, UnderdampedLangevinX, ) +from .base import AbstractAdaptiveSolver from .foster_langevin_srk import ( AbstractCoeffs, AbstractFosterLangevinSRK, @@ -44,7 +45,9 @@ def __init__(self, beta, a1, b1, aa, chh): _ErrorEstimate = UnderdampedLangevinTuple -class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): +class ALIGN( + AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate], AbstractAdaptiveSolver +): r"""The Adaptive Langevin via Interpolated Gradients and Noise method designed by James Foster. This is a second order solver for the Underdamped Langevin Diffusion, and accepts terms of the form diff --git a/diffrax/_solver/spark.py b/diffrax/_solver/spark.py index bfaa0323..dda948b9 100644 --- a/diffrax/_solver/spark.py +++ b/diffrax/_solver/spark.py @@ -3,7 +3,7 @@ import equinox.internal as eqxi import numpy as np -from .base import AbstractStratonovichSolver +from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau @@ -35,7 +35,7 @@ ) -class SPaRK(AbstractSRK, AbstractStratonovichSolver): +class SPaRK(AbstractSRK, AbstractStratonovichSolver, AbstractAdaptiveSolver): r"""The Splitting Path Runge-Kutta method. It uses three evaluations of the drift and diffusion per step, and has the following From 8a72ee5cc662af11402100737b254e37002037dd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 31 Aug 2025 13:47:19 +0200 Subject: [PATCH 14/24] Fixes 681 --- diffrax/_integrate.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index bc319d40..1855d4ae 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -161,10 +161,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): pass elif n_term_args == 2: vf_type_expected, control_type_expected = term_args - try: - vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) - except Exception as e: - raise ValueError(f"Error while tracing {term}.vf: " + str(e)) + vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) vf_type_compatible = eqx.filter_eval_shape( better_isinstance, vf_type, vf_type_expected ) @@ -173,10 +170,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 - try: - control_type = eqx.filter_eval_shape(contr, t, t) - except Exception as e: - raise ValueError(f"Error while tracing {term}.contr: " + str(e)) + control_type = eqx.filter_eval_shape(contr, t, t) control_type_compatible = eqx.filter_eval_shape( better_isinstance, control_type, control_type_expected ) From 2fd3ef34ef4267425b28b4c1338803dd1ec22e9a Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:32 +0200 Subject: [PATCH 15/24] Added benchmarking FAQ --- docs/further_details/faq.md | 62 +++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index fe30a402..79d44d6c 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -1,15 +1,65 @@ # FAQ -### Compilation is taking a long time. - -- Set `dt0=`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time. -- Prefer `SaveAt(t0=True, t1=True)` over `SaveAt(ts=[t0, t1])`, if possible. -- It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s), e.g. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. - ### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). +### Diffrax seem to be slower than ? + +Questions of this form are a fairly common source of issues in the Diffrax issue tracker! In practice, Diffrax is consistently amongst the fastest ODE solvers, and these usually stem from incorrect usage (e.g. recompiling your JAX program on each invocation) or comparisons (e.g. using different solvers/tolerances in each implementation). + +Here's a list of some of the things to keep in mind when performing such comparisons: + +1. First of all, the usual list of JAX profiling concerns: + + a. Make sure that your JAX program is compiled only once, and not repeatedly on each invocation (for example by passing in different raw Python floats each time). Use [`equinox.debug.assert_max_traces(max_traces=1)`](https://docs.kidger.site/equinox/api/debug/#equinox.debug.assert_max_traces) to debug this. + + b. Your entire computation should be wrapped in a single `jax.jit`'d function (or equivalently `equinox.filter_jit`). + + c. Run this function in advance (to JIT-compile it), before running it again to measure its speed. + + d. Make sure not to include any code that is ran outside of the JIT'd function in your timings. + + e. Make sure to call `jax.block_until_ready` on the output of the the function. + + Typically your code should follow this template: + ```python + import equinox as eqx + import jax + import timeit + + @jax.jit + @eqx.debug.assert_max_traces(max_traces=1) + def run(x): + ... + + x = ... + run(x) # compile + execution_time = min(timeit.repeat(lambda: jax.block_until_ready(run(x)), number=1, repeat=20)) + ``` + +2. Use the same ODE solver in both implementations to get an apples-to-apples comparison. It's not surprising that different solvers give different performance characteristics. (And if one implementation does not provide a solver that the other does, then no comparison can be made.) + +3. Use the same step size control in both implementations. + + a. If using adaptive step sizes then note that tolerances (the `rtol`, `atol` in `diffeqsolve(..., stepsize_controller=PIDController(rtol=..., atol=...))`) have solver- and implementation-specific meanings, so having these be the same is not enough. Aim to have roughly the same number of steps instead. You can check the number of steps taken in Diffrax via `diffeqsolve(...).stats['num_steps']`. + + b. If using an automatic initial step size (`diffeqsolve(..., dt0=None)`) then use this (or disable this) in both implementations. + +4. If comparing to other JAX implementations, then make sure to set `import os; os.environ["EQX_ON_ERROR"] = "nan"` at the top of your script (before you import Diffrax or Equinox). This will disable various runtime correctness checks performed by Diffrax that are are typically not performed by other JAX frameworks. These add a few milliseconds of overhead that typically does not matter in real-word usage but may be large enough to appear in microbenchmarks. + + a. If comparing to a loop-over-steps using `jax.lax.scan`, then the equivalent step size control in Diffrax is `diffeqsolve(..., stepsize_controller=StepTo(...))`. + +5. If you'd like to be really precise, then the best way to benchmark competing implementations is with a work-precision diagram: solve your ODE once with very tight tolerances and a very accurate solver (in any implementation). Then for each implementation: vary the tolerances or step sizes, and plot the time for the solve against and the numerical difference between the solution and the very accurate solution. This isn't required but is the gold-standard for benchmark comparisons. + +6. Both implementations should use the same precision (`float32` vs `float64`). Note that JAX defaults to 32-bit precision and requires a flag to enable 64-bit precision. + +7. The problem being solved should be large enough (ideally at least 100 milliseconds to solve) that you are not simply measuring various small overheads in different frameworks. + +Take a look at [Diffrax issue #82](https://github.com/patrick-kidger/diffrax/issues/82) for a good example of how seemingly-reasonable benchmarks can hide a few pitfalls! + +If you think you have a performance issue – after checking all of the above! – then feel free to open an issue on the Diffrax issue page. You should include a code snippet that demonstrates the issue; typically this should not be more than around 50 lines long if we are going to be able to volunteer to help you debug it :-). + ### How does this compare to `jax.experimental.ode.odeint`? The equivalent solver in Diffrax is: From 43f82dc5aa6db586e96349c057890ff8604bc333 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 3 Oct 2025 14:21:36 +0200 Subject: [PATCH 16/24] Standardised infra --- .github/workflows/release.yml | 6 +++--- .github/workflows/run_tests.yml | 11 +++++++---- .gitignore | 1 + .pre-commit-config.yaml | 8 ++++---- CONTRIBUTING.md | 18 +++++------------- diffrax/_adjoint.py | 3 +-- diffrax/_brownian/path.py | 6 +++--- diffrax/_brownian/tree.py | 6 +++--- diffrax/_integrate.py | 6 +++++- diffrax/_solution.py | 10 +++++----- diffrax/_solver/dopri5.py | 6 +++--- diffrax/_solver/dopri8.py | 6 +++--- diffrax/_solver/euler.py | 6 +++--- diffrax/_solver/euler_heun.py | 6 +++--- diffrax/_solver/foster_langevin_srk.py | 6 +++--- diffrax/_solver/implicit_euler.py | 6 +++--- diffrax/_solver/kencarp3.py | 6 +++--- diffrax/_solver/kencarp4.py | 6 +++--- diffrax/_solver/kencarp5.py | 6 +++--- diffrax/_solver/leapfrog_midpoint.py | 6 +++--- diffrax/_solver/milstein.py | 12 ++++++------ diffrax/_solver/reversible_heun.py | 6 +++--- diffrax/_solver/runge_kutta.py | 7 ++++++- diffrax/_solver/semi_implicit_euler.py | 10 +++++----- diffrax/_solver/srk.py | 3 +-- diffrax/_solver/tsit5.py | 6 +++--- diffrax/_step_size_controller/clip.py | 3 +-- diffrax/_step_size_controller/pid.py | 2 +- diffrax/_term.py | 8 ++++---- diffrax/_typing.py | 2 +- docs/devdocs/srk_example.ipynb | 10 +++++----- pyproject.toml | 20 +++++++++++++++++--- test/helpers.py | 4 ++-- test/requirements.txt | 6 ------ test/test_brownian.py | 1 + test/test_saveat_solution.py | 4 ++++ test/test_solver.py | 6 +++--- test/test_underdamped_langevin.py | 12 ++++++------ 38 files changed, 136 insertions(+), 121 deletions(-) delete mode 100644 test/requirements.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 861187e0..82cc5c2f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,14 +10,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Release - uses: patrick-kidger/action_update_python_project@v6 + uses: patrick-kidger/action_update_python_project@v8 with: python-version: "3.11" test-script: | cp -r ${{ github.workspace }}/test ./test cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml - python -m pip install -r ./test/requirements.txt - python -m test + uv sync --extra tests --no-install-project --inexact + uv run --no-sync pytest pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger github-token: ${{ github.token }} diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index b209bb3d..0137d757 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -23,13 +23,16 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r ./test/requirements.txt - + python -m pip install '.[dev,docs,tests]' - name: Checks with pre-commit - uses: pre-commit/action@v3.0.1 + run: | + pre-commit run --all-files - name: Test with pytest run: | - python -m pip install . python -m test + + - name: Check that documentation can be built. + run: | + mkdocs build diff --git a/.gitignore b/.gitignore index daf54d8c..176aa2af 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ site/ .pymon .idea/ .venv/ +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f8df8ce..29b072c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,15 +8,15 @@ repos: files: ^pyproject\.toml$ additional_dependencies: ["toml-sort==0.23.1"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 + rev: v0.13.0 hooks: - id: ruff-format # formatter - types_or: [ python, pyi, jupyter ] + types_or: [ python, pyi, jupyter, toml ] - id: ruff # linter - types_or: [ python, pyi, jupyter ] + types_or: [ python, pyi, jupyter, toml ] args: [ --fix ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.350 + rev: v1.1.405 hooks: - id: pyright additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1c9b3ced..78f188a3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started. First fork the library on GitHub. -Then clone and install the library in development mode: +Then clone and install the library: ```bash git clone https://github.com/your-username-here/diffrax.git cd diffrax -pip install -e . +pip install -e '.[dev]' +pre-commit install # `pre-commit` is installed by `pip` on the previous line ``` -Then install the pre-commit hook: - -```bash -pip install pre-commit -pre-commit install -``` - -These hooks use ruff to lint and format the code, and pyright to type-check it. - --- **If you're making changes to the code:** @@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass: ```bash -pip install -r test/requirements.txt -pytest +pip install -e '.[tests]' +pytest # `pytest` is installed by `pip` on the previous line. ``` Then push your changes back to your fork of the repository: diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 69a59b7a..7bc081b9 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -362,8 +362,7 @@ def loop( if is_unsafe_sde(terms): kind = "lax" msg = ( - "Cannot reverse-mode autodifferentiate when using " - "`UnsafeBrownianPath`." + "Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`." ) elif max_steps is None: kind = "lax" diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index f97eebf1..61f49644 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -62,9 +62,9 @@ class UnsafeBrownianPath(AbstractBrownianPath): """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea - ] = eqx.field(static=True) + levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = ( + eqx.field(static=True) + ) key: PRNGKeyArray def __init__( diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index fd0ede84..8a430668 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -235,9 +235,9 @@ class VirtualBrownianTree(AbstractBrownianPath): t1: RealScalarLike tol: RealScalarLike shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea - ] = eqx.field(static=True) + levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = ( + eqx.field(static=True) + ) key: PyTree[PRNGKeyArray] _spline: _Spline = eqx.field(static=True) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 1855d4ae..6fc38ce3 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import ( # noqa: UP035 Any, + cast, get_args, get_origin, Tuple, @@ -1164,7 +1165,10 @@ def _wrap(term): def _get_tols(x): outs = [] for attr in ("rtol", "atol", "norm"): - if getattr(solver.root_finder, attr) is use_stepsize_tol: + if ( + getattr(cast(AbstractImplicitSolver, solver).root_finder, attr) + is use_stepsize_tol + ): outs.append(getattr(x, attr)) return tuple(outs) diff --git a/diffrax/_solution.py b/diffrax/_solution.py index 3abe2725..392c447e 100644 --- a/diffrax/_solution.py +++ b/diffrax/_solution.py @@ -10,7 +10,7 @@ from ._path import AbstractPath -class RESULTS(optx.RESULTS): # pyright: ignore +class RESULTS(optx.RESULTS): # pyright: ignore[reportGeneralTypeIssues] successful = "" max_steps_reached = ( "The maximum number of solver steps was reached. Try increasing `max_steps`." @@ -121,8 +121,8 @@ class Solution(AbstractPath): # the structure of `subs`. # SaveAt(fn=...) means that `ys` will then follow with arbitrary sub-dependent # PyTree structures. - ts: PyTree[Real[Array, " ?times"], " S"] | None - ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None + ts: PyTree[Real[Array, " ?times"], " S"] | None # pyright: ignore[reportUndefinedVariable] + ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None # pyright: ignore interpolation: DenseInterpolation | None stats: dict[str, Any] result: RESULTS @@ -133,7 +133,7 @@ class Solution(AbstractPath): def evaluate( self, t0: RealScalarLike, t1: RealScalarLike | None = None, left: bool = True - ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: + ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable] """If dense output was saved, then evaluate the solution at any point in the region of integration `self.t0` to `self.t1`. @@ -153,7 +153,7 @@ def evaluate( def derivative( self, t: RealScalarLike, left: bool = True - ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: + ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable] r"""If dense output was saved, then calculate an **approximation** to the derivative of the solution at any point in the region of integration `self.t0` to `self.t1`. diff --git a/diffrax/_solver/dopri5.py b/diffrax/_solver/dopri5.py index 4a3cedfe..325f717f 100644 --- a/diffrax/_solver/dopri5.py +++ b/diffrax/_solver/dopri5.py @@ -91,9 +91,9 @@ class Dopri5(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _dopri5_tableau - interpolation_cls: ClassVar[ - Callable[..., _Dopri5Interpolation] - ] = _Dopri5Interpolation + interpolation_cls: ClassVar[Callable[..., _Dopri5Interpolation]] = ( + _Dopri5Interpolation + ) def order(self, terms): del terms diff --git a/diffrax/_solver/dopri8.py b/diffrax/_solver/dopri8.py index 4801eccc..958d8819 100644 --- a/diffrax/_solver/dopri8.py +++ b/diffrax/_solver/dopri8.py @@ -340,9 +340,9 @@ class Dopri8(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _dopri8_tableau - interpolation_cls: ClassVar[ - Callable[..., _Dopri8Interpolation] - ] = _Dopri8Interpolation + interpolation_cls: ClassVar[Callable[..., _Dopri8Interpolation]] = ( + _Dopri8Interpolation + ) def order(self, terms): del terms diff --git a/diffrax/_solver/euler.py b/diffrax/_solver/euler.py index 7ed11381..b1a323f7 100644 --- a/diffrax/_solver/euler.py +++ b/diffrax/_solver/euler.py @@ -24,9 +24,9 @@ class Euler(AbstractItoSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/euler_heun.py b/diffrax/_solver/euler_heun.py index 4940cfa5..97fbaee2 100644 --- a/diffrax/_solver/euler_heun.py +++ b/diffrax/_solver/euler_heun.py @@ -29,9 +29,9 @@ class EulerHeun(AbstractStratonovichSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index 717477d5..ea1026fa 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -231,9 +231,9 @@ def _choose(tay_leaf, direct_leaf): if inner is sentinel: inner = jtu.tree_structure(out) else: - assert ( - jtu.tree_structure(out) == inner - ), f"Expected {inner}, got {jtu.tree_structure(out)}" + assert jtu.tree_structure(out) == inner, ( + f"Expected {inner}, got {jtu.tree_structure(out)}" + ) return out diff --git a/diffrax/_solver/implicit_euler.py b/diffrax/_solver/implicit_euler.py index 064209da..68477c78 100644 --- a/diffrax/_solver/implicit_euler.py +++ b/diffrax/_solver/implicit_euler.py @@ -35,9 +35,9 @@ class ImplicitEuler(AbstractImplicitSolver, AbstractAdaptiveSolver): # # We don't use it as this seems to be quite a bad choice for low-order solvers: it # produces very oscillatory interpolations. - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp3.py b/diffrax/_solver/kencarp3.py index f15c7c03..cda04489 100644 --- a/diffrax/_solver/kencarp3.py +++ b/diffrax/_solver/kencarp3.py @@ -163,9 +163,9 @@ class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp3Interpolation] - ] = _KenCarp3Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp3Interpolation]] = ( + _KenCarp3Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp4.py b/diffrax/_solver/kencarp4.py index 4764900f..752cf38c 100644 --- a/diffrax/_solver/kencarp4.py +++ b/diffrax/_solver/kencarp4.py @@ -166,9 +166,9 @@ class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp4Interpolation] - ] = _KenCarp4Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp4Interpolation]] = ( + _KenCarp4Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp5.py b/diffrax/_solver/kencarp5.py index ba9af78c..b5b0f213 100644 --- a/diffrax/_solver/kencarp5.py +++ b/diffrax/_solver/kencarp5.py @@ -233,9 +233,9 @@ class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp5Interpolation] - ] = _KenCarp5Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp5Interpolation]] = ( + _KenCarp5Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 76487dfd..ddcaa12e 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -44,9 +44,9 @@ class LeapfrogMidpoint(AbstractSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 2 diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 69893716..897bc9c1 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -44,9 +44,9 @@ class StratonovichMilstein(AbstractStratonovichSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): raise ValueError("`StratonovichMilstein` should not be used to solve ODEs.") @@ -123,9 +123,9 @@ class ItoMilstein(AbstractItoSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): raise ValueError("`ItoMilstein` should not be used to solve ODEs.") diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 9288d00d..91617d4f 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -36,9 +36,9 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation # TODO use something better than this? + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation # TODO use something better than this? + ) def order(self, terms): return 2 diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index f46d2045..9473ab44 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -358,6 +358,11 @@ class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): tableau: AbstractClassVar[ButcherTableau | MultiButcherTableau] calculate_jacobian: AbstractClassVar[CalculateJacobian] + if TYPE_CHECKING: + # Pretend that we're implicit + root_finder: ClassVar[optx.AbstractRootFinder] + root_find_max_steps: ClassVar[int] + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau @@ -804,7 +809,7 @@ def embed_c(tab): ) implicit_predictor = np.zeros( (num_stages, num_stages), - dtype=np.result_type(*implicit_tableau.a_predictor), + dtype=np.result_type(*cast(tuple, implicit_tableau.a_predictor)), ) for i, a_predictor_i in enumerate(implicit_tableau.a_predictor): # pyright: ignore implicit_predictor[i + 1, : i + 1] = a_predictor_i diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 9e8a92ed..34122fe3 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -14,8 +14,8 @@ _ErrorEstimate: TypeAlias = None _SolverState: TypeAlias = None -Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] -Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] +Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] # pyright: ignore[reportUndefinedVariable] +Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] # pyright: ignore[reportUndefinedVariable] class SemiImplicitEuler(AbstractSolver): @@ -26,9 +26,9 @@ class SemiImplicitEuler(AbstractSolver): """ term_structure: ClassVar = (AbstractTerm, AbstractTerm) - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 1af9579f..ddb03095 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -54,8 +54,7 @@ class AbstractStochasticCoeffs(eqx.Module): b_error: eqx.AbstractVar[Float[np.ndarray, " s"] | None] @abc.abstractmethod - def check(self) -> int: - ... + def check(self) -> int: ... class AdditiveCoeffs(AbstractStochasticCoeffs): diff --git a/diffrax/_solver/tsit5.py b/diffrax/_solver/tsit5.py index 3060088a..7dc7a14f 100644 --- a/diffrax/_solver/tsit5.py +++ b/diffrax/_solver/tsit5.py @@ -181,9 +181,9 @@ class Tsit5(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _tsit5_tableau - interpolation_cls: ClassVar[ - Callable[..., _Tsit5Interpolation] - ] = _Tsit5Interpolation + interpolation_cls: ClassVar[Callable[..., _Tsit5Interpolation]] = ( + _Tsit5Interpolation + ) def order(self, terms): return 5 diff --git a/diffrax/_step_size_controller/clip.py b/diffrax/_step_size_controller/clip.py index c8a5cb76..0a642d6e 100644 --- a/diffrax/_step_size_controller/clip.py +++ b/diffrax/_step_size_controller/clip.py @@ -220,8 +220,7 @@ def __init__( self.jump_ts = _none_or_sorted_array(jump_ts) if (store_rejected_steps is not None) and (store_rejected_steps <= 0): raise ValueError( - "`store_rejected_steps must either be `None`" - " or a non-negative integer." + "`store_rejected_steps must either be `None` or a non-negative integer." ) self.store_rejected_steps = store_rejected_steps self.callback_on_reject = _callback_on_reject diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 7fb034fb..1092f184 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -89,7 +89,7 @@ def intermediate(carry): # PIDController(... step_ts=s, jump_ts=j) this should return a # ClipStepSizeController(PIDController(...), s, j). class _MetaPID(type(eqx.Module)): - def __call__(cls, *args, **kwargs): # pyright: ignore[reportSelfClsParameterName] + def __call__(cls, *args, **kwargs): step_ts = kwargs.pop("step_ts", None) jump_ts = kwargs.pop("jump_ts", None) if step_ts is not None or jump_ts is not None: diff --git a/diffrax/_term.py b/diffrax/_term.py index 41f7af09..9b0c0314 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -3,7 +3,7 @@ import typing import warnings from collections.abc import Callable -from typing import cast, Generic, TypeAlias, TypeVar +from typing import Any, cast, Generic, TypeAlias, TypeVar import equinox as eqx import jax @@ -12,7 +12,7 @@ import lineax as lx import numpy as np from equinox.internal import ω -from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef, Shaped +from jaxtyping import Array, ArrayLike, PyTree, Shaped from ._brownian import AbstractBrownianPath from ._custom_types import ( @@ -835,7 +835,7 @@ def _fn(_control): jac = make_jac(_fn)(control) assert vf_prod_tree is not sentinel - vf_prod_tree = cast(PyTreeDef, vf_prod_tree) + vf_prod_tree = cast(Any, vf_prod_tree) if jtu.tree_structure(None) in (vf_prod_tree, control_tree): # An unusual/not-useful edge case to handle. raise NotImplementedError( @@ -868,7 +868,7 @@ def _get_vf_tree(_, tree): jtu.tree_map(_get_vf_tree, control, vf) assert vf_prod_tree is not sentinel - vf_prod_tree = cast(PyTreeDef, vf_prod_tree) + vf_prod_tree = cast(Any, vf_prod_tree) vf = jtu.tree_transpose(control_tree, vf_prod_tree, vf) diff --git a/diffrax/_typing.py b/diffrax/_typing.py index 74627847..90a0b7b8 100644 --- a/diffrax/_typing.py +++ b/diffrax/_typing.py @@ -184,7 +184,7 @@ def _get_args_of_impl( if len(params) == 0: error_cls = cls else: - error_cls = cls[params] + error_cls = cls[params] # pyright: ignore[reportIndexIssue] raise TypeError( f"{error_cls} inherits from {base_cls} in multiple incompatible ways." ) diff --git a/docs/devdocs/srk_example.ipynb b/docs/devdocs/srk_example.ipynb index 39364def..319beee2 100644 --- a/docs/devdocs/srk_example.ipynb +++ b/docs/devdocs/srk_example.ipynb @@ -55,11 +55,6 @@ "source": [ "%env JAX_PLATFORM_NAME=cuda\n", "\n", - "from test.helpers import (\n", - " get_mlp_sde,\n", - " get_time_sde,\n", - " simple_sde_order,\n", - ")\n", "from warnings import simplefilter\n", "\n", "import diffrax\n", @@ -76,6 +71,11 @@ " SRA1,\n", ")\n", "from jax import config\n", + "from test.helpers import (\n", + " get_mlp_sde,\n", + " get_time_sde,\n", + " simple_sde_order,\n", + ")\n", "\n", "\n", "simplefilter(\"ignore\", category=FutureWarning)\n", diff --git a/pyproject.toml b/pyproject.toml index e7326e29..c7ec6d27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,15 @@ keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning license = {file = "LICENSE"} name = "diffrax" readme = "README.md" -requires-python = ">=3.10,<4.0" +requires-python = ">=3.10" urls = {repository = "https://github.com/patrick-kidger/diffrax"} version = "0.7.0" [project.optional-dependencies] +dev = ["pre-commit"] docs = [ "hippogriffe==0.2.2", + "griffe==1.7.3", "mkdocs==1.6.1", "mkdocs-include-exclude-files==0.1.0", "mkdocs-ipynb==0.1.1", @@ -40,6 +42,14 @@ docs = [ "mkdocstrings-python==1.16.8", "pymdown-extensions==10.14.3" ] +tests = [ + "beartype", + "jaxlib", + "optax", + "pytest", + "scipy", + "tqdm" +] [tool.hatch.build] include = ["diffrax/*"] @@ -60,10 +70,14 @@ src = [] [tool.ruff.lint] fixable = ["I001", "F401", "UP"] -ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] -ignore-init-module-imports = true +ignore = ["E402", "E721", "E731", "E741", "F722"] select = ["E", "F", "I001", "UP"] +[tool.ruff.lint.flake8-import-conventions.extend-aliases] +"collections" = "co" +"functools" = "ft" +"itertools" = "it" + [tool.ruff.lint.isort] combine-as-imports = true extra-standard-library = ["typing_extensions"] diff --git a/test/helpers.py b/test/helpers.py index b6311065..97b0f074 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -95,8 +95,8 @@ def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False): def path_l2_dist( - ys1: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], - ys2: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], + ys1: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], # pyright: ignore[reportUndefinedVariable] + ys2: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], # pyright: ignore[reportUndefinedVariable] ): # first compute the square of the difference and sum over # all but the first two axes (which represent the number of samples diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index 9de88eb6..00000000 --- a/test/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -beartype -jaxlib -optax -pytest -scipy -tqdm diff --git a/test/test_brownian.py b/test/test_brownian.py index 361c761d..d33bbda8 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -173,6 +173,7 @@ def _eval(key): else: w = values + assert isinstance(w, jax.Array) assert w.shape == (num_samples,) ref_dist = stats.norm(loc=0, scale=math.sqrt(dt)) _, pval = stats.kstest(w, ref_dist.cdf) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 3db4e7ba..7b8b9dfc 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -1,5 +1,6 @@ import contextlib import math +from typing import cast import diffrax import equinox as eqx @@ -119,6 +120,7 @@ def test_saveat_solution(): assert sol.ts.shape == (4096,) # pyright: ignore assert sol.ys.shape == (4096, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) @@ -140,6 +142,7 @@ def test_saveat_solution(): assert sol.ts.shape == (n,) # pyright: ignore assert sol.ys.shape == (n, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) @@ -161,6 +164,7 @@ def test_saveat_solution(): assert sol.ts.shape == (n,) # pyright: ignore assert sol.ys.shape == (n, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) diff --git a/test/test_solver.py b/test/test_solver.py index aa618712..a022f644 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -58,9 +58,9 @@ class _DoubleDopri5(diffrax.AbstractRungeKutta): tableau: ClassVar[diffrax.MultiButcherTableau] = diffrax.MultiButcherTableau( diffrax.Dopri5.tableau, diffrax.Dopri5.tableau ) - calculate_jacobian: ClassVar[ - diffrax.CalculateJacobian - ] = diffrax.CalculateJacobian.never + calculate_jacobian: ClassVar[diffrax.CalculateJacobian] = ( + diffrax.CalculateJacobian.never + ) @staticmethod def interpolation_cls(**kwargs): diff --git a/test/test_underdamped_langevin.py b/test/test_underdamped_langevin.py index 246506bb..53140fba 100644 --- a/test/test_underdamped_langevin.py +++ b/test/test_underdamped_langevin.py @@ -100,9 +100,9 @@ def test_shape(solver, dtype): # check that the output has the correct pytree structure and shape def check_shape(y0_leaf, sol_leaf): - assert ( - sol_leaf.shape == (7,) + y0_leaf.shape - ), f"shape={sol_leaf.shape}, expected={(7,) + y0_leaf.shape}" + assert sol_leaf.shape == (7,) + y0_leaf.shape, ( + f"shape={sol_leaf.shape}, expected={(7,) + y0_leaf.shape}" + ) assert sol_leaf.dtype == dtype, f"dtype={sol_leaf.dtype}, expected={dtype}" jtu.tree_map(check_shape, sde.y0, sol.ys) @@ -193,9 +193,9 @@ def get_dt_and_controller(level): ref_solution=true_sol, ) - assert ( - -0.2 < order - theoretical_order < 0.25 - ), f"order={order}, theoretical_order={theoretical_order}" + assert -0.2 < order - theoretical_order < 0.25, ( + f"order={order}, theoretical_order={theoretical_order}" + ) @pytest.mark.parametrize("solver_cls", _only_uld_solvers_cls()) From 6694c864ee59924305cc08a5a1d84bc7de57ed10 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:29:02 +0200 Subject: [PATCH 17/24] Fix failing test --- test/test_sde1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_sde1.py b/test/test_sde1.py index b50d014f..ad7318e6 100644 --- a/test/test_sde1.py +++ b/test/test_sde1.py @@ -157,7 +157,8 @@ def test_sde_strong_limit( ts_coarse = jnp.linspace(t0, t1, 2**level_coarse + 1, endpoint=True) contr_fine = diffrax.StepTo(ts=ts_fine) contr_coarse = diffrax.StepTo(ts=ts_coarse) - save_ts = jnp.linspace(t0, t1, 2**5 + 1, endpoint=True) + save_ts = ts_coarse[:: 2 ** (level_coarse - 5)] + assert len(save_ts) == 2**5 + 1 assert len(jnp.intersect1d(ts_fine, save_ts)) == len(save_ts) assert len(jnp.intersect1d(ts_coarse, save_ts)) == len(save_ts) saveat = diffrax.SaveAt(ts=save_ts) From 02d6b8a0a1c838a6d3d109418eec19aaac07e57e Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Wed, 5 Nov 2025 11:46:58 +0000 Subject: [PATCH 18/24] fix float error in prev_dt step calculation that led to an infinite loop When: dt is clipped to dtmin, and we wish to continue solver (force_dtmin=True) Calculating if a step should be kept from: prev_dt = t1 - t0 (next_t1 = next_t0 + dt (in previous step)) keep_step = keep_step | (prev_dt <= self.dtmin) can result in float error for high t0 where prev_dt is never <= self.dtmin, and further steps are never accepted -> infinite loop Fix: add a keep_next_step: bool flag to controller_state, and track when we are, and continue to be, at dtmin --- diffrax/_step_size_controller/pid.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 1092f184..56c19f61 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,8 +81,8 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error) -_PidState = tuple[RealScalarLike, RealScalarLike] +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] # We use a metaclass for backwards compatibility. When a user calls @@ -388,6 +388,7 @@ def init( return t1, ( jnp.array(1.0, dtype=real_dtype), jnp.array(1.0, dtype=real_dtype), + False, ) def adapt_step_size( @@ -469,6 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, + keep_next_step, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -489,9 +491,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 - # Automatically keep the step if we're at dtmin. + # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | (prev_dt <= self.dtmin) + keep_step = keep_step | keep_next_step # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -545,6 +547,9 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) + # flag next step to be kept if dtmin is reached + # or if it was reached previously and dt is unchanged + keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -554,7 +559,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error + controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result From b91138f5f54880537ff036a612da97b5357330f1 Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Thu, 6 Nov 2025 09:08:57 +0000 Subject: [PATCH 19/24] avoids accumulation of float precision errors in dt this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1) if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1) these errors are likely to be minor, but i believe this is the intended behaviour --- diffrax/_step_size_controller/pid.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 56c19f61..c0ccfbec 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,7 +81,7 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin) _PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] @@ -470,7 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, - keep_next_step, + at_dtmin, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -493,7 +493,7 @@ def _scale(_y0, _y1_candidate, _y_error): keep_step = scaled_error < 1 # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | keep_next_step + keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -547,9 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) - # flag next step to be kept if dtmin is reached - # or if it was reached previously and dt is unchanged - keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) + # if we are already at dtmin and dt is unchanged (factor == 1), + # reset dt to dtmin to avoid accumulating float precision errors + dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt) + # this flags the next loop to accept step + at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -559,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step + controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result From 62bf87692e3b1a6e9b34b2c4398424e44d2cc8c7 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 22 Dec 2025 17:35:15 +0100 Subject: [PATCH 20/24] Fixed case in which t0 is prevbefore a jump time --- .github/workflows/run_tests.yml | 2 +- benchmarks/against_scan.py | 2 +- diffrax/_step_size_controller/clip.py | 20 ++++++++++++++++- diffrax/_step_size_controller/pid.py | 2 +- pyproject.toml | 2 +- test/test_adaptive_stepsize_controller.py | 27 +++++++++++++++++++++++ 6 files changed, 50 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0137d757..5268212b 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ "3.10", "3.12" ] + python-version: [ "3.11", "3.13" ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/benchmarks/against_scan.py b/benchmarks/against_scan.py index 6c67655f..a7771c4e 100644 --- a/benchmarks/against_scan.py +++ b/benchmarks/against_scan.py @@ -36,7 +36,7 @@ def speedtest(fn, name): # INTEGRATE WITH scan -@jax.checkpoint # pyright: ignore +@jax.checkpoint def body(carry, t): u, v, dt = carry u = u + du(t, v, None) * dt diff --git a/diffrax/_step_size_controller/clip.py b/diffrax/_step_size_controller/clip.py index 0a642d6e..cac6ec1a 100644 --- a/diffrax/_step_size_controller/clip.py +++ b/diffrax/_step_size_controller/clip.py @@ -356,7 +356,7 @@ def callback(_keep_step, _t1): step_info = None else: step_index, step_ts = controller_state.step_info - # We actaully bump `next_t0` past any `step_ts` whilst checking where to + # We actually bump `next_t0` past any `step_ts` whilst checking where to # clip `next_t1`. This is in case we have a set up like the following: # ```python # ClipStepSizeController( @@ -376,6 +376,24 @@ def callback(_keep_step, _t1): else: jump_index, jump_ts = controller_state.jump_info next_t0, made_jump2 = _bump_next_t0(next_t0, jump_ts) + # This next line is to fix + # https://github.com/patrick-kidger/diffrax/issues/713 + # TODO: should we add this to the `step_ts` branch as well? + # + # What's going on here is that we may have + # the `diffeqsolve(t0=...)` be prevbefore a jump time (for example due to a + # previous diffeqsolve targeting that time), in which case during `.init` + # we will obtain `t0 = t1 = prevbefore(jump_time)`. + # The `_bump_next_t0` will then move `next_t0` to after the `jump_time`... + # whilst leaving `next_t1` unchanged! We actually end up `next_t1 < next_t0` + # which is very not okay. + # + # The fix is to ensure that `next_t1` is itself bumped to at least this + # value. As a final detail, we need to make it `nextafter` so that we don't + # have a zero-length interval – in this case an underlying PID controller + # would just never change the interval size at all, since it acts + # multiplicatively. (And even just 1 ULP is enough to unstick it.) + next_t1 = jnp.maximum(eqxi.nextafter(next_t0), next_t1) made_jump = made_jump | made_jump2 jump_index = _find_idx_with_hint(next_t0, jump_ts, jump_index) next_t1 = _clip_t(next_t1, jump_index, jump_ts, True) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index c0ccfbec..752b9bed 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -93,7 +93,7 @@ def __call__(cls, *args, **kwargs): step_ts = kwargs.pop("step_ts", None) jump_ts = kwargs.pop("jump_ts", None) if step_ts is not None or jump_ts is not None: - return ClipStepSizeController(cls(*args, **kwargs), step_ts, jump_ts) + return ClipStepSizeController(cls(*args, **kwargs), step_ts, jump_ts) # pyright: ignore return super().__call__(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index c7ec6d27..459deb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning license = {file = "LICENSE"} name = "diffrax" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" urls = {repository = "https://github.com/patrick-kidger/diffrax"} version = "0.7.0" diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 21c24e4a..d785d16e 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import optimistix as optx import pytest from diffrax._step_size_controller.clip import _find_idx_with_hint from jaxtyping import Array @@ -361,3 +362,29 @@ def test_jump_at_t1_with_large_t1_in_float32(): saveat=saveat, ) assert sol.ts == jnp.array([t1]) + + +# https://github.com/patrick-kidger/diffrax/issues/713 +def test_t0_at_jump_time(): + jump_time = 0.98 + controller = diffrax.PIDController(rtol=1e-6, atol=1e-6) + controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time]) + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)), + diffrax.Heun(), + t0=eqxi.prevbefore(jnp.asarray(jump_time)), + t1=1.2, + dt0=None, + y0=jnp.array([0, 0, 0, 0.0]), + stepsize_controller=controller, + event=diffrax.Event( + cond_fn=lambda t, y, args, **kw: jump_time - t, + root_finder=optx.Newton(atol=1e-4, rtol=1e-4), + direction=True, + ), + max_steps=100, + ) + # And in particular not an event. + # What used to happen was something very weird where we'd oscillate across the + # jump time. + assert sol.result == diffrax.RESULTS.successful From fdfecc7a791c874dc627a520d1bedd8394969c72 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:31:34 +0100 Subject: [PATCH 21/24] Fix #720; bool event + root find + terminate on first step --- diffrax/_integrate.py | 11 ++++++++--- test/test_event.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6fc38ce3..3017e30d 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -761,7 +761,11 @@ def _call_real_impl(): _tfinal = _event_root_find.value # TODO: we might need to change the way we evaluate `_yfinal` in order to # get more accurate derivatives? - _yfinal = _interpolator.evaluate(_tfinal) + _yfinal = lax.cond( + final_state.num_steps == 0, + lambda: final_state.y, + lambda: _interpolator.evaluate(_tfinal), + ) _result = RESULTS.where( _event_root_find.result == optx.RESULTS.successful, result, @@ -1323,7 +1327,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: event_mask = None else: event_tprev = tprev - event_tnext = tnext + event_tnext = tprev # Fill the dense-info with dummy values on the first step, when we haven't yet # made any steps. # Note that we're threading a needle here! What if we terminate on the very @@ -1334,8 +1338,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: # to the end of the interval). # - A floating event can't terminate on the first step (it requires a sign # change). + # c.f. https://github.com/patrick-kidger/diffrax/issues/720 event_dense_info = jtu.tree_map( - lambda x: jnp.empty(x.shape, x.dtype), + lambda x: jnp.zeros(x.shape, x.dtype), dense_info_struct, # pyright: ignore[reportPossiblyUnboundVariable] ) diff --git a/test/test_event.py b/test/test_event.py index 5de0ded6..6a6b97e1 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -833,3 +833,30 @@ def run(event): t_final, y_final = run(event) assert jnp.allclose(t_final, 10.0) assert jnp.allclose(y_final, jnp.array([10.0, 11.0])) + + +# https://github.com/patrick-kidger/diffrax/issues/720 +def test_boolean_with_root_find_terminating_on_first_step(): + controller = diffrax.PIDController(rtol=1e-6, atol=1e-6) + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) + root_finder = optx.Newton(atol=1e-4, rtol=1e-4) + + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)), + diffrax.Kvaerno5(), + t0=0.0, + t1=1.2, + dt0=None, + y0=jnp.array([10.0]), + stepsize_controller=controller, + event=diffrax.Event( + cond_fn=steady_state_event, + root_finder=root_finder, + ), + saveat=diffrax.SaveAt(t1=True), + max_steps=100, + ) + assert sol.ts is not None + assert sol.ys is not None + assert jnp.allclose(sol.ts, jnp.array([0.0])) + assert jnp.allclose(sol.ys, jnp.array([[10.0]])) From 550c20272360e6bdec1be80b997135357d4dd000 Mon Sep 17 00:00:00 2001 From: andyElking Date: Mon, 29 Jul 2024 17:24:42 +0100 Subject: [PATCH 22/24] Added Advanced SDE example and a table of SRKs --- docs/api/solvers/sde_solvers.md | 5 +- docs/devdocs/SDE_solver_table.md | 46 +++ docs/examples/sde_example.ipynb | 578 +++++++++++++++++++++++++++++++ mkdocs.yml | 2 + 4 files changed, 630 insertions(+), 1 deletion(-) create mode 100644 docs/devdocs/SDE_solver_table.md create mode 100644 docs/examples/sde_example.ipynb diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index e307c195..9ec5b025 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -1,6 +1,9 @@ # SDE solvers -See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochastic-differential-equations). +See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochastic-differential-equations) +and [Advanced SDE example](../../examples/sde_example.ipynb) which gives a walkthrough of how to simulate SDEs +and how to perform optimisation with respect to SDE parameters. +For a table of all SDE solvers and their properties see [SDE solver table](../../devdocs/SDE_solver_table.md). !!! info "Term structure" diff --git a/docs/devdocs/SDE_solver_table.md b/docs/devdocs/SDE_solver_table.md new file mode 100644 index 00000000..d1d390eb --- /dev/null +++ b/docs/devdocs/SDE_solver_table.md @@ -0,0 +1,46 @@ +# SDE solver table + +For an explanation of the terms in the table, see [how to choose a solver](../usage/how-to-choose-a-solver.md#stochastic-differential-equations). + +This table is included as a reference that we think *should* be correct, but if you're going to use any of this information in a load-bearing way (for example a publication) then please double-check the information! It's a big table, we might have gotten something wrong. + +``` ++----------------+-------+------------+------------------------------------+-------------------+----------------+------------------------------------------+ +| | SDE | Lévy | Strong/weak order (per noise type) | VF evaluations | Embedded error | Recommended for | +| | type | area +----------+--------------+----------+-------+-----------+ estimation | (and other notes) | +| | | | General | Commutative | Additive | Drift | Diffusion | | | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Euler | Itô | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 1 | No | Itô SDEs when a cheap solver is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Heun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Standard solver for Stratonovich SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| EulerHeun | Strat | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 2 | No | Stratonovich SDEs with expensive drift. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ItoMilstein | Itô | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | No | Better than Euler for Itô SDEs, but | +| | | | | | | | | | comuptes the derivative of diffusion VF. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Stratonovich | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | No | For commutative Stratonovich SDEs when | +| Milstein | | | | | | | | | space-time Lévy area is not available. | +| | | | | | | | | | Computes derivative of diffusion VF. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ReversibleHeun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | When a reversible solver is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Midpoint | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Ralston | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ShARK | Strat | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Additive noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SRA1 | Strat | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Only slightly worse than ShARK. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SEA | Strat | space-time | / | / | 1.0/1.0 | 1 | 1 | No | Cheap solver for additive noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SPaRK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 3 | 3 | Yes | General SDEs when embedded error | +| | | | | | | | | | estimation is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| GeneralShARK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 2 | 3 | No | General SDEs when embedded error | +| | | | | | | | | | estimaiton is not needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SlowRK | Strat | space-time | 0.5/1.0 | 1.5/2.0 | 1.5/2.0 | 2 | 5 | No | Commutative noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +``` \ No newline at end of file diff --git a/docs/examples/sde_example.ipynb b/docs/examples/sde_example.ipynb new file mode 100644 index 00000000..c00e56b2 --- /dev/null +++ b/docs/examples/sde_example.ipynb @@ -0,0 +1,578 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:28.847571Z", + "start_time": "2025-10-09T01:03:28.174400Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: JAX_PLATFORM_NAME=cuda\n" + ] + } + ], + "source": [ + "%env JAX_PLATFORM_NAME=cuda\n", + "\n", + "from warnings import simplefilter\n", + "\n", + "\n", + "simplefilter(\"ignore\", category=FutureWarning)\n", + "\n", + "from functools import partial\n", + "\n", + "import diffrax\n", + "import equinox as eqx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "import optax\n", + "from jaxtyping import Array\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "jnp.set_printoptions(precision=4, suppress=True)" + ] + }, + { + "cell_type": "markdown", + "id": "86d4e8b062a81d7e", + "metadata": {}, + "source": [ + "# Advanced SDE example\n", + "\n", + "We will be simulating a Stratonovich SDE of the form:\n", + "\n", + "$$\n", + " dy(t) = f(y(t), t) dt + g(y(t), t) \\circ dw(t), \n", + "$$\n", + "\n", + "where $t \\in [0, T]$, $y(t) \\in \\mathbb{R}^e$, and $w$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field. The Stratonovich integral is denoted by $\\circ$.\n", + "\n", + "Our SDE will have the following drift and diffusion terms:\n", + "\n", + "\\begin{align*}\n", + " f(y(t), t) &= \\alpha - \\beta y(t), \\\\\n", + " g(y(t), t) &= \\gamma \\begin{bmatrix} \\Vert y(t) \\Vert_2 & 0 \\\\ 0 & y_1(t) \\\\ 0 & 10t \\end{bmatrix},\n", + "\\end{align*}\n", + "\n", + "where $\\alpha, \\gamma \\in \\mathbb{R}^3$ and $\\beta \\in \\mathbb{R}_{\\geq 0}$ are some parameters.\n", + "\n", + "Let's write the SDE in the form that Diffrax expects:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ba23e9cc0370fbac", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:29.331509Z", + "start_time": "2025-10-09T01:03:28.853859Z" + } + }, + "outputs": [], + "source": [ + "# Drift VF (e = 3)\n", + "def f(t, y, args):\n", + " alpha, beta, gamma = args\n", + " beta = jnp.abs(beta)\n", + " assert alpha.shape == (3,)\n", + " return jnp.array(alpha - beta * y, dtype=y.dtype)\n", + "\n", + "\n", + "# Diffusion matrix field (e = 3, d = 2)\n", + "def g(t, y, args):\n", + " alpha, beta, gamma = args\n", + " assert gamma.shape == y.shape == (3,)\n", + " gamma = jnp.reshape(gamma, (3, 1))\n", + " out = gamma * jnp.array(\n", + " [[jnp.sqrt(jnp.sum(y**2)), 0.0], [0.0, 3 * y[0]], [0.0, 20 * t]], dtype=y.dtype\n", + " )\n", + " return out\n", + "\n", + "\n", + "# Initial condition\n", + "y0 = jnp.array([1.0, 1.0, 1.0])\n", + "\n", + "# Args\n", + "alpha = 0.5 * jnp.ones((3,))\n", + "beta = 1.0\n", + "gamma = jnp.ones((3,))\n", + "args = (alpha, beta, gamma)\n", + "\n", + "# Time domain\n", + "t0 = 0.0\n", + "t1 = 2.0\n", + "dt0 = 2**-9" + ] + }, + { + "cell_type": "markdown", + "id": "ef2ff90865907b7d", + "metadata": {}, + "source": [ + "## Brownian motion and its Levy area\n", + "\n", + "Different solvers require different information about the Brownian motion. For example, the `SPaRK` solver requires access to the space-time Levy area of the Brownian motion. The required Levy area for each solver is documented in the table at the end of this notebook, or can be checked via `solver.minimal_levy_area`.\n", + " \n", + "We will use the `VirtualBrownianTree` class to generate the Brownian motion and its Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4110735158215acc", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:29.483519Z", + "start_time": "2025-10-09T01:03:29.337297Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Minimal levy area for SPaRK: .\n" + ] + } + ], + "source": [ + "# check minimal levy area\n", + "solver = diffrax.SPaRK()\n", + "print(f\"Minimal levy area for SPaRK: {solver.minimal_levy_area}.\")\n", + "\n", + "# Brownian motion\n", + "key = jr.key(0)\n", + "bm_tol = 2**-13\n", + "bm_shape = (2,)\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=diffrax.SpaceTimeLevyArea\n", + ")\n", + "\n", + "# Defining the terms of the SDE\n", + "ode_term = diffrax.ODETerm(f)\n", + "diffusion_term = diffrax.ControlTerm(g, bm) # Note that the BM is baked into the term\n", + "terms = diffrax.MultiTerm(ode_term, diffusion_term)" + ] + }, + { + "cell_type": "markdown", + "id": "e71db03c5257bd46", + "metadata": {}, + "source": [ + "### Using `diffrax.diffeqsolve` to solve the SDE\n", + "\n", + "We will first use constant steps of size $h = 2^{-9}$ to solve the SDE. It is very important to have $h > \\mathtt{bm\\_tol}$, where $\\mathtt{bm\\_tol}$ is the tolerance of the Brownian motion. This is important because the output distribution of the VirtualBrownianTree is precise as long as the times that we sample it at are at least $\\mathtt{bm\\_tol}$ apart. For more details see the [Single-seed Brownian Motion paper](https://arxiv.org/abs/2405.06464).\n", + "\n", + " We will use the SPaRK solver to solve the SDE. SPaRK is a stochastic Runge-Kutta method that requires access to space-time Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8a969e1b9bd9f09", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:35.618860Z", + "start_time": "2025-10-09T01:03:29.493121Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sol = diffrax.diffeqsolve(\n", + " terms, diffrax.SPaRK(), t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(steps=True)\n", + ")\n", + "\n", + "# Plotting the solution on ax1 and the BM on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol.ts, sol.ys[:, 0], label=\"y_1\")\n", + "ax1.plot(sol.ts, sol.ys[:, 1], label=\"y_2\")\n", + "ax1.plot(sol.ts, sol.ys[:, 2], label=\"y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "bm_vals = jax.vmap(lambda t: bm.evaluate(t0, t))(jnp.clip(sol.ts, t0, t1))\n", + "ax2.plot(sol.ts, bm_vals[:, 0], label=\"BM_1\")\n", + "ax2.plot(sol.ts, bm_vals[:, 1], label=\"BM_2\")\n", + "ax2.set_title(\"Brownian motion\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fd3251c814306cd", + "metadata": {}, + "source": [ + "## Using adaptive time-stepping via the PID-controller\n", + "\n", + "In order to use adaptive time stepping, the solver must produce an estimate of its error on each step. This is then used by the PID controller to adjust the step size.\n", + "To perform this error estimation the `SPaRK` solver uses an embedded method. For solvers like `GeneralShARK`, which do not have an embedded method, we'd instead need to use `HalfSolver(GeneralShARK())` as the solver in order to estimate the error." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "42ca5c5520079b5f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:43.605326Z", + "start_time": "2025-10-09T01:03:35.705678Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accepted steps: 2968, Rejected steps: 1637, total steps: 4605\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=0.005,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-12,\n", + " dtmax=0.25,\n", + ")\n", + "\n", + "solver = diffrax.SPaRK()\n", + "# solver = diffrax.HalfSolver(diffrax.GeneralShARK())\n", + "\n", + "sol_pid_spark = diffrax.diffeqsolve(\n", + " terms,\n", + " solver,\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_spark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_spark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 0], label=\"y_1\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 1], label=\"y_2\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 2], label=\"y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_spark.ts))\n", + "ts = sol_pid_spark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "344b5f07d5120128", + "metadata": {}, + "source": [ + "## Solving an SDE for a batch of Brownian motions\n", + "\n", + "When doing Monte Carlo simulations, we often need to solve the same SDE for multiple Brownian motions. We can do this via `jax.vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ffe3ced461ebb823", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:43.764668Z", + "start_time": "2025-10-09T01:03:43.657163Z" + } + }, + "outputs": [], + "source": [ + "def get_terms(bm):\n", + " return diffrax.MultiTerm(ode_term, diffrax.ControlTerm(g, bm))\n", + "\n", + "\n", + "# Fix which times we step to (this is equivalent to a constant step size)\n", + "# We do this because the combination of using dt0 and SaveAt(steps=True) pads the\n", + "# output with inf values up to max_steps.\n", + "# Instead we specify exactly which times we want to save at, so Diffrax allocates\n", + "# the correct amount of memory at the outset.\n", + "num_steps = 2**8\n", + "step_times = jnp.linspace(t0, t1, num_steps + 1, endpoint=True)\n", + "constant_controller = diffrax.StepTo(ts=step_times)\n", + "saveat = diffrax.SaveAt(ts=step_times)\n", + "\n", + "\n", + "# We will vmap over keys\n", + "@eqx.filter_jit\n", + "@partial(jax.vmap, in_axes=(0, None, None))\n", + "def batch_sde_solve(key, saveat, args):\n", + " bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=diffrax.SpaceTimeLevyArea\n", + " )\n", + " terms = get_terms(bm)\n", + " return diffrax.diffeqsolve(\n", + " terms,\n", + " diffrax.SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + " )\n", + "\n", + "\n", + "# Split the keys and compute the batched solutions\n", + "num_samples = 100\n", + "keys = jr.split(jr.PRNGKey(0), num_samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3c1206025f30100d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:46.771758Z", + "start_time": "2025-10-09T01:03:43.769093Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of batch_sols: (100, 257, 3) == 100 x 257 x (dim of y)\n" + ] + } + ], + "source": [ + "batch_sols = batch_sde_solve(keys, saveat, args)\n", + "print(\n", + " f\"Shape of batch_sols: \"\n", + " f\"{batch_sols.ys.shape} == {num_samples} x {num_steps + 1} x (dim of y)\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "71dda42d79d4c553", + "metadata": {}, + "source": [ + "## Optimizing wrt. SDE parameters\n", + "We will optimize the SDE parameters with the aim of achieving a mean of 0 and variance 4 at time `t1`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d278fc2d438ffc82", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:07:34.736540Z", + "start_time": "2025-10-09T01:03:46.832280Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(100, 1, 3)\n", + "Stats at t=t1: mean=[-1.0154 1.5479 2.0878], var=[329.4802 711.1078 424.3449]\n", + "Step 0, loss: 34.94212183883967\n", + "Step 10, loss: 4.874442625767792\n", + "Step 20, loss: 2.521634210842532\n", + "Step 30, loss: 1.4702096092338783\n", + "Step 40, loss: 0.7936488640119762\n", + "Step 50, loss: 0.20701309373712398\n", + "Step 60, loss: 0.43545896965573144\n", + "Step 70, loss: 0.48191871779789575\n", + "Step 80, loss: 0.14351136791805125\n", + "Step 90, loss: 0.42323194856385005\n", + "Step 100, loss: 0.7814543571174357\n", + "Step 110, loss: 0.5590729910392899\n", + "Step 120, loss: 0.09288914937239617\n", + "Step 130, loss: 0.1462945784163213\n", + "Step 140, loss: 0.29703455403048784\n", + "Step 150, loss: 0.06270444996116936\n", + "Step 160, loss: 0.01298645327270607\n", + "Step 170, loss: 0.08775177455266986\n", + "Step 180, loss: 0.016462953232162895\n", + "Step 190, loss: 0.018917675036979466\n", + "Optimal parameters:\n", + "alpha=[-0.1822 3.5395 -0.0834], beta=3.645413009852767, gamma=[-1.6817 -0.8223 0.149 ]\n" + ] + } + ], + "source": [ + "saveat_t1 = diffrax.SaveAt(t1=True)\n", + "batch_ys = batch_sde_solve(keys, saveat_t1, args).ys\n", + "print(batch_ys.shape)\n", + "ys_t1 = batch_ys[:, 0]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")\n", + "\n", + "\n", + "# We will optimize for achieving a mean of 0\n", + "def loss(args: tuple[Array, Array, Array]):\n", + " _batch_sols = batch_sde_solve(keys, saveat_t1, args)\n", + " batch_ys = _batch_sols.ys\n", + " assert batch_ys.shape == (num_samples, 1, 3)\n", + " mean = jnp.mean(batch_ys, axis=(0, 1))\n", + " std = jnp.sqrt(jnp.mean(batch_ys**2, axis=(0, 1)) - mean**2)\n", + " target_mean = jnp.array([0.0, 1.0, 0.0])\n", + " target_stds = 2 * jnp.ones((3,))\n", + " loss = jnp.sqrt(\n", + " jnp.sum((mean - target_mean) ** 2) + jnp.sum((std - target_stds) ** 2)\n", + " )\n", + " return loss\n", + "\n", + "\n", + "# Define the parameters to optimize\n", + "alpha_opt = 0.5 * jnp.ones((3,))\n", + "beta_opt = jnp.array(1.0)\n", + "gamma_opt = jnp.ones((3,))\n", + "args_opt = (alpha_opt, beta_opt, gamma_opt)\n", + "\n", + "# Define the optimizer\n", + "num_steps = 191\n", + "schedule = optax.cosine_decay_schedule(3e-1, num_steps, 1e-2)\n", + "opt = optax.chain(\n", + " optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-8),\n", + " optax.scale_by_schedule(schedule),\n", + " optax.scale(-1),\n", + ")\n", + "# opt = optax.adam(2e-1)\n", + "opt_state = opt.init(args_opt)\n", + "\n", + "\n", + "@jax.jit\n", + "def step(i, opt_state, args):\n", + " loss_val, grad = jax.value_and_grad(loss)(args)\n", + " updates, opt_state = opt.update(grad, opt_state)\n", + "\n", + " # One way to apply updates\n", + " # args = optax.apply_updates(args, updates)\n", + "\n", + " # Another way to apply updates\n", + " args = jax.tree_util.tree_map(lambda x, u: x + u, args, updates)\n", + "\n", + " return opt_state, args, loss_val\n", + "\n", + "\n", + "for i in range(num_steps):\n", + " opt_state, args_opt, loss_val = step(i, opt_state, args_opt)\n", + " alpha_opt, beta_opt, gamma_opt = args_opt\n", + " if i % 10 == 0:\n", + " print(f\"Step {i}, loss: {loss_val}\")\n", + "\n", + "print(f\"Optimal parameters:\\nalpha={alpha_opt}, beta={beta_opt}, gamma={gamma_opt}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "834651877787c7e6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:07:38.550105Z", + "start_time": "2025-10-09T01:07:34.801808Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stats at t=t1: mean=[0.0001 1.0005 0.0002], var=[4.0193 4.0269 4.0155]\n" + ] + } + ], + "source": [ + "batch_ys_opt = batch_sde_solve(keys, saveat_t1, args_opt).ys\n", + "ys_t1 = batch_ys_opt[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d103fe1695cdd847", + "metadata": {}, + "source": "With the magic of JAX and Diffrax we were able to differentiate through the SDE solver and optimize the parameters of the SDE to achieve the desired mean and variance at time `t1`." + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index a493b353..25ce7aa8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -123,6 +123,7 @@ nav: - Second-order sensitivities: 'examples/hessian.ipynb' - Nonlinear heat PDE: 'examples/nonlinear_heat_pde.ipynb' - Underdamped Langevin diffusion: 'examples/underdamped_langevin_example.ipynb' + - Advanced SDE simulation example: 'examples/sde_example.ipynb' - Basic API: - 'api/diffeqsolve.md' - Solvers: @@ -150,3 +151,4 @@ nav: - 'devdocs/predictor_dirk.md' - 'devdocs/adjoint_commutative_noise.md' - Stochastic Runge-Kutta methods: 'devdocs/srk_example.ipynb' + - Table of SDE solvers: 'devdocs/SDE_solver_table.md' From c4f6c7d6abf193b7d0abb2ba308e844829162cab Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:24:43 +0100 Subject: [PATCH 23/24] remove spurious type in union --- diffrax/_custom_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 3bb17282..eef13b12 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -21,7 +21,7 @@ BoolScalarLike = bool | Array | np.ndarray FloatScalarLike = float | Array | np.ndarray IntScalarLike = int | Array | np.ndarray - RealScalarLike = bool | int | float | Array | np.ndarray + RealScalarLike = int | float | Array | np.ndarray else: BoolScalarLike = Bool[ArrayLike, ""] FloatScalarLike = Float[ArrayLike, ""] From e7d78bf22b57ce08b3e63c3b3ae492c9a4bbc8d8 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:26:02 +0100 Subject: [PATCH 24/24] 0.7.1 version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 459deb3c..880c50c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ name = "diffrax" readme = "README.md" requires-python = ">=3.11" urls = {repository = "https://github.com/patrick-kidger/diffrax"} -version = "0.7.0" +version = "0.7.1" [project.optional-dependencies] dev = ["pre-commit"]