NotesMath for LLMs

Chain Rule and Backpropagation

Multivariate Calculus / Chain Rule and Backpropagation

Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Notes

"Backpropagation is an algorithm for computing gradients efficiently in a computational graph. At its heart, it is nothing more than the chain rule of calculus applied repeatedly."

  • Goodfellow, Bengio & Courville, Deep Learning (2016)

Overview

The chain rule is the single most important theorem in applied mathematics for machine learning. Every time a neural network is trained - whether a two-layer MLP or a 70-billion parameter language model - the update step depends on computing gradients of a scalar loss with respect to millions or billions of parameters. That computation is backpropagation, and backpropagation is nothing more than the multivariate chain rule applied systematically to a computational graph.

This section develops the chain rule from first principles and then derives backpropagation rigorously. We prove every formula from scratch: the general Jacobian chain rule, the VJP (vector-Jacobian product) form that makes backprop efficient, the recurrence relation for the error signal δ[l]\boldsymbol{\delta}^{[l]}, and the gradient formulae for every layer type that appears in modern transformers. We also analyse the pathologies - vanishing and exploding gradients - with mathematical precision, and derive the interventions that cure them: careful initialisation, residual connections, and normalisation layers.

The connection to automatic differentiation (05) is previewed but not developed here: this section establishes what is computed (the chain rule gradient) while 05 establishes how it is computed mechanically by an AD engine.

Prerequisites

Companion Notebooks

NotebookDescription
theory.ipynbChain rule verification, backprop from scratch, all layer gradients, vanishing/exploding gradient simulation, checkpointing
exercises.ipynb10 graded exercises: chain rule through LoRA backward pass and BPTT

Learning Objectives

After completing this section, you will be able to:

  1. State and prove the general chain rule Jfg(x)=Jf(g(x))Jg(x)J_{f\circ g}(\mathbf{x}) = J_f(g(\mathbf{x}))\cdot J_g(\mathbf{x}) via the Frchet derivative
  2. Explain why the VJP form xL=Jg(x)g(x)L\nabla_\mathbf{x}\mathcal{L} = J_g(\mathbf{x})^\top\nabla_{g(\mathbf{x})}\mathcal{L} is the fundamental equation of backpropagation
  3. Define a computational graph, perform forward and backward passes, handle gradient accumulation at branches
  4. Derive the backprop recurrence δ[l]=(W[l+1])δ[l+1]σ(z[l])\boldsymbol{\delta}^{[l]} = (W^{[l+1]})^\top\boldsymbol{\delta}^{[l+1]}\odot\sigma'(\mathbf{z}^{[l]}) from first principles
  5. Derive gradients for linear layers, activations, softmax+CE (fused), LayerNorm, and scaled dot-product attention
  6. Analyse vanishing/exploding gradients and derive Xavier/He initialisation from variance propagation
  7. Explain residual connections as gradient highways via Jblock=I+JFJ_{\text{block}} = I + J_F
  8. Implement gradient checkpointing and explain the O(L)O(\sqrt{L}) memory / O(L)O(L) extra compute tradeoff
  9. Differentiate through discrete operations using the straight-through estimator and REINFORCE
  10. Trace the full backward pass through a transformer layer (attention + MLP + LayerNorm + residuals)

Table of Contents


1. Intuition

1.1 From Single-Variable to Multivariate Chain Rule

The single-variable chain rule says: if y=f(u)y = f(u) and u=g(x)u = g(x), then

dydx=dydududx\frac{dy}{dx} = \frac{dy}{du}\cdot\frac{du}{dx}

The intuition is rates of change compose multiplicatively. If gg triples its input and ff doubles its input, then fgf \circ g multiplies by six.

The multivariate generalisation replaces scalars with vectors and scalar derivatives with Jacobian matrices. If y=f(u)\mathbf{y} = f(\mathbf{u}) and u=g(x)\mathbf{u} = g(\mathbf{x}), then

Jfg(x)=Jf(g(x))Jg(x)J_{f\circ g}(\mathbf{x}) = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x})

The product is now matrix multiplication. This is not a different rule - it is the same rule, stated in the correct language for vector-valued functions. The single-variable rule is the special case n=m=p=1n = m = p = 1 where Jacobians degenerate to scalars.

SCALAR CHAIN RULE vs JACOBIAN CHAIN RULE


  Scalar:    x -> g -> u -> f -> y
             R    R    R    R    R
             dy/dx = (dy/du)(du/dx)   [scalar multiplication]

  Vector:    x -> g -> u -> f -> y
             R   R   R   R   R
             J_{fog} = J_f * J_g     [matrix multiplication]
             (mxp) = (mxn) * (nxp)

  The dimensions work out exactly like matrix multiplication.
  The chain rule IS matrix multiplication for Jacobians.


What makes the multivariate version non-trivial is that JfJ_f must be evaluated at g(x)g(\mathbf{x}) - the output of the inner function - not at x\mathbf{x} itself. This point-dependence is where the local linear approximation lives: the Jacobian Jf(g(x))J_f(g(\mathbf{x})) is the best linear approximation to ff at the specific point g(x)g(\mathbf{x}), and Jg(x)J_g(\mathbf{x}) is the best linear approximation to gg at x\mathbf{x}.

1.2 Backpropagation as Iterated Chain Rule

A deep neural network is a long composition of functions:

L=(fL(f2(f1(x))))\mathcal{L} = \ell\bigl(f_L(\cdots f_2(f_1(\mathbf{x}))\cdots)\bigr)

where each flf_l is a layer (linear + activation), \ell is the loss function, and L\mathcal{L} is a scalar. Computing w[l]L\nabla_{\mathbf{w}^{[l]}}\mathcal{L} - the gradient of the loss with respect to layer ll's parameters - requires applying the chain rule through every layer from ll to LL.

The chain rule gives:

w[l]L=Jfl,w(a[l1])δ[l]\nabla_{\mathbf{w}^{[l]}}\mathcal{L} = J_{f_l,\mathbf{w}}(\mathbf{a}^{[l-1]})^\top \cdot \boldsymbol{\delta}^{[l]}

where δ[l]=z[l]L\boldsymbol{\delta}^{[l]} = \nabla_{\mathbf{z}^{[l]}}\mathcal{L} is the error signal at layer ll, and it satisfies the backpropagation recurrence:

δ[l]=Jfl+1,a(z[l])δ[l+1]\boldsymbol{\delta}^{[l]} = J_{f_{l+1},\mathbf{a}}(\mathbf{z}^{[l]})^\top \cdot \boldsymbol{\delta}^{[l+1]}

This recurrence propagates the error signal backward from layer LL to layer 1 - hence "backpropagation." At each step, we multiply by the transposed Jacobian of the next layer. The entire algorithm is:

  1. Forward pass: compute and store z[l]\mathbf{z}^{[l]}, a[l]\mathbf{a}^{[l]} for l=1,,Ll = 1, \ldots, L
  2. Initialise: δ[L]=z[L]L\boldsymbol{\delta}^{[L]} = \nabla_{\mathbf{z}^{[L]}}\mathcal{L}
  3. Backward pass: compute δ[l]\boldsymbol{\delta}^{[l]} for l=L1,,1l = L-1, \ldots, 1 using the recurrence
  4. Gradients: extract W[l]L\nabla_{W^{[l]}}\mathcal{L} from δ[l]\boldsymbol{\delta}^{[l]} and a[l1]\mathbf{a}^{[l-1]}

Backpropagation is not a fundamentally different concept from the chain rule. It is the chain rule, applied efficiently by sharing intermediate computations (the error signals δ[l]\boldsymbol{\delta}^{[l]}) across all parameters in a layer.

1.3 Historical Context

YearContributorDevelopment
1676LeibnizDifferential calculus; first statement of the single-variable chain rule
1755EulerExtended to multiple variables
1960KelleyGradient computation for optimal control (independent discovery of backprop concept)
1970LinnainmaaFirst complete description of reverse-mode automatic differentiation for computing gradients
1974WerbosFirst application to neural networks in his PhD thesis
1986Rumelhart, Hinton, WilliamsPopularised backpropagation in "Learning representations by back-propagating errors" - the paper that launched the neural network revolution
1989LeCunApplied backprop to convolutional networks for handwritten digit recognition
2012Krizhevsky, Sutskever, HintonAlexNet demonstrated GPU-accelerated backprop at scale - kicked off the deep learning era
2015Google Brain, Facebook AIPyTorch/TensorFlow: automatic differentiation engines that compute backprop automatically
2017Vaswani et al.Transformer: backprop through multi-head attention; the architecture underlying GPT, BERT, LLaMA
2021Hu et al. (LoRA)Parameter-efficient fine-tuning by limiting gradient flow to low-rank subspaces
2022Dao et al. (FlashAttention)Recompute activations during backward to avoid materialising the N×NN\times N attention matrix

1.4 Why Backprop Defines Modern AI

Every large language model, image classifier, and diffusion model trained today relies on backpropagation for every gradient update. The scale is staggering: training GPT-4 reportedly required 1025\sim 10^{25} floating-point operations, the vast majority of which are forward and backward passes through the transformer network.

Backprop enables gradient-based learning at scale because its cost is proportional to the cost of the forward pass - typically O(nFLOPs(f))O(n \cdot \text{FLOPs}(f)) where nn is the number of parameters. Alternative approaches (finite differences, evolution strategies, zeroth-order methods) are orders of magnitude more expensive.

Three properties make backprop indispensable:

  1. Efficiency: One backward pass computes θL\nabla_{\boldsymbol{\theta}}\mathcal{L} for all θ1010|\boldsymbol{\theta}| \sim 10^{10} parameters simultaneously. Finite differences would need 101010^{10} forward passes.

  2. Exactness: Unlike finite differences, backprop computes the exact gradient (up to floating-point precision), not an approximation.

  3. Composability: Any differentiable function composed of differentiable primitives has an automatically computable gradient. This is why PyTorch/JAX can differentiate arbitrary Python code that uses differentiable operations.

For AI in 2026: The gradient is the workhorse of every training algorithm: SGD, Adam, AdaGrad, Muon, SOAP - all are gradient-based. Fine-tuning (LoRA, QLoRA, DoRA), RLHF (PPO, DPO, GRPO), distillation, and continual learning all depend on backprop. Even methods that appear gradient-free (evolutionary strategies, black-box optimisation) are often used because they approximate the gradient in settings where backprop is unavailable (non-differentiable objectives, external APIs).


2. The Multivariate Chain Rule - Full Theory

2.1 The General Chain Rule - Proof

We prove the chain rule using the Frchet derivative from 02. Recall:

Definition. f:URnRmf: U \subseteq \mathbb{R}^n \to \mathbb{R}^m is Frchet differentiable at x\mathbf{x} if there exists a linear map Lx:RnRmL_\mathbf{x}: \mathbb{R}^n \to \mathbb{R}^m such that

limδ0f(x+δ)f(x)Lxδδ=0\lim_{\|\boldsymbol{\delta}\|\to 0}\frac{\|f(\mathbf{x}+\boldsymbol{\delta}) - f(\mathbf{x}) - L_\mathbf{x}\boldsymbol{\delta}\|}{\|\boldsymbol{\delta}\|} = 0

The matrix of LxL_\mathbf{x} is the Jacobian Jf(x)J_f(\mathbf{x}).

Theorem (Chain Rule). Let g:RpRng: \mathbb{R}^p \to \mathbb{R}^n be Frchet differentiable at x\mathbf{x}, and f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m be Frchet differentiable at g(x)g(\mathbf{x}). Then h=fg:RpRmh = f \circ g: \mathbb{R}^p \to \mathbb{R}^m is Frchet differentiable at x\mathbf{x} and

Jh(x)=Jf(g(x))Jg(x)J_h(\mathbf{x}) = J_f(g(\mathbf{x}))\cdot J_g(\mathbf{x})

Proof. Let u0=g(x)\mathbf{u}_0 = g(\mathbf{x}). We need to show that Jf(u0)Jg(x)J_f(\mathbf{u}_0)J_g(\mathbf{x}) is the Frchet derivative of hh at x\mathbf{x}. Write:

h(x+δ)h(x)=f(g(x+δ))f(g(x))h(\mathbf{x}+\boldsymbol{\delta}) - h(\mathbf{x}) = f(g(\mathbf{x}+\boldsymbol{\delta})) - f(g(\mathbf{x}))

Let η=g(x+δ)g(x)\boldsymbol{\eta} = g(\mathbf{x}+\boldsymbol{\delta}) - g(\mathbf{x}). Since gg is Frchet differentiable:

η=Jg(x)δ+rg(δ),rg(δ)δ0 as δ0\boldsymbol{\eta} = J_g(\mathbf{x})\boldsymbol{\delta} + \mathbf{r}_g(\boldsymbol{\delta}), \quad \frac{\|\mathbf{r}_g(\boldsymbol{\delta})\|}{\|\boldsymbol{\delta}\|} \to 0 \text{ as } \boldsymbol{\delta} \to \mathbf{0}

Now apply Frchet differentiability of ff at u0\mathbf{u}_0:

f(u0+η)f(u0)=Jf(u0)η+rf(η),rf(η)η0 as η0f(\mathbf{u}_0 + \boldsymbol{\eta}) - f(\mathbf{u}_0) = J_f(\mathbf{u}_0)\boldsymbol{\eta} + \mathbf{r}_f(\boldsymbol{\eta}), \quad \frac{\|\mathbf{r}_f(\boldsymbol{\eta})\|}{\|\boldsymbol{\eta}\|} \to 0 \text{ as } \boldsymbol{\eta} \to \mathbf{0}

Substituting:

h(x+δ)h(x)=Jf(u0)(Jg(x)δ+rg(δ))+rf(η)h(\mathbf{x}+\boldsymbol{\delta}) - h(\mathbf{x}) = J_f(\mathbf{u}_0)(J_g(\mathbf{x})\boldsymbol{\delta} + \mathbf{r}_g(\boldsymbol{\delta})) + \mathbf{r}_f(\boldsymbol{\eta}) =Jf(u0)Jg(x)δ+Jf(u0)rg(δ)+rf(η)remainder= J_f(\mathbf{u}_0)J_g(\mathbf{x})\boldsymbol{\delta} + \underbrace{J_f(\mathbf{u}_0)\mathbf{r}_g(\boldsymbol{\delta}) + \mathbf{r}_f(\boldsymbol{\eta})}_{\text{remainder}}

We show the remainder is o(δ)o(\|\boldsymbol{\delta}\|):

  • First term: Jf(u0)rg(δ)Jf(u0)rg(δ)=o(δ)\|J_f(\mathbf{u}_0)\mathbf{r}_g(\boldsymbol{\delta})\| \leq \|J_f(\mathbf{u}_0)\|\cdot\|\mathbf{r}_g(\boldsymbol{\delta})\| = o(\|\boldsymbol{\delta}\|).
  • Second term: Since ηJgδ+rg=O(δ)\|\boldsymbol{\eta}\| \leq \|J_g\|\|\boldsymbol{\delta}\| + \|\mathbf{r}_g\| = O(\|\boldsymbol{\delta}\|), we have rf(η)=o(η)=o(δ)\|\mathbf{r}_f(\boldsymbol{\eta})\| = o(\|\boldsymbol{\eta}\|) = o(\|\boldsymbol{\delta}\|).

Therefore Jh(x)=Jf(g(x))Jg(x)J_h(\mathbf{x}) = J_f(g(\mathbf{x}))J_g(\mathbf{x}). \square

When the chain rule fails. The chain rule requires both gg at x\mathbf{x} and ff at g(x)g(\mathbf{x}) to be Frchet differentiable. If either fails - for example at a ReLU kink where z[l]=0\mathbf{z}^{[l]} = 0 - the classical chain rule does not apply. In practice, these measure-zero sets are handled by choosing a subgradient (any element of the Clarke subdifferential), which is what deep learning frameworks do automatically.

2.2 Three Cases in Increasing Generality

Case 1: Scalar composition h:RRh: \mathbb{R} \to \mathbb{R}. h(x)=f(g(x))h(x) = f(g(x)) where f,g:RRf, g: \mathbb{R} \to \mathbb{R}. Jacobians are 1×11\times 1 = scalars, so h(x)=f(g(x))g(x)h'(x) = f'(g(x)) \cdot g'(x).

Case 2: Scalar loss of a vector function. h:RnRh: \mathbb{R}^n \to \mathbb{R}, h(x)=f(g(x))h(\mathbf{x}) = f(g(\mathbf{x})) where g:RnRmg: \mathbb{R}^n \to \mathbb{R}^m and f:RmRf: \mathbb{R}^m \to \mathbb{R}. Jacobians: JgRm×nJ_g \in \mathbb{R}^{m\times n} and JfR1×mJ_f \in \mathbb{R}^{1\times m} (a row vector). So:

Jh=Jf(g(x))Jg(x)R1×nJ_h = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x}) \in \mathbb{R}^{1\times n}

Taking the transpose: xh=Jg(x)g(x)f\nabla_\mathbf{x} h = J_g(\mathbf{x})^\top \nabla_{g(\mathbf{x})} f - the gradient of hh with respect to x\mathbf{x} is the transposed Jacobian of gg times the gradient of ff. This is the VJP equation, the core of backprop.

Case 3: Vector composition h:RpRmh: \mathbb{R}^p \to \mathbb{R}^m. The most general case; Jacobians are full matrices and the chain rule is full matrix multiplication:

Jh(x)=Jf(g(x))Jg(x)Rm×pJ_h(\mathbf{x}) = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x}) \in \mathbb{R}^{m\times p}

The dimensions verify: (m×p)=(m×n)(n×p)(m\times p) = (m\times n)(n\times p). The "inner dimension" nn (the dimension of the intermediate space Rn\mathbb{R}^n) cancels in the product, exactly as in matrix multiplication.

2.3 The VJP Form - Foundation of Backprop

Definition (VJP). For g:RnRmg: \mathbb{R}^n \to \mathbb{R}^m and a "cotangent" vector uRm\mathbf{u} \in \mathbb{R}^m, the vector-Jacobian product is:

VJP(g,x,u)=Jg(x)uRn\text{VJP}(g, \mathbf{x}, \mathbf{u}) = J_g(\mathbf{x})^\top \mathbf{u} \in \mathbb{R}^n

Why this is the right primitive for backprop. For a scalar loss L:RmR\mathcal{L}: \mathbb{R}^m \to \mathbb{R} composed with g:RnRmg: \mathbb{R}^n \to \mathbb{R}^m:

x(Lg)=Jg(x)g(x)L=VJP(g,x,g(x)L)\nabla_\mathbf{x}(\mathcal{L} \circ g) = J_g(\mathbf{x})^\top \nabla_{g(\mathbf{x})}\mathcal{L} = \text{VJP}(g, \mathbf{x}, \nabla_{g(\mathbf{x})}\mathcal{L})

The gradient of the composed function with respect to the input is the VJP of the inner function, with the cotangent being the gradient of the outer function.

The backprop recursion is a chain of VJPs. For L=fLf1\mathcal{L} = \ell \circ f_L \circ \cdots \circ f_1:

a[l]L=VJP(fl+1,a[l],a[l+1]L)=Jfl+1(a[l])a[l+1]L\nabla_{\mathbf{a}^{[l]}}\mathcal{L} = \text{VJP}(f_{l+1}, \mathbf{a}^{[l]}, \nabla_{\mathbf{a}^{[l+1]}}\mathcal{L}) = J_{f_{l+1}}(\mathbf{a}^{[l]})^\top \nabla_{\mathbf{a}^{[l+1]}}\mathcal{L}

Starting from a[L]L\nabla_{\mathbf{a}^{[L]}}\mathcal{L} and applying VJPs from right to left computes all intermediate gradients.

Cost comparison. Computing xL\nabla_\mathbf{x}\mathcal{L} for L:RnR\mathcal{L}: \mathbb{R}^n \to \mathbb{R}:

  • JVP (forward mode): requires nn passes (one per input dimension). Cost: n×O(FLOPs(g))n \times O(\text{FLOPs}(g)).
  • VJP (reverse mode): requires 11 pass. Cost: O(FLOPs(g))O(\text{FLOPs}(g)).

For n1010n \sim 10^{10} parameters, reverse mode is 101010^{10} times cheaper. This asymmetry is why all gradient-based deep learning uses reverse mode (backprop).

2.4 Long Chains and Telescoping Products

For a depth-LL network h=fLf1h = f_L \circ \cdots \circ f_1, the Jacobian of hh with respect to x\mathbf{x} is:

Jh(x)=JfL(a[L1])JfL1(a[L2])Jf1(x)J_h(\mathbf{x}) = J_{f_L}(\mathbf{a}^{[L-1]}) \cdot J_{f_{L-1}}(\mathbf{a}^{[L-2]}) \cdots J_{f_1}(\mathbf{x})

This is a product of LL matrices. The spectral norm of the product satisfies:

Jh2l=1LJfl2\|J_h\|_2 \leq \prod_{l=1}^L \|J_{f_l}\|_2

If each Jfl2=ρ\|J_{f_l}\|_2 = \rho, then Jh2ρL\|J_h\|_2 \leq \rho^L. For ρ<1\rho < 1, the gradient vanishes exponentially; for ρ>1\rho > 1, it explodes. This is the mathematical source of the vanishing/exploding gradient problem (6).

Efficient computation: reverse order. In the forward direction, we compute a[1],,a[L]\mathbf{a}^{[1]}, \ldots, \mathbf{a}^{[L]} left to right. In the backward direction, we compute the error signals δ[L],δ[L1],,δ[1]\boldsymbol{\delta}^{[L]}, \boldsymbol{\delta}^{[L-1]}, \ldots, \boldsymbol{\delta}^{[1]} right to left, reusing stored activations. The key observation: at step ll, we only need δ[l+1]\boldsymbol{\delta}^{[l+1]} and the stored activation a[l]\mathbf{a}^{[l]} (or z[l]\mathbf{z}^{[l]}) - we do not need to recompute from scratch.

2.5 Differentiating Through Discrete Operations

Some operations in neural networks are discontinuous or discrete: argmax (in beam search), rounding/quantisation (in QAT), sampling (in VAEs and RL). The chain rule does not directly apply.

Straight-Through Estimator (STE). For a quantisation function q(x)=xq(x) = \lfloor x \rceil (round to nearest integer), the derivative is q(x)=0q'(x) = 0 almost everywhere, giving zero gradient. The STE replaces the "true" zero gradient with 1 during the backward pass:

LxLq(x)(pass gradient through as if q were the identity)\frac{\partial \mathcal{L}}{\partial x} \approx \frac{\partial \mathcal{L}}{\partial q(x)} \quad \text{(pass gradient through as if $q$ were the identity)}

In code: y = round(x).detach() + x - x.detach() - adds xx in the forward pass (cancels) but contributes its gradient in the backward pass. STE is used in VQ-VAE, binary neural networks, and quantisation-aware training.

REINFORCE (score function estimator). For a stochastic node ypθ(yx)y \sim p_\theta(y|x) and loss L(y)\mathcal{L}(y), the gradient of E[L(y)]\mathbb{E}[\mathcal{L}(y)] with respect to θ\theta is:

θEypθ[L(y)]=Eypθ[L(y)θlogpθ(yx)]\nabla_\theta \mathbb{E}_{y \sim p_\theta}[\mathcal{L}(y)] = \mathbb{E}_{y \sim p_\theta}[\mathcal{L}(y) \nabla_\theta \log p_\theta(y|x)]

This allows gradient estimation without differentiating through the sampling step. Used in RLHF (PPO, GRPO) and variational inference. High variance; mitigated by baselines.


3. Computation Graphs

3.1 Formal DAG Definition

A computation graph is a directed acyclic graph G=(V,E)G = (V, E) encoding how scalar or tensor quantities depend on one another.

Nodes VV partition into three types:

TypeSymbolRole
Input nodesv1,,vnv_1,\ldots,v_nHold model inputs and parameters; no incoming edges
Intermediate nodesvn+1,,vN1v_{n+1},\ldots,v_{N-1}Hold computed activations; receive edges from their operands
Output nodevNv_NHolds the scalar loss L\mathcal{L}; required to be scalar for standard backprop

Edges (u,v)E(u, v) \in E encode data dependency: v=ϕ(u1,,uk)v = \phi(u_1, \ldots, u_k) for some primitive ϕ\phi. Each edge carries an implicit local Jacobian v/ui\partial v / \partial u_i.

Primitive operations are the atomic building blocks with known local gradients:

PRIMITIVE OPERATIONS AND THEIR LOCAL GRADIENTS


  Operation          Forward           Local gradient (wrt input i)
  
  z = x + y          z = x + y         partialz/partialx = 1,  partialz/partialy = 1
  z = x * y          z = xy            partialz/partialx = y,  partialz/partialy = x
  z = exp(x)         z = e            partialz/partialx = e
  z = log(x)         z = ln x          partialz/partialx = 1/x
  z = relu(x)        z = max(0,x)      partialz/partialx = [x>0]  (a.e.)
  z = W x + b        Wx+b              partialz/partialW = x (as outer),  partialz/partialx = W
  z = softmax(x)     e/Sigmae         diag(p) - pp  (see 02)
  

  Every deep learning framework maintains a lookup table of these
  primitives together with their vjp implementations.


Topological ordering - a linear ordering π\pi of VV such that for every edge (u,v)E(u,v) \in E, uu appears before vv in π\pi. Topological order exists iff GG is acyclic (Kahn's algorithm, 1962). Both the forward pass and the backward pass respect topological order (the latter in reverse).

For AI: Every modern deep learning framework (PyTorch, JAX, TensorFlow) represents a neural network as a computation graph. PyTorch builds the graph dynamically during the forward pass via the autograd tape; JAX traces the graph statically via XLA compilation.

3.2 Forward Pass - Value Propagation

The forward pass evaluates all node values in topological order, caching intermediates required by the backward pass.

Algorithm (Forward Pass):

Input:  graph G = (V, E),  input values {x_1,...,x}
Output: loss value v_N,    cache of intermediates

  For v in topological_order(G):
    if v is an input node:
      cache[v] = x_v  (given)
    else:
      cache[v] = phi_v(cache[u_1], ..., cache[u])
                 where u_1,...,u = parents(v)

  return cache[v_N]

Memory cost of a naive forward pass: Caching all intermediates for backprop costs O(N)O(N) memory where NN is the number of nodes. For a transformer with L=96L=96 layers and activations of size (B,T,d)(B, T, d), this is approximately:

Memory=LBTdsizeof(float16)96×32×4096×8192×2200GB\text{Memory} = L \cdot B \cdot T \cdot d \cdot \text{sizeof(float16)} \approx 96 \times 32 \times 4096 \times 8192 \times 2 \approx 200\,\text{GB}

This is why gradient checkpointing (7) is essential for large models.

What gets cached? A memory-optimal forward pass only caches values that appear in at least one local gradient formula. For a linear layer z=Wx+b\mathbf{z} = W\mathbf{x} + \mathbf{b}, the backward needs x\mathbf{x} (to compute W\nabla_W) but not z\mathbf{z} (already accumulated into the output).

3.3 Backward Pass - Gradient Accumulation

The backward pass evaluates adjoint values vˉ=L/v\bar{v} = \partial \mathcal{L}/\partial v for every node, in reverse topological order.

Define the adjoint of node vv as:

vˉ  :=  Lv\bar{v} \;:=\; \frac{\partial \mathcal{L}}{\partial v}

where we treat vv as a scalar intermediate (extending to tensors componentwise).

Initialisation: vˉN=1\bar{v}_N = 1 (the loss node).

Backward recurrence: For a node vv with children (successors) c1,,cmc_1, \ldots, c_m - nodes that depend on vv:

vˉ=j=1mcˉjcjv\bar{v} = \sum_{j=1}^{m} \bar{c}_j \cdot \frac{\partial c_j}{\partial v}

This is exactly the chain rule applied in reverse.

Algorithm (Backward Pass):

Input:  graph G,  cache from forward pass
Output: partial/partialv for all v in V

  adjoint[v_N] <- 1
  For v in reverse_topological_order(G):
    For each parent u of v:
      adjoint[u] += adjoint[v] * partialv/partialu(cache)
                    
                    local_vjp(v, u, adjoint[v])

  return {adjoint[u] : u is a parameter node}

The key observation: each edge (u,v)(u, v) requires only:

  1. The cached forward value at uu (for the local gradient formula)
  2. The downstream adjoint vˉ\bar{v} (for the VJP multiplication)

3.4 Gradient Accumulation at Branching Nodes

A fan-out node uu has multiple children c1,,cmc_1, \ldots, c_m. The correct gradient is the sum of contributions:

uˉ=j=1mcˉjcju\bar{u} = \sum_{j=1}^{m} \bar{c}_j \cdot \frac{\partial c_j}{\partial u}

Proof: By the total derivative,

Lu=j=1mLcjcju=j=1mcˉjcju\frac{\partial \mathcal{L}}{\partial u} = \sum_{j=1}^{m} \frac{\partial \mathcal{L}}{\partial c_j} \cdot \frac{\partial c_j}{\partial u} = \sum_{j=1}^{m} \bar{c}_j \cdot \frac{\partial c_j}{\partial u} \qquad \square

Example - residual connection:

RESIDUAL BRANCH: u feeds into both F(u) and the skip path


          u
         / \
        /   \
      F(u)   \  <- identity skip
        \   /
         \ /
    z = F(u) + u

  Forward:   z = F(u) + u
  Backward:   = z * partialF(u)/partialu  +  z * 1
                = z * J_F(u)  +  z

  The identity skip guarantees a gradient highway:
  even if J_F(u) ~= 0 (saturated layer), z flows back unchanged.


This is the deep reason residual networks (He et al., 2016) solved the vanishing gradient problem: the skip connection creates a constant-1 term in the backward accumulation, guaranteeing uˉzˉ\bar{u} \geq \bar{z} in gradient magnitude.

3.5 Dynamic vs Static Graphs

Two design philosophies produce different tradeoffs:

DYNAMIC GRAPHS (PyTorch eager mode)    STATIC GRAPHS (JAX jit / TF graph)
     

Graph built anew each forward pass     Graph compiled once, reused

 Natural Python control flow           XLA/CUDA fusion, kernel merging
 Easy debugging (print anywhere)       Memory-optimal buffer allocation
 Variable-length sequences trivial     Can export/serve without Python
 Graph construction overhead           Tracing must handle all branches
 Less compiler optimisation            Python side-effects invisible

Examples: PyTorch, early Chainer       Examples: JAX jit, TF2 tf.function,
                                        ONNX Runtime, TensorRT

     

For transformers: Most production LLM training uses torch.compile (PyTorch 2.0+) which bridges the two: eager-mode graph construction with TorchDynamo tracing and inductor backend compilation, recovering ~30-50% throughput from kernel fusion.


4. Backpropagation

4.1 Network Notation

Consider a feedforward neural network with LL layers. Define:

SymbolMeaning
a[0]=x\mathbf{a}^{[0]} = \mathbf{x}Input vector, dimension n0n_0
W[l]Rnl×nl1W^{[l]} \in \mathbb{R}^{n_l \times n_{l-1}}Weight matrix for layer ll
b[l]Rnl\mathbf{b}^{[l]} \in \mathbb{R}^{n_l}Bias vector for layer ll
z[l]=W[l]a[l1]+b[l]\mathbf{z}^{[l]} = W^{[l]}\mathbf{a}^{[l-1]} + \mathbf{b}^{[l]}Pre-activation (linear combination)
a[l]=σ(z[l])\mathbf{a}^{[l]} = \sigma(\mathbf{z}^{[l]})Post-activation (elementwise)
y^=a[L]\hat{\mathbf{y}} = \mathbf{a}^{[L]}Network output
L(y^,y)\mathcal{L}(\hat{\mathbf{y}}, \mathbf{y})Scalar loss

The forward pass computes z[l]\mathbf{z}^{[l]} and a[l]\mathbf{a}^{[l]} for l=1,,Ll = 1, \ldots, L.

4.2 Forward Equations

z[l]=W[l]a[l1]+b[l],l=1,,L\mathbf{z}^{[l]} = W^{[l]}\mathbf{a}^{[l-1]} + \mathbf{b}^{[l]}, \qquad l = 1,\ldots,L a[l]=σ[l](z[l]),l=1,,L\mathbf{a}^{[l]} = \sigma^{[l]}(\mathbf{z}^{[l]}), \qquad l = 1,\ldots,L y^=a[L]\hat{\mathbf{y}} = \mathbf{a}^{[L]}

Cache for backward: {z[l],a[l1]}l=1L\{\mathbf{z}^{[l]}, \mathbf{a}^{[l-1]}\}_{l=1}^{L}.

4.3 Output Layer Gradient

For cross-entropy loss with softmax output, the output gradient has the celebrated clean form (derived in 5.3):

δ[L]:=Lz[L]=y^y\boldsymbol{\delta}^{[L]} := \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[L]}} = \hat{\mathbf{y}} - \mathbf{y}

where y\mathbf{y} is the one-hot label. This combines the softmax Jacobian with the cross-entropy gradient into a single elegant expression.

For MSE loss (L=12y^y2\mathcal{L} = \tfrac{1}{2}\|\hat{\mathbf{y}} - \mathbf{y}\|^2) with linear output:

δ[L]=y^y\boldsymbol{\delta}^{[L]} = \hat{\mathbf{y}} - \mathbf{y}

(Same form, different derivation - a useful coincidence that makes implementation uniform.)

4.4 Backpropagation Recurrence - Proof

Define the error signal:

δ[l]  :=  Lz[l]Rnl\boldsymbol{\delta}^{[l]} \;:=\; \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l]}} \in \mathbb{R}^{n_l}

Theorem (Backpropagation Recurrence):

δ[l]=(W[l+1])δ[l+1]σ[l](z[l])\boldsymbol{\delta}^{[l]} = \left(W^{[l+1]}\right)^\top \boldsymbol{\delta}^{[l+1]} \odot \sigma'^{[l]}(\mathbf{z}^{[l]})

Proof: Apply the chain rule from L\mathcal{L} to z[l]\mathbf{z}^{[l]} via a[l]\mathbf{a}^{[l]} and z[l+1]\mathbf{z}^{[l+1]}:

δ[l]=Lz[l]=Lz[l+1]z[l+1]z[l]\boldsymbol{\delta}^{[l]} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l]}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l+1]}} \cdot \frac{\partial \mathbf{z}^{[l+1]}}{\partial \mathbf{z}^{[l]}}

Step 1: L/z[l+1]=(δ[l+1])\partial \mathcal{L}/\partial \mathbf{z}^{[l+1]} = (\boldsymbol{\delta}^{[l+1]})^\top.

Step 2: z[l+1]=W[l+1]a[l]+b[l+1]=W[l+1]σ(z[l])+b[l+1]\mathbf{z}^{[l+1]} = W^{[l+1]}\mathbf{a}^{[l]} + \mathbf{b}^{[l+1]} = W^{[l+1]}\sigma(\mathbf{z}^{[l]}) + \mathbf{b}^{[l+1]}.

zi[l+1]zj[l]=Wij[l+1]σ(zj[l])\frac{\partial \mathbf{z}^{[l+1]}_i}{\partial \mathbf{z}^{[l]}_j} = W^{[l+1]}_{ij} \cdot \sigma'(\mathbf{z}^{[l]}_j)

In matrix form: J=W[l+1]diag(σ(z[l]))J = W^{[l+1]} \operatorname{diag}(\sigma'(\mathbf{z}^{[l]})).

Step 3: Multiply by (δ[l+1])(\boldsymbol{\delta}^{[l+1]})^\top and transpose to get column vector δ[l]\boldsymbol{\delta}^{[l]}:

δ[l]=Jδ[l+1]=diag(σ(z[l]))(W[l+1])δ[l+1]=(W[l+1])δ[l+1]σ(z[l])\boldsymbol{\delta}^{[l]} = J^\top \boldsymbol{\delta}^{[l+1]} = \operatorname{diag}(\sigma'(\mathbf{z}^{[l]})) \left(W^{[l+1]}\right)^\top \boldsymbol{\delta}^{[l+1]} = \left(W^{[l+1]}\right)^\top \boldsymbol{\delta}^{[l+1]} \odot \sigma'(\mathbf{z}^{[l]}) \qquad \square

The \odot (elementwise) product arises because σ\sigma is applied elementwise - its Jacobian is diagonal.

4.5 Weight and Bias Gradients

Once error signals δ[l]\boldsymbol{\delta}^{[l]} are computed, parameter gradients follow immediately:

LW[l]=δ[l](a[l1])Rnl×nl1\frac{\partial \mathcal{L}}{\partial W^{[l]}} = \boldsymbol{\delta}^{[l]} (\mathbf{a}^{[l-1]})^\top \in \mathbb{R}^{n_l \times n_{l-1}} Lb[l]=δ[l]Rnl\frac{\partial \mathcal{L}}{\partial \mathbf{b}^{[l]}} = \boldsymbol{\delta}^{[l]} \in \mathbb{R}^{n_l}

Derivation of weight gradient:

LWij[l]=Lzi[l]zi[l]Wij[l]=δi[l]aj[l1]\frac{\partial \mathcal{L}}{\partial W^{[l]}_{ij}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l]}_i} \cdot \frac{\partial \mathbf{z}^{[l]}_i}{\partial W^{[l]}_{ij}} = \delta^{[l]}_i \cdot a^{[l-1]}_j

Collecting over all i,ji,j: W[l]L=δ[l](a[l1])\nabla_{W^{[l]}} \mathcal{L} = \boldsymbol{\delta}^{[l]} (\mathbf{a}^{[l-1]})^\top.

This is an outer product - the gradient is rank-1 for a single sample. For a batch of BB samples it averages to higher rank.

4.6 Batched Backpropagation

With a mini-batch {(x(b),y(b))}b=1B\{(\mathbf{x}^{(b)}, \mathbf{y}^{(b)})\}_{b=1}^{B}, stack inputs into a matrix XRn0×BX \in \mathbb{R}^{n_0 \times B}.

The forward pass becomes:

Z[l]=W[l]A[l1]+b[l]1,A[l]=σ(Z[l])Z^{[l]} = W^{[l]} A^{[l-1]} + \mathbf{b}^{[l]} \mathbf{1}^\top, \quad A^{[l]} = \sigma(Z^{[l]})

where A[l]Rnl×BA^{[l]} \in \mathbb{R}^{n_l \times B}.

The backward pass produces Δ[l]Rnl×B\Delta^{[l]} \in \mathbb{R}^{n_l \times B} (error signals for all samples simultaneously).

Weight gradient for the batch:

LbatchW[l]=1BΔ[l](A[l1])\frac{\partial \mathcal{L}_\text{batch}}{\partial W^{[l]}} = \frac{1}{B} \Delta^{[l]} (A^{[l-1]})^\top

This is a single matrix multiplication, making batched backprop efficient on GPUs which excel at large GEMM (general matrix multiplication) operations.


5. Gradient Derivations for Standard Layers

5.1 Linear Layer

Forward: z=Wx+b\mathbf{z} = W\mathbf{x} + \mathbf{b}, where WRm×nW \in \mathbb{R}^{m \times n}, xRn\mathbf{x} \in \mathbb{R}^n.

Upstream gradient: zˉ=L/zRm\bar{\mathbf{z}} = \partial \mathcal{L}/\partial \mathbf{z} \in \mathbb{R}^m.

VJP (backward):

Lx=WzˉRn\frac{\partial \mathcal{L}}{\partial \mathbf{x}} = W^\top \bar{\mathbf{z}} \in \mathbb{R}^n LW=zˉxRm×n\frac{\partial \mathcal{L}}{\partial W} = \bar{\mathbf{z}} \mathbf{x}^\top \in \mathbb{R}^{m \times n} Lb=zˉRm\frac{\partial \mathcal{L}}{\partial \mathbf{b}} = \bar{\mathbf{z}} \in \mathbb{R}^m

Derivation of L/x\partial \mathcal{L}/\partial \mathbf{x}: Each zi=kWikxk+biz_i = \sum_k W_{ik} x_k + b_i, so zi/xj=Wij\partial z_i / \partial x_j = W_{ij}. By VJP: L/xj=izˉiWij=(Wzˉ)j\partial \mathcal{L}/\partial x_j = \sum_i \bar{z}_i W_{ij} = (W^\top \bar{\mathbf{z}})_j.

For AI: In a transformer with hidden dim dd and MLP expansion 4d4d: the two linear layers in FFN pass gradients back with WW^\top operations - same cost as the forward GEMM. Gradient computation for WW is also a GEMM.

5.2 Activation Functions

For elementwise a=σ(z)\mathbf{a} = \sigma(\mathbf{z}):

Lz=aˉσ(z)\frac{\partial \mathcal{L}}{\partial \mathbf{z}} = \bar{\mathbf{a}} \odot \sigma'(\mathbf{z})

Gradient formulas for common activations:

Activationσ(z)\sigma(z)σ(z)\sigma'(z)Notes
ReLUmax(0,z)\max(0,z)1[z>0]\mathbf{1}[z>0]Sparse gradient; "dead neurons" if z<0z<0 always
Sigmoid1/(1+ez)1/(1+e^{-z})σ(z)(1σ(z))\sigma(z)(1-\sigma(z))Saturates; max gradient 0.25 at z=0z=0
Tanhtanh(z)\tanh(z)1tanh2(z)1-\tanh^2(z)Saturates; max gradient 1 at z=0z=0
GELUzΦ(z)z\Phi(z)Φ(z)+zϕ(z)\Phi(z) + z\phi(z)Φ\Phi = Gaussian CDF; smooth at 0
SiLU/Swishzσ(z)z \sigma(z)σ(z)(1+z(1σ(z)))\sigma(z)(1+z(1-\sigma(z)))Used in LLaMA, Mistral
SoftPluslog(1+ez)\log(1+e^z)σ(z)\sigma(z)Smooth ReLU; gradient never zero

GELU (Hendrycks & Gimpel, 2016) is the standard activation in GPT-2/3, BERT, and most modern LLMs. It gates the input by its own probability under a Gaussian, producing richer gradient structure than ReLU.

5.3 Fused Softmax + Cross-Entropy Gradient

Setup: Output logits zRK\mathbf{z} \in \mathbb{R}^K, softmax probabilities p=softmax(z)\mathbf{p} = \text{softmax}(\mathbf{z}), true label y{1,,K}y \in \{1,\ldots,K\}, loss L=logpy\mathcal{L} = -\log p_y.

Claim:

Lz=pey\frac{\partial \mathcal{L}}{\partial \mathbf{z}} = \mathbf{p} - \mathbf{e}_y

where ey\mathbf{e}_y is the yy-th standard basis vector.

Proof: Write L=logpy=zy+logkezk\mathcal{L} = -\log p_y = -z_y + \log \sum_k e^{z_k}.

Lzj=1[j=y]+ezjkezk=pj1[j=y]=(pey)j\frac{\partial \mathcal{L}}{\partial z_j} = -\mathbf{1}[j=y] + \frac{e^{z_j}}{\sum_k e^{z_k}} = p_j - \mathbf{1}[j=y] = (\mathbf{p} - \mathbf{e}_y)_j \qquad \square

This direct derivation bypasses the softmax Jacobian computation entirely, which is why modern frameworks implement cross-entropy as a fused operation. For numerical stability, the logexp\log\sum\exp is computed with the log-sum-exp trick: logkezk=m+logkezkm\log\sum_k e^{z_k} = m + \log\sum_k e^{z_k - m} where m=maxkzkm = \max_k z_k.

5.4 LayerNorm Gradient

Forward: LayerNorm normalises each token independently:

x^=xμσ2+ϵ,y=γx^+β\hat{\mathbf{x}} = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}}, \qquad \mathbf{y} = \boldsymbol{\gamma} \odot \hat{\mathbf{x}} + \boldsymbol{\beta}

where μ=1dxi\mu = \tfrac{1}{d}\sum x_i, σ2=1d(xiμ)2\sigma^2 = \tfrac{1}{d}\sum (x_i - \mu)^2.

Backward: Let yˉ=L/y\bar{\mathbf{y}} = \partial \mathcal{L}/\partial \mathbf{y} be the upstream gradient.

Lγ=yˉx^,Lβ=yˉ\frac{\partial \mathcal{L}}{\partial \boldsymbol{\gamma}} = \bar{\mathbf{y}} \odot \hat{\mathbf{x}}, \qquad \frac{\partial \mathcal{L}}{\partial \boldsymbol{\beta}} = \bar{\mathbf{y}}

For the input gradient, define xˉnorm=yˉγ\bar{\mathbf{x}}_\text{norm} = \bar{\mathbf{y}} \odot \boldsymbol{\gamma}. The full gradient through the normalisation is:

Lx=1σ2+ϵ(xˉnorm1d1xˉnorm11dx^(x^xˉnorm)1)\frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{1}{\sqrt{\sigma^2+\epsilon}} \left( \bar{\mathbf{x}}_\text{norm} - \frac{1}{d}\mathbf{1}^\top \bar{\mathbf{x}}_\text{norm} \cdot \mathbf{1} - \frac{1}{d}\hat{\mathbf{x}} \odot (\hat{\mathbf{x}}^\top \bar{\mathbf{x}}_\text{norm}) \cdot \mathbf{1} \right)

This expression subtracts mean and mean-of-hadamard terms, reflecting that LayerNorm's Jacobian projects out two degrees of freedom (02 exercises).

For AI: LayerNorm appears in every transformer layer (pre-norm placement in modern architectures like GPT-NeoX, LLaMA). The gradient through LayerNorm is never zero - it always passes signal, unlike BatchNorm which can become degenerate at small batch sizes.

5.5 Dot-Product Attention Gradient

Forward (simplified single-head):

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V S=QK/dk,P=softmax(S),O=PVS = QK^\top / \sqrt{d_k}, \quad P = \text{softmax}(S), \quad O = PV

Backward: Given upstream OˉRT×dv\bar{O} \in \mathbb{R}^{T \times d_v}:

Vˉ=POˉ,Pˉ=OˉV\bar{V} = P^\top \bar{O}, \quad \bar{P} = \bar{O} V^\top Sˉ=softmax_backward(P,Pˉ)=P(Pˉ(PPˉ)11)1dk\bar{S} = \text{softmax\_backward}(P, \bar{P}) = P \odot (\bar{P} - (P \odot \bar{P})\mathbf{1}\mathbf{1}^\top) \cdot \frac{1}{\sqrt{d_k}} Qˉ=SˉK,Kˉ=SˉQ\bar{Q} = \bar{S} K, \quad \bar{K} = \bar{S}^\top Q

Then WQ=XQˉ\nabla_{W_Q} = X^\top \bar{Q}, and similarly for WKW_K, WVW_V.

Critical memory issue: Storing PRT×TP \in \mathbb{R}^{T \times T} for the backward costs O(T2)O(T^2) - this is what FlashAttention avoids by recomputing PP from Q,KQ, K during the backward pass (see 7.3).

5.6 Embedding Layer Gradient

Forward: ht=E[tokent]\mathbf{h}_t = E[\text{token}_t], where ERV×dE \in \mathbb{R}^{V \times d} is the embedding table.

Backward: Given upstream hˉt\bar{\mathbf{h}}_t for all tokens t=1,,Tt = 1,\ldots,T:

LE[i]=t:tokent=ihˉt\frac{\partial \mathcal{L}}{\partial E[i]} = \sum_{t : \text{token}_t = i} \bar{\mathbf{h}}_t

This is a sparse gradient - only rows corresponding to tokens in the sequence receive nonzero updates. For vocabulary size V=128,000V = 128{,}000 (LLaMA-3), the embedding matrix is 128,000×4,096128{,}000 \times 4{,}096, but only a tiny fraction of rows are updated per batch. Distributed training with embedding sharding exploits this sparsity.


6. Vanishing and Exploding Gradients

6.1 Magnitude Analysis - The Core Problem

Consider an LL-layer network with no activation functions (to isolate the linear case). The gradient of the loss with respect to layer ll parameters involves the product:

Lz[l]=W[L]W[L1]W[l+1]δ[L]\frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l]}} = W^{[L]} W^{[L-1]} \cdots W^{[l+1]} \cdot \boldsymbol{\delta}^{[L]}

This is a product of LlL - l matrices. By the submultiplicativity of the spectral norm:

Lz[l]2k=l+1LW[k]2δ[L]2\left\| \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[l]}} \right\|_2 \leq \prod_{k=l+1}^{L} \|W^{[k]}\|_2 \cdot \|\boldsymbol{\delta}^{[L]}\|_2

If W[k]2=ρ<1\|W^{[k]}\|_2 = \rho < 1 for all layers:

gradient normρLlδ[L]2L0(vanishing)\text{gradient norm} \leq \rho^{L-l} \cdot \|\boldsymbol{\delta}^{[L]}\|_2 \xrightarrow{L \to \infty} 0 \quad \text{(vanishing)}

If W[k]2=ρ>1\|W^{[k]}\|_2 = \rho > 1:

gradient normρLlδ[L]2L(exploding)\text{gradient norm} \geq \rho^{L-l} \cdot \|\boldsymbol{\delta}^{[L]}\|_2 \xrightarrow{L \to \infty} \infty \quad \text{(exploding)}
GRADIENT MAGNITUDE ACROSS LAYERS


  gradient norm
  
    exploding (rho > 1)
    
     
   ideal (rho = 1)
       
         vanishing (rho < 1)
   layer l
  L                                             0

  With activations, the product includes sigma'(z) terms (< 1 for sigmoid)
  compounding the vanishing problem.


This was identified by Hochreiter (1991) as the fundamental obstacle to training deep networks with gradient descent.

6.2 Activations and Saturation

For sigmoid σ\sigma: σ(z)=σ(z)(1σ(z))0.25\sigma'(z) = \sigma(z)(1-\sigma(z)) \leq 0.25 for all zz, with equality only at z=0z=0. In the tails (z0|z| \gg 0), σ(z)0\sigma'(z) \approx 0.

For tanh: tanh(z)=1tanh2(z)1\tanh'(z) = 1 - \tanh^2(z) \leq 1, saturating similarly.

In a network with LL sigmoid layers and all activations near saturation, the gradient at layer 1 is suppressed by approximately 0.25L0.25^L. For L=20L = 20: 0.252010120.25^{20} \approx 10^{-12} - numerically zero.

ReLU resolves saturation: relu(z)=1[z>0]\text{relu}'(z) = \mathbf{1}[z > 0], which is either 0 or 1. For active neurons, it passes gradients unchanged. However, "dying ReLU" (neurons with z<0z < 0 always) creates a different problem - those neurons receive zero gradient and never recover.

GELU and SiLU (used in LLaMA) are smooth approximations that avoid hard zeros, maintaining nonzero gradients everywhere.

6.3 Xavier and He Initialisation

Goal: Choose initial weights so that gradient (and activation) variance is preserved across layers - avoiding exponential growth or decay from the start of training.

Xavier Initialisation (Glorot & Bengio, 2010) - for symmetric activations (tanh, linear):

Assumption: Weights WijN(0,σ2)W_{ij} \sim \mathcal{N}(0, \sigma^2) i.i.d., inputs xjx_j with variance Var(xj)=v\text{Var}(x_j) = v.

Forward variance preservation: Var(zi)=ninσ2vσ2=1/nin\text{Var}(z_i) = n_\text{in} \sigma^2 v \Rightarrow \sigma^2 = 1/n_\text{in}.

Backward variance preservation: Var(xˉj)=noutσ2Var(zˉi)σ2=1/nout\text{Var}(\bar{x}_j) = n_\text{out} \sigma^2 \cdot \text{Var}(\bar{z}_i) \Rightarrow \sigma^2 = 1/n_\text{out}.

Compromise:

σ2=2nin+nout,or uniform WijU ⁣[6nin+nout,  6nin+nout]\sigma^2 = \frac{2}{n_\text{in} + n_\text{out}}, \qquad \text{or uniform } W_{ij} \sim U\!\left[-\sqrt{\frac{6}{n_\text{in}+n_\text{out}}},\; \sqrt{\frac{6}{n_\text{in}+n_\text{out}}}\right]

He Initialisation (He et al., 2015) - for ReLU activations:

ReLU zeroes half the distribution, so effective variance is halved: Var(relu(z))=12Var(z)\text{Var}(\text{relu}(z)) = \tfrac{1}{2}\text{Var}(z). To compensate:

σ2=2nin\sigma^2 = \frac{2}{n_\text{in}}

For AI: GPT-2 uses a scaled version: weight initialisation N(0,0.02)\mathcal{N}(0, 0.02) with the residual projection layers further scaled by 1/2L1/\sqrt{2L} where LL is the number of transformer layers, to control the variance accumulation in the residual stream.

6.4 Residual Connections as Gradient Highways

Theorem: In a residual network F[l+1](x)=x+G[l](x)F^{[l+1]}(\mathbf{x}) = \mathbf{x} + G^{[l]}(\mathbf{x}), the gradient satisfies:

Lx[0]=l=1L(I+JG[l])Lx[L]\frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[0]}} = \prod_{l=1}^{L}\left(I + J_{G^{[l]}}\right) \cdot \frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[L]}}

Key insight: Expanding the product, we get:

Lx[0]=(I+lJG[l]+l<lJG[l]JG[l]+)Lx[L]\frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[0]}} = \left(I + \sum_l J_{G^{[l]}} + \sum_{l < l'} J_{G^{[l']}} J_{G^{[l]}} + \cdots \right) \frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[L]}}

The identity term guarantees that even if all JG[l]0J_{G^{[l]}} \approx 0 (at initialisation), the gradient L/x[0]\partial \mathcal{L}/\partial \mathbf{x}^{[0]} receives the full upstream signal unchanged. This is the theoretical explanation for why ResNets (He et al., 2016) can be trained with hundreds of layers.

In modern transformers, the pre-norm architecture (LayerNorm before the sublayer, not after) further improves gradient flow by ensuring that the residual path carries a pure copy of the signal.

6.5 Gradient Clipping

Gradient explosion is addressed pragmatically by global gradient norm clipping:

ggmin ⁣(1,  τg2)\mathbf{g} \leftarrow \mathbf{g} \cdot \min\!\left(1,\; \frac{\tau}{\|\mathbf{g}\|_2}\right)

where g\mathbf{g} is the concatenated parameter gradient vector and τ\tau is the clip threshold.

Typical values: τ=1.0\tau = 1.0 for transformers (used in GPT-3, PaLM, LLaMA).

Why global (not per-layer)? Clipping each layer's gradient independently destroys the relative proportions of updates across layers, disrupting the Adam momentum states. Global clipping preserves direction, only reducing magnitude.

Relationship to RNNs: Gradient clipping was originally introduced for RNNs (Mikolov, 2012; Pascanu et al., 2013), where the vanishing/exploding problem is especially severe due to the long chain of time steps.

6.6 Batch Normalisation and Layer Normalisation

BatchNorm (Ioffe & Szegedy, 2015) normalises each feature across the batch, stabilising the distribution of pre-activations. Its gradient has a complex form involving batch statistics, but crucially it prevents activations from saturating on average.

LayerNorm (Ba et al., 2016) normalises each sample across features - preferred in transformers because:

  1. Behaviour is independent of batch size (critical for small-batch inference)
  2. Gradient analysis shows it damps large pre-activation magnitudes
  3. Pre-norm placement ensures the residual stream grows in a controlled manner

Empirical gradient norm tracking is standard practice in LLM training: the gradient norm is logged at every step, and sudden spikes indicate loss spikes or numerical issues. The Chinchilla and GPT-4 training runs used gradient norm monitoring as a primary signal for training health.


7. Memory-Efficient Backpropagation

7.1 Memory Cost of Standard Backprop

Standard backpropagation caches all intermediate activations for use in the backward pass. For a transformer with LL layers, batch size BB, sequence length TT, and hidden dimension dd:

Component cachedSizeAt FP16
Attention QKV projections3×L×B×T×d3 \times L \times B \times T \times d6LBTd6LBTd bytes
Attention scores (pre-softmax)L×B×H×T2L \times B \times H \times T^22LBHT22LBH T^2 bytes
MLP intermediateL×B×T×4dL \times B \times T \times 4d8LBTd8LBTd bytes
LayerNorm stats2×2L×B×T2 \times 2L \times B \times Tnegligible

For GPT-3 (L=96L=96, B=512B=512, T=2048T=2048, d=12288d=12288, H=96H=96): the attention scores alone require 96×512×96×20482×22.4 TB96 \times 512 \times 96 \times 2048^2 \times 2 \approx 2.4 \text{ TB} - clearly infeasible without optimisation.

7.2 Gradient Checkpointing

Idea: Trade compute for memory. Instead of caching all activations, cache only a subset of "checkpoint" activations and recompute the rest during the backward pass.

Algorithm (checkpointing at every kk-th layer):

GRADIENT CHECKPOINTING


  Forward pass:
    Compute all layers normally
    Save activations only at layers 0, k, 2k, 3k, ...
    Discard all other intermediate activations

  Backward pass:
    For each segment [lk, (l+1)k]:
      Re-run the forward pass from checkpoint lk
      Now have all intermediates for this segment
      Compute gradients for layers lk+1 to (l+1)k-1
      Discard intermediates (no longer needed)


Memory-compute tradeoff:

  • Memory: O(L)O(\sqrt{L}) checkpoints (optimal with k=Lk = \sqrt{L}) instead of O(L)O(L)
  • Compute: Each layer's forward pass is run twice (once in original forward, once in recomputation) -> approximately +33%+33\% compute overhead

For AI: torch.utils.checkpoint.checkpoint() implements this in PyTorch with a single function call. LLaMA, Mistral, and most OSS LLM trainers enable activation checkpointing by default for sequences longer than ~2048 tokens.

Selective recomputation: Flash Attention (see 7.3) takes a more targeted approach - instead of checkpointing by layer, it recomputes only the attention scores (the T2T^2 term) during the backward pass, since those are the dominant memory consumer.

7.3 FlashAttention: Fused Backward Pass

The O(T2)O(T^2) problem: Standard attention stores P=softmax(QK/d)RT×TP = \text{softmax}(QK^\top/\sqrt{d}) \in \mathbb{R}^{T \times T} for the backward pass. For T=32,768T = 32{,}768 (long-context models), this is 327682×2232768^2 \times 2 \approx 2 GB per layer per batch element.

FlashAttention solution (Dao et al., 2022): Compute attention in tiles that fit in SRAM (GPU on-chip cache), using the online softmax algorithm (Milakov & Gimelshein, 2018) to avoid materialising the full T×TT \times T matrix.

Backward pass in FlashAttention: The backward pass needs PP but doesn't store it. Instead:

  1. Store only the softmax normalisation statistics mi,im_i, \ell_i (scalars per row) - O(T)O(T) memory
  2. During the backward pass, recompute PP tile by tile from Q,KQ, K and the stored statistics
  3. Accumulate gradients Qˉ,Kˉ,Vˉ\bar{Q}, \bar{K}, \bar{V} tile by tile without ever forming full PP

Complexity:

  • Memory: O(T)O(T) instead of O(T2)O(T^2)
  • FLOPs: 4×4 \times the forward FLOPs (small constant factor)
  • Wall-clock speedup: 2-4x over standard PyTorch attention on A100

For AI: FlashAttention is the default attention implementation in modern LLM training (vLLM, HuggingFace Transformers, NanoGPT). FlashAttention-3 (2024) further optimises for H100 tensor core and async operations.

7.4 Mixed Precision Training

Observation: FP32 (32-bit float) is unnecessarily precise for gradients. FP16 (16-bit float) has 3×\sim 3\times higher memory bandwidth on modern GPUs, but overflow/underflow is common for small/large gradient values.

AMP (Automatic Mixed Precision) strategy:

ComponentPrecisionReason
Forward activationsFP16Fast compute, lower memory
Backward gradientsFP16Fast compute
Weight updatesFP32Avoid precision loss
Master weightsFP32Preserve small updates (Δww\Delta w \ll w)
Loss scalingDynamicPrevent FP16 underflow for small gradients

Loss scaling: Multiply the loss by a large scale factor SS (typically 2122^{12} to 2162^{16}) before backward, then divide gradients by SS before the weight update. This shifts gradient values into the representable FP16 range. The scale factor is increased or decreased based on whether overflow (inf/nan) occurred.

BF16 (Brain Float 16, used in TPUs and H100): same 16-bit width but with 8 exponent bits (same as FP32) and only 7 mantissa bits. Eliminates overflow issues while retaining dynamic range - now the preferred format for LLM training.


8. Advanced Differentiation Topics

8.1 Backpropagation Through Time (BPTT)

A recurrent neural network (RNN) with hidden state ht=σ(Whht1+Wxxt+b)\mathbf{h}_t = \sigma(W_h \mathbf{h}_{t-1} + W_x \mathbf{x}_t + \mathbf{b}) can be viewed as a feedforward network unrolled through time:

UNROLLED RNN - BPTT VIEW


  x_1 -> [cell] -> h_1 -> [cell] -> h_2 -> [cell] -> h_3 -> ... -> h -> loss
                                        
          W             W             W         (shared weights)

  BPTT = backprop through the unrolled graph.
  Gradient of loss w.r.t. W = sum of gradients from all time steps.


The gradient with respect to WhW_h at time step tt involves the product:

LWh=t=1TLtht(k=1t1htk+1htk)h1Wh\frac{\partial \mathcal{L}}{\partial W_h} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \left( \prod_{k=1}^{t-1} \frac{\partial \mathbf{h}_{t-k+1}}{\partial \mathbf{h}_{t-k}} \right) \frac{\partial \mathbf{h}_1}{\partial W_h}

Each factor ht+1/ht=Whdiag(σ(zt))\partial \mathbf{h}_{t+1}/\partial \mathbf{h}_t = W_h \cdot \text{diag}(\sigma'(\mathbf{z}_t)). When Wh2σ<1\|W_h\|_2 \cdot \|\sigma'\|_\infty < 1, the product of TT such factors vanishes exponentially. This is the core failure mode of vanilla RNNs on long sequences (Hochreiter, 1991; Bengio et al., 1994).

Truncated BPTT: In practice, gradients are truncated to a window of kk steps to reduce memory and compute costs, at the cost of ignoring long-range dependencies beyond step kk.

LSTM/GRU solution: Long Short-Term Memory networks (Hochreiter & Schmidhuber, 1997) use gating mechanisms to maintain a cell state ct\mathbf{c}_t with additive updates - replacing multiplicative products of weight matrices with additive accumulation, similar to residual connections.

8.2 Implicit Differentiation Preview

For optimisation problems or fixed-point iterations, we sometimes need gradients of implicit functions.

Example: Consider y=argminyL(y,θ)\mathbf{y}^* = \arg\min_\mathbf{y} \mathcal{L}(\mathbf{y}, \theta) where the optimum satisfies yL(y,θ)=0\nabla_\mathbf{y} \mathcal{L}(\mathbf{y}^*, \theta) = 0.

By the implicit function theorem:

dydθ=[yy2L]1yθ2L\frac{d\mathbf{y}^*}{d\theta} = -\left[\nabla^2_{\mathbf{y}\mathbf{y}} \mathcal{L}\right]^{-1} \nabla^2_{\mathbf{y}\theta} \mathcal{L}

This allows differentiating through optimisation steps without unrolling them - the basis of MAML (Model-Agnostic Meta-Learning, Finn et al., 2017) and DEQs (Deep Equilibrium Models, Bai et al., 2019).

Full treatment: Implicit differentiation and differentiable optimisation are covered in depth in 05/05-Automatic-Differentiation.

8.3 Straight-Through Estimator and REINFORCE

The discrete problem: When a node in the computation graph applies a discrete operation (argmax, sampling, rounding), the gradient is zero almost everywhere. The chain rule breaks - the graph is not differentiable at these nodes.

Straight-Through Estimator (STE) (Hinton, 2012; Bengio et al., 2013):

Lx:=Lx^1(treat discretisation as identity in backward)\frac{\partial \mathcal{L}}{\partial \mathbf{x}} := \frac{\partial \mathcal{L}}{\partial \hat{\mathbf{x}}} \cdot \mathbf{1} \qquad \text{(treat discretisation as identity in backward)}

Applications:

  • Quantisation-aware training (QAT): Simulate INT8 forward, use STE backward. Used in GPTQ, AWQ, and quantised LLM training.
  • VQ-VAE: Vector quantisation in the encoder uses STE so gradients flow from decoder back to encoder.
  • Binary neural networks: Forward uses sign(x), backward uses STE with gradient identity.

REINFORCE (Williams, 1992): For stochastic nodes, use the log-derivative trick:

θEzpθ[L(z)]=Ezpθ[L(z)θlogpθ(z)]\nabla_\theta \mathbb{E}_{z \sim p_\theta}[\mathcal{L}(z)] = \mathbb{E}_{z \sim p_\theta}[\mathcal{L}(z) \nabla_\theta \log p_\theta(z)]

This produces an unbiased gradient estimate but with high variance (addressed by baseline subtraction: (L(z)b)θlogpθ(z)(\mathcal{L}(z) - b)\nabla_\theta \log p_\theta(z)). REINFORCE is the foundation of policy gradient methods in RL and is used in RLHF's PPO step.

8.4 Higher-Order Gradients

Second-order gradients arise in:

  1. Newton's method: requires Hessian H=2LH = \nabla^2 \mathcal{L} (see 02-Jacobians-and-Hessians)
  2. Meta-learning (MAML): gradient of gradient w.r.t. outer parameters
  3. Gradient penalty in GAN training: xD(x)2\|\nabla_x D(x)\|^2

In PyTorch: Higher-order gradients are computed by running autograd through itself:

# Second derivative of loss w.r.t. input
loss = model(x).sum()
g = torch.autograd.grad(loss, x, create_graph=True)[0]
g2 = torch.autograd.grad(g.sum(), x)[0]  # second derivative

create_graph=True tells autograd to build a graph for the gradient computation itself, enabling differentiation through it.

Hessian-vector products (HVPs): As shown in 02, the HVP HvHv can be computed in O(n)O(n) time without forming HH:

Hv=x[(xL)v]Hv = \nabla_\mathbf{x}[(\nabla_\mathbf{x} \mathcal{L})^\top v]

This is the primitive operation behind conjugate gradient and Lanczos methods for curvature estimation.


9. Transformer Backpropagation

9.1 Full Transformer Layer Gradient Flow

A pre-norm transformer layer processes the residual stream xRd\mathbf{x} \in \mathbb{R}^d as:

x=x+Attn(LN1(x))\mathbf{x}' = \mathbf{x} + \text{Attn}(\text{LN}_1(\mathbf{x})) x=x+MLP(LN2(x))\mathbf{x}'' = \mathbf{x}' + \text{MLP}(\text{LN}_2(\mathbf{x}'))

Backward through one transformer layer (given xˉ\bar{\mathbf{x}}''\,):

GRADIENT FLOW - ONE TRANSFORMER LAYER


  FORWARD                              BACKWARD

  x      x'' flows in
                                     
      LN_1(x)                           x' = x'' + MLP_backward
                                                  (x'')
      Attn(*)                          x = x' + Attn_backward
                                                  (x')
  x' = x + Attn(LN_1(x)) 
                                       The two residual additions
                                     split the gradient stream
      LN_2(x')                         into parallel paths -
                                     the identity skip carries
      MLP(*)                          the full upstream signal
                                     unchanged.
  x'' = x' + MLP(LN_2(x')) 


The critical observation: both residual additions in the transformer layer act as gradient splitters. The skip path carries a copy of xˉ\bar{\mathbf{x}}'' directly back to xˉ\bar{\mathbf{x}}' without passing through the MLP Jacobian. This gives transformers well-behaved gradients even at L=96L = 96 layers (GPT-3) or L=126L = 126 layers (Grok-1).

9.2 LoRA Backward Pass

Low-Rank Adaptation (Hu et al., 2022) reparametrises a weight matrix:

W=W0+BA,W0Rm×n (frozen),BRm×r,ARr×nW = W_0 + BA, \quad W_0 \in \mathbb{R}^{m \times n} \text{ (frozen)}, \quad B \in \mathbb{R}^{m \times r}, A \in \mathbb{R}^{r \times n}

Forward: y=Wx=W0x+BAx\mathbf{y} = W\mathbf{x} = W_0\mathbf{x} + BA\mathbf{x}.

Backward (given yˉ\bar{\mathbf{y}}):

Aˉ=ByˉxRr×n\bar{A} = B^\top \bar{\mathbf{y}} \mathbf{x}^\top \in \mathbb{R}^{r \times n} Bˉ=yˉxARm×r\bar{B} = \bar{\mathbf{y}} \mathbf{x}^\top A^\top \in \mathbb{R}^{m \times r} xˉfrom LoRA=AByˉ=(BA)yˉ\bar{\mathbf{x}}_\text{from LoRA} = A^\top B^\top \bar{\mathbf{y}} = (BA)^\top \bar{\mathbf{y}}

Note: W0W_0 is frozen, so Wˉ0=0\bar{W}_0 = 0 - no gradient is computed or stored for W0W_0. The backward pass only updates AA and BB.

Memory savings: For W0R4096×4096W_0 \in \mathbb{R}^{4096 \times 4096} with r=16r = 16: gradient storage reduces from 40962=16.7M4096^2 = 16.7\text{M} to r(m+n)=16×8192=131Kr(m + n) = 16 \times 8192 = 131\text{K} parameters - a 128×128\times reduction in gradient memory for that layer.

DoRA (Liu et al., 2024) further decomposes LoRA into magnitude + direction components, improving fine-tuning quality while preserving the low-rank backward structure.

9.3 Gradient Accumulation

Problem: Large effective batch sizes (B=4MB = 4\text{M} tokens, as in GPT-4 training) don't fit in GPU memory for a single forward-backward pass.

Solution - gradient accumulation:

For each micro-batch b = 1, ..., G:
    loss_b = forward(micro_batch_b) / G   # scaled loss
    backward(loss_b)                       # accumulates gradients
    # gradients are NOT zeroed between micro-batches

optimizer.step()   # update once after G micro-batches
optimizer.zero_grad()

The division by GG ensures the accumulated gradient is mathematically identical to what a single pass with the full batch would produce.

For AI: GPT-3 used gradient accumulation to achieve an effective batch of 3M\sim 3\text{M} tokens with hardware that could only process 500K\sim 500\text{K} tokens per step.

9.4 Distributed Gradient Synchronisation

In data parallelism, each GPU processes a different micro-batch but shares the same model weights. After the backward pass, gradients must be synchronised:

All-Reduce: Sum gradients across all NN GPUs and divide by NN. Implemented via ring-all-reduce (NCCL) for O(Nθ)O(N \cdot |\theta|) communication that is bandwidth-optimal.

Gradient sharding (ZeRO): DeepSpeed's ZeRO (Zero Redundancy Optimizer) partitions gradient storage across GPUs:

  • ZeRO Stage 1: Shard optimiser states -> 4×4\times memory reduction
  • ZeRO Stage 2: Shard gradients additionally -> 8×8\times reduction
  • ZeRO Stage 3: Shard parameters too -> N×N\times reduction (linear in GPU count)

For LLaMA-3 70B training: ZeRO Stage 3 across 1024 H100 GPUs allows storing only 70B/102468M\sim 70\text{B}/1024 \approx 68\text{M} parameters per GPU - fitting the model in memory.


10. Common Mistakes

#MistakeWhy It's WrongFix
1Applying scalar chain rule dydx=dydududx\frac{dy}{dx} = \frac{dy}{du}\frac{du}{dx} for vector functionsScalar chain rule multiplies; multivariate chain rule composes Jacobians. Order matters: Jfg=JfJgJ_{f\circ g} = J_f \cdot J_g, not JgJfJ_g \cdot J_fWrite Jacobians explicitly and multiply left-to-right in the order of composition
2Forgetting to sum gradients at fan-out (shared weight) nodesEach use of a weight contributes a gradient; missing uses means undercountingAccumulate gradients with += in the backward loop over all uses
3Treating WL=δx\nabla_W \mathcal{L} = \boldsymbol{\delta} \mathbf{x}^\top as shape-correct without checkingThe outer product δx\boldsymbol{\delta} \mathbf{x}^\top has shape (nout,nin)(n_\text{out}, n_\text{in}) matching WW; but transposing either vector gives wrong shapeAlways verify gradient shapes match parameter shapes before implementation
4Using sigmoid/tanh in deep networks expecting no vanishing gradientsTheir derivatives are bounded by 0.250.25 / 1.01.0 - products over many layers vanish exponentiallyUse ReLU, GELU, or SiLU with proper initialisation; add residual connections
5Initialising all weights to zero (or same value)Symmetry breaking fails: every neuron in a layer computes the same gradient, so they all update identically and remain identical foreverUse Xavier or He initialisation with random values
6Skipping the fused softmax + cross-entropy optimisation and computing them separatelyIntermediate probabilities pi=ezi/ezjp_i = e^{z_i}/\sum e^{z_j} overflow/underflow for large logitsAlways use the log-sum-exp trick or a library's CrossEntropyLoss (which applies it internally)
7Confusing JVP and VJP - using JVP for all gradient computationsJVP costs O(n)O(n) passes for scalar output; VJP costs O(1)O(1) per output dimension. For scalar loss, always use VJP (backprop)Use VJP (backward) for scalar losses; reserve JVP for computing Jacobian rows or directional derivatives
8Clipping per-layer gradients independently instead of global normDestroys the relative scale of gradients across layers; disrupts Adam's per-parameter adaptive scalingClip the global gradient norm: compute g\|\mathbf{g}\| across all parameters, scale down if above threshold
9Using STE incorrectly in quantisation-aware training - applying STE to continuous weightsSTE should only be applied at the discrete rounding step, not to subsequent continuous operationsApply STE only at the round() or sign() node; propagate real gradients elsewhere
10Misunderstanding gradient accumulation - forgetting to scale the lossAccumulating GG micro-batch gradients without dividing by GG produces G×G\times too large an effective gradientDivide loss by GG before backward, or divide accumulated gradients by GG before the optimiser step
11Not using create_graph=True when computing higher-order gradients in PyTorchWithout create_graph=True, the gradient computation is not tracked, so differentiating through it returns None or wrong valuesUse create_graph=True in the first torch.autograd.grad() call when second derivatives are needed
12Confusing BPTT truncation with sequence truncationTruncated BPTT still runs the full forward sequence; it only truncates the backward window. Sequence truncation shortens bothThese are different operations - read the framework docs to confirm which is applied

11. Exercises

Exercise 1 - Scalar Chain Rule Verification

Let f(t)=sin(t2)f(t) = \sin(t^2) and g(t)=e3tg(t) = e^{3t}.

(a) Compute ddt[f(g(t))]\frac{d}{dt}[f(g(t))] using the chain rule analytically.

(b) Evaluate the derivative at t=0t = 0 and t=1t = 1.

(c) Verify numerically using centred finite differences.

(d) Compute ddt[g(f(t))]\frac{d}{dt}[g(f(t))] - explain why the order of composition matters.


Exercise 2 - Jacobian Composition

Let f:R3R2f: \mathbb{R}^3 \to \mathbb{R}^2 and g:R2R3g: \mathbb{R}^2 \to \mathbb{R}^3 be defined by:

f(u)=(u1u2, u2+u32),g(x)=(x12, x1x2, ex2)f(\mathbf{u}) = (u_1 u_2,\ u_2 + u_3^2), \quad g(\mathbf{x}) = (x_1^2,\ x_1 x_2,\ e^{x_2})

(a) Compute Jf(u)J_f(\mathbf{u}) and Jg(x)J_g(\mathbf{x}) analytically.

(b) Compute the Jacobian of h=fg:R2R2h = f \circ g: \mathbb{R}^2 \to \mathbb{R}^2 using the chain rule Jh=Jf(g(x))Jg(x)J_h = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x}).

(c) Verify using finite differences at x0=(1,0)\mathbf{x}_0 = (1, 0)^\top.

(d) Compute JhJ_h directly and confirm it equals part (b).


Exercise 3 - Backprop Through a 2-Layer Network

Two-layer network: z[1]=W[1]x+b[1]\mathbf{z}^{[1]} = W^{[1]}\mathbf{x} + \mathbf{b}^{[1]}, a[1]=relu(z[1])\mathbf{a}^{[1]} = \text{relu}(\mathbf{z}^{[1]}), y^=w[2]a[1]+b[2]\hat{y} = \mathbf{w}^{[2]} \cdot \mathbf{a}^{[1]} + b^{[2]}, L=12(y^y)2\mathcal{L} = \tfrac{1}{2}(\hat{y} - y)^2.

With W[1]R3×2W^{[1]} \in \mathbb{R}^{3 \times 2}, xR2\mathbf{x} \in \mathbb{R}^2, w[2]R3\mathbf{w}^{[2]} \in \mathbb{R}^3:

(a) Implement forward pass. Compute L\mathcal{L} for given values.

(b) Implement backward pass manually using the backpropagation recurrence.

(c) Verify your gradients using numpy finite differences.

(d) Implement gradient descent for 100 steps with learning rate 0.010.01 and verify loss decreases.


Exercise 4 - Vanishing Gradients Analysis

(a) Construct a 20-layer sigmoid network with all weights =0.3= 0.3. Compute the gradient at layer 1 symbolically and numerically.

(b) Repeat with ReLU activation. Compare gradient magnitudes at layers 1, 5, 10, 20.

(c) Apply Xavier initialisation to the sigmoid network and compare gradient flow.

(d) Add residual connections to the 20-layer sigmoid network. Quantify the improvement.

(e) Plot gradient norm vs. layer depth for all four cases.


Exercise 5 - Gradient Checkpointing

(a) Implement a 10-layer feedforward network with explicit intermediate caching. Measure peak memory usage.

(b) Implement the same network with gradient checkpointing at every 3rd layer. Measure memory.

(c) Verify that both implementations produce identical gradients.

(d) Measure the compute overhead of recomputation. How does it compare to the theoretical +33%+33\%?

(e) Find the optimal checkpoint interval kk that minimises total memory x compute cost.


Exercise 6 - Attention Gradient

Single-head attention: O=softmax(QK/d)VO = \text{softmax}(QK^\top/\sqrt{d})V with Q,K,VRT×dQ, K, V \in \mathbb{R}^{T \times d} for T=4T = 4, d=3d = 3.

(a) Implement forward pass.

(b) Implement backward pass computing Qˉ,Kˉ,Vˉ\bar{Q}, \bar{K}, \bar{V} given Oˉ\bar{O}.

(c) Verify all three gradients using finite differences.

(d) For causal masking (set Sij=S_{ij} = -\infty for j>ij > i), show that the backward pass is unchanged except at masked positions.


Exercise 7 - LoRA Gradient Analysis

(a) Implement a linear layer y=(W0+BA)x\mathbf{y} = (W_0 + BA)\mathbf{x} with LoRA adaptation. Set m=8,n=6,r=2m=8, n=6, r=2.

(b) Compute gradients AL\nabla_A \mathcal{L} and BL\nabla_B \mathcal{L} analytically and verify numerically.

(c) Confirm that W0L=yˉx\nabla_{W_0} \mathcal{L} = \bar{\mathbf{y}} \mathbf{x}^\top but is not used (frozen).

(d) Compare the number of gradient parameters for full fine-tuning vs. LoRA.

(e) Implement LoRA training for 200 steps on a toy task and compare convergence with full fine-tuning.


Exercise 8 - REINFORCE and STE

(a) Implement a stochastic computational graph: zBernoulli(σ(θ))z \sim \text{Bernoulli}(\sigma(\theta)), L=z2\mathcal{L} = z^2.

(b) Compute the REINFORCE gradient θE[L]\nabla_\theta \mathbb{E}[\mathcal{L}] analytically.

(c) Estimate the REINFORCE gradient with 10000 samples. Verify against the analytical value.

(d) Implement STE for the rounding operation: w^=round(w)\hat{w} = \text{round}(w), L=(w^wtarget)2\mathcal{L} = (\hat{w} - w_\text{target})^2. Compute the STE gradient and update ww.

(e) Compare STE-based quantisation-aware training on a toy example: train for 50 steps and measure quantisation error vs. a post-training quantised model.


12. Why This Matters for AI (2026 Perspective)

ConceptConcrete AI Impact
Multivariate chain ruleThe mathematical foundation of every gradient-based learning algorithm - without it, backprop cannot be defined
VJP as backprop primitiveModern autodiff systems (JAX, PyTorch) are built around VJP primitives; the O(1)O(1) cost of reverse mode is what makes training billion-parameter models tractable
Computation graphstorch.compile (PyTorch 2.0), XLA (JAX/TensorFlow), TensorRT all operate by analysing the computation graph to fuse kernels and optimise memory layout
Fused softmax + CE gradientThe clean gradient pey\mathbf{p} - \mathbf{e}_y makes language model training numerically stable; Flash Attention's backward uses the same softmax log-sum-exp statistics
Xavier/He initialisationEnsures O(1)O(1) gradient scale at depth 1 or depth 96 - a critical practical enabler for deep network training
Residual connectionsThe "gradient highway" identity term in ResNets/transformers is why 100-layer networks train at all; this was the key insight enabling GPT-3's 96 layers
Gradient checkpointingEnables training LLMs with 128K context lengths; without it, the O(T2)O(T^2) activation memory would be prohibitive
FlashAttention backwardIO-aware backward pass reduces memory from O(T2)O(T^2) to O(T)O(T) while maintaining numerical equivalence; standard in all production LLM training as of 2024
LoRA backwardOnly r(m+n)mnr(m+n) \ll mn parameters accumulate gradients; enables fine-tuning 70B models on a single H100 via the low-rank backward structure
STE / REINFORCESTE enables quantisation-aware training (GPTQ, AWQ, QLoRA); REINFORCE enables RLHF's policy gradient step in PPO-based alignment training
BPTTThe failure of vanilla BPTT for long sequences motivated LSTMs, GRUs, and ultimately the attention mechanism which replaces recurrence with direct pairwise interactions
ZeRO gradient shardingPartitions gradient storage across GPUs linearly in GPU count; enables training models that would require 8×8\times more memory per GPU without it
Mixed precision backwardBF16 backward passes achieve 23×2-3\times memory bandwidth vs FP32 on H100, with dynamic loss scaling preventing underflow; standard in all LLM training since GPT-3
Higher-order gradientsGradient penalties in GANs, MAML's meta-gradient, and Hessian-vector products for learning rate scheduling all require differentiating through the backward pass

Conceptual Bridge

Where we came from: 01 (Partial Derivatives) gave us tools to differentiate multivariate functions component by component. 02 (Jacobians and Hessians) assembled those into matrix objects capturing full first- and second-order sensitivity. We now know what a derivative is for a function f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m.

What this section added: The chain rule tells us how derivatives compose - allowing us to differentiate functions built from primitives. Backpropagation is the algorithmic instantiation of this composition for computation graphs, and the VJP (reverse mode) makes the cost of differentiating a scalar loss with respect to millions of parameters equal to the cost of a single forward pass. This is not an approximation - it is exact and provably optimal.

What this enables: Every gradient-based learning algorithm - SGD, Adam, RMSprop, LARS, Shampoo - requires only the gradient θL\nabla_\theta \mathcal{L}, which backprop provides. The advanced sections of this chapter (04 Optimisation, 05 Automatic Differentiation) build directly on the VJP abstraction established here.

Connection to transformer training: Modern LLM training is essentially an exercise in efficient backpropagation at scale. Every engineering decision - Flash Attention's tiled backward, ZeRO's gradient sharding, gradient checkpointing, LoRA's low-rank backward, mixed precision loss scaling - is a response to the memory and compute constraints of the backward pass. Understanding backpropagation is therefore prerequisite to understanding why LLM training systems are designed the way they are.

POSITION IN THE CURRICULUM


  PREREQUISITES (must know):
    01 Partial Derivatives - partialf/partialx, gradient, directional derivative
    02 Jacobians & Hessians - J_f, Frchet derivative, VJP/JVP

  THIS SECTION (03):
  
    Chain Rule & Backpropagation                                   
    - Multivariate chain rule (J_{fog} = J_f * J_g)               
    - Computation graphs (DAG, topological order)                  
    - Backprop recurrence (delta = Wdelta  sigma'(z))                     
    - Gradient derivations (linear, softmax+CE, LN, attention)     
    - Vanishing/exploding gradients + solutions                    
    - Memory-efficient backprop (checkpointing, Flash Attention)   
    - Advanced: BPTT, STE, REINFORCE, higher-order gradients       
  

  WHAT THIS ENABLES:
    04 Optimisation - gradient descent, Adam, second-order methods
    05 Automatic Differentiation - AD systems, tape, jit compilation
    07 Neural Networks - full training loop built on backprop
    08 Transformer Architecture - FlashAttention, LoRA, gradient flow

  CROSS-CHAPTER CONNECTIONS:
    03-Advanced-LA/02-SVD - gradient low-rank structure
    04-Calculus/02-Derivatives - scalar chain rule (special case)
    06-Probability/03-MLE - loss functions that backprop optimises



For automatic differentiation systems that implement these ideas at scale, see 05 Automatic Differentiation.

For the optimisation algorithms that consume backprop's output, see 04 Multivariate Optimisation.


Appendix A: Worked Backpropagation Example

A.1 Complete Worked Example - 3-Layer Network

To make the backpropagation formulas concrete, we trace through a minimal example end-to-end.

Network architecture:

  • Input: x=(1,2)\mathbf{x} = (1, 2)^\top
  • Layer 1: W[1]=(0.50.10.20.3)W^{[1]} = \begin{pmatrix}0.5 & 0.1 \\ 0.2 & 0.3\end{pmatrix}, b[1]=(0,0)\mathbf{b}^{[1]} = (0, 0)^\top, activation: ReLU
  • Layer 2: w[2]=(0.4,0.6)\mathbf{w}^{[2]} = (0.4, 0.6)^\top, b[2]=0b^{[2]} = 0, activation: none (scalar output)
  • Loss: L=12(y^1)2\mathcal{L} = \tfrac{1}{2}(\hat{y} - 1)^2 (MSE with target y=1y = 1)

Forward pass:

z[1]=W[1]x=(0.51+0.120.21+0.32)=(0.70.8)\mathbf{z}^{[1]} = W^{[1]}\mathbf{x} = \begin{pmatrix}0.5 \cdot 1 + 0.1 \cdot 2 \\ 0.2 \cdot 1 + 0.3 \cdot 2\end{pmatrix} = \begin{pmatrix}0.7 \\ 0.8\end{pmatrix} a[1]=relu(z[1])=(0.70.8)(both positive, no change)\mathbf{a}^{[1]} = \text{relu}(\mathbf{z}^{[1]}) = \begin{pmatrix}0.7 \\ 0.8\end{pmatrix} \quad \text{(both positive, no change)} y^=w[2]a[1]=0.4×0.7+0.6×0.8=0.28+0.48=0.76\hat{y} = \mathbf{w}^{[2]} \cdot \mathbf{a}^{[1]} = 0.4 \times 0.7 + 0.6 \times 0.8 = 0.28 + 0.48 = 0.76 L=12(0.761)2=12(0.0576)=0.0288\mathcal{L} = \tfrac{1}{2}(0.76 - 1)^2 = \tfrac{1}{2}(0.0576) = 0.0288

Backward pass:

Output layer gradient:

Ly^=y^y=0.761=0.24\frac{\partial \mathcal{L}}{\partial \hat{y}} = \hat{y} - y = 0.76 - 1 = -0.24

Layer 2 gradients (scalar output, linear):

Lw[2]=Ly^a[1]=0.24×(0.70.8)=(0.1680.192)\frac{\partial \mathcal{L}}{\partial \mathbf{w}^{[2]}} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \mathbf{a}^{[1]} = -0.24 \times \begin{pmatrix}0.7 \\ 0.8\end{pmatrix} = \begin{pmatrix}-0.168 \\ -0.192\end{pmatrix} Lb[2]=0.24\frac{\partial \mathcal{L}}{\partial b^{[2]}} = -0.24

Error signal propagated to layer 1:

La[1]=Ly^w[2]=0.24×(0.40.6)=(0.0960.144)\frac{\partial \mathcal{L}}{\partial \mathbf{a}^{[1]}} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \mathbf{w}^{[2]} = -0.24 \times \begin{pmatrix}0.4 \\ 0.6\end{pmatrix} = \begin{pmatrix}-0.096 \\ -0.144\end{pmatrix}

Through ReLU:

δ[1]=Lz[1]=La[1]relu(z[1])=(0.0960.144)(11)=(0.0960.144)\boldsymbol{\delta}^{[1]} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{[1]}} = \frac{\partial \mathcal{L}}{\partial \mathbf{a}^{[1]}} \odot \text{relu}'(\mathbf{z}^{[1]}) = \begin{pmatrix}-0.096 \\ -0.144\end{pmatrix} \odot \begin{pmatrix}1 \\ 1\end{pmatrix} = \begin{pmatrix}-0.096 \\ -0.144\end{pmatrix}

Layer 1 weight gradients:

LW[1]=δ[1]x=(0.0960.144)(1,2)=(0.0960.1920.1440.288)\frac{\partial \mathcal{L}}{\partial W^{[1]}} = \boldsymbol{\delta}^{[1]} \mathbf{x}^\top = \begin{pmatrix}-0.096 \\ -0.144\end{pmatrix}(1, 2) = \begin{pmatrix}-0.096 & -0.192 \\ -0.144 & -0.288\end{pmatrix}

Verification (finite difference for W11[1]W^{[1]}_{11}): Perturb W11[1]W^{[1]}_{11} by h=105h = 10^{-5}:

L(W11[1]+h)L(W11[1]h)2h(0.096)\mathcal{L}(W^{[1]}_{11} + h) - \mathcal{L}(W^{[1]}_{11} - h) \approx 2h \cdot (-0.096)

Numerically: L(0.50001)0.02880960,L(0.49999)0.02881152\mathcal{L}(0.50001) \approx 0.02880960, \mathcal{L}(0.49999) \approx 0.02881152. Difference /2h=0.096/2h = -0.096.

A.2 Computational Cost Comparison

Forward pass: O(n0n1+n1n2++nL1nL)O(n_0 n_1 + n_1 n_2 + \cdots + n_{L-1} n_L) - one GEMM per layer.

Backward pass: Also O(lnl1nl)O(\sum_l n_{l-1} n_l) - same asymptotic cost, with constant factor 23\approx 2-3.

Memory: Cache all z[l]\mathbf{z}^{[l]} and a[l]\mathbf{a}^{[l]}: O(lnl)O(\sum_l n_l) scalars - linear in total neuron count.

The fundamental theorem of backpropagation: Computing θL\nabla_\theta \mathcal{L} for all θ|\theta| parameters costs only a constant factor more than computing L\mathcal{L} itself. This is the miracle that makes gradient-based learning tractable.

Formal statement: Let TfT_f be the time to evaluate L\mathcal{L} in the forward pass. Then the time to compute all partial derivatives L/θi\partial \mathcal{L}/\partial \theta_i via backprop is at most cTfc \cdot T_f where c5c \leq 5 in practice.

This contrasts with finite differences: computing L/θi\partial \mathcal{L}/\partial \theta_i for each of θ|\theta| parameters via finite differences costs 2θ2|\theta| forward passes - for GPT-3 with θ=175B|\theta| = 175\text{B}, this would be 350350 billion forward passes, or approximately the heat death of the universe in compute time.


Appendix B: JVP vs VJP - Mode Selection and Complexity

B.1 Forward Mode vs Reverse Mode

Given f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m, both modes compute the same gradient information but with different costs:

ModeComputesCost per passTotal cost for full Jacobian
Forward (JVP)One column of JfJ_fO(Tf)O(T_f)O(nTf)O(n \cdot T_f)
Reverse (VJP)One row of JfJ_fO(Tf)O(T_f)O(mTf)O(m \cdot T_f)
COST MATRIX: WHICH MODE WINS?


  Goal: compute partial/partialtheta for : R -> R (scalar loss)

  n = |theta| = 175,000,000,000  (GPT-3 parameter count)
  m = 1                       (scalar loss)

  Forward mode: m x Tf = 1 x Tf  <- ONE PASS
  Reverse mode: n x Tf = 175B x Tf  <- 175 BILLION PASSES

  Wait - that's backwards! Reverse mode (backprop) costs 1 pass
  because m=1 means ONE ROW of J_f = the gradient row vector.
  Forward mode would need n=175B passes to fill all columns.

  
    RULE: Use reverse mode (backprop) when m  n               
    RULE: Use forward mode (JVP) when n  m                    
  

  Most ML: n  m = 1 -> backprop is optimal


When forward mode wins: Computing the sensitivity of all mm outputs to one input parameter - e.g., computing how the entire model output changes as a single hyperparameter varies. Also: Jacobian-vector products in conjugate gradient (no need for the full Jacobian).

Mixed strategies: For functions f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m with nmn \approx m, the optimal choice is to split the Jacobian into row/column blocks and use each mode for the appropriate blocks - the basis of adjoint methods in numerical PDE solvers.

B.2 Tangent Mode for Hessian-Vector Products

As shown in 02, the Hessian-vector product HvH\mathbf{v} can be computed by composing forward and reverse modes:

Hv=Jgvwhereg=θLH\mathbf{v} = J_\mathbf{g}^\top \mathbf{v} \quad \text{where} \quad \mathbf{g} = \nabla_\theta \mathcal{L}

Algorithm (Pearlmutter's R{} trick, 1994):

  1. Forward pass (JVP with direction v\mathbf{v}): compute g=L\mathbf{g} = \nabla \mathcal{L} and g˙=Jgv\dot{\mathbf{g}} = J_\mathbf{g} \mathbf{v} simultaneously
  2. Cost: same as backprop (O(Tf)O(T_f)) - one pass suffices

Implementation in PyTorch:

g = torch.autograd.grad(loss, params, create_graph=True)
flat_g = torch.cat([gi.view(-1) for gi in g])
hvp = torch.autograd.grad(flat_g @ v, params)

Cost: 2 backprop passes, no n×nn \times n matrix formed. This is the primitive for:

  • Conjugate gradient for Newton steps (K-FAC-style)
  • Lanczos iteration for λmax\lambda_\text{max} of Hessian
  • Eigenvalue monitoring during training (Cohen et al., 2022 - edge of stability)

Appendix C: Automatic Differentiation Preview

C.1 The AD Abstraction

Automatic differentiation (AD) is a mechanical procedure for transforming any program that computes f(x)f(\mathbf{x}) into a program that also computes f(x)\nabla f(\mathbf{x}) (or JVPs/VJPs). This section previews the idea; the full treatment is in 05.

AD is neither symbolic differentiation (too slow, exponentially large expressions) nor numerical differentiation (finite differences - too imprecise, costs O(n)O(n) evaluations). AD exploits the fact that every program is a composition of primitives, and the chain rule tells us exactly how to compose their derivatives.

Two flavours:

SYMBOLIC DIFF           NUMERICAL DIFF          AUTO DIFF


  f(x) = x^2 + sin(x)    Compute f(x+h)          Track ops in
                          and f(x-h)              computation tape

  -> d/dx = 2x+cos(x)    -> (f(x+h)-f(x-h))/2h    -> Exact as FP allows

  Exact but expression   Approximate; costs       Exact, costs O(1)
  size can explode       O(n) evaluations          evaluations


C.2 The Tape (Wengert List)

The Wengert list (1964) records, during the forward pass, every primitive operation applied and its operands. The backward pass replays this tape in reverse, accumulating adjoints.

FORWARD TAPE EXAMPLE: f(x) = exp(x) * (x + 1)


  Tape (built during forward):
    v_1 = x            (input)
    v_2 = exp(v_1)      (op: exp,  operand: v_1)
    v_3 = v_1 + 1.0     (op: add,  operands: v_1, 1.0)
    v_4 = v_2 x v_3      (op: mul,  operands: v_2, v_3)

  Backward (replay in reverse):
    v_4 = 1.0                          (seed)
    v_2 += v_4 x v_3  = 1.0 x (x+1)    (mul backward)
    v_3 += v_4 x v_2  = 1.0 x exp(x)   (mul backward)
    v_1 += v_3 x 1.0 = exp(x)          (add backward)
    v_1 += v_2 x exp(v_1) = (x+1)exp(x) (exp backward)

  Total: v_1 = exp(x) + (x+1)exp(x) = (x+2)exp(x)  (by product rule)


PyTorch's Tensor stores a grad_fn attribute at each node - this is the tape in disguise. Calling .backward() replays the tape in reverse.

For more: See 05 Automatic Differentiation for the complete treatment of forward/reverse mode AD, source transformation, operator overloading, and the design of JAX vs PyTorch autograd.


Appendix D: Numerical Gradient Verification

In practice, every backpropagation implementation should be verified against finite differences. This appendix presents the standard toolkit.

D.1 Centred Finite Differences

For a scalar loss L:RnR\mathcal{L}: \mathbb{R}^n \to \mathbb{R} and parameter θi\theta_i:

LθiL(θ+hei)L(θhei)2h\frac{\partial \mathcal{L}}{\partial \theta_i} \approx \frac{\mathcal{L}(\boldsymbol{\theta} + h\mathbf{e}_i) - \mathcal{L}(\boldsymbol{\theta} - h\mathbf{e}_i)}{2h}

Error analysis: Centred differences have O(h2)O(h^2) error (vs O(h)O(h) for forward differences). Optimal step size hh balances truncation error (O(h2)O(h^2)) against floating-point cancellation error (O(ϵ/h)O(\epsilon/h) where ϵ\epsilon is machine epsilon):

hoptϵ1/3(1016)1/3105(for float64)h_\text{opt} \approx \epsilon^{1/3} \approx (10^{-16})^{1/3} \approx 10^{-5} \quad \text{(for float64)}

Use h=105h = 10^{-5} for float64 and h=103h = 10^{-3} for float32.

Relative error check: Accept the gradient check if:

backpropFDbackprop+FD<107(float64)\frac{\|\nabla_\text{backprop} - \nabla_\text{FD}\|}{\|\nabla_\text{backprop}\| + \|\nabla_\text{FD}\|} < 10^{-7} \quad \text{(float64)}

D.2 When Gradient Checks Fail

Common failure modes:

SymptomLikely cause
Relative error 103\sim 10^{-3} throughouth too large (truncation) or float32 precision
Relative error 1\sim 1 for specific parametersBug in backward for that parameter type
Relative error 0\sim 0 for all gradientsLoss is approximately linear in those parameters at the test point
Fails at kink (ReLU/max)Gradient not defined at z=0z=0; test point near kink; use x\mathbf{x} away from kinks
Fails only for batch size 1BatchNorm statistics degenerate; use batch size 2\geq 2 for BN checks

D.3 Gradient Check in PyTorch

from torch.autograd import gradcheck

def f(x):
    return (x ** 2).sum()

x = torch.randn(5, requires_grad=True, dtype=torch.float64)
gradcheck(f, (x,), eps=1e-6, atol=1e-4, rtol=1e-4)

gradcheck automates centred finite differences for all inputs with requires_grad=True. Always use dtype=torch.float64 for gradient checking - float32 precision is insufficient for reliable checks.


Appendix E: Key Formulas Reference

E.1 Chain Rule Summary

SettingFormula
Scalar composition(fg)(x)=f(g(x))g(x)(f \circ g)'(x) = f'(g(x)) \cdot g'(x)
Vector compositionJfg(x)=Jf(g(x))Jg(x)J_{f \circ g}(\mathbf{x}) = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x})
VJP (backprop step)xˉ=Jg(x)yˉ\bar{\mathbf{x}} = J_g(\mathbf{x})^\top \bar{\mathbf{y}}
JVP (forward step)y˙=Jg(x)x˙\dot{\mathbf{y}} = J_g(\mathbf{x}) \dot{\mathbf{x}}

E.2 Backpropagation Formulas

LayerForwardBackward (zˉ\bar{\mathbf{z}} given)
Linear z=Wx+b\mathbf{z} = W\mathbf{x} + \mathbf{b}-xˉ=Wzˉ\bar{\mathbf{x}} = W^\top \bar{\mathbf{z}}, Wˉ=zˉx\bar{W} = \bar{\mathbf{z}}\mathbf{x}^\top, bˉ=zˉ\bar{\mathbf{b}} = \bar{\mathbf{z}}
Elementwise a=σ(z)\mathbf{a} = \sigma(\mathbf{z})-zˉ=aˉσ(z)\bar{\mathbf{z}} = \bar{\mathbf{a}} \odot \sigma'(\mathbf{z})
Softmax+CEpi=ezi/ezjp_i = e^{z_i}/\sum e^{z_j}L/z=pey\partial \mathcal{L}/\partial \mathbf{z} = \mathbf{p} - \mathbf{e}_y
Residual y=x+F(x)\mathbf{y} = \mathbf{x} + F(\mathbf{x})-xˉ=yˉ+JFyˉ\bar{\mathbf{x}} = \bar{\mathbf{y}} + J_F^\top \bar{\mathbf{y}}
LayerNormx^i=(xiμ)/σ\hat{x}_i = (x_i-\mu)/\sigmaComplex (see 5.4); passes signal

E.3 Activation Derivatives

Nameσ(z)\sigma(z)σ(z)\sigma'(z)
ReLUmax(0,z)\max(0,z)1[z>0]\mathbf{1}[z>0]
Sigmoid1/(1+ez)1/(1+e^{-z})σ(1σ)\sigma(1-\sigma)
Tanh(ezez)/(ez+ez)(e^z-e^{-z})/(e^z+e^{-z})1tanh21-\tanh^2
GELUzΦ(z)z\Phi(z)Φ(z)+zϕ(z)\Phi(z)+z\phi(z)
SiLUz/(1+ez)z/(1+e^{-z})σ(z)(1+z(1σ(z)))\sigma(z)(1+z(1-\sigma(z)))
Softpluslog(1+ez)\log(1+e^z)σ(z)\sigma(z)

E.4 Initialisation Standards

MethodDistributionVarianceWhen
Xavier uniformU[a,a]U[-a,a]2/(nin+nout)2/(n_\text{in}+n_\text{out})Sigmoid, tanh
Xavier normalN(0,σ2)\mathcal{N}(0,\sigma^2)2/(nin+nout)2/(n_\text{in}+n_\text{out})Sigmoid, tanh
He uniformU[a,a]U[-a,a]2/nin2/n_\text{in}ReLU
He normalN(0,σ2)\mathcal{N}(0,\sigma^2)2/nin2/n_\text{in}ReLU
GPT-2 residualN(0,(0.02/2L)2)\mathcal{N}(0,(0.02/\sqrt{2L})^2)-Transformer residuals

Appendix F: Deep Dive - Vanishing Gradients in Transformers

F.1 Why Transformers Don't Vanish

A naive reading of the vanishing gradient analysis (6.1) suggests that 96-layer transformers should suffer catastrophic vanishing. They don't. Here is why.

The residual stream analysis: In a pre-norm transformer, the residual stream after layer ll is:

x[l]=x[0]+k=1lF[k](x[k1])\mathbf{x}^{[l]} = \mathbf{x}^{[0]} + \sum_{k=1}^{l} F^{[k]}(\mathbf{x}^{[k-1]})

where F[k]F^{[k]} is the kk-th sublayer (attention or MLP, wrapped in LayerNorm).

The gradient of the loss with respect to the input is:

Lx[0]=Lx[L]l=1L(I+JF[l])\frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[0]}} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}^{[L]}} \cdot \prod_{l=1}^{L}\left(I + J_{F^{[l]}}\right)

At initialisation, the transformer weights are small, so JF[l]0J_{F^{[l]}} \approx 0 and the product (I+JF[l])I\prod(I + J_{F^{[l]}}) \approx I. The gradient flows back unchanged through all LL layers. This is categorically different from a plain deep network where the product of small Jacobians vanishes.

Gradient norm growth: As training progresses and weights grow, JF[l]J_{F^{[l]}} becomes nontrivial. The gradient norm may grow with depth, but this is controlled by:

  1. LayerNorm dampening (see 6.6)
  2. GPT-2's 1/2L1/\sqrt{2L} scaling of residual projections
  3. Gradient clipping (τ=1.0\tau = 1.0)

The "edge of stability" phenomenon (Cohen et al., 2022): In practice, the maximum Hessian eigenvalue λmax\lambda_\text{max} often approaches 2/η2/\eta (twice the inverse learning rate) and oscillates there. This is a gradient flow regime where the training dynamics are neither fully stable nor unstable, and gradients are large enough to cause oscillation but not divergence.

F.2 Gradient Norm as Training Signal

Modern LLM training monitors gradient norm at every step. Typical patterns:

GRADIENT NORM DURING LLM TRAINING


  nablatheta_2
    
      spike (loss spike)
        
  1                clip threshold
     
       normal training
    
     steps

  Patterns:
  - Steady nablatheta < 1: healthy training, clipping inactive
  - Sudden spike -> loss spike -> recovery: numerical event
    (often a "bad" batch; LLM training has ~1-3 such events
     per trillion tokens at scale)
  - Slow upward drift: learning rate may be too high


Loss spike mitigation: When the gradient norm exceeds the clip threshold, the entire gradient update is scaled down. If the spike is from a corrupted batch, this prevents permanent damage to the model weights.

Gradient accumulation and norm: When using GG accumulation steps, each micro-batch contributes 1/G1/G of the gradient. The global norm is computed across the accumulated gradient (after summation, before the optimiser step) - not across individual micro-batches.

F.3 Per-Layer Gradient Norm Analysis

For diagnostic purposes, logging the gradient norm per layer reveals:

  • Embedding gradients: Often the largest, due to sparse updates (5.6)
  • Early layers: Smallest (furthest from loss); potential vanishing
  • Late layers: Largest; potential exploding
  • LayerNorm parameters: Very small - γ\boldsymbol{\gamma} and β\boldsymbol{\beta} converge quickly

This per-layer analysis guided the design of:

  • LARS/LAMB optimisers (You et al., 2017): layer-wise adaptive learning rates based on weight-to-gradient ratio
  • Muon (2024): applies Newton step in gradient space with Nesterov momentum; designed for hidden layers while AdamW handles embedding and output

Appendix G: Historical Development

G.1 Timeline of Backpropagation

The development of backpropagation spans three centuries and multiple independent discoveries:

YearEventSignificance
1676Leibniz publishes differential calculus (chain rule for single variable)Mathematical foundation
1744Euler uses variational methods (antecedent of reverse mode)First "adjoint" idea
1847Cauchy introduces gradient descentThe algorithm backprop serves
1960Kalman filter (reverse-mode for linear dynamical systems)AD in engineering
1964Wengert introduces the "reverse accumulation" algorithmFirst explicit AD
1970Linnainmaa's thesis: general backpropagationFull theoretical framework
1974Werbos PhD thesis: backprop for neural networksConnection to ML
1982Hopfield networks (energy-based models with gradient)Alternative to backprop
1986Rumelhart, Hinton & Williams - "Learning representations by back-propagating errors"Popularised backprop for NNs
1991Hochreiter: vanishing gradient problem analysedIdentified depth barrier
1997LSTM: gating to address vanishing gradient in RNNsFirst scalable deep sequence model
2012AlexNet: backprop on GPU at scalePractical deep learning
2015ResNets: residual connections for gradient flowEnabled 100+ layer networks
2016PyTorch / TensorFlow 1.0: autodiff frameworksDemocratised backprop
2017Transformers: attention replaces BPTTSolved long-range vanishing
2018JAX: functional autodiff, JIT compilationResearch-grade AD
2022FlashAttention: IO-aware backward passEfficient O(T2)O(T^2) attention backward
2022PyTorch 2.0 torch.compileGraph-based kernel fusion
2023FlashAttention-2: improved GPU utilisationStandard for production
2024FlashAttention-3: H100-optimised with asyncState-of-art attention backward

G.2 The Independent Discoveries

Backpropagation was independently discovered at least four times before becoming widely known:

  1. Linnainmaa (1970): In his master's thesis, presented the general algorithm for computing exact partial derivatives of any function composed of elementary operations - precisely what we today call reverse-mode AD.

  2. Werbos (1974): Applied the same idea to multi-layer neural networks in his PhD thesis, but the work was largely ignored for over a decade.

  3. Parker (1985): Independently rediscovered backpropagation for neural networks.

  4. Rumelhart, Hinton & Williams (1986): Published the algorithm in Nature and produced the critical experimental demonstrations that convinced the community it could work. Their paper is the one most often cited today.

This pattern of independent rediscovery is common in mathematics - the ideas are "in the air" once the prerequisites are established. The chain rule (1676) + computation graphs (1960s) + gradient descent (1847) = backpropagation (inevitable).

G.3 The Hardware-Algorithm Co-evolution

