Theory Notebook
1 min read18 headings
Theory Notebook
Converted from
theory.ipynbfor web reading.
Normalization Techniques - Theory Notebook
Executable companion to notes.md.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
COLORS={"primary":"#0077BB","secondary":"#EE7733","tertiary":"#009988","error":"#CC3311","neutral":"#555555","highlight":"#EE3377"}
Code cell 4
def check_close(name, value, expected, tol=1e-6):
ok=np.allclose(value, expected, atol=tol, rtol=tol)
print(f"{'PASS' if ok else 'FAIL'} - {name}: value={value}, expected={expected}")
return ok
def check_true(name, condition):
ok=bool(condition); print(f"{'PASS' if ok else 'FAIL'} - {name}"); return ok
def batch_norm(X, gamma=None, beta=None, eps=1e-5):
mu=X.mean(axis=0, keepdims=True); var=X.var(axis=0, keepdims=True)
Y=(X-mu)/np.sqrt(var+eps)
if gamma is not None: Y=Y*gamma
if beta is not None: Y=Y+beta
return Y, mu, var
def layer_norm(X, gamma=None, beta=None, eps=1e-5):
mu=X.mean(axis=-1, keepdims=True); var=X.var(axis=-1, keepdims=True)
Y=(X-mu)/np.sqrt(var+eps)
if gamma is not None: Y=Y*gamma
if beta is not None: Y=Y+beta
return Y, mu, var
def rms_norm(X, gamma=None, eps=1e-5):
rms=np.sqrt(np.mean(X**2, axis=-1, keepdims=True)+eps)
Y=X/rms
if gamma is not None: Y=Y*gamma
return Y, rms
def group_norm(X, groups=2, eps=1e-5):
B,C=X.shape
assert C % groups == 0
G=X.reshape(B, groups, C//groups)
mu=G.mean(axis=2, keepdims=True); var=G.var(axis=2, keepdims=True)
Y=((G-mu)/np.sqrt(var+eps)).reshape(B,C)
return Y
print("Normalization helpers ready.")
1. Axis intuition
Code cell 6
X=np.array([[1.,2.,3.],[10.,20.,30.],[2.,4.,6.]])
print("X shape", X.shape)
print("feature means over batch", X.mean(axis=0))
print("example means over features", X.mean(axis=1))
2. BatchNorm basic
Code cell 8
X=np.random.normal(loc=[0,5,-3], scale=[1,2,0.5], size=(256,3))
Y,mu,var=batch_norm(X)
print("batch means before", np.round(mu.ravel(),3))
print("means after", np.round(Y.mean(axis=0),6))
print("vars after", np.round(Y.var(axis=0),6))
check_close("BatchNorm means zero", Y.mean(axis=0), np.zeros(3), tol=1e-5)
3. BatchNorm visualization
Code cell 10
fig, axes=plt.subplots(1,2,figsize=(12,5))
axes[0].hist(X[:,1], bins=30, color=COLORS["primary"], alpha=0.8)
axes[0].set_title("Before BatchNorm")
axes[0].set_xlabel("Feature value"); axes[0].set_ylabel("Count")
axes[1].hist(Y[:,1], bins=30, color=COLORS["secondary"], alpha=0.8)
axes[1].set_title("After BatchNorm")
axes[1].set_xlabel("Normalized value"); axes[1].set_ylabel("Count")
fig.tight_layout(); plt.show(); print("BatchNorm histogram plotted.")
4. Running averages
Code cell 12
running_mu=np.zeros(3); alpha=0.1
for step in range(5):
batch=np.random.normal(loc=step, scale=1, size=(32,3))
batch_mu=batch.mean(axis=0)
running_mu=(1-alpha)*running_mu+alpha*batch_mu
print(f"step={step}: batch_mu={np.round(batch_mu,2)}, running_mu={np.round(running_mu,2)}")
5. Train eval gap
Code cell 14
train_batch=np.random.normal(loc=5, scale=2, size=(8,3))
running_mu=np.zeros((1,3)); running_var=np.ones((1,3))
train_out,train_mu,train_var=batch_norm(train_batch)
eval_out=(train_batch-running_mu)/np.sqrt(running_var+1e-5)
print("train output mean", np.round(train_out.mean(axis=0),3))
print("eval output mean with stale stats", np.round(eval_out.mean(axis=0),3))
check_true("stale running stats create gap", np.linalg.norm(eval_out.mean(axis=0))>1)
6. Batch composition dependence
Code cell 16
x_single=np.array([[1.,2.,3.]])
batch_a=np.vstack([x_single, np.array([[2.,3.,4.],[3.,4.,5.]])])
batch_b=np.vstack([x_single, np.array([[100.,200.,300.],[110.,210.,310.]])])
out_a=batch_norm(batch_a)[0][0]
out_b=batch_norm(batch_b)[0][0]
print("same sample in batch A", np.round(out_a,3))
print("same sample in batch B", np.round(out_b,3))
check_true("BatchNorm output depends on other examples", np.linalg.norm(out_a-out_b)>1)
7. LayerNorm basic
Code cell 18
X=np.random.normal(loc=5, scale=3, size=(4,6))
Y,mu,var=layer_norm(X)
print("row means after", np.round(Y.mean(axis=1),6))
print("row vars after", np.round(Y.var(axis=1),6))
check_close("LayerNorm row means zero", Y.mean(axis=1), np.zeros(4), tol=1e-5)
8. LayerNorm batch independence
Code cell 20
sample=np.array([[1.,2.,4.,8.]])
batch1=np.vstack([sample, np.zeros((2,4))])
batch2=np.vstack([sample, 100*np.ones((2,4))])
y1=layer_norm(batch1)[0][0]
y2=layer_norm(batch2)[0][0]
print("sample output batch1", np.round(y1,3))
print("sample output batch2", np.round(y2,3))
check_close("LayerNorm independent of other batch examples", y1, y2)
9. BatchNorm vs LayerNorm heatmap
Code cell 22
X=np.random.normal(size=(6,8))*np.linspace(1,4,8)+np.arange(8)
BN=batch_norm(X)[0]; LN=layer_norm(X)[0]
fig, axes=plt.subplots(1,3,figsize=(15,4))
for ax,mat,title in zip(axes,[X,BN,LN],["Raw","BatchNorm","LayerNorm"]):
im=ax.imshow(mat, cmap="viridis", aspect="auto"); ax.set_title(title); ax.set_xlabel("Feature"); ax.set_ylabel("Example")
fig.colorbar(im, ax=axes.ravel().tolist(), label="Value")
plt.show(); print("Axis behavior heatmap plotted.")
10. RMSNorm basic
Code cell 24
X=np.random.normal(loc=3, scale=2, size=(4,6))
Y,rms=rms_norm(X)
print("RMS after", np.sqrt(np.mean(Y**2, axis=1)))
print("Means after RMSNorm", np.round(Y.mean(axis=1),3))
check_close("RMSNorm unit RMS", np.sqrt(np.mean(Y**2, axis=1)), np.ones(4), tol=1e-5)
11. RMSNorm vs LayerNorm means
Code cell 26
LN=layer_norm(X)[0]; RN=rms_norm(X)[0]
print("LayerNorm means", np.round(LN.mean(axis=1),6))
print("RMSNorm means", np.round(RN.mean(axis=1),6))
check_true("RMSNorm does not force zero mean", np.abs(RN.mean(axis=1)).mean()>0.1)
12. Epsilon effect
Code cell 28
tiny=np.array([[1.,1.+1e-8,1.-1e-8]])
for eps in [1e-12,1e-8,1e-5,1e-2]:
y=layer_norm(tiny, eps=eps)[0]
print(f"eps={eps:g}, output={np.round(y,6)}, var={y.var():.6e}")
13. GroupNorm
Code cell 30
X=np.array([[1.,2.,10.,12.],[2.,4.,20.,24.]])
Y=group_norm(X, groups=2)
print("GroupNorm output", np.round(Y,3))
print("group means", np.round(Y.reshape(2,2,2).mean(axis=2),6))
check_close("group means zero", Y.reshape(2,2,2).mean(axis=2), np.zeros((2,2)), tol=1e-5)
14. InstanceNorm image-style
Code cell 32
image=np.random.normal(loc=np.array([0.,5.])[:,None,None], scale=np.array([1.,2.])[:,None,None], size=(2,4,4))
mu=image.mean(axis=(1,2), keepdims=True); var=image.var(axis=(1,2), keepdims=True)
out=(image-mu)/np.sqrt(var+1e-5)
print("per-channel means", np.round(out.mean(axis=(1,2)),6))
check_close("InstanceNorm channel means zero", out.mean(axis=(1,2)), np.zeros(2), tol=1e-5)
15. WeightNorm
Code cell 34
v=np.array([3.,4.]); g=2.0
w=g*v/np.linalg.norm(v)
print("w", w, "norm", np.linalg.norm(w))
check_close("WeightNorm norm equals g", np.linalg.norm(w), g)
16. SpectralNorm exact
Code cell 36
W=np.array([[3.,0.],[0.,1.]])
sigma=np.linalg.svd(W, compute_uv=False)[0]
Wn=W/sigma
print("sigma", sigma, "spectral norm after", np.linalg.svd(Wn, compute_uv=False)[0])
check_close("spectral norm one", np.linalg.svd(Wn, compute_uv=False)[0], 1.0)
17. Power iteration
Code cell 38
W=np.random.normal(size=(5,5)); u=np.random.normal(size=5); u=u/np.linalg.norm(u)
for _ in range(20):
v=W.T@u; v=v/np.linalg.norm(v)
u=W@v; u=u/np.linalg.norm(u)
sigma_est=u@(W@v); sigma_true=np.linalg.svd(W, compute_uv=False)[0]
print("estimate", sigma_est, "true", sigma_true)
check_true("power iteration close", abs(sigma_est-sigma_true)<1e-3)
18. Pre-norm residual toy
Code cell 40
h=np.random.normal(size=(4,6))
F=lambda x: 0.1*np.tanh(x)
pre=h+F(layer_norm(h)[0])
post=layer_norm(h+F(h))[0]
print("pre-norm residual variance", pre.var())
print("post-norm output variance", post.var())
check_true("post norm forces unit-ish feature variance", abs(post.var()-1)<1e-3)
19. Broadcasting gamma beta
Code cell 42
X=np.random.normal(size=(2,3,4))
gamma=np.arange(1,5).reshape(1,1,4); beta=np.zeros((1,1,4))
mu=X.mean(axis=-1, keepdims=True); var=X.var(axis=-1, keepdims=True)
Y=gamma*(X-mu)/np.sqrt(var+1e-5)+beta
print("Y shape", Y.shape)
check_true("gamma broadcasts over batch and time", Y.shape==X.shape)
20. Small batch BN failure
Code cell 44
one=np.array([[3.,3.,3.]])
y,mu,var=batch_norm(one)
print("variance", var.ravel(), "output", y.ravel())
check_close("single example BN centered zero", y, np.zeros_like(y))
21. Mixed precision intuition
Code cell 46
x64=np.array([1.0001,1.0002,1.0003], dtype=np.float64)
x16=x64.astype(np.float16)
print("float64 var", x64.var())
print("float16 values", x16, "var", x16.var())
check_true("low precision can erase tiny variance", x16.var() <= x64.var()*2)
22. Fused form same math
Code cell 48
X=np.random.normal(size=(4,5)); gamma=np.ones((1,5))*2; beta=np.ones((1,5))*0.5
Y1=batch_norm(X, gamma=gamma, beta=beta)[0]
mu=X.mean(axis=0, keepdims=True); inv=1/np.sqrt(X.var(axis=0, keepdims=True)+1e-5)
Y2=(X-mu)*inv*gamma+beta
check_close("fused algebra equals staged algebra", Y1, Y2)
23. Norm does not remove all scale after gamma
Code cell 50
X=np.random.normal(size=(100,4)); gamma=np.array([[1.,2.,3.,4.]])
Y=batch_norm(X, gamma=gamma)[0]
print("feature std after gamma", np.round(Y.std(axis=0),3))
check_true("gamma restores learnable scale", Y.std(axis=0)[-1]>Y.std(axis=0)[0])
24. Norm diagnostics table
Code cell 52
diagnostics={"mean_abs":abs(Y.mean()),"max_std":Y.std(axis=0).max(),"min_std":Y.std(axis=0).min()}
for k,v in diagnostics.items(): print(k, round(float(v),4))
check_true("diagnostics finite", all(np.isfinite(v) for v in diagnostics.values()))
25. Summary checks
Code cell 54
checks=[]
checks.append(check_true("BatchNorm batch-dependent", True))
checks.append(check_true("LayerNorm batch-independent", True))
checks.append(check_true("RMSNorm scale-only", True))
print(f"Passed {sum(checks)}/{len(checks)} summary checks.")