Theory Notebook
1 min read18 headings
Theory Notebook
Converted from
theory.ipynbfor web reading.
Scaling Laws: Theory Notebook
This notebook builds scaling-law intuition with small numerical experiments: power-law fitting, IsoFLOP search, compute-optimal allocation, data quality, serving cost, and forecast residuals.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Generate a toy power law
Code cell 4
rng = np.random.default_rng(1)
X = np.logspace(2, 6, 16)
L_inf = 1.5
A = 2.0
alpha = 0.12
loss_clean = L_inf + A * X ** (-alpha)
loss = loss_clean + rng.normal(0, 0.01, size=X.shape)
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(X, loss, "o-", label="observed")
ax.plot(X, loss_clean, "--", label="true")
ax.set_xscale("log")
ax.set_title("Toy power-law loss")
ax.set_xlabel("resource X")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()
print("first and last loss:", loss[0], loss[-1])
2. Fit the exponent when the floor is known
Code cell 6
excess = loss - L_inf
coef = np.polyfit(np.log(X), np.log(excess), deg=1)
alpha_hat = -coef[0]
A_hat = np.exp(coef[1])
pred = L_inf + A_hat * X ** (-alpha_hat)
print("alpha_hat:", alpha_hat)
print("A_hat:", A_hat)
print("mean abs residual:", np.mean(np.abs(loss - pred)))
3. Dense training FLOPs
Code cell 8
N = 7e9
D = 300e9
C = 6 * N * D
print("parameters:", f"{N:.2e}")
print("tokens:", f"{D:.2e}")
print("approx FLOPs:", f"{C:.3e}")
4. IsoFLOP curve
Code cell 10
C_budget = 1e23
N_grid = np.logspace(8, 11, 100)
D_grid = C_budget / (6 * N_grid)
loss_proxy = 1.6 + 1.2 * (N_grid / 1e9) ** -0.08 + 0.9 * (D_grid / 1e10) ** -0.10
best = np.argmin(loss_proxy)
fig, ax1 = plt.subplots(figsize=(8, 4))
ax1.plot(N_grid, D_grid, label="tokens from C=6ND")
ax1.set_xscale("log")
ax1.set_yscale("log")
ax1.set_xlabel("parameters N")
ax1.set_ylabel("tokens D")
ax1.set_title("IsoFLOP tradeoff")
ax1.scatter([N_grid[best]], [D_grid[best]], color="red", label="toy optimum")
ax1.legend()
fig.tight_layout()
plt.show()
print("best N:", f"{N_grid[best]:.3e}")
print("best D:", f"{D_grid[best]:.3e}")
print("tokens per parameter:", D_grid[best] / N_grid[best])
5. Two-term compute-optimal grid search
Code cell 12
def toy_loss(N, D):
return 1.55 + 0.7 * (N / 1e9) ** -0.09 + 0.8 * (D / 1e10) ** -0.11
budgets = np.logspace(21, 24, 8)
best_rows = []
for Cb in budgets:
N_try = np.logspace(7, 12, 400)
D_try = Cb / (6 * N_try)
losses = toy_loss(N_try, D_try)
idx = np.argmin(losses)
best_rows.append((Cb, N_try[idx], D_try[idx], losses[idx]))
print(f"C={Cb:.1e} N*={N_try[idx]:.2e} D*={D_try[idx]:.2e} loss={losses[idx]:.3f}")
6. Undertraining diagnostic
Code cell 14
models = {
"small_long": (7e9, 300e9),
"big_short": (70e9, 300e9),
"big_long": (70e9, 1400e9),
}
target_ratio = 20
for name, (N, D) in models.items():
ratio = D / N
print(f"{name:>10s}: tokens/param={ratio:.1f}, under target={ratio < target_ratio}")
7. Sensitivity to exponent estimates
Code cell 16
C = np.logspace(20, 25, 80)
alphas = [0.03, 0.05, 0.08]
fig, ax = plt.subplots(figsize=(7, 4))
for a in alphas:
L = 1.5 + 1.0 * C ** (-a)
ax.plot(C, L, label=f"alpha={a}")
ax.set_xscale("log")
ax.set_title("Small exponent changes affect forecasts")
ax.set_xlabel("compute")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()
print("loss spread at largest C:", max(1.5 + C[-1]**(-a) for a in alphas) - min(1.5 + C[-1]**(-a) for a in alphas))
8. Effective data tokens
Code cell 18
sources = {
"curated": (100e9, 1.4),
"web": (500e9, 0.7),
"code": (80e9, 1.1),
"duplicates": (120e9, 0.2),
}
raw = sum(tokens for tokens, quality in sources.values())
effective = sum(tokens * quality for tokens, quality in sources.values())
for name, (tokens, quality) in sources.items():
print(f"{name:>10s}: raw={tokens/1e9:6.1f}B quality={quality:.1f} eff={tokens*quality/1e9:6.1f}B")
print("raw total B:", raw / 1e9)
print("effective total B:", effective / 1e9)
9. Repeated data diminishing returns
Code cell 20
passes = np.arange(1, 9)
unique_tokens = 100e9
repeat_efficiency = 1 / np.sqrt(passes)
effective_tokens = unique_tokens * np.cumsum(repeat_efficiency)
for p, eff in zip(passes, effective_tokens):
print(f"passes={p}: effective tokens={eff/1e9:.1f}B")
10. Inference-aware choice
Code cell 22
choices = {
"small_overtrained": {"loss": 1.82, "train_cost": 1.0, "serve_cost_per_m": 0.02},
"large_short": {"loss": 1.76, "train_cost": 3.0, "serve_cost_per_m": 0.12},
}
queries_m = 200
quality_weight = 20
for name, v in choices.items():
total = quality_weight * v["loss"] + v["train_cost"] + queries_m * v["serve_cost_per_m"]
print(f"{name:>17s}: objective={total:.2f}")
11. Forecast residuals
Code cell 24
actual = np.array([2.10, 1.98, 1.90, 1.84, 1.80])
pred = np.array([2.11, 1.99, 1.89, 1.82, 1.76])
resid = actual - pred
for i, r in enumerate(resid):
print(f"run {i}: residual={r:+.3f}")
print("mean residual:", resid.mean())
print("max abs residual:", np.max(np.abs(resid)))
12. Threshold metrics can look sudden
Code cell 26
loss_values = np.linspace(2.0, 1.0, 50)
score = 1 / (1 + np.exp(8 * (loss_values - 1.45)))
thresholded = (score > 0.5).astype(float)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(loss_values, score, label="smooth score")
ax.step(loss_values, thresholded, where="mid", label="thresholded metric")
ax.invert_xaxis()
ax.set_title("Smooth loss can create abrupt threshold metrics")
ax.set_xlabel("loss")
ax.set_ylabel("metric")
ax.legend()
fig.tight_layout()
plt.show()
print("threshold crossing near loss:", loss_values[np.argmin(np.abs(score - 0.5))])
13. Forecast checklist
Code cell 28
checks = [
"same tokenizer, architecture, optimizer, and data mixture",
"held-out loss measured on a fixed evaluation set",
"fit uses multiple scales and reserves validation runs",
"loss floor assumption is stated",
"residuals are plotted",
"forecast includes uncertainty and a stop condition",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")