A minimal automatic differentiation library in Rust, inspired by the functional, composable architecture of JAX.
chainrule is an exploration of the core mechanics behind modern automatic differentiation frameworks. It is a project to understand and reconstruct these systems from first principles in a systems language.
A showcase demonstrating the training of a 3 layer MLP on the Fashion MNIST dataset
-
Reverse-mode automatic differentiation.
-
Support for higher-order gradients (
gradofgrad). -
Dynamic computation graph construction via function tracing.
-
Operator overloading for a multi-dimensional Tensor type (powered by
ndarray).
The current proposed API aims to implement a two-staged process:
#[trace]: a proc macro that rewrites normal Rust math into graph operations (*,+,.sin(), etc.).trace_fn: takes this graph builder, captures the operations as IR once, and returns aTraceableFnobject exposing.eval(),.grad(), etc.
use chainrule::{trace, trace_fn, Tensor};
use ndarray::array;
#[trace]
fn foo(x: Tensor, y: Tensor) -> Tensor {
x * y + 1.0
}
fn main() {
// 1. Convert the graph-builder function into a runnable object.
let f = trace_fn(foo);
// 2. Define input data.
let a = array![1., 2., 3.];
let b = array![4., 5., 6.];
// 3. Evaluate the function or its gradients.
let result = f.eval()((&a, &b));
let df = f.grad()((&a, &b));
let ddf = f.grad().grad().eval()((&a, &b));
println!("result: {:?}", result);
println!("d/dx w.r.t inputs: {:?}", df);
println!("d^2/dx^2 w.r.t inputs: {:?}", ddf);
}Note that functions are defined using the symbolic Tensor type to enable tracing, but the resulting TraceableFn is executed with concrete ndarray::Array types
Experimental:
- This project is currently in the design and early implementation phase, with incomplete documentation.
- The README describes the target architecture and the API as envisioned upon completion
- chainrule is built for learning and demonstration. The API is subject to breaking changes and performance is not a primary design goal.
Add this to your Cargo.toml:
[dependencies]
chainrule = { git = "https://github.com/rawcptr/chainrule.git" }