Theory NotebookMath for LLMs

Normalization Techniques

ML Specific Math / Normalization Techniques

Run notebook
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Notebook
1 min read18 headings

Theory Notebook

Converted from theory.ipynb for web reading.

Normalization Techniques - Theory Notebook

Executable companion to notes.md.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")
COLORS={"primary":"#0077BB","secondary":"#EE7733","tertiary":"#009988","error":"#CC3311","neutral":"#555555","highlight":"#EE3377"}

Code cell 4

def check_close(name, value, expected, tol=1e-6):
    ok=np.allclose(value, expected, atol=tol, rtol=tol)
    print(f"{'PASS' if ok else 'FAIL'} - {name}: value={value}, expected={expected}")
    return ok

def check_true(name, condition):
    ok=bool(condition); print(f"{'PASS' if ok else 'FAIL'} - {name}"); return ok

def batch_norm(X, gamma=None, beta=None, eps=1e-5):
    mu=X.mean(axis=0, keepdims=True); var=X.var(axis=0, keepdims=True)
    Y=(X-mu)/np.sqrt(var+eps)
    if gamma is not None: Y=Y*gamma
    if beta is not None: Y=Y+beta
    return Y, mu, var

def layer_norm(X, gamma=None, beta=None, eps=1e-5):
    mu=X.mean(axis=-1, keepdims=True); var=X.var(axis=-1, keepdims=True)
    Y=(X-mu)/np.sqrt(var+eps)
    if gamma is not None: Y=Y*gamma
    if beta is not None: Y=Y+beta
    return Y, mu, var

def rms_norm(X, gamma=None, eps=1e-5):
    rms=np.sqrt(np.mean(X**2, axis=-1, keepdims=True)+eps)
    Y=X/rms
    if gamma is not None: Y=Y*gamma
    return Y, rms

def group_norm(X, groups=2, eps=1e-5):
    B,C=X.shape
    assert C % groups == 0
    G=X.reshape(B, groups, C//groups)
    mu=G.mean(axis=2, keepdims=True); var=G.var(axis=2, keepdims=True)
    Y=((G-mu)/np.sqrt(var+eps)).reshape(B,C)
    return Y
print("Normalization helpers ready.")

1. Axis intuition

Code cell 6

X=np.array([[1.,2.,3.],[10.,20.,30.],[2.,4.,6.]])
print("X shape", X.shape)
print("feature means over batch", X.mean(axis=0))
print("example means over features", X.mean(axis=1))

2. BatchNorm basic

Code cell 8

X=np.random.normal(loc=[0,5,-3], scale=[1,2,0.5], size=(256,3))
Y,mu,var=batch_norm(X)
print("batch means before", np.round(mu.ravel(),3))
print("means after", np.round(Y.mean(axis=0),6))
print("vars after", np.round(Y.var(axis=0),6))
check_close("BatchNorm means zero", Y.mean(axis=0), np.zeros(3), tol=1e-5)

3. BatchNorm visualization

Code cell 10

fig, axes=plt.subplots(1,2,figsize=(12,5))
axes[0].hist(X[:,1], bins=30, color=COLORS["primary"], alpha=0.8)
axes[0].set_title("Before BatchNorm")
axes[0].set_xlabel("Feature value"); axes[0].set_ylabel("Count")
axes[1].hist(Y[:,1], bins=30, color=COLORS["secondary"], alpha=0.8)
axes[1].set_title("After BatchNorm")
axes[1].set_xlabel("Normalized value"); axes[1].set_ylabel("Count")
fig.tight_layout(); plt.show(); print("BatchNorm histogram plotted.")

4. Running averages

Code cell 12

running_mu=np.zeros(3); alpha=0.1
for step in range(5):
    batch=np.random.normal(loc=step, scale=1, size=(32,3))
    batch_mu=batch.mean(axis=0)
    running_mu=(1-alpha)*running_mu+alpha*batch_mu
    print(f"step={step}: batch_mu={np.round(batch_mu,2)}, running_mu={np.round(running_mu,2)}")

5. Train eval gap

Code cell 14

train_batch=np.random.normal(loc=5, scale=2, size=(8,3))
running_mu=np.zeros((1,3)); running_var=np.ones((1,3))
train_out,train_mu,train_var=batch_norm(train_batch)
eval_out=(train_batch-running_mu)/np.sqrt(running_var+1e-5)
print("train output mean", np.round(train_out.mean(axis=0),3))
print("eval output mean with stale stats", np.round(eval_out.mean(axis=0),3))
check_true("stale running stats create gap", np.linalg.norm(eval_out.mean(axis=0))>1)

6. Batch composition dependence

Code cell 16

x_single=np.array([[1.,2.,3.]])
batch_a=np.vstack([x_single, np.array([[2.,3.,4.],[3.,4.,5.]])])
batch_b=np.vstack([x_single, np.array([[100.,200.,300.],[110.,210.,310.]])])
out_a=batch_norm(batch_a)[0][0]
out_b=batch_norm(batch_b)[0][0]
print("same sample in batch A", np.round(out_a,3))
print("same sample in batch B", np.round(out_b,3))
check_true("BatchNorm output depends on other examples", np.linalg.norm(out_a-out_b)>1)

7. LayerNorm basic

Code cell 18

X=np.random.normal(loc=5, scale=3, size=(4,6))
Y,mu,var=layer_norm(X)
print("row means after", np.round(Y.mean(axis=1),6))
print("row vars after", np.round(Y.var(axis=1),6))
check_close("LayerNorm row means zero", Y.mean(axis=1), np.zeros(4), tol=1e-5)

8. LayerNorm batch independence

Code cell 20

sample=np.array([[1.,2.,4.,8.]])
batch1=np.vstack([sample, np.zeros((2,4))])
batch2=np.vstack([sample, 100*np.ones((2,4))])
y1=layer_norm(batch1)[0][0]
y2=layer_norm(batch2)[0][0]
print("sample output batch1", np.round(y1,3))
print("sample output batch2", np.round(y2,3))
check_close("LayerNorm independent of other batch examples", y1, y2)

9. BatchNorm vs LayerNorm heatmap

Code cell 22

X=np.random.normal(size=(6,8))*np.linspace(1,4,8)+np.arange(8)
BN=batch_norm(X)[0]; LN=layer_norm(X)[0]
fig, axes=plt.subplots(1,3,figsize=(15,4))
for ax,mat,title in zip(axes,[X,BN,LN],["Raw","BatchNorm","LayerNorm"]):
    im=ax.imshow(mat, cmap="viridis", aspect="auto"); ax.set_title(title); ax.set_xlabel("Feature"); ax.set_ylabel("Example")
fig.colorbar(im, ax=axes.ravel().tolist(), label="Value")
plt.show(); print("Axis behavior heatmap plotted.")

10. RMSNorm basic

Code cell 24

X=np.random.normal(loc=3, scale=2, size=(4,6))
Y,rms=rms_norm(X)
print("RMS after", np.sqrt(np.mean(Y**2, axis=1)))
print("Means after RMSNorm", np.round(Y.mean(axis=1),3))
check_close("RMSNorm unit RMS", np.sqrt(np.mean(Y**2, axis=1)), np.ones(4), tol=1e-5)

11. RMSNorm vs LayerNorm means

Code cell 26

LN=layer_norm(X)[0]; RN=rms_norm(X)[0]
print("LayerNorm means", np.round(LN.mean(axis=1),6))
print("RMSNorm means", np.round(RN.mean(axis=1),6))
check_true("RMSNorm does not force zero mean", np.abs(RN.mean(axis=1)).mean()>0.1)

12. Epsilon effect

Code cell 28

tiny=np.array([[1.,1.+1e-8,1.-1e-8]])
for eps in [1e-12,1e-8,1e-5,1e-2]:
    y=layer_norm(tiny, eps=eps)[0]
    print(f"eps={eps:g}, output={np.round(y,6)}, var={y.var():.6e}")

13. GroupNorm

Code cell 30

X=np.array([[1.,2.,10.,12.],[2.,4.,20.,24.]])
Y=group_norm(X, groups=2)
print("GroupNorm output", np.round(Y,3))
print("group means", np.round(Y.reshape(2,2,2).mean(axis=2),6))
check_close("group means zero", Y.reshape(2,2,2).mean(axis=2), np.zeros((2,2)), tol=1e-5)

14. InstanceNorm image-style

Code cell 32

image=np.random.normal(loc=np.array([0.,5.])[:,None,None], scale=np.array([1.,2.])[:,None,None], size=(2,4,4))
mu=image.mean(axis=(1,2), keepdims=True); var=image.var(axis=(1,2), keepdims=True)
out=(image-mu)/np.sqrt(var+1e-5)
print("per-channel means", np.round(out.mean(axis=(1,2)),6))
check_close("InstanceNorm channel means zero", out.mean(axis=(1,2)), np.zeros(2), tol=1e-5)

15. WeightNorm

Code cell 34

v=np.array([3.,4.]); g=2.0
w=g*v/np.linalg.norm(v)
print("w", w, "norm", np.linalg.norm(w))
check_close("WeightNorm norm equals g", np.linalg.norm(w), g)

16. SpectralNorm exact

Code cell 36

W=np.array([[3.,0.],[0.,1.]])
sigma=np.linalg.svd(W, compute_uv=False)[0]
Wn=W/sigma
print("sigma", sigma, "spectral norm after", np.linalg.svd(Wn, compute_uv=False)[0])
check_close("spectral norm one", np.linalg.svd(Wn, compute_uv=False)[0], 1.0)

17. Power iteration

Code cell 38

W=np.random.normal(size=(5,5)); u=np.random.normal(size=5); u=u/np.linalg.norm(u)
for _ in range(20):
    v=W.T@u; v=v/np.linalg.norm(v)
    u=W@v; u=u/np.linalg.norm(u)
sigma_est=u@(W@v); sigma_true=np.linalg.svd(W, compute_uv=False)[0]
print("estimate", sigma_est, "true", sigma_true)
check_true("power iteration close", abs(sigma_est-sigma_true)<1e-3)

18. Pre-norm residual toy

Code cell 40

h=np.random.normal(size=(4,6))
F=lambda x: 0.1*np.tanh(x)
pre=h+F(layer_norm(h)[0])
post=layer_norm(h+F(h))[0]
print("pre-norm residual variance", pre.var())
print("post-norm output variance", post.var())
check_true("post norm forces unit-ish feature variance", abs(post.var()-1)<1e-3)

19. Broadcasting gamma beta

Code cell 42

X=np.random.normal(size=(2,3,4))
gamma=np.arange(1,5).reshape(1,1,4); beta=np.zeros((1,1,4))
mu=X.mean(axis=-1, keepdims=True); var=X.var(axis=-1, keepdims=True)
Y=gamma*(X-mu)/np.sqrt(var+1e-5)+beta
print("Y shape", Y.shape)
check_true("gamma broadcasts over batch and time", Y.shape==X.shape)

20. Small batch BN failure

Code cell 44

one=np.array([[3.,3.,3.]])
y,mu,var=batch_norm(one)
print("variance", var.ravel(), "output", y.ravel())
check_close("single example BN centered zero", y, np.zeros_like(y))

21. Mixed precision intuition

Code cell 46

x64=np.array([1.0001,1.0002,1.0003], dtype=np.float64)
x16=x64.astype(np.float16)
print("float64 var", x64.var())
print("float16 values", x16, "var", x16.var())
check_true("low precision can erase tiny variance", x16.var() <= x64.var()*2)

22. Fused form same math

Code cell 48

X=np.random.normal(size=(4,5)); gamma=np.ones((1,5))*2; beta=np.ones((1,5))*0.5
Y1=batch_norm(X, gamma=gamma, beta=beta)[0]
mu=X.mean(axis=0, keepdims=True); inv=1/np.sqrt(X.var(axis=0, keepdims=True)+1e-5)
Y2=(X-mu)*inv*gamma+beta
check_close("fused algebra equals staged algebra", Y1, Y2)

23. Norm does not remove all scale after gamma

Code cell 50

X=np.random.normal(size=(100,4)); gamma=np.array([[1.,2.,3.,4.]])
Y=batch_norm(X, gamma=gamma)[0]
print("feature std after gamma", np.round(Y.std(axis=0),3))
check_true("gamma restores learnable scale", Y.std(axis=0)[-1]>Y.std(axis=0)[0])

24. Norm diagnostics table

Code cell 52

diagnostics={"mean_abs":abs(Y.mean()),"max_std":Y.std(axis=0).max(),"min_std":Y.std(axis=0).min()}
for k,v in diagnostics.items(): print(k, round(float(v),4))
check_true("diagnostics finite", all(np.isfinite(v) for v in diagnostics.values()))

25. Summary checks

Code cell 54

checks=[]
checks.append(check_true("BatchNorm batch-dependent", True))
checks.append(check_true("LayerNorm batch-independent", True))
checks.append(check_true("RMSNorm scale-only", True))
print(f"Passed {sum(checks)}/{len(checks)} summary checks.")

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue