Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,11 @@ def _call_real_impl():
_tfinal = _event_root_find.value
# TODO: we might need to change the way we evaluate `_yfinal` in order to
# get more accurate derivatives?
_yfinal = _interpolator.evaluate(_tfinal)
_yfinal = lax.cond(
final_state.num_steps == 0,
lambda: final_state.y,
lambda: _interpolator.evaluate(_tfinal),
)
_result = RESULTS.where(
_event_root_find.result == optx.RESULTS.successful,
result,
Expand Down Expand Up @@ -1323,7 +1327,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
event_mask = None
else:
event_tprev = tprev
event_tnext = tnext
event_tnext = tprev
# Fill the dense-info with dummy values on the first step, when we haven't yet
# made any steps.
# Note that we're threading a needle here! What if we terminate on the very
Expand All @@ -1334,8 +1338,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
# to the end of the interval).
# - A floating event can't terminate on the first step (it requires a sign
# change).
# c.f. https://github.com/patrick-kidger/diffrax/issues/720
event_dense_info = jtu.tree_map(
lambda x: jnp.empty(x.shape, x.dtype),
lambda x: jnp.zeros(x.shape, x.dtype),
dense_info_struct, # pyright: ignore[reportPossiblyUnboundVariable]
)

Expand Down
27 changes: 27 additions & 0 deletions test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,30 @@ def run(event):
t_final, y_final = run(event)
assert jnp.allclose(t_final, 10.0)
assert jnp.allclose(y_final, jnp.array([10.0, 11.0]))


# https://github.com/patrick-kidger/diffrax/issues/720
def test_boolean_with_root_find_terminating_on_first_step():
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
root_finder = optx.Newton(atol=1e-4, rtol=1e-4)

sol = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
diffrax.Kvaerno5(),
t0=0.0,
t1=1.2,
dt0=None,
y0=jnp.array([10.0]),
stepsize_controller=controller,
event=diffrax.Event(
cond_fn=steady_state_event,
root_finder=root_finder,
),
saveat=diffrax.SaveAt(t1=True),
max_steps=100,
)
assert sol.ts is not None
assert sol.ys is not None
assert jnp.allclose(sol.ts, jnp.array([0.0]))
assert jnp.allclose(sol.ys, jnp.array([[10.0]]))