Implicit Differentiation and Differentiable Programming

Backpropagate through optimization, fixed points, and ODEs — learn implicit differentiation for meta-learning, hyperparameter tuning, and Neural ODEs.

Calculus & Optimization March 7, 2026 8 min read

Beyond Standard Backpropagation

Standard backpropagation differentiates through a fixed computational graph. But what if part of your computation is an optimization procedure, a fixed-point iteration, or a differential equation?

  • Meta-learning (MAML): Differentiate through an inner training loop to learn good initializations
  • Hyperparameter optimization: Compute val_lossλ\frac{\partial \text{val\_loss}}{\partial \lambda} where λ\lambda is a regularization coefficient
  • Neural ODEs: Define a neural network as a continuous differential equation and differentiate through the ODE solver
  • Differentiable rendering: Backpropagate through a physics simulation

All of these require implicit differentiation — computing gradients without unrolling the full computation.

Implicit Functions and the Implicit Function Theorem

Recall from derivatives: an implicit function is defined by an equation F(x,y)=0F(\mathbf{x}, \mathbf{y}) = \mathbf{0} rather than an explicit formula y=f(x)\mathbf{y} = f(\mathbf{x}).

The Implicit Function Theorem states: if F(x,y)=0F(\mathbf{x}^*, \mathbf{y}^*) = \mathbf{0} and Fy\frac{\partial F}{\partial \mathbf{y}} is invertible at (x,y)(\mathbf{x}^*, \mathbf{y}^*), then there exists a function y(x)\mathbf{y}(\mathbf{x}) near x\mathbf{x}^* satisfying F(x,y(x))=0F(\mathbf{x}, \mathbf{y}(\mathbf{x})) = \mathbf{0}, with derivative:

dydx=(Fy)1Fx\frac{d\mathbf{y}}{d\mathbf{x}} = -\left(\frac{\partial F}{\partial \mathbf{y}}\right)^{-1} \frac{\partial F}{\partial \mathbf{x}}

This formula gives us dydx\frac{d\mathbf{y}}{d\mathbf{x}} without ever computing y(x)\mathbf{y}(\mathbf{x}) explicitly — hence “implicit” differentiation.

Key insight: Implicit differentiation computes the gradient of an output defined by a condition (like “the solution to this optimization problem”) without needing to differentiate through the procedure that found the solution. This decouples the forward computation from the backward pass.

Differentiating Through Optimization

The Setup

Consider an inner optimization problem parameterized by λ\boldsymbol{\lambda}:

θ(λ)=argminθLtrain(θ,λ)\boldsymbol{\theta}^*(\boldsymbol{\lambda}) = \arg\min_{\boldsymbol{\theta}} \mathcal{L}_{\text{train}}(\boldsymbol{\theta}, \boldsymbol{\lambda})

We want the gradient of some outer objective with respect to λ\boldsymbol{\lambda}:

ddλLval(θ(λ),λ)\frac{d}{d\boldsymbol{\lambda}} \mathcal{L}_{\text{val}}(\boldsymbol{\theta}^*(\boldsymbol{\lambda}), \boldsymbol{\lambda})

This arises in hyperparameter optimization (where λ\boldsymbol{\lambda} controls regularization, learning rate, or architecture) and meta-learning (where λ\boldsymbol{\lambda} is the initialization).

Approach 1: Unrolling

Unrolled differentiation treats the TT-step optimization as a computational graph and backpropagates through all TT steps:

θ0θ1θTθ\boldsymbol{\theta}_0 \to \boldsymbol{\theta}_1 \to \cdots \to \boldsymbol{\theta}_T \approx \boldsymbol{\theta}^*

Each step θt+1=θtαθL(θt,λ)\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha \nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}_t, \boldsymbol{\lambda}) is differentiable, so the chain rule gives:

dθTdλ=t=0T1(Iαθ2Lt)(α2Lθλ)\frac{d\boldsymbol{\theta}_T}{d\boldsymbol{\lambda}} = \prod_{t=0}^{T-1} \left(\mathbf{I} - \alpha \nabla^2_{\boldsymbol{\theta}} \mathcal{L}_t\right) \cdot \left(-\alpha \frac{\partial^2 \mathcal{L}}{\partial \boldsymbol{\theta} \partial \boldsymbol{\lambda}}\right)

Pros: Simple, exact gradient for TT steps. Cons: Memory scales as O(T)O(T), suffers from vanishing/exploding gradients for large TT, and the gradient is for TT steps of optimization, not the true optimum.

Approach 2: Implicit Differentiation

At the optimum, the gradient is zero: θLtrain(θ,λ)=0\nabla_{\boldsymbol{\theta}} \mathcal{L}_{\text{train}}(\boldsymbol{\theta}^*, \boldsymbol{\lambda}) = \mathbf{0}.