The practical impact of backpropagation depends critically on hardware:

  • CPU era (1986-2011): Backprop is theoretically valid but computationally slow. Networks with more than 3-4 layers were impractical.
  • GPU era (2012-present): NVIDIA's CUDA (2007) enables massively parallel GEMM operations. The bottleneck shifts from FLOPS to memory bandwidth.
  • Tensor core era (2017-present): NVIDIA Volta/Ampere/Hopper GPUs have dedicated matrix multiply accelerators. FP16/BF16 tensor cores achieve 10x the throughput of FP32.
  • Memory wall: As models scale, the backward pass's memory requirements dominate. FlashAttention, ZeRO, gradient checkpointing all address the memory wall.

The 2024 FLOP/memory ratio in H100 GPUs (2000\sim 2000 TFLOPS vs 3.35\sim 3.35 TB/s bandwidth) means that memory access, not computation, is the primary bottleneck for backprop at scale. This fundamental constraint is why FlashAttention's IO-aware design is so impactful.


Appendix H: Connections to Optimisation and Learning Theory

H.1 What the Gradient Tells Us

The gradient θL\nabla_\theta \mathcal{L} computed by backpropagation is the direction of steepest ascent in parameter space (by the first-order Taylor expansion). Gradient descent moves in the opposite direction:

θθηθL(θ)\theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}(\theta)

What the gradient does NOT tell us:

  • The curvature of the loss landscape (need Hessian for that)
  • The optimal step size η\eta
  • Whether we are near a local minimum, saddle point, or maximum
  • Whether the gradient is statistically well-estimated (needs large enough batch)

What the gradient DOES tell us:

  • The direction of maximal increase (used negated for descent)
  • The sensitivity of the loss to each parameter
  • Which parameters are "active" (nonzero gradient) vs. saturated (near-zero gradient)

H.2 Gradient Stochasticity

In practice, the true gradient θE[L]\nabla_\theta \mathbb{E}[\mathcal{L}] over the full data distribution is approximated by the stochastic gradient over a mini-batch:

g^B=1Bb=1BθL(x(b),y(b))\hat{g}_B = \frac{1}{B}\sum_{b=1}^B \nabla_\theta \mathcal{L}(\mathbf{x}^{(b)}, \mathbf{y}^{(b)})

This is an unbiased estimator: E[g^B]=θE[L]\mathbb{E}[\hat{g}_B] = \nabla_\theta \mathbb{E}[\mathcal{L}].

Variance: Var(g^B)=Var(θL)/B\text{Var}(\hat{g}_B) = \text{Var}(\nabla_\theta \mathcal{L})/B. Larger batches have lower gradient variance (more accurate gradient estimate) but provide diminishing returns beyond the "critical batch size" (McCandlish et al., 2018).

For LLMs: The critical batch size for GPT-3-scale models is approximately B14B^* \approx 1-4 million tokens. Training at this batch size achieves the best loss-per-FLOP tradeoff. Using larger batches wastes compute; using smaller batches wastes gradient estimation quality.

H.3 The Gradient as a Sufficient Statistic

For first-order optimisers (SGD, Adam, AdaGrad, RMSprop), the gradient is the only information extracted from the forward-backward pass. Second-order information (Hessian curvature) is either ignored or approximated.

Why not use the full Hessian? For θ=70B|\theta| = 70\text{B} parameters, the Hessian is a 70B×70B70\text{B} \times 70\text{B} matrix - 1022\sim 10^{22} entries. Storing it is impossible (102210^{22} FP32 values ~= 4×10224 \times 10^{22} bytes ~= 40×102140 \times 10^{21} GB). Inverting it is even more impossible.

Practical second-order methods use approximations:

  • Diagonal: AdaGrad/Adam maintain diagonal Hessian approximations (O(θ)O(|\theta|) memory)
  • Kronecker factored: K-FAC (see 02) uses AGA \otimes G per layer (O(n2)O(n^2) per layer)
  • Low-rank: PSGD, Shampoo maintain low-rank or block-diagonal approximations
  • Newton-Schulz: Muon (2024) approximates the matrix square root H1/2H^{-1/2} efficiently

H.4 Generalisation and the Implicit Gradient Bias

Gradient descent with small learning rate and large mini-batches does not merely find any minimum - it has an implicit bias toward flat minima (large regions with low loss) over sharp minima (narrow valleys).

Conjecture (Keskar et al., 2017): Flat minima generalise better because small perturbations to the parameters don't change the loss much - robust to noise in the data.

Mathematical foundation: The SGD noise g^Bg\hat{g}_B - g effectively adds a regularisation term proportional to ηB1tr(H)\eta B^{-1} \text{tr}(H) - the trace of the Hessian - biasing toward flat (low-trace-Hessian) minima.

This connects gradient computation (the topic of this section) to generalisation theory (a major open question in deep learning theory) - a reminder that understanding backpropagation fully requires understanding not just the mechanics, but the geometry of the loss landscape it navigates.


Appendix I: Practical Implementation Guide

I.1 Implementing Backprop from Scratch

When building a neural network framework from scratch, implement these components in order:

1. Primitive registry:

primitives = {}

def register_primitive(name, forward_fn, backward_fn):
    """Register a primitive op with its VJP."""
    primitives[name] = (forward_fn, backward_fn)

# Example: multiplication primitive
def mul_forward(x, y): return x * y
def mul_backward(x, y, g_out): return g_out * y, g_out * x  # (g_x, g_y)
register_primitive('mul', mul_forward, mul_backward)

2. Value class with gradient tracking:

class Value:
    def __init__(self, data, parents=(), op=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None  # closure capturing parents
        self._parents = parents
        self._op = op
    
    def __mul__(self, other):
        out = Value(self.data * other.data, (self, other), 'mul')
        def _backward():
            self.grad += other.data * out.grad   # VJP for self
            other.grad += self.data * out.grad   # VJP for other
        out._backward = _backward
        return out
    
    def backward(self):
        # Topological sort, then reverse
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for p in v._parents: build_topo(p)
                topo.append(v)
        build_topo(self)
        self.grad = 1.0
        for v in reversed(topo): v._backward()

This is essentially the complete autograd engine from Karpathy's micrograd (2020) - approximately 100 lines implement a working backprop engine.

3. Building blocks: Extend Value with __add__, __pow__, exp, log, relu, softmax - each with its VJP closure.

I.2 Common Implementation Bugs

Bug 1: Overwriting instead of accumulating gradients

# Wrong:
self.grad = other.data * out.grad   # erases previous contributions!

# Correct:
self.grad += other.data * out.grad  # accumulates (fan-out nodes)

Bug 2: Forgetting to zero gradients between batches

# Wrong: gradient accumulates across batches
loss = model(x)
loss.backward()
optimizer.step()

# Correct:
optimizer.zero_grad()   # <- must come before backward
loss = model(x)
loss.backward()
optimizer.step()

Bug 3: Not detaching from the graph for inference

# Wrong: builds graph unnecessarily during inference
with torch.no_grad():  # <- this is the correct fix
    prediction = model(x)

Bug 4: Shape mismatch in weight gradient

# Wrong: grad_W and W may have different shapes
grad_W = delta @ x    # (n_out, 1) @ (1, n_in) only works for batch size 1

# Correct: outer product for single sample
grad_W = np.outer(delta, x)  # (n_out, n_in)

# Correct: batched
grad_W = (1/B) * Delta @ X.T  # (n_out, B) @ (B, n_in) = (n_out, n_in)

I.3 Testing Checklist

Before deploying any backprop implementation:

  • Gradient check passes for all primitive operations (relative error <106< 10^{-6})
  • Loss decreases monotonically for small enough learning rate (verify on toy problem)
  • Gradients are zero for frozen parameters
  • Gradient accumulation at fan-out nodes verified (shared weight receives sum)
  • Shape of each gradient matches shape of corresponding parameter
  • Memory usage is O(num_layers)O(\text{num\_layers}) not O(num_layers2)O(\text{num\_layers}^2)
  • Higher-order gradients work if needed (use create_graph=True in PyTorch)
  • Mixed precision: FP16 forward, FP32 gradient accumulation, loss scaling in place

Appendix J: Connections to Information Theory and Statistics

J.1 Fisher Information and the Natural Gradient

The ordinary gradient θL\nabla_\theta \mathcal{L} measures the steepest direction in parameter space with respect to the Euclidean metric. But parameter space has a natural metric induced by the probability distribution pθp_\theta - the Fisher information metric.

Fisher information matrix:

F(θ)=Expθ[θlogpθ(x)(θlogpθ(x))]F(\theta) = \mathbb{E}_{x \sim p_\theta}\left[\nabla_\theta \log p_\theta(x) \, (\nabla_\theta \log p_\theta(x))^\top\right]

Natural gradient (Amari, 1998):

~θL=F(θ)1θL\tilde{\nabla}_\theta \mathcal{L} = F(\theta)^{-1} \nabla_\theta \mathcal{L}

The natural gradient is the steepest direction in the distributional geometry of the model - invariant to reparametrisation. Computing it exactly requires inverting FF, which costs O(θ3)O(|\theta|^3).

K-FAC (02) approximates F1F^{-1} as a Kronecker product, making the natural gradient step tractable. It remains the most principled second-order optimiser for neural networks.

For LLMs: The approximation used in practice is Adam's diagonal F1F^{-1} (second moment of gradient as proxy for diagonal Fisher). This is crude but sufficient - Adam is a diagonal natural gradient step.

J.2 Gradient as Score Function

For a probabilistic model pθ(x)p_\theta(\mathbf{x}), the gradient of the log-likelihood is the score function:

s(x;θ)=θlogpθ(x)s(\mathbf{x}; \theta) = \nabla_\theta \log p_\theta(\mathbf{x})

The score function is the quantity computed by backpropagation during maximum likelihood estimation. Properties:

  • Expθ[s(x;θ)]=0\mathbb{E}_{x \sim p_\theta}[s(\mathbf{x};\theta)] = 0 (score has zero mean)
  • Varxpθ[s(x;θ)]=F(θ)\text{Var}_{x \sim p_\theta}[s(\mathbf{x};\theta)] = F(\theta) (Fisher information = variance of score)

For language models: The negative log-likelihood L=logpθ(yx)\mathcal{L} = -\log p_\theta(\mathbf{y}|\mathbf{x}) has gradient s(yx;θ)=(pyey)=eypy-s(\mathbf{y}|\mathbf{x};\theta) = -(p_y - e_y)^\top = e_y - p_y - the same pey\mathbf{p} - \mathbf{e}_y formula from 5.3, now understood as the negative score.

J.3 KL Divergence and the Gradient of ELBO

In variational inference and RL (RLHF), we often need gradients of KL divergences. For discrete distributions:

KL(pq)=xp(x)logp(x)q(x)\text{KL}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)} ϕKL(pθqϕ)=Expθ[ϕlogqϕ(x)]\nabla_\phi \text{KL}(p_\theta \| q_\phi) = -\mathbb{E}_{x \sim p_\theta}\left[\nabla_\phi \log q_\phi(x)\right]

This is computed via backprop through the log-probability of the policy under KL regularisation - the precise form used in RLHF's PPO loss, which includes a KL penalty between the fine-tuned policy πϕ\pi_\phi and the reference model πref\pi_\text{ref}.


References

  1. Rumelhart, Hinton & Williams (1986) - "Learning representations by back-propagating errors." Nature, 323, 533-536. The canonical backpropagation paper.

  2. Linnainmaa, S. (1970) - "The representation of the cumulative rounding error of an algorithm as a Taylor expansion of the local rounding errors." Master's thesis, University of Helsinki. First general reverse-mode AD.

  3. Hochreiter, S. (1991) - "Untersuchungen zu dynamischen neuronalen Netzen." Diploma thesis, TU Munich. First analysis of vanishing gradients.

  4. Glorot, X. & Bengio, Y. (2010) - "Understanding the difficulty of training deep feedforward neural networks." AISTATS. Xavier initialisation.

  5. He, K. et al. (2015) - "Delving Deep into Rectifiers." ICCV. He initialisation for ReLU networks.

  6. He, K. et al. (2016) - "Deep Residual Learning for Image Recognition." CVPR. ResNets and gradient highways.

  7. Ba, J. et al. (2016) - "Layer Normalization." arXiv:1607.06450. LayerNorm for transformers.

  8. Vaswani, A. et al. (2017) - "Attention Is All You Need." NeurIPS. Transformer architecture with attention backward pass.

  9. Amari, S. (1998) - "Natural Gradient Works Efficiently in Learning." Neural Computation. Natural gradient and Fisher information.

  10. Martens, J. & Grosse, R. (2015) - "Optimizing Neural Networks with Kronecker-factored Approximate Curvature." ICML. K-FAC.

  11. Hu, E. et al. (2022) - "LoRA: Low-Rank Adaptation of Large Language Models." ICLR. LoRA backward pass.

  12. Dao, T. et al. (2022) - "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS. IO-aware backward for attention.

  13. Cohen, J. et al. (2022) - "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability." ICLR. Edge of stability phenomenon.

  14. Dao, T. (2023) - "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. FlashAttention-2.

  15. Liu, S. et al. (2024) - "DoRA: Weight-Decomposed Low-Rank Adaptation." DoRA backward analysis.


Appendix K: Summary Tables

K.1 Backpropagation Algorithm Summary

COMPLETE BACKPROPAGATION ALGORITHM


  INPUT: network weights theta, training pair (x, y)

  PHASE 1 - FORWARD PASS
  
  a^0 = x
  For l = 1, 2, ..., L:
    z^l = W^l a^{l-1} + b^l       (cache z^l and a^{l-1})
    a^l = sigma^l(z^l)                 (cache a^l)
   = a^L
   = loss(, y)

  PHASE 2 - BACKWARD PASS
  
  delta^L = partial/partialz^L                  (output layer gradient, layer-specific)
  For l = L-1, L-2, ..., 1:
    delta^l = (W^{l+1}) delta^{l+1}  sigma'^l(z^l)

  PHASE 3 - GRADIENT ASSEMBLY
  
  For l = 1, 2, ..., L:
    nabla_{W^l}  = delta^l (a^{l-1})
    nabla_{b^l}  = delta^l

  PHASE 4 - PARAMETER UPDATE
  
  theta <- theta - eta * nabla_theta              (or Adam/RMSprop update)


K.2 Complexity Summary

OperationTimeMemory
Forward pass (L layers, width n)O(Ln2)O(Ln^2)O(Ln)O(Ln) cached activations
Backward passO(Ln2)O(Ln^2)O(Ln)O(Ln) error signals
Full Jacobian via FD$O(\theta
Full Jacobian via backpropO(mTf)O(m \cdot T_f)O(L)O(L)
Hessian-vector productO(Tf)O(T_f)O(L)O(L)
Gradient checkpointingO(1.33Tf)O(1.33 T_f)O(L)O(\sqrt{L})
FlashAttention forwardO(T2d)O(T^2 d)O(T)O(T)
FlashAttention backwardO(T2d)O(T^2 d)O(T)O(T)

K.3 Gradient Flow Interventions

ProblemDiagnosisIntervention
Vanishing gradientsδ[1]1\|\boldsymbol{\delta}^{[1]}\| \ll 1ReLU/GELU, He init, residual connections
Exploding gradientsδ[1]1\|\boldsymbol{\delta}^{[1]}\| \gg 1Gradient clipping, LR warmup
Dead neuronsδ[l]=0\|\boldsymbol{\delta}^{[l]}\| = 0 for layerLeaky ReLU, better init, BN
Slow convergenceθL0\|\nabla_\theta \mathcal{L}\| \approx 0 at saddleMomentum, Adam, noise injection
Oscillating lossθL\|\nabla_\theta \mathcal{L}\| spikesReduce LR, increase batch
NaN gradientsθL=\|\nabla_\theta \mathcal{L}\| = \inftyLoss scaling, check log/softmax