Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Exercises Notebook
2 min read18 headings

Exercises Notebook

Converted from exercises.ipynb for web reading.

Training at Scale: Exercises

These ten exercises train the accounting skills that matter before a large LLM run: optimizer steps, clipping, schedules, memory, parallelism, FLOPs, MFU, and launch checks.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Exercise 1: AdamW scalar update

Compute one AdamW update for a scalar parameter.

Code cell 4

# Your Solution
theta = 2.0
grad = 0.5
print("Starter: update m, v, bias-correct them, then apply AdamW.")

Code cell 5

# Solution
theta = 2.0
grad = 0.5
beta1, beta2 = 0.9, 0.999
lr, eps, wd = 1e-3, 1e-8, 0.01
m = (1 - beta1) * grad
v = (1 - beta2) * grad**2
m_hat = m / (1 - beta1)
v_hat = v / (1 - beta2)
theta_next = theta - lr * m_hat / (np.sqrt(v_hat) + eps) - lr * wd * theta
print("theta_next:", theta_next)

Exercise 2: Clip a gradient

Clip vector [6,8][6,8] to norm 5.

Code cell 7

# Your Solution
g = np.array([6.0, 8.0])
print("Starter: multiply by min(1, 5 / norm(g)).")

Code cell 8

# Solution
g = np.array([6.0, 8.0])
scale = min(1.0, 5.0 / np.linalg.norm(g))
clipped = g * scale
print("clipped:", clipped)
print("norm:", np.linalg.norm(clipped))
assert np.isclose(np.linalg.norm(clipped), 5.0)

Exercise 3: Warmup schedule value

Find LR at step 50 with 100-step warmup and peak 0.0003.

Code cell 10

# Your Solution
step = 50
warmup = 100
peak = 3e-4
print("Starter: during warmup, lr = peak * (step + 1) / warmup.")

Code cell 11

# Solution
step = 50
warmup = 100
peak = 3e-4
lr = peak * (step + 1) / warmup
print("lr:", lr)
assert lr < peak

Exercise 4: Effective token batch

Compute tokens per optimizer step.

Code cell 13

# Your Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
print("Starter: multiply micro_batch, seq_len, dp, and accum.")

Code cell 14

# Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
tokens = micro_batch * seq_len * dp * accum
print("tokens per optimizer step:", tokens)

Exercise 5: Memory estimate

Estimate replicated Adam training state for 1B parameters with bf16 weights/grads and fp32 moments.

Code cell 16

# Your Solution
P = 1e9
print("Starter: weights=2P, grads=2P, moments=8P bytes.")

Code cell 17

# Solution
P = 1e9
total_bytes = 2*P + 2*P + 8*P
print("GB:", total_bytes / 1e9)
assert np.isclose(total_bytes / 1e9, 12.0)

Exercise 6: Pipeline bubble

Compute bubble fraction for 4 stages and 12 micro-batches.

Code cell 19

# Your Solution
P = 4
M = 12
print("Starter: bubble = (P - 1) / (M + P - 1).")

Code cell 20

# Solution
P = 4
M = 12
bubble = (P - 1) / (M + P - 1)
print("bubble:", bubble)

Exercise 7: Tensor-parallel shard

Split a 4096 x 16384 weight matrix across 4 column-parallel ranks.

Code cell 22

# Your Solution
in_dim = 4096
out_dim = 16384
tp = 4
print("Starter: each rank owns out_dim / tp columns.")

Code cell 23

# Solution
in_dim = 4096
out_dim = 16384
tp = 4
shape = (in_dim, out_dim // tp)
print("per-rank shard shape:", shape)
assert shape == (4096, 4096)

Exercise 8: Training FLOPs

Use C=6NDC=6ND for a 7B model trained on 300B tokens.

Code cell 25

# Your Solution
N = 7e9
D = 300e9
print("Starter: C = 6 * N * D.")

Code cell 26

# Solution
N = 7e9
D = 300e9
C = 6 * N * D
print("FLOPs:", f"{C:.3e}")

Exercise 9: MFU

Compute MFU from useful FLOPs/sec and hardware peak.

Code cell 28

# Your Solution
useful = 60e15
peak = 160e15
print("Starter: MFU = useful / peak.")

Code cell 29

# Solution
useful = 60e15
peak = 160e15
mfu = useful / peak
print("MFU:", mfu)
assert 0 <= mfu <= 1

Exercise 10: Launch checklist

Write four checks before scaling a training run.

Code cell 31

# Your Solution
print("Starter: include loss, resume, memory, and batch checks.")

Code cell 32

# Solution
checks = [
    "small run reduces validation loss",
    "resume restores optimizer, scheduler, RNG, and dataloader state",
    "memory estimate includes activations and optimizer states",
    "effective global token batch is documented",
]
for check in checks:
    print("-", check)
assert len(checks) == 4

Closing Reflection

At scale, arithmetic mistakes become infrastructure failures. Keep units explicit, test small, and make the loss curve prove that the system is learning.

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue