36-651/751: Automatic Differentiation

– Spring 2019, mini 3 (last updated February 12, 2019) all courses · refsmmat.com

As statisticians, we often have to compute derivatives and gradients:

Optimization
Many optimization algorithms either estimate gradients or use exact gradients when available, to better maximize the target function.
Neural networks
Fitting a neural network is, of course, an optimization problem – but a particularly big and unpleasant nonlinear one. It’s usually done via gradient descent, and that requires a gradient.
MCMC
Hamiltonian Monte Carlo in particular involves numerically integrating differential equations written in terms of the posterior of interest, and hence requires its derivatives.

There are simple ways to get derivatives, of course. If we have a function f(x), we can approximate its derivative with a simple finite difference, like

\displaystyle \frac{df(x)}{dx} \approx \frac{f(x + \Delta x) - f(x)}{\Delta x},

where we pick a “sufficiently small” \Delta x, being careful to avoid floating-point error. There are higher-order finite difference methods with smaller errors that require evaluating the function at more points.

But these methods are all approximate, and getting better accuracy requires evaluating f(x) at more points to do a higher-order approximation. And for a multivariate function g(x, y, z, ...) we need to select \Delta x, \Delta y, \Delta z, and so on, being careful to avoid problems if the variables have very different scales, and must evaluate g with perturbations along every dimension to approximate its gradient. This O(d) scaling in the number of dimensions is not pleasant when you’re fitting an enormous neural network with millions of parameters.

We could avoid approximate derivatives by analytically calculating the exact derivatives and writing code to calculate those as well, either by hand or using a symbolic math tool like Mathematica or SymPy. Symbolic math tools (also called Computer Algebra Systems) take in the structure of a mathematical expression as data, and can rewrite the expression into a different expression and then evaluate it; hence they can apply replacement rules like

\begin{aligned} \frac{d}{dx} \left( f(x) + g(x) \right) &\to \frac{d}{dx} f(x) + \frac{d}{dx} g(x) \\ \frac{d}{dx} \left( ax^b \right) &\to ba x^{b - 1}, \end{aligned}

and so on. Often, mechanically taking derivatives like this can lead to very large mathematical expressions, and turning these back into code is tedious and error-prone. It also doesn’t scale well when you need gradients of huge multivariate functions. And it doesn’t always work: what if my answers are calculated with a bunch of for loops and if statements and recursive functions? How do I write out the derivatives of those?

This is the problem automatic differentiation attempts to solve. In a good automatic differentiation system, you write code to calculate f(x), and that code is augmented (by a library or special compiler) to also calculate any derivative you want, exactly, without using either numerical approximations or symbolic manipulation.

There are two main ways to do automatic differentiation, “forward mode” and “reverse mode”. Let’s start with forward-mode automatic differentiation, since it’s the most direct to understand.

(As a reference, I recommend Automatic Differentiation in Machine Learning: a Survey.)

Forward-mode

(I’ll be following the steps of this tutorial.)

Suppose you want to calculate some function f(x, y):

f(x, y) = xy + \sin(x)

If you’re writing Python or R, you’d probably just write something like

f <- function(x, y) {
    x * y + sin(x)
}

and bam! Done. But if you want derivatives, you have to do some work. The work is easy in this case, but suppose f were some arbitrarily complicated function, with loops, conditionals, function calls, and so on. Do I have to manually write out the derivatives of every function I write and call?

We can make progress by realizing that every calculation is composed of many simpler calculations: a whole bunch of addition, subtraction, multiplication, trigonometry, special functions, and so on. The derivatives of those calculations are easy, so let’s break f down into its pieces: a series of equations, each involving a single elementary operation:

\begin{aligned} x &= \text{something}\\ y &= \text{something}\\ a &= xy\\ b &= \sin(x)\\ f(x, y) &= a + b \end{aligned}

This is sometimes called an evaluation trace or a Wengert list, and expands out the evaluation into many steps with intermediate variables.

Suppose we differentiate with respect to t, some arbitrary variable:

