Amateur Hour

Automatic Differentiation

Or mathemagically finding derivatives


TLDR: I talk about a technique called automatic differentiation, going through a mathematical derivation before examining two different implementations: one in Rust and one in Python.

About a year ago, I read a blog post on automatic differentiation, a cute technique which automatically computes derivatives (along with generalizations like gradients and Jacobians). But that’s not actually that interesting — after all, we could just use finite differences to calculate derivatives:

\[\frac{df}{dx} \approx \frac{f(x + h) - f(x)}{h}\]

choosing a small \(h\). Unfortunately, this kind of numerical differentiation usually doesn’t work too well in practice. If you make \(h\) too small, then your accuracy gets killed by floating point roundoff, and if \(h\) is too big, then approximation errors start ballooning.[1]

Automatic differentiation avoids these problems entirely: it calculates exact derivatives, so your accuracy is only limited by floating point error.

The applications of automatic differentiation should be pretty obvious. But just in case it isn’t, I’ll just point out that Google’s new machine learning framework TensorFlow, along with its competitor (and inspiration?) Theano, leverage automatic differentiation heavily.

What is Automatic Differentiation?

Automatic differentiation is really just a jumped-up chain rule. When you implement a function on a computer, you only have a small number of primitive operations available (e.g. addition, multiplication, logarithm). Any complicated function, like \(\frac{\log 2x}{x ^ x}\) is just a combination of these simple functions.

In other words, any complicated function \(f\) can be rewritten as the composition of a sequence of primitive functions \(f_k\):

\[f = f_0 \circ f_1 \circ f_2 \circ \ldots \circ f_n\]

Because each primitive function \(f_k\) has a simple derivative, we can use the chain rule to find \(\frac{df}{dx}\) pretty easily.[2]

Although I’ve used a single-variable function \(f: \mathbb{R} \rightarrow \mathbb{R}\) as my example here, it’s straightforward to extend this idea to multivariate functions \(f: \mathbb{R}^n \rightarrow \mathbb{R}^m\).

Forward Mode

There are actually two different modes of automatic differentiation, based on how you apply the chain rule. We’ll start with forward mode automatic differentiation, which I find a little more intuitive.

Basics: Partial Derivatives

Given a function \(f\), we can construct a computational graph (a directed acyclic one) representing our function. For example, given the function \(f(x, y) = \cos x \sin y + \frac{x}{y}\), we can construct the graph:

Computation Graph

Each node of the graph represents a primitive function, while the edges represent the flow of information. In this example, the top-most node \(w_7\) represents the value of \(f(x, y)\) while the bottom-most nodes \(w_1\) and \(w_2\) represent our input variables.

Forward differentiation works by recursively defining derivatives of nodes in terms of their parents. For example, suppose we want to calculate the partial \(\pder{f}{x}\). For reasons that’ll become clear in two paragraphs, let’s denote \(\pder{f}{x}\) using a derivative operator (i.e. \(\pder{f}{x} = D f\)). Then \(D f = D w_7\), and:

\[\begin{align} D w_7 &= D (w_5 + w_6) = D w_5 + D w_6 \\ D w_6 &= D \frac{w_1}{w_2} = \frac{w_1 D w_2 - w_2 D w_1}{w_2 ^ 2} \\ D w_5 &= D w_3 w_4 = w_3 D w_4 + w_4 D w_3 \\ D w_4 &= D \sin w_2 = \cos w_2 \cdot D w_2 \\ D w_3 &= D \cos w_1 = -\sin w_1 \cdot D w_1\\ D w_2 &= D y \\ D w_1 &= D x \end{align}\]

The final value of \(D f\) depends only on \(x\), \(y\), \(D w_1\) and \(D w_2\). In this case, we’ve let \(D = \frac{\partial}{\partial x}\), so \(D x = 1\) and \(D y = 0\). But if we let \(D = \frac{\partial}{\partial y}\), all we have to change is \(D x\) (from 1 to 0) and \(D y\) (from 0 to 1). Everything else stays the same.

This neatly extends to calculating arbitrary directional derivatives by defining \(\langle D x, D y \rangle\) to be the unit vector in the direction of our derivative. We can also set \(D = \nabla\) and use vector addition/subtraction instead of scalar addition to calculate gradients.

Why is this called forward mode? Well, our actual values start at the bottom of our graph and flow to the top, just like they do when we evaluate the expression. Because information flows in the same direction as when we evaluate the expression (bottom-up), we call this "forward mode". Predictably, the information flows top-down in "reverse mode".

Runtime Complexity

Calculating derivatives of certain primitive functions requires both the values and the derivatives of their component parts. For example:

\[D w_i w_j = w_i D w_j + w_j D w_i\]

requires both the values of \(w_i\) and \(w_j\) along with the derivatives \(D w_i\) and \(D w_j\). If \(CD(w)\) is the cost of computing the derivative of node \(w\) and \(CV(w)\) is the cost of computing the values of node \(w\), then:

\[CD(w_i w_j) = CD(w_i) + CD(w_j) + CV(w_i) + CV(w_j) + 3\]

where the 3 is for the 3 extra arithmetic operations (two multiplications and one addition). More generally:

\[CD(w) \le \sum_{w_k \in \text{children}(w)} CD(w_k) + \sum_{w_k \in \text{children}(w)} CV(w_k) + c\]

for some constant \(c\).

Consider an expression like \(x x x x x x x x x x x x x x x x x\). Under a naive differentiation scheme, we might have to recompute the value for node \(xxx\) 8 or 9 times, leading to a quadratic blow-up. Under a smarter scheme, we could calculate the value and the derivative of a given node at the same time to see:

\[CD(w) + CV(w) \le \sum_{w_k \in \text{children}(w)} (CD(w_k) + CV(w_k) ) + c + 1\]

If our function is \(f: \mathbb{R}^n \rightarrow \mathbb{R}\), composed of \(P(f)\) primitive operations, then \(CD(f) + CV(f)\) is clearly linear in \(P(f)\). Calculating directional derivatives, then, is linear in the number of primitive operations.

For gradients, all our scalar operations become vector operations, so \(c\) becomes \(cn\) (vector operations are linear in the vector’s dimensionality). The cost of computing gradients is thus linear in \(n P(f)\).

Reverse Mode

Reverse mode automatic differentiation lets you calculate gradients much more efficiently than forward mode automatic differentiation.

Consider our old function \(f(x, y) = \cos x \sin y + \frac{x}{y}\) and its computation graph:

Computation Graph

Suppose we want to calculate the gradient \(\nabla f = \langle \pder{f}{x}, \pder{f}{y} \rangle = \langle \pder{w_7}{w_1}, \pder{w_7}{w_2} \rangle\). Then:

\[\begin{align} \pder{w_7}{w_1} &= \pder{w_3}{w_1} \pder{w_7}{w_3} + \pder {w_6}{w_1} \pder{w_7}{w_6} = - \sin{w_1} \pder{w_7}{w_3} + \frac{1}{w_2} \pder{w_7}{w_6} \\ \pder{w_7}{w_2} &= \pder{w_4}{w_2} \pder{w_7}{w_4} + \pder {w_6}{w_2} \pder{w_7}{w_6} = \cos{w_2} \pder{w_7}{w_4} - \frac{w_1}{w_2 ^ 2} \pder{w_7}{w_6} \\ \pder{w_7}{w_3} &= \pder{w_5}{w_3} \pder{w_7}{w_5} = w_4 \pder{w_7}{w_5} \\ \pder{w_7}{w_4} &= \pder{w_5}{w_4} \pder{w_7}{w_5} = w_3 \pder{w_7}{w_4} \\ \pder{w_7}{w_5} &= 1 \\ \pder{w_7}{w_6} &= 1 \end{align}\]
For some reason that I couldn’t find, the intermediate values \(\pder{w_7}{w_k}\) are called adjoints.

As the name "reverse mode" suggests, things are reversed here. Information flows top-down here: instead of using the chain rule to find derivatives of parents in terms of derivatives of their children, we find derivatives of children nodes in terms of their parents. Because information flows top-down, the reverse of normal evaluation, we call this "reverse mode" automatic differentiation.

Runtime Complexity

If we reuse our notation of \(CD\) for the cost of calculating a gradient and \(CV\) for the cost of calculating a value, we see:

\[CD(w) + CV(w) \le \sum_{w_k \in \text{children}(w)} (CD(w_k) + CV(w_k) ) + c\]

Notice that all of our operations in reverse mode differentiation are scalar operations, even though we’re calculating the (vector) gradient. Thus, the last term is a \(+ c\) and not a \(+ c n\). Computing gradients via reverse mode is thus linear in \(P(f)\), and not in \(nP(f)\), which can be a big speed-up if \(f\) takes lots of input variables.

The only way I can see to calculate directional derivatives via reverse mode differentiation is to take a dot product of the gradient. In that case, the runtime cost of finding a directional derivative is linear in \(P(f) + n\): \(P(f)\) for the gradient and \(n\) for the dot product.

Thoughts on Implementation

I went ahead and implemented some basic automatic differentiation. It only supports arithmetic operators (addition, subtraction, multiplication, division, and exponentiation), although it should be pretty trivial to add support for unary functions like \(\sin\) or \(\log\).

I went with the easy approach of creating an Expr class and forcing the user to manually build their computation graph (via operator overloading). While that’s certainly feasible for something where people are already building these computation graphs (e.g. TensorFlow or Theano), this isn’t ideal. After all, why should someone have to completely replace their compute engine just with yours to do automatic differentiation?

A better solution would be something that parses source code of a function and uses the abstract syntax tree to build the computation graphs manually. Another solution would be to use dual numbers. But both of these are significantly more challenging to implement, so I didn’t bother.


I originally started implementing things in Rust. Algebraic data types seemed like the perfect way to represent computation graphs, which narrowed my initial choices to Rust or Haskell. I’m a little too rusty with Haskell to be really productive, and I wanted to play around with Rust some more anyways, so I choose Rust.

You can see the code for this on Github.


Ideally, I’d want Expr to look something like:

enum Expr {
    Add(Expr, Expr),
    Sub(Expr, Expr),
    Mul(Expr, Expr),

Unfortunately for us, Rust is a systems language that doesn’t support this kind of recursive structures. After all, how many bytes should Rust allocate to Expr in memory? There’s no good answer, because the memory usage of Expr must be enough to allocate an Expr along with other stuff. That’s why we have pointers.

It took a fair amount of struggling, before I realized that Rust’s reference counted pointers std::rc::Rc were perfect for the job: shared immutable ownership of data (i.e. a node in a graph can have multiple parents). Luckily, computation graphs are acyclic, so I didn’t have to mess around with weak references or any other cycle-breaking mechanism.

To make sure that the user doesn’t need to worry about any lifetime stuff, I actually made a private enum InnerExpr to store the computation graph and made the public-facing Expr struct a thin wrapper around a std::rc::Rc<InnerExpr>.

I also decided to split off the Add, Sub, Mul, Div, and Pow variants into their own Arithmetic sub-enum. I thought that this would make the code a little more modular by separating the computation out from the plumbing, so-to-speak. In retrospect, the sub-enum was way more trouble than it was worth.

I represented points as hashmaps mapping the variable names (strings) to their values. It’s a little more verbose than I’d like, but I couldn’t think of any better alternatives.

I only ended implementing forward mode directional differentiation in Rust before moving on to Python (Rust is great, but developing in it is definitely slower than developing in Python, and I don’t really care about performance or memory efficiency here). To prevent the quadratic blow-up I mentioned earlier, I calculated both the value and the derivative of a node and used a lightweight struct to bubble it up. Otherwise, this was a pretty straightforward recursive implementation.


I should start by saying that for a systems language, Rust was surprisingly nice to develop in. Much nicer than C (shudder). The type system was a giant plus, and the memory management was surprisingly easy once I got the hang of it. Still, it’s definitely more verbose than Python is, and there were a bunch of papercuts that really irritated me.

Operator Overloading

My biggest irritation with Rust was their operator overloading. In Rust, operators take their arguments by value, and thus claim ownership of their arguments. For example:

let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = a + b;                  // c now owns a and b
println!("{}", a.eval())        // so this fails

The only solution I could think of was to implement operator overloading for references (they’re kind of like const pointers), so things like:

let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = &a + &b;
println!("{}", a.eval())

work. It’s a little annoying to have to write & everywhere, but that’s not a huge deal. Infinitely more annoying is the fact because operators returned an Expr and not an &Expr, proper operator chaining forces me to implement everything four times:

let a = Expr::constant(1.0);
let b = Expr::constant(1.3);
let c = &a + &b;                // Ref + Ref
let d = &a * &a + &b;           // Value + Ref
let e = &a + &b * &b;           // Ref + Value
let f = &a * &a + &b * &b;      // Value + Value

That’s really annoying. I eventually wrote a macro to implement overloading, which cut the verbosity by a lot. The flip side is that I suck at reading macros, so it takes me a couple extra seconds to figure out what the hell my macro is doing everytime I read it.


This complaint’s probably more because of my lack of knowledge than any shortcoming with Rust, but I’ll make it anyways. The interaction between String and &str and &String can get very annoying very quickly.

I don’t remember the exact details, but when I made InnerExpr use &str to store variable names, I ran into all kinds of irritating lifetime issues, so I head to make InnerExpr use String. But when I made my hashmaps use keys of type &str, I ran into all kinds of ownership problems trying to lookup my String names. In the end, I just used String everywhere, which meant that I needed to call to_owned all over the place. Not the end of the world, but definitely another papercut.


This is more of a complaint about the Rust’s immaturity than the language itself. Testing in Rust is much more annoying than what I’m used to from Python.

Part of it is that I’m spoiled from py.test --pdb, which opens up an interactive debugger when your Python tests fail. Not being able to just open up an interpreter made the information Rust gives you with assert! seem painfully limited.

I also couldn’t find anything providing set-up tear-down style testing, let alone fixtures from py.test. It didn’t make a huge difference for this project, but it’s still kind of irritation.

There also doesn’t seem to be a built-in float comparer. I understand that technically nan != nan, but for testing purposes, it’d be nice if there was an equality function that handles edge-cases like nan and inf and +0.0 vs -0.0 correctly, along with having some tolerance for floating-point roundoff errors.[3]

I also looked into a port of Quickcheck for Rust. It’s nice, but the test data it generated for floats didn’t include any of the wonky edge cases like +0.0 versus -0.0 or nan. I’m not sure why, especially because those are the things which usually make everything go to hell.


Eventually, I got frustrated with Rust’s verbosity, so I switched over to my go-to Python. My Python implementation is pretty different from my Rust implementation: I used inheritance to control method dispatch instead of pattern matching and algebraic data types, and I definitely took advantage of Python’s dynamic capabilities.

You can look at my implementation here.

The high-level design of the Python implementation was pretty similar to my Rust implementation. We have an Expr class that represents a given node and expose the user-facing interface (the operator overloading and eval, forward_diff, and reverse_diff). The actual implementation is handled by the subclasses (e.g. Add). I used dictionaries of names (strings) to values (floats) to store points.

The subclasses all implemented a private _eval method which recursively populated a dictionary mapping node ids (ints) to their evaluated value (floats). eval just used _eval to fill out this cache and does a lookup for the current node’s value. Normally, I would’ve just used the built-in functools.lru_cache for caching, but the input (points/dictionaries) aren’t hashable, so lru_cache woudln’t have worked.

forward_diff first calls _eval to populate the cache, which it passes to _forward_diff to do the actual computation. _forward_diff is a pretty standard recursive implementation.

For reverse_diff, we again use _eval to populate the value cache before dispatching to _reverse_diff for the actual work. _reverse_diff takes advantage of the fact that Python dictionaries are mutable; we pass in we pass in the dictionary that reverse_diff will return to _reverse_diff, and _reverse_diff modifies it in-place. Otherwise, it’s a pretty standard recursive algorithm, where each node merely calculates its adjoint and passes it to its children. Only variable nodes actually modify the output dictionary.

There is one subtlety regarding the chain rule. Imagine a situation where a node \(w_i\) has two parents, \(w_j\) and \(w_k\). Then:

\[\pder{f}{w_i} = \pder{f}{w_j} \pder{w_j}{w_i} + \pder{f}{w_k} \pder{w_k}{w_i}\]

Our implementation calculates \(\pder{f}{w_j} \pder{w_j}{w_i}\) and \(\pder{f}{w_k} \pder{w_k}{w_i}\) separately. Luckily, addition is associative and commutative, so it doesn’t matter and we compute the right answer anyways.

1. Admittedly, I’m no expert on numerical differentiation, so it’s entirely possible that these problems have been solved through more complicated formulas. On the other hand, numerical differentiation packages never really worked for me, which makes me suspect that this is a problem inherent with numerical differentiation.
2. This idea seems suspiciously similar to backpropagation, a method to efficiently train neural networks. I’d bet money that there’s some kind of historical connection between them, but I don’t know enough to be certain.
3. To be fair, I’m not sure Python has this built-in either. I usually use testing functions from numpy for this.