-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Hi @martinResearch and thanks for creating this package, which I stumbled upon while reading jax-ml/jax#1032. This issue is about understanding it a little better.
From what I can tell, known sparsity patterns are most useful in conjunction with coloring algorithms, because they reduce the number of forward (or reverse) passes needed to compute a Jacobian. Typically, the Jacobian of a function
As you stated in jax-ml/jax#1032 (comment), your library does not rely on this paradigm. When you talk of a "single forward pass using matrix-matrix products at each step of the forward computation", is it correct that you still end up computing the JVP with every single basis vector? In other words, while the runtime may be low in practice thanks to vectorization and efficient sparse operations, the theoretical complexity remains
Thanks in advance for your response
ping @adrhill