About a year ago, I read a blog post on automatic differentiation, a cute technique which automatically computes derivatives (along with generalizations like gradients and Jacobians). But that’s not actually that interesting — after all, we could just use finite differences to calculate derivatives:
choosing a small \(h\). Unfortunately, this kind of numerical differentiation usually doesn’t work too well in practice. If you make \(h\) too small, then your accuracy gets killed by floating point roundoff, and if \(h\) is too big, then approximation errors start ballooning.^{[1]}
Automatic differentiation avoids these problems entirely: it calculates exact derivatives, so your accuracy is only limited by floating point error.
The applications of automatic differentiation should be pretty obvious. But just in case it isn’t, I’ll just point out that Google’s new machine learning framework TensorFlow, along with its competitor (and inspiration?) Theano, leverage automatic differentiation heavily.
What is Automatic Differentiation?
Automatic differentiation is really just a jumpedup chain rule. When you implement a function on a computer, you only have a small number of primitive operations available (e.g. addition, multiplication, logarithm). Any complicated function, like \(\frac{\log 2x}{x ^ x}\) is just a combination of these simple functions.
In other words, any complicated function \(f\) can be rewritten as the composition of a sequence of primitive functions \(f_k\):
Because each primitive function \(f_k\) has a simple derivative, we can use the chain rule to find \(\frac{df}{dx}\) pretty easily.^{[2]}
Although I’ve used a singlevariable function \(f: \mathbb{R} \rightarrow \mathbb{R}\) as my example here, it’s straightforward to extend this idea to multivariate functions \(f: \mathbb{R}^n \rightarrow \mathbb{R}^m\).
Forward Mode
There are actually two different modes of automatic differentiation, based on how you apply the chain rule. We’ll start with forward mode automatic differentiation, which I find a little more intuitive.
Basics: Partial Derivatives
Given a function \(f\), we can construct a computational graph (a directed acyclic one) representing our function. For example, given the function \(f(x, y) = \cos x \sin y + \frac{x}{y}\), we can construct the graph:
Each node of the graph represents a primitive function, while the edges represent the flow of information. In this example, the topmost node \(w_7\) represents the value of \(f(x, y)\) while the bottommost nodes \(w_1\) and \(w_2\) represent our input variables.
Forward differentiation works by recursively defining derivatives of nodes in terms of their parents. For example, suppose we want to calculate the partial \(\pder{f}{x}\). For reasons that’ll become clear in two paragraphs, let’s denote \(\pder{f}{x}\) using a derivative operator (i.e. \(\pder{f}{x} = D f\)). Then \(D f = D w_7\), and:
The final value of \(D f\) depends only on \(x\), \(y\), \(D w_1\) and \(D w_2\). In this case, we’ve let \(D = \frac{\partial}{\partial x}\), so \(D x = 1\) and \(D y = 0\). But if we let \(D = \frac{\partial}{\partial y}\), all we have to change is \(D x\) (from 1 to 0) and \(D y\) (from 0 to 1). Everything else stays the same.
This neatly extends to calculating arbitrary directional derivatives by defining \(\langle D x, D y \rangle\) to be the unit vector in the direction of our derivative. We can also set \(D = \nabla\) and use vector addition/subtraction instead of scalar addition to calculate gradients.
Why is this called forward mode? Well, our actual values start at the bottom of our graph and flow to the top, just like they do when we evaluate the expression. Because information flows in the same direction as when we evaluate the expression (bottomup), we call this "forward mode". Predictably, the information flows topdown in "reverse mode".
Runtime Complexity
Calculating derivatives of certain primitive functions requires both the values and the derivatives of their component parts. For example:
requires both the values of \(w_i\) and \(w_j\) along with the derivatives \(D w_i\) and \(D w_j\). If \(CD(w)\) is the cost of computing the derivative of node \(w\) and \(CV(w)\) is the cost of computing the values of node \(w\), then:
where the 3 is for the 3 extra arithmetic operations (two multiplications and one addition). More generally:
for some constant \(c\).
Consider an expression like \(x x x x x x x x x x x x x x x x x\). Under a naive differentiation scheme, we might have to recompute the value for node \(xxx\) 8 or 9 times, leading to a quadratic blowup. Under a smarter scheme, we could calculate the value and the derivative of a given node at the same time to see:
If our function is \(f: \mathbb{R}^n \rightarrow \mathbb{R}\), composed of \(P(f)\) primitive operations, then \(CD(f) + CV(f)\) is clearly linear in \(P(f)\). Calculating directional derivatives, then, is linear in the number of primitive operations.
For gradients, all our scalar operations become vector operations, so \(c\) becomes \(cn\) (vector operations are linear in the vector’s dimensionality). The cost of computing gradients is thus linear in \(n P(f)\).
Reverse Mode
Reverse mode automatic differentiation lets you calculate gradients much more efficiently than forward mode automatic differentiation.
Consider our old function \(f(x, y) = \cos x \sin y + \frac{x}{y}\) and its computation graph:
Suppose we want to calculate the gradient \(\nabla f = \langle \pder{f}{x}, \pder{f}{y} \rangle = \langle \pder{w_7}{w_1}, \pder{w_7}{w_2} \rangle\). Then:
Note

For some reason that I couldn’t find, the intermediate values \(\pder{w_7}{w_k}\) are called adjoints. 
As the name "reverse mode" suggests, things are reversed here. Information flows topdown here: instead of using the chain rule to find derivatives of parents in terms of derivatives of their children, we find derivatives of children nodes in terms of their parents. Because information flows topdown, the reverse of normal evaluation, we call this "reverse mode" automatic differentiation.
Runtime Complexity
If we reuse our notation of \(CD\) for the cost of calculating a gradient and \(CV\) for the cost of calculating a value, we see:
Notice that all of our operations in reverse mode differentiation are scalar operations, even though we’re calculating the (vector) gradient. Thus, the last term is a \(+ c\) and not a \(+ c n\). Computing gradients via reverse mode is thus linear in \(P(f)\), and not in \(nP(f)\), which can be a big speedup if \(f\) takes lots of input variables.
The only way I can see to calculate directional derivatives via reverse mode differentiation is to take a dot product of the gradient. In that case, the runtime cost of finding a directional derivative is linear in \(P(f) + n\): \(P(f)\) for the gradient and \(n\) for the dot product.
Thoughts on Implementation
I went ahead and implemented some basic automatic differentiation. It only supports arithmetic operators (addition, subtraction, multiplication, division, and exponentiation), although it should be pretty trivial to add support for unary functions like \(\sin\) or \(\log\).
I went with the easy approach of creating an Expr
class and forcing
the user to manually build their computation graph (via operator
overloading). While that’s certainly feasible for something where people
are already building these computation graphs (e.g. TensorFlow or
Theano), this isn’t ideal. After all, why should someone have to
completely replace their compute engine just with yours to do
automatic differentiation?
A better solution would be something that parses source code of a function and uses the abstract syntax tree to build the computation graphs manually. Another solution would be to use dual numbers. But both of these are significantly more challenging to implement, so I didn’t bother.
Rust
I originally started implementing things in Rust. Algebraic data types seemed like the perfect way to represent computation graphs, which narrowed my initial choices to Rust or Haskell. I’m a little too rusty with Haskell to be really productive, and I wanted to play around with Rust some more anyways, so I choose Rust.
You can see the code for this on Github.
Design
Ideally, I’d want Expr
to look something like:
enum Expr {
Constant(f64),
Variable(string),
Add(Expr, Expr),
Sub(Expr, Expr),
Mul(Expr, Expr),
...
}
Unfortunately for us, Rust is a systems language that doesn’t support
this kind of recursive structures. After all, how many bytes should Rust
allocate to Expr
in memory? There’s no good answer, because the memory
usage of Expr
must be enough to allocate an Expr
along with other
stuff. That’s why we have pointers.
It took a fair amount of struggling, before I realized that Rust’s
reference counted pointers std::rc::Rc
were perfect for the job:
shared immutable ownership of data (i.e. a node in a graph can have
multiple parents). Luckily, computation graphs are acyclic, so I didn’t
have to mess around with weak references or any other cyclebreaking
mechanism.
To make sure that the user doesn’t need to worry about any lifetime
stuff, I actually made a private enum InnerExpr
to store the
computation graph and made the publicfacing Expr
struct a thin
wrapper around a std::rc::Rc<InnerExpr>
.
I also decided to split off the Add
, Sub
, Mul
, Div
, and Pow
variants into their own Arithmetic
subenum. I thought that this would
make the code a little more modular by separating the computation out
from the plumbing, sotospeak. In retrospect, the subenum was way
more trouble than it was worth.
I represented points as hashmaps mapping the variable names (strings) to their values. It’s a little more verbose than I’d like, but I couldn’t think of any better alternatives.
I only ended implementing forward mode directional differentiation in Rust before moving on to Python (Rust is great, but developing in it is definitely slower than developing in Python, and I don’t really care about performance or memory efficiency here). To prevent the quadratic blowup I mentioned earlier, I calculated both the value and the derivative of a node and used a lightweight struct to bubble it up. Otherwise, this was a pretty straightforward recursive implementation.
Reflections
I should start by saying that for a systems language, Rust was surprisingly nice to develop in. Much nicer than C (shudder). The type system was a giant plus, and the memory management was surprisingly easy once I got the hang of it. Still, it’s definitely more verbose than Python is, and there were a bunch of papercuts that really irritated me.
My biggest irritation with Rust was their operator overloading. In Rust, operators take their arguments by value, and thus claim ownership of their arguments. For example:
let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = a + b; // c now owns a and b
println!("{}", a.eval()) // so this fails
The only solution I could think of was to implement operator overloading for references (they’re kind of like const pointers), so things like:
let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = &a + &b;
println!("{}", a.eval())
work. It’s a little annoying to have to write &
everywhere, but that’s
not a huge deal. Infinitely more annoying is the fact because operators
returned an Expr
and not an &Expr
, proper operator chaining forces
me to implement everything four times:
let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = &a + &b; // Ref + Ref
let d = &a * &a + &b; // Value + Ref
let e = &a + &b * &b; // Ref + Value
let f = &a * &a + &b * &b; // Value + Value
That’s really annoying. I eventually wrote a macro to implement overloading, which cut the verbosity by a lot. The flip side is that I suck at reading macros, so it takes me a couple extra seconds to figure out what the hell my macro is doing everytime I read it.
This complaint’s probably more because of my lack of knowledge than any
shortcoming with Rust, but I’ll make it anyways. The interaction between
String
and &str
and &String
can get very annoying very quickly.
I don’t remember the exact details, but when I made InnerExpr
use
&str
to store variable names, I ran into all kinds of irritating
lifetime issues, so I head to make InnerExpr
use String
. But when I
made my hashmaps use keys of type &str
, I ran into all kinds of
ownership problems trying to lookup my String
names. In the end, I
just used String
everywhere, which meant that I needed to call
to_owned
all over the place. Not the end of the world, but definitely
another papercut.
This is more of a complaint about the Rust’s immaturity than the language itself. Testing in Rust is much more annoying than what I’m used to from Python.
Part of it is that I’m spoiled from py.test pdb
, which opens up an
interactive debugger when your Python tests fail. Not being able to just
open up an interpreter made the information Rust gives you with
assert!
seem painfully limited.
I also couldn’t find anything providing setup teardown style testing,
let alone fixtures from py.test
. It didn’t make a huge difference for
this project, but it’s still kind of irritation.
There also doesn’t seem to be a builtin float comparer. I understand
that technically nan != nan
, but for testing purposes, it’d be nice if
there was an equality function that handles edgecases like nan
and
inf
and +0.0
vs 0.0
correctly, along with having some tolerance
for floatingpoint roundoff errors.^{[3]}
I also looked into a port of
Quickcheck for Rust. It’s
nice, but the test data it generated for floats didn’t include any of
the wonky edge cases like +0.0
versus 0.0
or nan
. I’m not sure
why, especially because those are the things which usually make
everything go to hell.
Python
Eventually, I got frustrated with Rust’s verbosity, so I switched over to my goto Python. My Python implementation is pretty different from my Rust implementation: I used inheritance to control method dispatch instead of pattern matching and algebraic data types, and I definitely took advantage of Python’s dynamic capabilities.
You can look at my implementation here.
The highlevel design of the Python implementation was pretty similar to
my Rust implementation. We have an Expr
class that represents a given
node and expose the userfacing interface (the operator overloading and
eval
, forward_diff
, and reverse_diff
). The actual implementation
is handled by the subclasses (e.g. Add
). I used dictionaries of names
(strings) to values (floats) to store points.
The subclasses all implemented a private _eval
method which
recursively populated a dictionary mapping node ids (ints) to their
evaluated value (floats). eval
just used _eval
to fill out this
cache and does a lookup for the current node’s value. Normally, I
would’ve just used the builtin functools.lru_cache
for caching, but
the input (points/dictionaries) aren’t hashable, so lru_cache
woudln’t
have worked.
forward_diff
first calls _eval
to populate the cache, which it
passes to _forward_diff
to do the actual computation. _forward_diff
is a pretty standard recursive implementation.
For reverse_diff
, we again use _eval
to populate the value cache
before dispatching to _reverse_diff
for the actual work. _reverse_diff
takes advantage of the fact that Python dictionaries are mutable; we
pass in we pass in the dictionary that reverse_diff
will return to
_reverse_diff
, and _reverse_diff
modifies it inplace. Otherwise,
it’s a pretty standard recursive algorithm, where each node merely
calculates its adjoint and passes it to its children. Only variable
nodes actually modify the output dictionary.
There is one subtlety regarding the chain rule. Imagine a situation where a node \(w_i\) has two parents, \(w_j\) and \(w_k\). Then:
Our implementation calculates \(\pder{f}{w_j} \pder{w_j}{w_i}\) and \(\pder{f}{w_k} \pder{w_k}{w_i}\) separately. Luckily, addition is associative and commutative, so it doesn’t matter and we compute the right answer anyways.
numpy
for this.