Training at ScaleMath for LLMs

Training at Scale

Math for LLMs

Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Training at Scale
25 min read18 headingsSplit lesson page

Lesson overview | Lesson overview | Next part

Training at Scale: Part 1: Scale as a Constraint Problem to 5. Parallelism Strategies

1. Scale as a Constraint Problem

This part focuses on scale as a constraint problem as a practical mathematical constraint in LLM training. The goal is not to memorize infrastructure names, but to understand the formulas that determine whether a run fits, learns, communicates, and resumes.

SubtopicOperational questionFormula
The same loss at larger costtraining at scale still minimizes next-token cross-entropyL(θ)=ilogpθ(tit<i)L(\theta)=-\sum_i \log p_\theta(t_i\mid t_{<i})
Four limiting resourcesmemory, compute, bandwidth, and data quality each become a bottlenecktimemax(Tcompute,Tcomm,Tinput)\mathrm{time}\approx\max(T_\mathrm{compute},T_\mathrm{comm},T_\mathrm{input})
Parameter, optimizer, and activation memoryweights are only one part of training memoryM=Mparams+Mgrads+Mopt+MactM=M_\mathrm{params}+M_\mathrm{grads}+M_\mathrm{opt}+M_\mathrm{act}
Throughput versus convergencefast tokens per second are useful only if loss improvestokens/sec\mathrm{tokens/sec} must be read with L(tokens)L(\mathrm{tokens})
Failure modeslarge training fails by divergence, stalls, bad data, communication bottlenecks, or checkpoint lossΔL>0\Delta L>0 for many steps is a symptom, not a diagnosis

1.1 The same loss at larger cost

Main idea. Training at scale still minimizes next-token cross-entropy.

Core relation:

L(θ)=ilogpθ(tit<i)L(\theta)=-\sum_i \log p_\theta(t_i\mid t_{<i})

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

1.2 Four limiting resources

Main idea. Memory, compute, bandwidth, and data quality each become a bottleneck.

Core relation:

timemax(Tcompute,Tcomm,Tinput)\mathrm{time}\approx\max(T_\mathrm{compute},T_\mathrm{comm},T_\mathrm{input})

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

1.3 Parameter, optimizer, and activation memory

Main idea. Weights are only one part of training memory.

Core relation:

M=Mparams+Mgrads+Mopt+MactM=M_\mathrm{params}+M_\mathrm{grads}+M_\mathrm{opt}+M_\mathrm{act}

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

1.4 Throughput versus convergence

Main idea. Fast tokens per second are useful only if loss improves.

Core relation:

\mathrm{tokens/sec}$ must be read with $L(\mathrm{tokens})

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

1.5 Failure modes

Main idea. Large training fails by divergence, stalls, bad data, communication bottlenecks, or checkpoint loss.

Core relation:

\Delta L>0$ for many steps is a symptom, not a diagnosis

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

2. Optimization Core

This part focuses on optimization core as a practical mathematical constraint in LLM training. The goal is not to memorize infrastructure names, but to understand the formulas that determine whether a run fits, learns, communicates, and resumes.

SubtopicOperational questionFormula
Mini-batch gradientdistributed workers estimate the same gradient with different data shardsgt=1Bi=1Bθig_t=\frac{1}{B}\sum_{i=1}^{B}\nabla_\theta \ell_i
Adam momentsfirst and second moment estimates adapt update scalemt=β1mt1+(1β1)gt,vt=β2vt1+(1β2)gt2m_t=\beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2v_{t-1}+(1-\beta_2)g_t^2
Bias correctionearly moments are corrected because they start at zerom^t=mt/(1β1t),v^t=vt/(1β2t)\hat m_t=m_t/(1-\beta_1^t),\quad \hat v_t=v_t/(1-\beta_2^t)
AdamWweight decay is applied outside the adaptive gradient ratioθt+1=θtηm^t/(v^t+ϵ)ηλθt\theta_{t+1}=\theta_t-\eta\hat m_t/(\sqrt{\hat v_t}+\epsilon)-\eta\lambda\theta_t
Gradient clippingcap update norm when rare batches produce spikesggmin(1,c/g2)g\leftarrow g\min(1,c/\Vert g\Vert_2)

2.1 Mini-batch gradient

Main idea. Distributed workers estimate the same gradient with different data shards.

Core relation:

