From 69661fc83a2cc04ce511aa5c1191706eb8ce3425 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:31:34 +0100 Subject: [PATCH] Fix #720; bool event + root find + terminate on first step --- diffrax/_integrate.py | 11 ++++++++--- test/test_event.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6fc38ce3..3017e30d 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -761,7 +761,11 @@ def _call_real_impl(): _tfinal = _event_root_find.value # TODO: we might need to change the way we evaluate `_yfinal` in order to # get more accurate derivatives? - _yfinal = _interpolator.evaluate(_tfinal) + _yfinal = lax.cond( + final_state.num_steps == 0, + lambda: final_state.y, + lambda: _interpolator.evaluate(_tfinal), + ) _result = RESULTS.where( _event_root_find.result == optx.RESULTS.successful, result, @@ -1323,7 +1327,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: event_mask = None else: event_tprev = tprev - event_tnext = tnext + event_tnext = tprev # Fill the dense-info with dummy values on the first step, when we haven't yet # made any steps. # Note that we're threading a needle here! What if we terminate on the very @@ -1334,8 +1338,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: # to the end of the interval). # - A floating event can't terminate on the first step (it requires a sign # change). + # c.f. https://github.com/patrick-kidger/diffrax/issues/720 event_dense_info = jtu.tree_map( - lambda x: jnp.empty(x.shape, x.dtype), + lambda x: jnp.zeros(x.shape, x.dtype), dense_info_struct, # pyright: ignore[reportPossiblyUnboundVariable] ) diff --git a/test/test_event.py b/test/test_event.py index 5de0ded6..6a6b97e1 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -833,3 +833,30 @@ def run(event): t_final, y_final = run(event) assert jnp.allclose(t_final, 10.0) assert jnp.allclose(y_final, jnp.array([10.0, 11.0])) + + +# https://github.com/patrick-kidger/diffrax/issues/720 +def test_boolean_with_root_find_terminating_on_first_step(): + controller = diffrax.PIDController(rtol=1e-6, atol=1e-6) + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) + root_finder = optx.Newton(atol=1e-4, rtol=1e-4) + + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)), + diffrax.Kvaerno5(), + t0=0.0, + t1=1.2, + dt0=None, + y0=jnp.array([10.0]), + stepsize_controller=controller, + event=diffrax.Event( + cond_fn=steady_state_event, + root_finder=root_finder, + ), + saveat=diffrax.SaveAt(t1=True), + max_steps=100, + ) + assert sol.ts is not None + assert sol.ys is not None + assert jnp.allclose(sol.ts, jnp.array([0.0])) + assert jnp.allclose(sol.ys, jnp.array([[10.0]]))