\newcommand\pt[1]{\frac{\partial #1}{\partial t}} \begin{aligned} \pt{x} &= \text{something} \\ \pt{y} &= \text{something} \\ \pt{a} &= x \pt{y} + y \pt{x} \\ \pt{b} &= \cos(x) \pt{x} \\ \pt{f(x, y)} &= \pt{a} + \pt{b} \end{aligned}

We can do this with merely the product rule, quotient rule, and chain rule – the chain rule most crucially of all.

(You might be thinking: hang on, you said there’s no symbolic differentiation here! I lied. There is symbolic differentiation using these rules, as well as hard-coded knowledge that \frac{d}{dx}\sin(x) = \cos(x), but by breaking the calculation into the evaluation trace, we need only very simple symbolic rules.)

Suppose we let t = x. What do we get?

\newcommand\pt[1]{\frac{\partial #1}{\partial x}} \begin{aligned} \pt{x} &= 1 \\ \pt{y} &= 0 \\ \pt{a} &= y \\ \pt{b} &= \cos(x) \\ \pt{f(x, y)} &= y + \cos(x) \end{aligned}

Bam! A derivative of f(x, y). We can do the same for y by letting t = y.

You’re probably objecting now: but this is just a particularly laborious way of writing out the chain rule! Yes, but it also suggests how we may write a program to do this automatically. Specifically:

  1. Break down every calculation into the evaluation trace. This is no problem – the program is calculating one step at a time anyway.
  2. Every time we calculate an intermediate value like a, we also calculate \frac{\partial a}{\partial t}.
  3. Run the program with \frac{\partial x}{\partial t} = 1 if you want to differentiate with respect to x, or do the same with y to get its derivative instead. Run the program once for each input variable to get the gradient.

The simplest way to achieve this is to say that every variable x in your code is actually a dual number: a value x + \Delta \dot x, where \Delta here is a special number such that \Delta^2 = 0 but \Delta \neq 0. This implies simple rules of arithmetic:

\begin{aligned} (x + \Delta \dot x) + (y + \Delta \dot y) &= (x + y) + \Delta (\dot x + \dot y) \\ (x + \Delta \dot x) (y + \Delta \dot y) &= xy + \Delta (x \dot y + y \dot x), \end{aligned}

and so on. Observe that these rules of arithmetic are just like the product and chain rule.

We then write our code so that every variable x is actually a tuple (x, \dot x) representing x + \Delta \dot x, and every arithmetic operation is one that applies the appropriate arithmetic rules for dual numbers. We also replace elementary mathematical functions like \sin(x) with ones that operate on tuples and return the right value:

\begin{aligned} \sin\left((x, \dot x)\right) &= (\sin(x), \cos(x) \dot x) \\ \cos\left((x, \dot x)\right) &= (\cos(x), -\sin(x) \dot x) \\ \sqrt{(x, \dot x)} &= \left(\sqrt{x}, \dot x / (2 \sqrt{x})\right) \\ \dots &= \dots \end{aligned}

Once we’ve established this foundation – and that requires either a language supporting function and operator overloading, or writing our code using special new function calls instead of ordinary arithmetic operators and math functions – we can write all our code in terms of dual numbers and functions operating on dual numbers. This code can look just like what we would have written just to evaluate the function, but it can be used to get exact derivatives with the chain rule.

To get derivatives with respect to x, set (x, \dot x) = (x, 1) and (y, \dot y) = (y, 0). For y, swap the 0 and 1. If all we need is the dot product \nabla f(x, y) \cdot r, for some vector r, set (\dot x, \dot y) = r.

See hyperreal.R and autodiff.scm for toy implementations of this idea in both R and Scheme.

Forward-mode automatic differentiation has been implemented for many languages, like Julia, Python, and C#/F#. For R, the madness package uses S4 classes and overloads – like my toy hyperreal.R – to do automatic differentiation of multivariate functions, with many overloaded methods defined for things like eigen() and det() and all common arithmetic operations. The intended use case is applying the delta method to arbitrarily weird estimators.

But forward-mode automatic differentiation is not the method of choice for many statistical and machine learning problem. For a function f : \mathbb{R}^d \to \mathbb{R}^m, forward mode requires O(d) evaluations of f to obtain the gradient, but d is often quite large (e.g. millions of parameters in a neural net). That’s where reverse mode appears: it only requires O(m) evaluations to get a gradient, which is great when d \gg m.

Reverse mode

In reverse mode, the function is evaluated normally, but its derivative is evaluated in reverse.

Specifically: First we proceed through the evaluation trace as normal, calculating the intermediate values and the final answer, but not using dual numbers. We store all the intermediate values and record which values depended on which inputs. Then, we run in reverse, using the outputs to calculate the derivatives.

Let’s again consider

f(x, y) = xy + \sin(x),

and suppose we want to know its gradient \nabla f(x, y). We have the evaluation trace

\begin{aligned} x &= \text{something}\\ y &= \text{something}\\ a &= xy\\ b &= \sin(x)\\ f(x, y) &= a + b \end{aligned}

and we have the values for f(x, y), a, b, and so on. Let z = f(x, y) for the given values of x and y we want the gradient for. Instead of applying the chain rule with respect to inputs, we differentiate each step with respect to z, the output. That requires reversing the evaluation list:

\newcommand\pz[1]{\frac{\partial z}{\partial #1}} \begin{aligned} \pz{f(x, y)} &= 1 \\ \pz{b} &= \frac{\partial f(x, y)}{\partial b} \pz{f(x, y)} &&= \pz{f(x, y)} \\ \pz{a} &= \frac{\partial f(x, y)}{\partial a} \pz{f(x, y)} &&= \pz{f(x, y)} \\ \pz{y} &= \frac{\partial a}{\partial y} \pz{a} &&= x \pz{a} \\ \pz{x} &= \frac{\partial a}{\partial x} \pz{a} + \frac{\partial b}{\partial x} \pz{b} &&= y \pz{a} + \cos(x) \pz{b} \end{aligned}

Notice how, in this order,

Implementation of reverse-mode automatic differentiation is more difficult than for forward mode. Dual numbers don’t work because we need to somehow store the operations made and then run through them in reverse.

A typical strategy is to use operator and function overloading to make arithmetic operations store a trace (often called a tape) of the operations rather than actually calculating the operation. That is, write something like

`+` <- function(x,  y) {
    list(op = "add",
         arg1 = x,
         arg2 = y)
}

The result of a program that calculates a value is hence not the value but a data structure of all the operations that must be done to calculate that value; you can process the data structure to calculate the value, storing the intermediate results, and then process the data structure in reverse to calculate the derivatives. The data structure is an expression graph, since it connects each intermediate value to its predecessor values.

If we don’t have operator overloading, we can use special functions and classes replacing ordinary numbers and arithmetic operations. Theano took this approach, and one advantage was that since you essentially use these special methods in advanced to create an expression graph, Theano could use graph algorithms to try to optimize the graph and make the calculation faster; on the other hand, you were limited to the operations provided by Theano, and couldn’t use arbitrary control flow.

A common problem is large expression graphs. If you write a calculation over many values (say, a log-likelihood involving sums and products over all the data), there can be many intermediate values and a huge expression graph; this can take up a lot of memory and make evaluating the gradients slow. Reverse-mode packages often provide shortcut functions like mean and sum whose derivatives are hard-coded, so instead of having to calculate the derivative of mean by breaking it into its individual operations, the package knows how to calculate the derivative directly in one step.

Reverse mode is widely used. In C++, the Stan Math library provides templates and classes; Python’s autograd supports reverse mode; and most deep learning packages like TensorFlow and PyTorch use reverse-mode automatic differentiation.

Backpropagation is automatic differentiation

When we build a feed-forward neural network, we often train the weights and biases via gradient descent. We take our training data, feed it into the network, and get output; the output is used to calculate a loss function, and we take the gradient of the loss with respect to the weights and biases.

Reverse-mode automatic differentiation makes sense here to get the gradient: there are many parameters but one loss, so we can calculate the gradient in just one reverse pass through the network. Backpropagation is just the special case of reverse-mode automatic differentiation applied to a feed-forward neural network.

Most deep learning frameworks now use full reverse mode, not just backpropagation, so you can write arbitrary activation functions and weird network architectures and always get correct gradients.

Programs are differentiable

In statistics and machine learning, we’re used to solving problems by defining some function to optimize. We define a model and then find parameters that maximize its likelihood; we choose a clustering method and find parameters that maximize the separation of classes; we choose a prediction algorithm and find parameters that maximize its prediction accuracy. We usually do this mathematically, by writing the model or algorithm mathematically, deriving the appropriate function to maximize, and then finding the appropriate method to maximize it – maybe gradient ascent or Newton’s method or something else.

Step 1 of Automatic Differentiation Enlightenment is to realize that you can maximize more difficult functions, because you can calculate their gradients very easily and plug them into your maximization algorithm.

Step 2 of Automatic Differentiation Enlightenment is realizing that your method no longer needs to be expressible as a mathematical function. Your method can be a program. You can write some goofy Python script that uses some parameters to make predictions using whatever rules you can write in Python code, and then differentiate it so you can fit the parameters with gradient descent.

Neural networks and statistical models are just special cases of differentiable programs. We could write any kind of program and differentiate it (though, of course, the derivatives may be arbitrarily unpleasant and there may be numerous local maxima and singularities and so on).

Some examples:

Inverse graphics
Consider interpreting an image or video. Suppose it’s a video of a person walking and you want to detect their limbs and track their gait. Write a generative program that has a model of a person, with parameters for their leg positions and lengths and angles and so on, plus models for lighting and shading and the colors in the scene, and generates a picture – basically, use a video game graphics engine to draw the person. Now differentiate that graphics engine to fit the model to the real image you have as data. OpenDR, for example, can render a scene and give the derivatives of all pixels with respect to model parameters. Differentiable Halide can be used to do image processing in networks or to solve inverse problems.
Generative Bayesian models
Stan does MCMC by letting users write a program – in special Stan syntax – that calculates the likelihood and specifies the prior. The program can involve loops and control flow, and it builds up a log-likelihood function that Stan can differentiate to run Hamiltonian Monte Carlo and efficiently draw from the posterior.
Inferring programs
One group has developed a differentiable interpreter for Forth, a stack-based programming language. The source code can have “sketches”, essentially pieces of code that say “I dunno what code goes here, but presumably it reads from these values and uses them to decide which of these instructions to evaluate”. If you have examples of inputs and outputs, you can differentiate the program and maximize to fill in the sketches with real code.