gt=1Bi=1Bθig_t=\frac{1}{B}\sum_{i=1}^{B}\nabla_\theta \ell_i

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

2.2 Adam moments

Main idea. First and second moment estimates adapt update scale.

Core relation:

mt=β1mt1+(1β1)gt,vt=β2vt1+(1β2)gt2m_t=\beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2v_{t-1}+(1-\beta_2)g_t^2

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

2.3 Bias correction

Main idea. Early moments are corrected because they start at zero.

Core relation:

m^t=mt/(1β1t),v^t=vt/(1β2t)\hat m_t=m_t/(1-\beta_1^t),\quad \hat v_t=v_t/(1-\beta_2^t)

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

2.4 AdamW

Main idea. Weight decay is applied outside the adaptive gradient ratio.

Core relation:

θt+1=θtηm^t/(v^t+ϵ)ηλθt\theta_{t+1}=\theta_t-\eta\hat m_t/(\sqrt{\hat v_t}+\epsilon)-\eta\lambda\theta_t

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

2.5 Gradient clipping

Main idea. Cap update norm when rare batches produce spikes.

Core relation:

ggmin(1,c/g2)g\leftarrow g\min(1,c/\Vert g\Vert_2)

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. Clipping is not a cure for a bad run, but it can prevent one rare batch from destroying useful optimizer state.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

3. Batching and Schedules

This part focuses on batching and schedules as a practical mathematical constraint in LLM training. The goal is not to memorize infrastructure names, but to understand the formulas that determine whether a run fits, learns, communicates, and resumes.

SubtopicOperational questionFormula
Effective batch sizeglobal batch combines devices and accumulation stepsBeff=BdeviceGaccumNdpB_\mathrm{eff}=B_\mathrm{device}G_\mathrm{accum}N_\mathrm{dp}
Gradient accumulationseveral micro-batches approximate one larger batchg=1Kk=1Kgkg=\frac{1}{K}\sum_{k=1}^{K}g_k
Linear warmupthe learning rate starts small to avoid early instability\eta_t=\eta_\max t/T_\mathrm{warmup}
Cosine decaythe learning rate anneals smoothly after warmup\eta_t=\eta_\min+\frac{1}{2}(\eta_\max-\eta_\min)(1+\cos(\pi s))
Critical batch intuitionpast a point, larger batches waste compute rather than reducing noise usefullynoise1/B\mathrm{noise}\propto 1/B only in the useful regime

3.1 Effective batch size

Main idea. Global batch combines devices and accumulation steps.

Core relation:

Beff=BdeviceGaccumNdpB_\mathrm{eff}=B_\mathrm{device}G_\mathrm{accum}N_\mathrm{dp}

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

3.2 Gradient accumulation

Main idea. Several micro-batches approximate one larger batch.

Core relation:

g=1Kk=1Kgkg=\frac{1}{K}\sum_{k=1}^{K}g_k

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

3.3 Linear warmup

Main idea. The learning rate starts small to avoid early instability.

Core relation:

\eta_t=\eta_\max t/T_\mathrm{warmup}

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

3.4 Cosine decay

Main idea. The learning rate anneals smoothly after warmup.

Core relation:

\eta_t=\eta_\min+\frac{1}{2}(\eta_\max-\eta_\min)(1+\cos(\pi s))

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

3.5 Critical batch intuition

Main idea. Past a point, larger batches waste compute rather than reducing noise usefully.

Core relation:

\mathrm{noise}\propto 1/B$ only in the useful regime

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

4. Memory Accounting

This part focuses on memory accounting as a practical mathematical constraint in LLM training. The goal is not to memorize infrastructure names, but to understand the formulas that determine whether a run fits, learns, communicates, and resumes.

SubtopicOperational questionFormula
Bytes per parameterbf16 weights use 2 bytes but Adam states are often fp32MAdam2P+2P+8PM_\mathrm{Adam}\approx 2P + 2P + 8P bytes
Activation memorystored forward activations can dominate at long contextMactBTLdM_\mathrm{act}\propto BTLd
Activation checkpointingsave memory by recomputing intermediate activationsMact,TcomputeM_\mathrm{act}\downarrow,\quad T_\mathrm{compute}\uparrow
Optimizer state shardingZeRO/FSDP shard model states across data-parallel ranksMper rankM/NM_\mathrm{per\ rank}\approx M/N for fully sharded states
Offload boundaryCPU or NVMe offload trades memory for bandwidth and latencyTstepT_\mathrm{step} can become transfer-bound

4.1 Bytes per parameter

Main idea. Bf16 weights use 2 bytes but adam states are often fp32.

Core relation:

M_\mathrm{Adam}\approx 2P + 2P + 8P$ bytes

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

4.2 Activation memory

Main idea. Stored forward activations can dominate at long context.

Core relation:

MactBTLdM_\mathrm{act}\propto BTLd

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

4.3 Activation checkpointing

Main idea. Save memory by recomputing intermediate activations.

Core relation:

Mact,TcomputeM_\mathrm{act}\downarrow,\quad T_\mathrm{compute}\uparrow

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

4.4 Optimizer state sharding

Main idea. Zero/fsdp shard model states across data-parallel ranks.

Core relation:

M_\mathrm{per\ rank}\approx M/N$ for fully sharded states

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This is why a model that cannot fit on one accelerator can still be trained across many accelerators.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

4.5 Offload boundary

Main idea. Cpu or nvme offload trades memory for bandwidth and latency.

Core relation:

T_\mathrm{step}$ can become transfer-bound

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

5. Parallelism Strategies

This part focuses on parallelism strategies as a practical mathematical constraint in LLM training. The goal is not to memorize infrastructure names, but to understand the formulas that determine whether a run fits, learns, communicates, and resumes.

SubtopicOperational questionFormula
Data parallelismreplicate the model and all-reduce gradientsg=1Nr=1Ngrg=\frac{1}{N}\sum_{r=1}^{N}g_r
Tensor parallelismsplit matrix multiplications across devicesY=X[W1 W2]Y=X[W_1\ W_2] or Y=XW1+XW2Y=XW_1+XW_2 depending on layout
Pipeline parallelismplace layers on stages and stream micro-batchesbubble(P1)/(M+P1)\mathrm{bubble}\approx(P-1)/(M+P-1)
Sequence parallelismsplit sequence-length work when activations are too largeTT is partitioned across ranks
Parallelism productlarge jobs combine data, tensor, pipeline, and sometimes sequence parallelismNtotal=NdpNtpNppNspN_\mathrm{total}=N_\mathrm{dp}N_\mathrm{tp}N_\mathrm{pp}N_\mathrm{sp}

5.1 Data parallelism

Main idea. Replicate the model and all-reduce gradients.

Core relation:

g=1Nr=1Ngrg=\frac{1}{N}\sum_{r=1}^{N}g_r

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

5.2 Tensor parallelism

Main idea. Split matrix multiplications across devices.

Core relation:

Y=X[W_1\ W_2]$ or $Y=XW_1+XW_2$ depending on layout

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

5.3 Pipeline parallelism

Main idea. Place layers on stages and stream micro-batches.

Core relation:

bubble(P1)/(M+P1)\mathrm{bubble}\approx(P-1)/(M+P-1)

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

5.4 Sequence parallelism

Main idea. Split sequence-length work when activations are too large.

Core relation:

T$ is partitioned across ranks

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

5.5 Parallelism product

Main idea. Large jobs combine data, tensor, pipeline, and sometimes sequence parallelism.

Core relation:

Ntotal=NdpNtpNppNspN_\mathrm{total}=N_\mathrm{dp}N_\mathrm{tp}N_\mathrm{pp}N_\mathrm{sp}

At small scale, this relation may feel like bookkeeping. At LLM scale, it becomes a hard constraint. A missing factor of two in a memory estimate can decide whether the job starts. A wrong batch-size convention can change the optimization regime. A poor communication plan can leave expensive accelerators idle.

Worked micro-example. Suppose a dense model has P=7P=7 billion parameters. bf16 weights alone require about 2P2P bytes, or roughly 14 GB. Training with Adam usually also needs gradients and two optimizer moment tensors. If the moments are fp32, the optimizer state adds about 8P8P bytes, before activations. That is why "weights fit" is not the same as "training fits."

Implementation check. Write down the unit. Is the number per parameter, per token, per device, per data-parallel rank, per step, or per full run? Most scale-training bugs are not exotic math errors; they are unit and axis errors.

AI connection. This formula is part of the control surface for a large training run.

Common mistake. Do not optimize one metric in isolation. More tokens per second can be bad if validation loss stops improving, and lower memory can be bad if recomputation makes the step too slow.

PreviousNext