Backpropagation — the Chain Rule Unwrapped
Backpropagation is the algorithm that makes deep learning possible. It is nothing more than the chain rule of calculus applied systematically to a computational graph — yet when Rumelhart, Hinton and Williams popularised it in 1986 it transformed the field. We trace the math from scalar chain rules through Jacobians and matrix calculus, and build a minimal autograd engine that differentiates any expression automatically.
1. The Chain Rule
Suppose f(x) = g(h(x)). The chain rule says the derivative of f with respect to x is:
For a chain of n composed functions this generalises to a product of n partial derivatives. Backprop is exactly this product, evaluated layer-by-layer starting from the loss output.
2. Computational Graphs
A computational graph is a directed acyclic graph (DAG) where nodes are operations (+, ·, exp, …) and edges carry values. Every mathematical expression can be expressed as such a graph:
During the forward pass we evaluate values left-to-right. During the backward pass we multiply local partial derivatives right-to-left to accumulate each node's contribution to the total gradient.
3. Forward and Backward Passes
For the graph above, the backward pass (setting dL/dL = 1 as seed) propagates:
Each node accumulates gradient from all nodes that depend on it (multivariate chain rule). Where a value is used in multiple downstream nodes its gradient contributions are summed.
4. Jacobians and Matrix Gradients
When inputs and outputs are vectors, the "derivative" is the Jacobian matrix J, where Jᵢⱼ = ∂yᵢ/∂xⱼ. For a linear layer y = Wx + b:
Element-wise ops
For y = σ(z) element-wise, the Jacobian is diagonal: ∂L/∂z = ∂L/∂y ⊙ σ'(z). Just a pointwise product.
Softmax Jacobian
For y = softmax(z), Jᵢⱼ = yᵢ(δᵢⱼ − yⱼ) — a full dense matrix. Combined with cross-entropy loss it simplifies to ŷ − y.
Batch dimension
With mini-batch of m samples, weight gradients are averaged over the batch: dW = (δ · Aᵀ) / m to keep gradients scale-independent of batch size.
Gradient check
Verify analytical gradients numerically: compare dL/dw to [L(w+ε)−L(w−ε)]/(2ε). If they agree to ~6 decimal places, your backprop is correct.
5. Vanishing and Exploding Gradients
In a 20-layer network, the gradient of the first layer involves a product of 20 weight matrices and 20 activation derivatives. If each factor has magnitude < 1 (sigmoid saturates to σ' ≈ 0), the product shrinks exponentially → vanishing gradient. If each factor > 1, it grows exponentially → exploding gradient.
Solutions:
- ReLU activations: σ'(z) = 1 for z > 0 — no saturation in the positive half.
- Residual connections (ResNet): gradient highway bypasses layers via shortcut additions.
- Batch normalisation: normalises pre-activations per mini-batch, keeping them in the linear regime.
- Careful weight init: He init for ReLU (var = 2/fan_in), Xavier/Glorot for tanh.
- Gradient clipping: cap gradient norm to a threshold, preventing exploding gradients in RNNs.
6. Beyond SGD — Adam and Friends
SGD updates each weight identically regardless of its gradient history. Adaptive optimisers track per-parameter gradient statistics:
Adam converges faster than SGD in most settings and is robust to learning rate choice. For fine-tuning language models AdamW adds proper L2 weight decay decoupled from the gradient step.
7. Minimal Autograd Engine
// Scalar autograd — reverse-mode automatic differentiation
class Value {
constructor(data, _children = [], _op = '') {
this.data = data;
this.grad = 0;
this._backward = () => {};
this._prev = new Set(_children);
this._op = _op;
}
add(other) {
other = other instanceof Value ? other : new Value(other);
const out = new Value(this.data + other.data, [this, other], '+');
out._backward = () => { this.grad += out.grad; other.grad += out.grad; };
return out;
}
mul(other) {
other = other instanceof Value ? other : new Value(other);
const out = new Value(this.data * other.data, [this, other], '*');
out._backward = () => {
this.grad += other.data * out.grad;
other.grad += this.data * out.grad;
};
return out;
}
pow(n) {
const out = new Value(this.data ** n, [this], `**${n}`);
out._backward = () => { this.grad += n * (this.data ** (n - 1)) * out.grad; };
return out;
}
relu() {
const out = new Value(this.data > 0 ? this.data : 0, [this], 'relu');
out._backward = () => { this.grad += (out.data > 0 ? 1 : 0) * out.grad; };
return out;
}
backward() {
// Topological sort, then call _backward in reverse order
const topo = []; const visited = new Set();
const build = v => {
if (!visited.has(v)) {
visited.add(v);
for (const child of v._prev) build(child);
topo.push(v);
}
};
build(this);
this.grad = 1;
for (const v of topo.reverse()) v._backward();
}
}
// Example: L = (x·w + b - y)²
const x = new Value(2.0);
const w = new Value(-3.0);
const b = new Value(6.88);
const y = new Value(1.0);
const L = x.mul(w).add(b).add(y.mul(-1)).pow(2);
L.backward();
console.log(w.grad); // ∂L/∂w ≈ 2*(x*w+b-y)*x
8. Automatic Differentiation Modes
Reverse mode (backprop)
One backward pass computes ∂L/∂θ for ALL parameters simultaneously. Ideal when outputs ≪ inputs — exactly the neural net case.
Forward mode
Computes the Jacobian-vector product Jv — one forward pass per input dimension. Efficient when inputs ≪ outputs (rare in ML).
Symbolic diff
Derives closed-form expressions (Mathematica, SymPy). Exact but can produce exponentially large expressions ("expression swell").
Numerical diff
Finite differences [f(x+h)−f(x)]/h. Simple but slow (one pass per parameter) and subject to floating-point cancellation errors.
PyTorch and JAX use reverse-mode AD with a dynamic computation graph (define-by-run). JAX additionally supports forward-mode AD and function composition operators (jit, vmap, grad) that compose cleanly because of its functional design.