The Chain Rule and Computational Graphs: The Engine Behind Backpropagation

How the chain rule powers backpropagation — from single-variable compositions to computational graphs and automatic differentiation.

Calculus & Optimization March 7, 2026 10 min read

Why the Chain Rule is Everything

A neural network is a composition of functions — layer after layer of linear transformations and nonlinear activations. To train it, we need the derivative of the loss with respect to every parameter, which means differentiating through the entire composition.

The chain rule tells us exactly how to do this. It is the mathematical principle behind backpropagation, the algorithm that makes deep learning possible. Without the chain rule, we could not train networks with more than one layer.

This article builds on partial derivatives and gradients and connects directly to how frameworks like PyTorch and TensorFlow compute gradients.

The Single-Variable Chain Rule

If y=f(g(x))y = f(g(x)), the derivative of the composition is:

dydx=dydududx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}

where u=g(x)u = g(x). In function notation:

[f(g(x))]=f(g(x))g(x)[f(g(x))]' = f'(g(x)) \cdot g'(x)

Intuition: If uu changes at rate g(x)g'(x) with respect to xx, and yy changes at rate f(u)f'(u) with respect to uu, then yy changes at rate f(u)g(x)f'(u) \cdot g'(x) with respect to xx. Rates of change multiply through compositions.

Worked Examples

Example 1: y=ex2y = e^{x^2}. Let u=x2u = x^2, so y=euy = e^u.

dydx=dydududx=eu2x=2xex2\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} = e^u \cdot 2x = 2x \, e^{x^2}

Example 2: y=ln(sinx)y = \ln(\sin x). Let u=sinxu = \sin x, so y=lnuy = \ln u.

dydx=1ucosx=cosxsinx=cotx\frac{dy}{dx} = \frac{1}{u} \cdot \cos x = \frac{\cos x}{\sin x} = \cot x

Extended Chain Rule

For longer compositions y=f(g(h(x)))y = f(g(h(x))), the chain rule extends naturally:

dydx=f(g(h(x)))g(h(x))h(x)\frac{dy}{dx} = f'(g(h(x))) \cdot g'(h(x)) \cdot h'(x)

Each link in the chain contributes a multiplicative factor. A neural network with LL layers is exactly this kind of long composition, and its gradient is a product of LL terms — which is why very deep networks suffer from vanishing or exploding gradients.

The Multivariable Chain Rule

When intermediate quantities depend on multiple inputs, the chain rule involves summation. If ff depends on z1,z2,,zkz_1, z_2, \ldots, z_k, each of which depends on xx, then:

fx=i=1kfzizix\frac{\partial f}{\partial x} = \sum_{i=1}^{k} \frac{\partial f}{\partial z_i} \cdot \frac{\partial z_i}{\partial x}

This summation is the mathematical reason why gradients from different paths through a network add up at shared parameters.

Worked Example

Let f(u,v)=u2+uvf(u, v) = u^2 + uv where u=2x+yu = 2x + y and v=x3yv = x - 3y. Find fx\frac{\partial f}{\partial x}:

fx=fuux+fvvx=(2u+v)2+u1=2(2u+v)+u=5u+2v\begin{aligned} \frac{\partial f}{\partial x} &= \frac{\partial f}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial f}{\partial v} \cdot \frac{\partial v}{\partial x} \\[6pt] &= (2u + v) \cdot 2 + u \cdot 1 \\[6pt] &= 2(2u + v) + u \\[6pt] &= 5u + 2v \end{aligned}

Substituting back: 5(2x+y)+2(x3y)=12xy5(2x + y) + 2(x - 3y) = 12x - y.

Key insight: When a variable influences the output through multiple paths, the total derivative is the sum of contributions from each path. This is the fundamental reason backpropagation works — it systematically accounts for all paths through the computational graph.

Computational Graphs

A computational graph represents a composite function as a directed acyclic graph (DAG). Each node performs a single operation, and edges carry values between operations.

Building a Graph

Consider the loss for a single training example with a linear model:

L=(wx+by)2L = (wx + b - y)^2

The computational graph breaks this into elementary operations:

x, w --> [*] --> p = wx
p, b --> [+] --> q = wx + b
q, y --> [-] --> r = wx + b - y
r    --> [^2] --> L = r^2

Each node stores its output during the forward pass and its local derivative during the backward pass.

Forward Pass

The forward pass evaluates the function from inputs to output. With w=2w = 2, x=3x = 3, b=1b = 1, y=5y = 5:

p=wx=6q=p+b=7r=qy=2L=r2=4\begin{aligned} p &= wx = 6 \\[6pt] q &= p + b = 7 \\[6pt] r &= q - y = 2 \\[6pt] L &= r^2 = 4 \end{aligned}

