Skip to content

Conversation

@EiffL
Copy link

@EiffL EiffL commented Nov 22, 2022

This draft PR is in response to #9 and presents a prototype implementation of the leap-frog solver that uses a lax.scan instead of a for loop in the nbody function.

Here are the results on the baseline default configuration.

Current master

  • No extra jitting (so, normal pmwd)
First result obtained in  16.21842312812805
Second result obtained in  9.312845468521118
  • jitted pmwd (what I'm told not to do ^^)
First result obtained in  184.43911576271057
Second result obtained in  9.710259437561035

This PR (using scan, and I actually removed all lower level jit)

  • No jitting at all (I removed all the jit in pmwd.nbody)
First result obtained in  19.26646661758423
Second result obtained in  9.970827102661133
  • jitted pmwd
First result obtained in  13.62941026687622
Second result obtained in  9.098160743713379

And here the notebook to reproduce this test (working off my fork):
https://gist.github.com/EiffL/aa6a651141f694ca257fb5ff83e829d6

So I would advocate using lax.scan.

In this draft implementation, I chose not to output intermediate ptcl and obsvl, exactly like what is done on master, but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

If you look at the implementation of odeint in jax, you can also have a slightly more complicated logic that exports the state of the system only at some desired pre-defined steps, and not necessarily at all time steps:
https://github.com/google/jax/blob/518fe6656ca2aab66dcfc8cd7866c10f476a17b1/jax/experimental/ode.py#L189

And finally, if you want to save the sims to disk, then nothing prevents you from using the nbody step function directly/manually in a for loop.

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

And actually ^^ it's generally not a good idea ^^' but if we want to, we can definitely write a custom CPU op that will dump the simulation in hdf5 from within jitted code, and from within the lax.scan.

In this particular instance, I think it would be pretty cool

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

So yeah I don't see any drawbacks of using a scan :-)

@eelregit
Copy link
Owner

Thanks! This was very much how it was done here. Also here for the adjoint.

So it's good to know that XLA or JAX has gotten better on this.

but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

I guess you meant nested scan's. We want interpolation between two steps. It looks like odeint is extrapolating from the last step? But interpolation should also be okay with nested scan's.

@EiffL
Copy link
Author

EiffL commented Nov 23, 2022

In odeint they use a while inside the scan function yes.

Would you be ok with an API with an argument which would be the array a to use in the solver, and maybe another optional array save_at which would contain the indices of the snapshots to export. By default it would be [-1]. If so I'm happy to implement it :-)

And then, I think it would be very cool to have the ability to do IO directly from jitted code :-) And I think I know how to do it, but probably that's for a different PR.

@eelregit eelregit force-pushed the master branch 2 times, most recently from 418337c to a5329ae Compare November 26, 2022 06:26
@eelregit
Copy link
Owner

eelregit commented Nov 29, 2022

Let's try switching to scan following the odeint way, once the checkpoint (exactly at a time step, directly copying disp and vel) and snapshot (interpolation between 2 steps) observables are implemented. @Yucheng-Zhang is working on those observables.

Yes, it'd be super cool to have a custom IO op ^^

@eelregit
Copy link
Owner

id_tap seems to be useful in writing snapshots

@eelregit eelregit force-pushed the master branch 2 times, most recently from be9f8a4 to 52ec0b4 Compare March 8, 2023 15:33
@eelregit eelregit force-pushed the master branch 4 times, most recently from 3e1b213 to c2f0c24 Compare March 31, 2023 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants