Numerical Stability in Optimization: Making Training Work in Practice

Master the log-sum-exp trick, gradient clipping, mixed precision, and other techniques that prevent numerical disasters during model training.

Calculus & Optimization March 7, 2026 9 min read

The Gap Between Theory and Practice

The optimization algorithms we have covered — gradient descent, SGD, Adam — are mathematically elegant. But computers use finite-precision arithmetic, and this introduces errors that can silently corrupt training or cause spectacular failures.

Understanding numerical stability is the difference between a model that trains and one that produces NaN losses.

Floating-Point Arithmetic

Computers represent real numbers in floating-point format:

x=±m×2ex = \pm m \times 2^e

where mm is the mantissa (significand) and ee is the exponent. The precision is finite:

FormatBitsMantissa bitsRangePrecision
FP64 (double)6452±10308\pm 10^{308}~16 decimal digits
FP32 (float)3223±1038\pm 10^{38}~7 decimal digits
FP16 (half)1610±65504\pm 65504~3 decimal digits
BF16 (bfloat16)167±1038\pm 10^{38}~2 decimal digits

Sources of Error

  • Overflow: Result exceeds the maximum representable value (±\to \pm\infty). Example: e1000e^{1000} in FP32.
  • Underflow: Result is smaller than the minimum representable value (0\to 0). Example: e1000e^{-1000} in FP32.
  • Catastrophic cancellation: Subtracting nearly equal numbers destroys significant digits. Example: (1+1015)1(1 + 10^{-15}) - 1 in FP64 loses 15 digits of precision.

The Log-Sum-Exp Trick

Computing softmax probabilities pk=ezkjezjp_k = \frac{e^{z_k}}{\sum_j e^{z_j}} directly is numerically dangerous. If any zkz_k is large (e.g., 1000), ezke^{z_k} overflows. If all zkz_k are very negative, the denominator underflows to zero.

The log-sum-exp (LSE) trick subtracts the maximum value before exponentiating:

logjezj=c+logjezjc\log \sum_j e^{z_j} = c + \log \sum_j e^{z_j - c}

where c=maxjzjc = \max_j z_j. Since zjc0z_j - c \leq 0, no exponential overflows. And the largest term ezmaxc=e0=1e^{z_{\max} - c} = e^0 = 1, so the sum is at least 1 — no underflow.

def log_sum_exp(z):
    c = z.max()
    return c + np.log(np.sum(np.exp(z - c)))

def softmax(z):
    c = z.max()
    exp_z = np.exp(z - c)
    return exp_z / exp_z.sum()

Key insight: The log-sum-exp trick is not optional — it is required for numerical correctness. Every deep learning framework implements softmax this way internally. Whenever you see logexp\log \sum \exp in a derivation, this trick should be applied in implementation.

Cross-Entropy with LogSoftmax

Computing cross-entropy loss from probabilities introduces another instability: log(pk)\log(p_k) where pkp_k might be very close to zero.

The solution is to compute log-softmax directly:

logsoftmax(zk)=zklogjezj=zkclogjezjc\log \text{softmax}(z_k) = z_k - \log \sum_j e^{z_j} = z_k - c - \log \sum_j e^{z_j - c}

This avoids ever computing pkp_k explicitly. PyTorch’s F.cross_entropy takes raw logits (not probabilities) precisely for this reason.

Gradient Clipping

When gradients become very large (exploding gradients), the parameter update overshoots catastrophically. Gradient clipping caps the gradient magnitude.

Gradient Norm Clipping

The most common approach clips the global gradient norm:

g^={gif gττggif g>τ\hat{\mathbf{g}} = \begin{cases} \mathbf{g} & \text{if } \|\mathbf{g}\| \leq \tau \\ \frac{\tau}{\|\mathbf{g}\|} \mathbf{g} & \text{if } \|\mathbf{g}\| > \tau \end{cases}

This preserves the gradient direction but limits its magnitude to τ\tau.

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Gradient Value Clipping

An alternative clips each gradient element independently to [τ,τ][-\tau, \tau]. This changes the gradient direction, which can be problematic, but is simpler.

When to Clip

  • RNNs/LSTMs: Long sequences amplify gradients through time. Clipping at τ=1.0\tau = 1.0 to 5.05.0 is standard.
  • Transformers: Gradient clipping at τ=1.0\tau = 1.0 is nearly universal in LLM training.
  • GAN training: Clipping stabilizes the discriminator/generator interplay.
  • Very deep networks: Without residual connections, gradients can explode through many layers.

Key insight: Gradient clipping is a safety net, not a solution. If you need aggressive clipping (small τ\tau), something else is likely wrong — initialization, learning rate, or architecture. But mild clipping (τ=1.0\tau = 1.0) is a good default practice that prevents rare catastrophic updates without affecting normal training.

Mixed Precision Training

Mixed precision uses FP16 (or BF16) for most computations and FP32 for critical accumulations, achieving 2-3x speedup on modern GPUs.

The Challenge

FP16 has only 3 decimal digits of precision and a maximum value of 65504. Gradients can easily underflow (small gradients become zero) or weights can overflow.

Loss Scaling

Loss scaling multiplies the loss by a large factor SS before the backward pass. This scales all gradients by SS, lifting small gradients out of the underflow zone. After gradient computation, gradients are divided by SS before the optimizer step.

Forward: Lscaled=SLBackward: scaled=SL(computed in FP16)Update: θθαscaledS(in FP32)\begin{aligned} &\text{Forward: } \mathcal{L}_\text{scaled} = S \cdot \mathcal{L} \\[4pt] &\text{Backward: } \nabla_\text{scaled} = S \cdot \nabla \mathcal{L} \quad \text{(computed in FP16)} \\[4pt] &\text{Update: } \boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \alpha \cdot \frac{\nabla_\text{scaled}}{S} \quad \text{(in FP32)} \end{aligned}

Dynamic loss scaling automatically adjusts SS: increase SS when no overflow is detected, decrease when overflow occurs.

BFloat16 vs Float16

PropertyFP16BF16
Exponent bits58
Mantissa bits107
Range±65504\pm 65504±3.4×1038\pm 3.4 \times 10^{38}
PrecisionHigherLower
Loss scaling needed?YesUsually not

BF16 has the same range as FP32 (8 exponent bits), which eliminates most overflow/underflow issues. It is becoming the preferred format for LLM training.

Numerical Issues in Specific Operations

Log of Small Probabilities

Computing log(1p)\log(1 - p) when p1p \approx 1 suffers from catastrophic cancellation. Use log1p:

# Bad: catastrophic cancellation when p ≈ 1
result = np.log(1 - p)

# Good: numerically stable
result = np.log1p(-p)

Similarly, expm1(x) computes ex1e^x - 1 accurately for small xx.

Sigmoid and Tanh Saturation

The sigmoid σ(x)=1/(1+ex)\sigma(x) = 1/(1 + e^{-x}) saturates for large x|x|:

  • x0x \gg 0: σ(x)1\sigma(x) \to 1, gradient σ(1σ)0\sigma(1-\sigma) \to 0
  • x0x \ll 0: σ(x)0\sigma(x) \to 0, gradient 0\to 0

This is the vanishing gradient problem for sigmoid/tanh networks. The numerical manifestation is that FP16 rounds σ(x)\sigma(x) to exactly 0 or 1 for moderate x|x|, making the gradient exactly zero.

Batch Normalization Numerical Issues

Batch normalization divides by Var+ϵ\sqrt{\text{Var} + \epsilon}. If the batch size is very small, the variance estimate can be near zero, causing division instability. The ϵ\epsilon term (typically 10510^{-5}) prevents this.

Attention Score Overflow

In transformers, attention scores QKT/dk\mathbf{Q}\mathbf{K}^T / \sqrt{d_k} can overflow in FP16 before the softmax. The 1/dk1/\sqrt{d_k} scaling prevents this for typical dimensions, but very large dkd_k or unnormalized queries/keys can still cause issues.

Gradient Checkpointing

Gradient checkpointing (activation recomputation) is not strictly a stability technique but addresses the memory bottleneck of storing all activations for backpropagation.

Instead of storing all intermediate activations during the forward pass, only store activations at selected checkpoints. During the backward pass, recompute intermediate activations from the nearest checkpoint.

Trade-off: reduces memory from O(L)O(L) to O(L)O(\sqrt{L}) for LL layers, at the cost of one extra forward pass (2x compute for a L\sqrt{L} memory reduction).

This enables training much deeper or larger models on limited GPU memory.

Debugging Numerical Issues

Common Symptoms and Causes

SymptomLikely causeFix
NaN lossGradient explosion, log(0), 0/0Gradient clipping, log-sum-exp, check data
Loss stuck at constantSaturated activations, dead ReLUCheck initialization, use LeakyReLU
Loss oscillates wildlyLearning rate too highReduce LR, add warmup
Loss decreases then explodesGradient explosion at specific inputGradient clipping, check for outliers
Very slow convergenceVanishing gradients, tiny LRResidual connections, increase LR
Inf values in weightsUnbounded optimization, no regularizationAdd weight decay, clip gradients

Debugging Checklist

  1. Check for NaN/Inf: torch.isnan(loss).any() — add assertions early
  2. Monitor gradient norms: Log L\|\nabla \mathcal{L}\| per layer to detect explosion/vanishing
  3. Histogram weights and gradients: Look for distribution collapse (all near zero) or explosion
  4. Test with FP64: If training works in FP64 but fails in FP16/FP32, it is a precision issue
  5. Reduce learning rate: Many “mysterious” training failures are simply learning rate too high

Key insight: Most numerical issues in practice are caused by one of three things: (1) overflow in exponentials (fix with log-sum-exp), (2) gradient explosion (fix with clipping), or (3) loss of precision in FP16 (fix with loss scaling or BF16). Knowing these three patterns handles 90% of training failures.

Why This Matters for ML

Numerical stability is what separates a paper’s algorithm from a working implementation:

  • Log-sum-exp is required for any computation involving softmax or log-probabilities — it is non-negotiable
  • Gradient clipping is standard practice for RNNs, transformers, and GANs — without it, training is fragile
  • Mixed precision (FP16/BF16) training doubles throughput and halves memory — essential for modern large models
  • Gradient checkpointing enables training models that would not fit in GPU memory otherwise
  • Understanding floating-point limitations explains many “mysterious” training failures and helps debug them systematically

Summary

  • Computers use finite-precision arithmetic — overflow, underflow, and cancellation are real threats
  • The log-sum-exp trick prevents overflow in softmax and cross-entropy — always use it
  • Gradient clipping (norm or value) prevents catastrophic updates from exploding gradients
  • Mixed precision (FP16/BF16 + FP32) doubles speed with loss scaling for numerical safety
  • BFloat16 matches FP32 range, reducing the need for loss scaling
  • Use log1p, expm1, and fused operations (F.cross_entropy with logits) to avoid cancellation
  • Gradient checkpointing trades compute for memory, enabling larger models
  • Monitor gradient norms and check for NaN/Inf early to catch problems before they cascade
  • Next: non-smooth optimization handles functions that are not differentiable everywhere

References

  • Micikevicius, P., et al. (2018). Mixed Precision Training. ICLR. arXiv:1710.03740
  • Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press. Chapter 8. deeplearningbook.org
  • Higham, N. J. (2002). Accuracy and Stability of Numerical Algorithms (2nd ed.). SIAM.
  • Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv:1604.06174
  • Kalamkar, D., et al. (2019). A Study of BFLOAT16 for Deep Learning Training. arXiv:1905.12322

Keyboard Shortcuts

Navigation
j
Next heading
k
Previous heading
n
Next article in series
p
Previous article in series
t
Scroll to top
Actions
r
Toggle reading mode
Ctrl K
Search articles
?
Toggle this help
Esc
Close overlay