Define F(θ,λ)=θLtrainF(\boldsymbol{\theta}, \boldsymbol{\lambda}) = \nabla_{\boldsymbol{\theta}} \mathcal{L}_{\text{train}}. By the Implicit Function Theorem:

dθdλ=(θ2Ltrain)12Ltrainθλ\frac{d\boldsymbol{\theta}^*}{d\boldsymbol{\lambda}} = -\left(\nabla^2_{\boldsymbol{\theta}} \mathcal{L}_{\text{train}}\right)^{-1} \frac{\partial^2 \mathcal{L}_{\text{train}}}{\partial \boldsymbol{\theta} \partial \boldsymbol{\lambda}}

The Hessian inverse is expensive, but we only need the matrix-vector product H1v\mathbf{H}^{-1}\mathbf{v}, which can be computed using conjugate gradient (as in Hessian-free methods).

Pros: Constant memory (does not depend on TT), gives the true optimum’s gradient, numerically stable. Cons: Requires convergence of the inner optimization, needs Hessian-vector products.

Key insight: Implicit differentiation is more memory-efficient and numerically stable than unrolling for long inner optimization loops. It computes the gradient of the solution rather than the gradient of the optimization trajectory, which is what we actually want.

MAML: Model-Agnostic Meta-Learning

MAML learns an initialization θ0\boldsymbol{\theta}_0 such that a few gradient steps on a new task produce a good model:

θ0=argminθ0task iLival(θ0αLitrain(θ0))\boldsymbol{\theta}_0^* = \arg\min_{\boldsymbol{\theta}_0} \sum_{\text{task } i} \mathcal{L}_i^{\text{val}}\left(\boldsymbol{\theta}_0 - \alpha \nabla \mathcal{L}_i^{\text{train}}(\boldsymbol{\theta}_0)\right)

The outer gradient requires differentiating through the inner gradient step — a second-order derivative (gradient of a gradient).

First-Order MAML (FOMAML)

To avoid computing second derivatives, FOMAML drops the second-order terms, using only:

dθ1dθ0I(ignoring α2L)\frac{d\boldsymbol{\theta}_1}{d\boldsymbol{\theta}_0} \approx \mathbf{I} \quad \text{(ignoring } {-\alpha \nabla^2 \mathcal{L}} \text{)}

This is surprisingly effective — the first-order approximation works well because the inner loop is short (1-5 steps).

iMAML: Implicit MAML

iMAML uses implicit differentiation instead of unrolling, computing:

dθdθ0=(I+1λ2Ltrain)1\frac{d\boldsymbol{\theta}^*}{d\boldsymbol{\theta}_0} = \left(\mathbf{I} + \frac{1}{\lambda}\nabla^2 \mathcal{L}_{\text{train}}\right)^{-1}

This is more accurate for long inner loops and more memory-efficient.

Fixed-Point Differentiation

Many neural network components compute a fixed point z=f(z,θ)\mathbf{z}^* = f(\mathbf{z}^*, \boldsymbol{\theta}):

  • Equilibrium models (DEQ): Run a transformation to convergence rather than stacking layers
  • Iterative algorithms: Power iteration, message passing, belief propagation
  • Implicit layers: Define the output as a fixed point rather than a feed-forward computation

Differentiating Through Fixed Points

At the fixed point, z=f(z,θ)\mathbf{z}^* = f(\mathbf{z}^*, \boldsymbol{\theta}). By implicit differentiation:

dzdθ=(Ifzz)1fθz\frac{d\mathbf{z}^*}{d\boldsymbol{\theta}} = \left(\mathbf{I} - \frac{\partial f}{\partial \mathbf{z}}\bigg|_{\mathbf{z}^*}\right)^{-1} \frac{\partial f}{\partial \boldsymbol{\theta}}\bigg|_{\mathbf{z}^*}

The matrix inverse can be computed iteratively. The Deep Equilibrium Model (DEQ) uses this to define infinitely deep networks with constant memory — the “depth” is the number of iterations to convergence, but backpropagation only requires solving one linear system.

Key insight: Fixed-point differentiation decouples the forward computation (iterating to convergence) from the backward computation (solving a linear system). This means you can use any iterative method for the forward pass — even a black-box solver — and still get exact gradients.

Neural ODEs

Neural Ordinary Differential Equations define a continuous-depth network as the solution to an ODE:

dh(t)dt=fθ(h(t),t)\frac{d\mathbf{h}(t)}{dt} = f_{\boldsymbol{\theta}}(\mathbf{h}(t), t)

where fθf_{\boldsymbol{\theta}} is a neural network parameterizing the dynamics. The output is h(T)\mathbf{h}(T), obtained by integrating from t=0t = 0 to t=Tt = T using an ODE solver.

The Adjoint Method

Backpropagating through ODE solver steps (unrolling) is memory-intensive. The adjoint method computes gradients by solving a second ODE backward in time:

da(t)dt=a(t)Tfh\frac{d\mathbf{a}(t)}{dt} = -\mathbf{a}(t)^T \frac{\partial f}{\partial \mathbf{h}}

where a(t)=Lh(t)\mathbf{a}(t) = \frac{\partial \mathcal{L}}{\partial \mathbf{h}(t)} is the adjoint state. This has O(1)O(1) memory cost regardless of the number of solver steps — it is implicit differentiation applied to ODEs.

The parameter gradient is:

dLdθ=T0a(t)Tfθdt\frac{d\mathcal{L}}{d\boldsymbol{\theta}} = -\int_T^0 \mathbf{a}(t)^T \frac{\partial f}{\partial \boldsymbol{\theta}} \, dt

Applications of Neural ODEs

  • Continuous normalizing flows: Model complex distributions with continuous transformations
  • Time series: Naturally handle irregularly-sampled data
  • Generative models: Continuous-time generative processes
  • Physics-informed learning: Embed physical laws as ODE/PDE constraints

Differentiable Programming

Differentiable programming is the broader paradigm: make entire programs differentiable so that gradients can flow through any computation.

Examples Beyond Neural Networks

  • Differentiable rendering: Compute image3D geometry\frac{\partial \text{image}}{\partial \text{3D geometry}} for inverse graphics
  • Differentiable physics simulation: Learn physical parameters from observed trajectories
  • Differentiable sorting: Soft approximations to sorting enable gradient-based ranking losses
  • Differentiable architecture search (DARTS): Relax discrete architecture choices to continuous and optimize with gradients

The JAX Ecosystem

JAX makes differentiable programming practical with composable transformations:

import jax
import jax.numpy as jnp

def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# Gradient of loss w.r.t. params
grad_fn = jax.grad(loss)

# Hessian-vector product
hvp_fn = lambda v: jax.jvp(jax.grad(loss), (params, x, y), (v, 0, 0))[1]

# Differentiate through a for loop, scan, or any Python control flow

JAX’s jax.grad, jax.jvp (forward-mode), and jax.vjp (reverse-mode) can be composed arbitrarily, enabling gradients of gradients, implicit differentiation, and more.

Why This Matters for ML

Implicit differentiation and differentiable programming expand what can be trained with gradients:

  • Meta-learning (MAML, iMAML) learns to learn by differentiating through training
  • Hyperparameter optimization tunes regularization, learning rates, and data augmentation via gradients
  • Neural ODEs define continuous-depth networks with constant memory backpropagation
  • Deep Equilibrium Models achieve infinite depth with fixed-point differentiation
  • Differentiable rendering/physics enables learning from visual and physical observations
  • These techniques extend the reach of gradient-based optimization to problems that were previously intractable

Summary

  • The Implicit Function Theorem gives gradients through implicitly defined functions: dydx=(Fy)1Fx\frac{d\mathbf{y}}{d\mathbf{x}} = -(\frac{\partial F}{\partial \mathbf{y}})^{-1}\frac{\partial F}{\partial \mathbf{x}}
  • Differentiating through optimization: implicit differentiation is more stable and memory-efficient than unrolling
  • MAML learns initializations by differentiating through inner gradient steps; iMAML uses implicit differentiation for scalability
  • Fixed-point differentiation enables infinitely deep equilibrium models with constant memory
  • Neural ODEs define continuous-depth networks; the adjoint method computes gradients with O(1)O(1) memory
  • Differentiable programming makes arbitrary computations (rendering, physics, sorting) gradient-compatible
  • Next: min-max optimization tackles the two-player games behind GANs and adversarial training

References

  • Krantz, S. G., & Parks, H. R. (2013). The Implicit Function Theorem: History, Theory, and Applications. Springer.
  • Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML. arXiv:1703.03400
  • Rajeswaran, A., et al. (2019). Meta-Learning with Implicit Gradients. NeurIPS. arXiv:1909.04630
  • Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural Ordinary Differential Equations. NeurIPS. arXiv:1806.07366
  • Bai, S., Kolter, J. Z., & Koltun, V. (2019). Deep Equilibrium Models. NeurIPS. arXiv:1909.01377
  • Blondel, M., et al. (2022). Efficient and Modular Implicit Differentiation. NeurIPS.

Keyboard Shortcuts

Navigation
j
Next heading
k
Previous heading
n
Next article in series
p
Previous article in series
t
Scroll to top
Actions
r
Toggle reading mode
Ctrl K
Search articles
?
Toggle this help
Esc
Close overlay