- 01 Limits and Continuity: The Foundation of Calculus 02 Derivatives and Differentiation: Measuring Rates of Change 03 Partial Derivatives and Gradients: Calculus in Multiple Dimensions 04 The Chain Rule and Computational Graphs: The Engine Behind Backpropagation 05 Taylor Series and Approximation: Local Models of Complex Functions 06 Gradient Descent: The Workhorse of Machine Learning Optimization 07 Stochastic Gradient Descent: Trading Precision for Speed 08 Adaptive Learning Rate Methods: From AdaGrad to Adam 09 Constrained Optimization: Lagrange Multipliers and KKT Conditions 10 Convexity and Convergence Theory: When Optimization Succeeds 11 Integration and Expectation: The Continuous Side of Probability 12 Calculus of Variations: Optimizing Over Functions 13 Second-Order and Natural Gradient Methods 14 Numerical Stability in Optimization: Making Training Work in Practice 15 Non-Smooth Optimization and Proximal Methods 16 Optimization Landscape of Neural Networks: Why Deep Learning Works 17 Implicit Differentiation and Differentiable Programming 18 Min-Max Optimization: Games, GANs, and Adversarial Training
Beyond Standard Backpropagation
Standard backpropagation differentiates through a fixed computational graph. But what if part of your computation is an optimization procedure, a fixed-point iteration, or a differential equation?
- Meta-learning (MAML): Differentiate through an inner training loop to learn good initializations
- Hyperparameter optimization: Compute where is a regularization coefficient
- Neural ODEs: Define a neural network as a continuous differential equation and differentiate through the ODE solver
- Differentiable rendering: Backpropagate through a physics simulation
All of these require implicit differentiation — computing gradients without unrolling the full computation.
Implicit Functions and the Implicit Function Theorem
Recall from derivatives: an implicit function is defined by an equation rather than an explicit formula .
The Implicit Function Theorem states: if and is invertible at , then there exists a function near satisfying , with derivative:
This formula gives us without ever computing explicitly — hence “implicit” differentiation.
Key insight: Implicit differentiation computes the gradient of an output defined by a condition (like “the solution to this optimization problem”) without needing to differentiate through the procedure that found the solution. This decouples the forward computation from the backward pass.
Differentiating Through Optimization
The Setup
Consider an inner optimization problem parameterized by :
We want the gradient of some outer objective with respect to :
This arises in hyperparameter optimization (where controls regularization, learning rate, or architecture) and meta-learning (where is the initialization).
Approach 1: Unrolling
Unrolled differentiation treats the -step optimization as a computational graph and backpropagates through all steps:
Each step is differentiable, so the chain rule gives:
Pros: Simple, exact gradient for steps. Cons: Memory scales as , suffers from vanishing/exploding gradients for large , and the gradient is for steps of optimization, not the true optimum.
Approach 2: Implicit Differentiation
At the optimum, the gradient is zero: .
Define . By the Implicit Function Theorem:
The Hessian inverse is expensive, but we only need the matrix-vector product , which can be computed using conjugate gradient (as in Hessian-free methods).
Pros: Constant memory (does not depend on ), gives the true optimum’s gradient, numerically stable. Cons: Requires convergence of the inner optimization, needs Hessian-vector products.
Key insight: Implicit differentiation is more memory-efficient and numerically stable than unrolling for long inner optimization loops. It computes the gradient of the solution rather than the gradient of the optimization trajectory, which is what we actually want.
MAML: Model-Agnostic Meta-Learning
MAML learns an initialization such that a few gradient steps on a new task produce a good model:
The outer gradient requires differentiating through the inner gradient step — a second-order derivative (gradient of a gradient).
First-Order MAML (FOMAML)
To avoid computing second derivatives, FOMAML drops the second-order terms, using only:
This is surprisingly effective — the first-order approximation works well because the inner loop is short (1-5 steps).
iMAML: Implicit MAML
iMAML uses implicit differentiation instead of unrolling, computing:
This is more accurate for long inner loops and more memory-efficient.
Fixed-Point Differentiation
Many neural network components compute a fixed point :
- Equilibrium models (DEQ): Run a transformation to convergence rather than stacking layers
- Iterative algorithms: Power iteration, message passing, belief propagation
- Implicit layers: Define the output as a fixed point rather than a feed-forward computation
Differentiating Through Fixed Points
At the fixed point, . By implicit differentiation:
The matrix inverse can be computed iteratively. The Deep Equilibrium Model (DEQ) uses this to define infinitely deep networks with constant memory — the “depth” is the number of iterations to convergence, but backpropagation only requires solving one linear system.
Key insight: Fixed-point differentiation decouples the forward computation (iterating to convergence) from the backward computation (solving a linear system). This means you can use any iterative method for the forward pass — even a black-box solver — and still get exact gradients.
Neural ODEs
Neural Ordinary Differential Equations define a continuous-depth network as the solution to an ODE:
where is a neural network parameterizing the dynamics. The output is , obtained by integrating from to using an ODE solver.
The Adjoint Method
Backpropagating through ODE solver steps (unrolling) is memory-intensive. The adjoint method computes gradients by solving a second ODE backward in time:
where is the adjoint state. This has memory cost regardless of the number of solver steps — it is implicit differentiation applied to ODEs.
The parameter gradient is:
Applications of Neural ODEs
- Continuous normalizing flows: Model complex distributions with continuous transformations
- Time series: Naturally handle irregularly-sampled data
- Generative models: Continuous-time generative processes
- Physics-informed learning: Embed physical laws as ODE/PDE constraints
Differentiable Programming
Differentiable programming is the broader paradigm: make entire programs differentiable so that gradients can flow through any computation.
Examples Beyond Neural Networks
- Differentiable rendering: Compute for inverse graphics
- Differentiable physics simulation: Learn physical parameters from observed trajectories
- Differentiable sorting: Soft approximations to sorting enable gradient-based ranking losses
- Differentiable architecture search (DARTS): Relax discrete architecture choices to continuous and optimize with gradients
The JAX Ecosystem
JAX makes differentiable programming practical with composable transformations:
import jax
import jax.numpy as jnp
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
# Gradient of loss w.r.t. params
grad_fn = jax.grad(loss)
# Hessian-vector product
hvp_fn = lambda v: jax.jvp(jax.grad(loss), (params, x, y), (v, 0, 0))[1]
# Differentiate through a for loop, scan, or any Python control flow
JAX’s jax.grad, jax.jvp (forward-mode), and jax.vjp (reverse-mode) can be composed arbitrarily, enabling gradients of gradients, implicit differentiation, and more.
Why This Matters for ML
Implicit differentiation and differentiable programming expand what can be trained with gradients:
- Meta-learning (MAML, iMAML) learns to learn by differentiating through training
- Hyperparameter optimization tunes regularization, learning rates, and data augmentation via gradients
- Neural ODEs define continuous-depth networks with constant memory backpropagation
- Deep Equilibrium Models achieve infinite depth with fixed-point differentiation
- Differentiable rendering/physics enables learning from visual and physical observations
- These techniques extend the reach of gradient-based optimization to problems that were previously intractable
Summary
- The Implicit Function Theorem gives gradients through implicitly defined functions:
- Differentiating through optimization: implicit differentiation is more stable and memory-efficient than unrolling
- MAML learns initializations by differentiating through inner gradient steps; iMAML uses implicit differentiation for scalability
- Fixed-point differentiation enables infinitely deep equilibrium models with constant memory
- Neural ODEs define continuous-depth networks; the adjoint method computes gradients with memory
- Differentiable programming makes arbitrary computations (rendering, physics, sorting) gradient-compatible
- Next: min-max optimization tackles the two-player games behind GANs and adversarial training
References
- Krantz, S. G., & Parks, H. R. (2013). The Implicit Function Theorem: History, Theory, and Applications. Springer.
- Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML. arXiv:1703.03400
- Rajeswaran, A., et al. (2019). Meta-Learning with Implicit Gradients. NeurIPS. arXiv:1909.04630
- Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural Ordinary Differential Equations. NeurIPS. arXiv:1806.07366
- Bai, S., Kolter, J. Z., & Koltun, V. (2019). Deep Equilibrium Models. NeurIPS. arXiv:1909.01377
- Blondel, M., et al. (2022). Efficient and Modular Implicit Differentiation. NeurIPS.