Backward Pass (Backpropagation)

The backward pass applies the chain rule in reverse, starting from the output and propagating gradients back to each input.

Each node computes: (incoming gradient) ×\times (local derivative).

Step 1 — Start at the output:

LL=1\frac{\partial L}{\partial L} = 1

Step 2 — Through the squaring node (L=r2L = r^2, local derivative 2r2r):

Lr=12r=2(2)=4\frac{\partial L}{\partial r} = 1 \cdot 2r = 2(2) = 4

Step 3 — Through the subtraction node (r=qyr = q - y):

Lq=41=4Ly=4(1)=4\frac{\partial L}{\partial q} = 4 \cdot 1 = 4 \qquad \frac{\partial L}{\partial y} = 4 \cdot (-1) = -4

Step 4 — Through the addition node (q=p+bq = p + b):

Lp=41=4Lb=41=4\frac{\partial L}{\partial p} = 4 \cdot 1 = 4 \qquad \frac{\partial L}{\partial b} = 4 \cdot 1 = 4

Step 5 — Through the multiplication node (p=wxp = wx):

Lw=4x=4(3)=12Lx=4w=4(2)=8\frac{\partial L}{\partial w} = 4 \cdot x = 4(3) = 12 \qquad \frac{\partial L}{\partial x} = 4 \cdot w = 4(2) = 8

We can verify: L=(wx+by)2L = (wx + b - y)^2, so Lw=2(wx+by)x=2(2)(3)=12\frac{\partial L}{\partial w} = 2(wx + b - y) \cdot x = 2(2)(3) = 12. It matches.

Backpropagation Through a Neuron

A single neuron with sigmoid activation computes:

z=wTx+b,a=σ(z),L=(ay)2z = \mathbf{w}^T\mathbf{x} + b, \qquad a = \sigma(z), \qquad L = (a - y)^2

Forward pass (with w=[0.5,0.3]T\mathbf{w} = [0.5, -0.3]^T, x=[1,2]T\mathbf{x} = [1, 2]^T, b=0.1b = 0.1, y=1y = 1):

z=0.5(1)+(0.3)(2)+0.1=0.0a=σ(0)=0.5L=(0.51)2=0.25\begin{aligned} z &= 0.5(1) + (-0.3)(2) + 0.1 = 0.0 \\[6pt] a &= \sigma(0) = 0.5 \\[6pt] L &= (0.5 - 1)^2 = 0.25 \end{aligned}

Backward pass:

La=2(ay)=2(0.51)=1.0Lz=Laσ(z)=(1.0)(0.25)=0.25Lw1=Lzx1=(0.25)(1)=0.25Lw2=Lzx2=(0.25)(2)=0.50Lb=Lz1=0.25\begin{aligned} \frac{\partial L}{\partial a} &= 2(a - y) = 2(0.5 - 1) = -1.0 \\[6pt] \frac{\partial L}{\partial z} &= \frac{\partial L}{\partial a} \cdot \sigma'(z) = (-1.0)(0.25) = -0.25 \\[6pt] \frac{\partial L}{\partial w_1} &= \frac{\partial L}{\partial z} \cdot x_1 = (-0.25)(1) = -0.25 \\[6pt] \frac{\partial L}{\partial w_2} &= \frac{\partial L}{\partial z} \cdot x_2 = (-0.25)(2) = -0.50 \\[6pt] \frac{\partial L}{\partial b} &= \frac{\partial L}{\partial z} \cdot 1 = -0.25 \end{aligned}

The negative gradients tell us to increase all three parameters to reduce the loss — which makes sense because the prediction a=0.5a = 0.5 is below the target y=1y = 1.

Automatic Differentiation

Modern deep learning frameworks do not require manual derivation of gradients. They use automatic differentiation (autodiff), which mechanically applies the chain rule to computational graphs.

Forward Mode vs Reverse Mode

There are two ways to propagate derivatives through a graph:

Forward modeReverse mode
DirectionInput \to outputOutput \to input
ComputesJacobian-vector product Jv\mathbf{J}\mathbf{v}Vector-Jacobian product vTJ\mathbf{v}^T\mathbf{J}
Cost per passOne pass per inputOne pass per output
Efficient whenFew inputs, many outputsMany inputs, few outputs

Key insight: Neural network training has many inputs (millions of parameters) and one output (scalar loss). This makes reverse mode autodiff — which is exactly backpropagation — the efficient choice. One backward pass gives the gradient with respect to all parameters simultaneously.

Forward mode would require a separate pass for each parameter — millions of passes versus one. This asymmetry is why backpropagation was such a breakthrough.

How Frameworks Implement Autodiff

PyTorch and TensorFlow build the computational graph dynamically (PyTorch) or statically (TensorFlow 1.x). Each operation registers:

  1. Its output tensor
  2. A reference to the backward function
  3. References to input tensors

When .backward() is called, the framework traverses the graph in reverse topological order, calling each backward function and accumulating gradients.

Common Gradient Patterns

Certain operations appear so frequently that their local gradients are worth memorizing:

OperationForwardLocal gradient (backward)
Addition: c=a+bc = a + bc=a+bc = a + bca=1,  cb=1\frac{\partial c}{\partial a} = 1, \; \frac{\partial c}{\partial b} = 1
Multiplication: c=abc = a \cdot bc=abc = abca=b,  cb=a\frac{\partial c}{\partial a} = b, \; \frac{\partial c}{\partial b} = a
ReLU: c=max(0,a)c = \max(0, a)c=max(0,a)c = \max(0, a)ca=1[a>0]\frac{\partial c}{\partial a} = \mathbb{1}[a > 0]
Sigmoid: c=σ(a)c = \sigma(a)c=σ(a)c = \sigma(a)ca=c(1c)\frac{\partial c}{\partial a} = c(1-c)
Matrix multiply: C=AB\mathbf{C} = \mathbf{A}\mathbf{B}C=AB\mathbf{C} = \mathbf{AB}LA=LCBT\frac{\partial L}{\partial \mathbf{A}} = \frac{\partial L}{\partial \mathbf{C}}\mathbf{B}^T

Notice that addition distributes the gradient equally, multiplication swaps and scales, and ReLU acts as a gate (passes or blocks the gradient). See matrix calculus for more on matrix-level derivatives.

Vanishing and Exploding Gradients

The chain rule multiplies local gradients together. For a network with LL layers:

LW1=LaL=2Laa1a1W1\frac{\partial L}{\partial \mathbf{W}_1} = \frac{\partial L}{\partial \mathbf{a}_L} \cdot \prod_{\ell=2}^{L} \frac{\partial \mathbf{a}_\ell}{\partial \mathbf{a}_{\ell-1}} \cdot \frac{\partial \mathbf{a}_1}{\partial \mathbf{W}_1}

If each factor aa1<1\|\frac{\partial \mathbf{a}_\ell}{\partial \mathbf{a}_{\ell-1}}\| < 1, the product shrinks exponentially — vanishing gradients. If each factor >1> 1, it grows exponentially — exploding gradients.

Mitigations include:

  • ReLU (gradient is exactly 0 or 1, no shrinkage for active neurons)
  • Residual connections (a=a1+f(a1)\mathbf{a}_\ell = \mathbf{a}_{\ell-1} + f(\mathbf{a}_{\ell-1}), which adds an identity term to the gradient product)
  • Gradient clipping (cap the gradient norm to prevent explosions)
  • Careful initialization (Xavier/He initialization scales weights to preserve gradient magnitude)

Why This Matters for ML

The chain rule and computational graphs are the foundation of modern deep learning:

  • Backpropagation is the chain rule applied to computational graphs — nothing more, nothing less
  • Autodiff in PyTorch/TensorFlow mechanizes this, freeing practitioners from manual gradient derivation
  • Reverse mode is efficient because neural networks map many parameters to one scalar loss
  • Vanishing/exploding gradients are direct consequences of the multiplicative chain rule — understanding this guides architectural choices (ReLU, residual connections, normalization)
  • These gradients feed into gradient descent to update parameters

Summary

  • The chain rule for compositions: dydx=dydududx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} — rates of change multiply
  • The multivariable chain rule sums contributions from all paths: fx=ifzizix\frac{\partial f}{\partial x} = \sum_i \frac{\partial f}{\partial z_i} \frac{\partial z_i}{\partial x}
  • Computational graphs decompose functions into elementary operations, enabling systematic gradient computation
  • Backpropagation = reverse-mode autodiff = chain rule applied backwards through the graph
  • Reverse mode is efficient for many-inputs-to-one-output (the neural network training setting)
  • Vanishing/exploding gradients arise from multiplying many factors in the chain rule
  • Next: Taylor series and approximation show how derivatives yield local function models

References

  • Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press. Chapter 6.5. deeplearningbook.org
  • Griewank, A., & Walther, A. (2008). Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation (2nd ed.). SIAM.
  • Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2018). Automatic Differentiation in Machine Learning: A Survey. JMLR, 18(153), 1-43.
  • Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning Representations by Back-Propagating Errors. Nature, 323, 533-536.

Keyboard Shortcuts

Navigation
j
Next heading
k
Previous heading
n
Next article in series
p
Previous article in series
t
Scroll to top
Actions
r
Toggle reading mode
Ctrl K
Search articles
?
Toggle this help
Esc
Close overlay