Theory NotebookMath for LLMs

Scaling Laws

Math for LLMs / Scaling Laws

Run notebook
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Notebook
1 min read18 headings

Theory Notebook

Converted from theory.ipynb for web reading.

Scaling Laws: Theory Notebook

This notebook builds scaling-law intuition with small numerical experiments: power-law fitting, IsoFLOP search, compute-optimal allocation, data quality, serving cost, and forecast residuals.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Generate a toy power law

Code cell 4

rng = np.random.default_rng(1)
X = np.logspace(2, 6, 16)
L_inf = 1.5
A = 2.0
alpha = 0.12
loss_clean = L_inf + A * X ** (-alpha)
loss = loss_clean + rng.normal(0, 0.01, size=X.shape)

fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(X, loss, "o-", label="observed")
ax.plot(X, loss_clean, "--", label="true")
ax.set_xscale("log")
ax.set_title("Toy power-law loss")
ax.set_xlabel("resource X")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()
print("first and last loss:", loss[0], loss[-1])

2. Fit the exponent when the floor is known

Code cell 6

excess = loss - L_inf
coef = np.polyfit(np.log(X), np.log(excess), deg=1)
alpha_hat = -coef[0]
A_hat = np.exp(coef[1])
pred = L_inf + A_hat * X ** (-alpha_hat)
print("alpha_hat:", alpha_hat)
print("A_hat:", A_hat)
print("mean abs residual:", np.mean(np.abs(loss - pred)))

3. Dense training FLOPs

Code cell 8

N = 7e9
D = 300e9
C = 6 * N * D
print("parameters:", f"{N:.2e}")
print("tokens:", f"{D:.2e}")
print("approx FLOPs:", f"{C:.3e}")

4. IsoFLOP curve

Code cell 10

C_budget = 1e23
N_grid = np.logspace(8, 11, 100)
D_grid = C_budget / (6 * N_grid)
loss_proxy = 1.6 + 1.2 * (N_grid / 1e9) ** -0.08 + 0.9 * (D_grid / 1e10) ** -0.10
best = np.argmin(loss_proxy)

fig, ax1 = plt.subplots(figsize=(8, 4))
ax1.plot(N_grid, D_grid, label="tokens from C=6ND")
ax1.set_xscale("log")
ax1.set_yscale("log")
ax1.set_xlabel("parameters N")
ax1.set_ylabel("tokens D")
ax1.set_title("IsoFLOP tradeoff")
ax1.scatter([N_grid[best]], [D_grid[best]], color="red", label="toy optimum")
ax1.legend()
fig.tight_layout()
plt.show()
print("best N:", f"{N_grid[best]:.3e}")
print("best D:", f"{D_grid[best]:.3e}")
print("tokens per parameter:", D_grid[best] / N_grid[best])

Code cell 12

def toy_loss(N, D):
    return 1.55 + 0.7 * (N / 1e9) ** -0.09 + 0.8 * (D / 1e10) ** -0.11

budgets = np.logspace(21, 24, 8)
best_rows = []
for Cb in budgets:
    N_try = np.logspace(7, 12, 400)
    D_try = Cb / (6 * N_try)
    losses = toy_loss(N_try, D_try)
    idx = np.argmin(losses)
    best_rows.append((Cb, N_try[idx], D_try[idx], losses[idx]))
    print(f"C={Cb:.1e} N*={N_try[idx]:.2e} D*={D_try[idx]:.2e} loss={losses[idx]:.3f}")

6. Undertraining diagnostic

Code cell 14

models = {
    "small_long": (7e9, 300e9),
    "big_short": (70e9, 300e9),
    "big_long": (70e9, 1400e9),
}
target_ratio = 20
for name, (N, D) in models.items():
    ratio = D / N
    print(f"{name:>10s}: tokens/param={ratio:.1f}, under target={ratio < target_ratio}")

7. Sensitivity to exponent estimates

Code cell 16

C = np.logspace(20, 25, 80)
alphas = [0.03, 0.05, 0.08]
fig, ax = plt.subplots(figsize=(7, 4))
for a in alphas:
    L = 1.5 + 1.0 * C ** (-a)
    ax.plot(C, L, label=f"alpha={a}")
ax.set_xscale("log")
ax.set_title("Small exponent changes affect forecasts")
ax.set_xlabel("compute")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()
print("loss spread at largest C:", max(1.5 + C[-1]**(-a) for a in alphas) - min(1.5 + C[-1]**(-a) for a in alphas))

8. Effective data tokens

Code cell 18

sources = {
    "curated": (100e9, 1.4),
    "web": (500e9, 0.7),
    "code": (80e9, 1.1),
    "duplicates": (120e9, 0.2),
}
raw = sum(tokens for tokens, quality in sources.values())
effective = sum(tokens * quality for tokens, quality in sources.values())
for name, (tokens, quality) in sources.items():
    print(f"{name:>10s}: raw={tokens/1e9:6.1f}B quality={quality:.1f} eff={tokens*quality/1e9:6.1f}B")
print("raw total B:", raw / 1e9)
print("effective total B:", effective / 1e9)

9. Repeated data diminishing returns

Code cell 20

passes = np.arange(1, 9)
unique_tokens = 100e9
repeat_efficiency = 1 / np.sqrt(passes)
effective_tokens = unique_tokens * np.cumsum(repeat_efficiency)
for p, eff in zip(passes, effective_tokens):
    print(f"passes={p}: effective tokens={eff/1e9:.1f}B")

10. Inference-aware choice

Code cell 22

choices = {
    "small_overtrained": {"loss": 1.82, "train_cost": 1.0, "serve_cost_per_m": 0.02},
    "large_short": {"loss": 1.76, "train_cost": 3.0, "serve_cost_per_m": 0.12},
}
queries_m = 200
quality_weight = 20
for name, v in choices.items():
    total = quality_weight * v["loss"] + v["train_cost"] + queries_m * v["serve_cost_per_m"]
    print(f"{name:>17s}: objective={total:.2f}")

11. Forecast residuals

Code cell 24

actual = np.array([2.10, 1.98, 1.90, 1.84, 1.80])
pred = np.array([2.11, 1.99, 1.89, 1.82, 1.76])
resid = actual - pred
for i, r in enumerate(resid):
    print(f"run {i}: residual={r:+.3f}")
print("mean residual:", resid.mean())
print("max abs residual:", np.max(np.abs(resid)))

12. Threshold metrics can look sudden

Code cell 26

loss_values = np.linspace(2.0, 1.0, 50)
score = 1 / (1 + np.exp(8 * (loss_values - 1.45)))
thresholded = (score > 0.5).astype(float)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(loss_values, score, label="smooth score")
ax.step(loss_values, thresholded, where="mid", label="thresholded metric")
ax.invert_xaxis()
ax.set_title("Smooth loss can create abrupt threshold metrics")
ax.set_xlabel("loss")
ax.set_ylabel("metric")
ax.legend()
fig.tight_layout()
plt.show()
print("threshold crossing near loss:", loss_values[np.argmin(np.abs(score - 0.5))])

13. Forecast checklist

Code cell 28

checks = [
    "same tokenizer, architecture, optimizer, and data mixture",
    "held-out loss measured on a fixed evaluation set",
    "fit uses multiple scales and reserves validation runs",
    "loss floor assumption is stated",
    "residuals are plotted",
    "forecast includes uncertainty and a stop condition",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue