-
-
Notifications
You must be signed in to change notification settings - Fork 169
Description
Hi Diffrax team,
I’ve noticed a subtle edge case in the current implementation regarding Events: if two cond_fns both change sign within the same integrator step, the one that appears first in the cond_fn PyTree will be treated as if it definitely occurred first—potentially ignoring the true first event. To handle this correctly, we would need to solve for the precise root time of each zero crossing, through that determine which event actually happens first and return the system state at that moment. In a case like mine ignoring this could lead to the system breaking completely (for example in the jumping ball example imagine two balls and one ball going through the floor because they hit the floor at a similar time. And once the ball is through the floor there is no coming back as the sign of the cond_fn won't change afterwards so no event is triggered).
Here a MWE:
import diffrax
import optimistix as optx
term = diffrax.ODETerm(lambda t, y, args: 1.0)
solver = diffrax.Euler()
def g1(t, y, args, **kwargs):
return t - 1.0
def g2(t, y, args, **kwargs):
return t - 0.5
event = diffrax.Event(
cond_fn=(g1, g2),
root_finder=optx.Newton(rtol=1e-4, atol=1e-4)
)
sol1 = diffrax.diffeqsolve(
term,
solver,
t0=0.0,
t1=10.0,
dt0=2.0,
y0=0.0,
event=event,
)
print("ts: ", sol1.ts[-1])
print("event_mask: ", sol1.event_mask)
print("\n")
event_swapped = diffrax.Event(
cond_fn=(g2, g1),
root_finder=optx.Newton(rtol=1e-4, atol=1e-4)
)
sol2 = diffrax.diffeqsolve(
term,
solver,
t0=0.0,
t1=10.0,
dt0=2.0,
y0=0.0,
event=event_swapped,
)
print("ts: ", sol2.ts[-1])
print("event_mask: ", sol2.event_mask)prints:
ts: 1.0
event_mask: (Array(True, dtype=bool), Array(False, dtype=bool))
ts: 0.5
event_mask: (Array(True, dtype=bool), Array(False, dtype=bool))
As you can see the event at t = 0.5 is skipped in sol1.
If you think this is reasonably doable I could try to look at it.
Thanks for the help!