Skip to content

Does this approach save forward passes? #18

@gdalle

Description

@gdalle

Hi @martinResearch and thanks for creating this package, which I stumbled upon while reading jax-ml/jax#1032. This issue is about understanding it a little better.

From what I can tell, known sparsity patterns are most useful in conjunction with coloring algorithms, because they reduce the number of forward (or reverse) passes needed to compute a Jacobian. Typically, the Jacobian of a function $f : R^n \to R^m$ would require $n$ forward passes, one for each of the Jacobian-vector products associated with the basis vectors of $R^n$. Grouping basis vectors together is what the coloring step is all about, and it can reduce the number of forward passes from $O(n)$ to $O(1)$ in the best cases. See https://epubs.siam.org/doi/10.1137/S0036144504444711 for more details, or the Example 5.7 from https://tomopt.com/docs/TOMLAB_MAD.pdf#page26.

As you stated in jax-ml/jax#1032 (comment), your library does not rely on this paradigm. When you talk of a "single forward pass using matrix-matrix products at each step of the forward computation", is it correct that you still end up computing the JVP with every single basis vector? In other words, while the runtime may be low in practice thanks to vectorization and efficient sparse operations, the theoretical complexity remains $O(n)$?

Thanks in advance for your response

ping @adrhill

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions