Theory NotebookMath for LLMs

Fisher Information

Information Theory / Fisher Information

Run notebook
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Fisher Information

Fisher information measures local statistical distinguishability. This notebook is the interactive companion to notes.md and builds the chapter through concrete score computations, KL-curvature checks, Fisher geometry, and ML-facing approximations.

Coverage: score functions, scalar and matrix Fisher information, observed vs expected vs empirical Fisher, reparameterization, local KL curvature, Jeffreys priors, natural gradient, empirical-Fisher pitfalls, and Fisher-based ML applications.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 4

COLORS = {
    "primary":   "#0077BB",
    "secondary": "#EE7733",
    "tertiary":  "#009988",
    "error":     "#CC3311",
    "neutral":   "#555555",
    "highlight": "#EE3377",
}

def header(title):
    print("\n" + "=" * 78)
    print(title)
    print("=" * 78)

def check_close(name, value, target, tol=1e-8):
    ok = np.allclose(value, target, atol=tol, rtol=tol)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    if not ok:
        print("  value :", value)
        print("  target:", target)
    return ok

def check_true(name, condition):
    print(f"{'PASS' if condition else 'FAIL'} - {name}")
    return condition

def bernoulli_score(x, p):
    return x / p - (1 - x) / (1 - p)

def bernoulli_fisher(p):
    return 1.0 / (p * (1 - p))

def poisson_score(x, lam):
    return -1.0 + x / lam

def poisson_fisher(lam):
    return 1.0 / lam

def sigmoid(z):
    return 1.0 / (1.0 + np.exp(-z))

print("Helpers ready.")

1. Score Functions and Scalar Fisher Information

Code cell 6

header("Bernoulli score and Fisher information")
p = 0.3
x_vals = np.array([0, 1])
probs = np.array([1 - p, p])
scores = np.array([bernoulli_score(x, p) for x in x_vals])
fisher_emp = np.sum(probs * scores**2)
fisher_formula = bernoulli_fisher(p)
print("scores:", dict(zip(x_vals, scores)))
print("E[score^2]      =", fisher_emp)
print("closed form     =", fisher_formula)
check_close("Bernoulli Fisher matches 1/(p(1-p))", fisher_emp, fisher_formula)

Code cell 7

header("Bernoulli Fisher curve")
ps = np.linspace(0.02, 0.98, 300)
vals = bernoulli_fisher(ps)
fig, ax = plt.subplots()
ax.plot(ps, vals, color=COLORS["primary"])
ax.set_title("Bernoulli Fisher information")
ax.set_xlabel("Probability p")
ax.set_ylabel("Fisher information I(p)")
fig.tight_layout()
plt.show()
check_true("Bernoulli Fisher is smallest near p=0.5", vals[np.argmin(np.abs(ps - 0.5))] < vals[0])

What to notice: Bernoulli Fisher is smallest near p=1/2p=1/2 and grows toward the boundaries. This is a local distinguishability statement, not a statement about global uncertainty.

Code cell 9

header("Poisson score and Fisher information")
lam = 4.0
xs = np.arange(0, 30)
pmf = np.exp(-lam) * lam**xs / np.array([math.factorial(int(x)) for x in xs])
scores = poisson_score(xs, lam)
fisher_emp = np.sum(pmf * scores**2)
fisher_formula = poisson_fisher(lam)
print("E[score^2]  =", fisher_emp)
print("closed form =", fisher_formula)
check_close("Poisson Fisher matches 1/lambda", fisher_emp, fisher_formula, tol=1e-6)

Code cell 10

header("Gaussian mean Fisher information")
sigma2 = 2.5
fisher_formula = 1.0 / sigma2
rng = np.random.default_rng(42)
x = rng.normal(loc=1.0, scale=np.sqrt(sigma2), size=200000)
score = (x - 1.0) / sigma2
fisher_emp = np.mean(score**2)
print("Monte Carlo E[score^2] =", fisher_emp)
print("closed form            =", fisher_formula)
check_close("Gaussian mean Fisher matches 1/sigma^2", fisher_emp, fisher_formula, tol=1e-2)

2. Matrix Fisher Information and Observed vs Expected Curvature

The matrix view packages directional sensitivity. Large eigenvalues correspond to sharp information-rich directions; small eigenvalues correspond to flat or weakly identifiable directions.

Code cell 13

header("Matrix Fisher for a Gaussian mean model")
Sigma = np.array([[2.0, 0.8], [0.8, 1.5]])
Sigma_inv = np.linalg.inv(Sigma)
print("Sigma^{-1} =")
print(Sigma_inv)
eigvals = np.linalg.eigvalsh(Sigma_inv)
print("eigenvalues:", eigvals)
check_true("Matrix Fisher is positive definite", np.all(eigvals > 0))

Code cell 14

header("Observed information versus expected Fisher")
rng = np.random.default_rng(1)
sigma2 = 1.7
mu_true = 0.4
n = 50
sample = rng.normal(mu_true, np.sqrt(sigma2), size=n)
observed_info = n / sigma2
expected_fisher = n / sigma2
print("Observed info for Gaussian mean with known variance:", observed_info)
print("Expected Fisher:", expected_fisher)
check_close("Observed equals expected in this model", observed_info, expected_fisher)

Code cell 15

header("Empirical Fisher for logistic regression on a toy dataset")
rng = np.random.default_rng(3)
X = rng.normal(size=(200, 2))
w = np.array([0.8, -0.5])
logits = X @ w
p_hat = sigmoid(logits)
y = rng.binomial(1, p_hat)
grads = ((y - p_hat)[:, None]) * X
empirical_fisher = grads.T @ grads / len(X)
print("empirical Fisher =")
print(empirical_fisher)
eigvals = np.linalg.eigvalsh(empirical_fisher)
print("eigenvalues:", eigvals)
check_true("Empirical Fisher is PSD", np.all(eigvals >= -1e-10))

3. Structural Properties: Additivity and Reparameterization

Additivity is why information scales linearly with sample size in regular iid models. Reparameterization shows why Fisher is a metric object rather than a coordinate-free scalar value.

Code cell 18

header("Additivity over independent observations")
p = 0.37
single = bernoulli_fisher(p)
n = 25
total = n * single
print("single-observation Fisher:", single)
print("n-sample Fisher:", total)
check_close("IID additivity gives n * I(theta)", total, n * single)

Code cell 19

header("Reparameterization: Bernoulli p versus logit phi")
p = 0.2
fisher_p = bernoulli_fisher(p)
dp_dphi = p * (1 - p)
fisher_phi = fisher_p * dp_dphi**2
print("Fisher in p-coordinates     =", fisher_p)
print("Fisher in logit coordinates =", fisher_phi)
check_close("Logit-coordinate Fisher equals p(1-p)", fisher_phi, p * (1 - p))

Code cell 20

header("Coordinate change plot")
ps = np.linspace(0.02, 0.98, 300)
fisher_p = bernoulli_fisher(ps)
fisher_phi = fisher_p * (ps * (1 - ps))**2
fig, ax = plt.subplots()
ax.plot(ps, fisher_p, color=COLORS["primary"], label="I(p)")
ax.plot(ps, fisher_phi, color=COLORS["secondary"], label="I(logit(p))")
ax.set_title("Coordinate values change under reparameterization")
ax.set_xlabel("Probability p")
ax.set_ylabel("Information value")
ax.legend()
fig.tight_layout()
plt.show()
check_true("Coordinate values differ substantially", np.max(np.abs(fisher_p - fisher_phi)) > 1.0)

Code cell 21

header("Sufficiency example: Gaussian sample mean preserves Fisher")
sigma2 = 3.0
n = 40
fisher_full = n / sigma2
fisher_mean = 1.0 / (sigma2 / n)
print("Fisher in full sample:", fisher_full)
print("Fisher in sample mean:", fisher_mean)
check_close("Sufficient statistic preserves Fisher", fisher_full, fisher_mean)

4. Fisher as Local KL Curvature

The KL expansion is the chapter's geometric hinge: Fisher information is the quadratic form inside local KL divergence.

Code cell 24

header("Bernoulli local KL is quadratic with Fisher coefficient")
p = 0.4
deltas = np.array([-0.02, -0.01, 0.01, 0.02])
fisher = bernoulli_fisher(p)
for delta in deltas:
    q = np.clip(p + delta, 1e-9, 1 - 1e-9)
    kl = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
    quad = 0.5 * fisher * delta**2
    print(f"delta={delta:+.3f}  KL={kl:.8f}  quadratic={quad:.8f}")
small = 0.01
q = p + small
kl = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
quad = 0.5 * fisher * small**2
check_close("Local KL matches quadratic approximation", kl, quad, tol=5e-6)

Code cell 25

header("Gaussian mean KL curvature")
sigma2 = 1.4
delta = 0.07
kl = 0.5 * delta**2 / sigma2
fisher_quad = 0.5 * (1.0 / sigma2) * delta**2
print("exact KL between nearby Gaussian means:", kl)
print("quadratic Fisher term             :", fisher_quad)
check_close("Gaussian local KL equals Fisher quadratic exactly", kl, fisher_quad)

Code cell 26

header("Quadratic KL approximation across delta")
p = 0.35
fisher = bernoulli_fisher(p)
deltas = np.linspace(-0.15, 0.15, 300)
kl_vals = []
quad_vals = 0.5 * fisher * deltas**2
for delta in deltas:
    q = p + delta
    if q <= 0 or q >= 1:
        kl_vals.append(np.nan)
    else:
        kl_vals.append(p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q)))
fig, ax = plt.subplots()
ax.plot(deltas, kl_vals, color=COLORS["primary"], label="Exact KL")
ax.plot(deltas, quad_vals, color=COLORS["secondary"], linestyle="--", label="Quadratic Fisher approx")
ax.set_title("Local KL curvature for Bernoulli models")
ax.set_xlabel("Parameter displacement delta")
ax.set_ylabel("D_KL")
ax.legend()
fig.tight_layout()
plt.show()
check_true("Near delta=0, curves are close", np.nanmax(np.abs(np.array(kl_vals)[140:160] - quad_vals[140:160])) < 1e-4)

5. Jeffreys Prior and Fisher Geometry

Jeffreys priors come from Fisher volume, not from arbitrary coordinate choices. This is one reason they are important in invariant Bayesian analysis.

Code cell 29

header("Jeffreys prior for Bernoulli")
p = np.linspace(0.01, 0.99, 400)
jeff = np.sqrt(bernoulli_fisher(p))
jeff /= np.trapezoid(jeff, p)
print("Approximate integral of Jeffreys density:", np.trapezoid(jeff, p))
check_close("Jeffreys Bernoulli prior numerically normalizes", np.trapezoid(jeff, p), 1.0, tol=5e-3)

Code cell 30

header("Jeffreys prior shape for Bernoulli")
p = np.linspace(0.01, 0.99, 400)
jeff = 1.0 / np.sqrt(p * (1 - p))
jeff /= np.trapezoid(jeff, p)
fig, ax = plt.subplots()
ax.plot(p, jeff, color=COLORS["highlight"])
ax.set_title("Jeffreys prior for Bernoulli(p)")
ax.set_xlabel("Probability p")
ax.set_ylabel("Density")
fig.tight_layout()
plt.show()
check_true("Jeffreys density is larger near the boundaries", jeff[0] > jeff[len(jeff)//2])

Code cell 31

header("Improper Jeffreys prior for an exponential rate")
lam = np.linspace(0.1, 10.0, 1000)
density = 1.0 / lam
partial_mass = np.trapezoid(density, lam)
print("Integral of 1/lambda over [0.1, 10] =", partial_mass)
check_true("Finite-window integral is large and grows with the window", partial_mass > 4.0)

6. Natural Gradient as KL-Constrained Steepest Descent

Natural gradient is steepest descent measured in local KL geometry, which is why Fisher information appears as the metric tensor.

Code cell 34

header("Natural gradient on a 2-parameter logistic model")
rng = np.random.default_rng(7)
X = rng.normal(size=(300, 2))
w = np.array([0.4, -0.7])
logits = X @ w
probs = sigmoid(logits)
y = rng.binomial(1, probs)
p_hat = sigmoid(X @ w)
grad = X.T @ (p_hat - y) / len(X)
fisher = np.zeros((2, 2))
for i in range(len(X)):
    fisher += p_hat[i] * (1 - p_hat[i]) * np.outer(X[i], X[i])
fisher /= len(X)
nat_grad = np.linalg.solve(fisher + 1e-6 * np.eye(2), grad)
print("ordinary gradient:", grad)
print("natural gradient:", nat_grad)
check_true("Natural gradient is not identical to the raw gradient", np.linalg.norm(nat_grad - grad) > 1e-4)

Code cell 35

header("Fisher metric ellipse")
fisher = np.array([[3.0, 1.0], [1.0, 1.5]])
vals, vecs = np.linalg.eigh(fisher)
t = np.linspace(0, 2 * np.pi, 200)
circle = np.vstack([np.cos(t), np.sin(t)])
ellipse = vecs @ np.diag(1.0 / np.sqrt(vals)) @ circle
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(circle[0], circle[1], color=COLORS["neutral"], linestyle="--", label="Euclidean unit circle")
ax.plot(ellipse[0], ellipse[1], color=COLORS["primary"], label="Fisher unit ball")
ax.set_title("Euclidean and Fisher geometry differ")
ax.set_xlabel("Direction 1")
ax.set_ylabel("Direction 2")
ax.legend()
ax.set_aspect('equal')
fig.tight_layout()
plt.show()
check_true("Ellipse is anisotropic", np.ptp(ellipse[0]) != np.ptp(ellipse[1]))

7. Entropy-Fisher Connections and Score Fields

Entropy and Fisher move in opposite directions under Gaussian smoothing: smoothing spreads probability mass while reducing local sharpness.

Code cell 38

header("Fisher divergence between two Gaussian score fields")
xs = np.linspace(-6, 6, 2000)
p_mu, p_var = 0.0, 1.0
q_mu, q_var = 1.0, 1.5
p_pdf = np.exp(-(xs - p_mu)**2 / (2 * p_var)) / np.sqrt(2 * np.pi * p_var)
score_p = -(xs - p_mu) / p_var
score_q = -(xs - q_mu) / q_var
fisher_div = np.trapezoid(p_pdf * (score_p - score_q)**2, xs)
print("Approximate Fisher divergence:", fisher_div)
check_true("Fisher divergence is nonnegative", fisher_div >= 0)

Code cell 39

header("Entropy rises and Fisher falls under Gaussian smoothing")
xs = np.linspace(-8, 8, 4000)
dx = xs[1] - xs[0]
base = 0.5 * np.exp(-(xs + 2)**2 / (2 * 0.6**2)) / (np.sqrt(2 * np.pi) * 0.6)
base += 0.5 * np.exp(-(xs - 2)**2 / (2 * 0.9**2)) / (np.sqrt(2 * np.pi) * 0.9)
base /= np.trapezoid(base, xs)
ts = np.linspace(0.0, 2.0, 11)
entropies = []
fishers = []
for t in ts:
    var = t + 1e-6
    kernel = np.exp(-xs**2 / (2 * var)) / np.sqrt(2 * np.pi * var)
    kernel /= np.trapezoid(kernel, xs)
    smooth = np.convolve(base, kernel, mode='same') * dx
    smooth = np.clip(smooth, 1e-12, None)
    smooth /= np.trapezoid(smooth, xs)
    log_smooth = np.log(smooth)
    grad = np.gradient(log_smooth, dx)
    h = -np.trapezoid(smooth * log_smooth, xs)
    j = np.trapezoid(smooth * grad**2, xs)
    entropies.append(h)
    fishers.append(j)
print("entropy trend:", np.round(entropies, 4))
print("fisher trend :", np.round(fishers, 4))
check_true("Entropy increases under smoothing", entropies[-1] > entropies[0])
check_true("Fisher decreases under smoothing", fishers[-1] < fishers[0])

Code cell 40

header("Entropy/Fisher trend plot")
ts = np.linspace(0.0, 2.0, 11)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(ts, entropies, color=COLORS["secondary"])
axes[0].set_title("Differential entropy under Gaussian smoothing")
axes[0].set_xlabel("Noise variance t")
axes[0].set_ylabel("Entropy h")
axes[1].plot(ts, fishers, color=COLORS["error"])
axes[1].set_title("Fisher information under Gaussian smoothing")
axes[1].set_xlabel("Noise variance t")
axes[1].set_ylabel("Fisher information J")
fig.tight_layout()
plt.show()
check_true("Both curves share the same time grid", len(entropies) == len(fishers) == len(ts))

8. Fisher, Hessian, and Empirical Fisher in a Toy Classifier

The deep-learning caution is essential: empirical Fisher may be convenient, but convenience is not identity.

Code cell 43

header("Observed Hessian versus true-Fisher approximation in logistic regression")
rng = np.random.default_rng(123)
X = rng.normal(size=(400, 2))
w = np.array([1.0, -0.7])
logits = X @ w
probs = sigmoid(logits)
y = rng.binomial(1, probs)
W = probs * (1 - probs)
hessian = (X.T * W) @ X / len(X)
true_fisher = hessian.copy()
grads = ((y - probs)[:, None]) * X
empirical_fisher = grads.T @ grads / len(X)
print("Observed Hessian / expected Fisher =")
print(hessian)
print("Empirical Fisher =")
print(empirical_fisher)
check_true("Expected Fisher matches logistic Hessian in this model", np.linalg.norm(hessian - true_fisher) < 1e-12)
check_true("Empirical Fisher differs from Hessian on finite data", np.linalg.norm(empirical_fisher - hessian) > 1e-4)

Code cell 44

header("Norm comparison")
fro_h = np.linalg.norm(hessian, ord='fro')
fro_e = np.linalg.norm(empirical_fisher, ord='fro')
diff = np.linalg.norm(empirical_fisher - hessian, ord='fro')
print("||H||_F        =", fro_h)
print("||EmpF||_F     =", fro_e)
print("||EmpF - H||_F =", diff)
check_true("Finite-sample empirical Fisher error is visible", diff / fro_h > 0.01)

Code cell 45

header("Redundant-feature singularity example")
rng = np.random.default_rng(11)
x1 = rng.normal(size=500)
X = np.column_stack([x1, x1])
w = np.array([1.0, -1.0])
probs = sigmoid(X @ w)
W = probs * (1 - probs)
fisher = (X.T * W) @ X / len(X)
eigvals = np.linalg.eigvalsh(fisher)
print("Fisher matrix =")
print(fisher)
print("eigenvalues   =", eigvals)
check_true("One eigenvalue is numerically tiny due to redundancy", eigvals[0] < 1e-8)

Code cell 46

header("Eigenvalue plot for redundant-feature Fisher")
fig, ax = plt.subplots()
ax.bar([0, 1], eigvals, color=[COLORS["error"], COLORS["primary"]])
ax.set_title("Redundant features create a nearly singular Fisher matrix")
ax.set_xlabel("Eigenvalue index")
ax.set_ylabel("Eigenvalue")
fig.tight_layout()
plt.show()
check_true("Largest eigenvalue dominates the redundant direction", eigvals[1] > 1000 * max(eigvals[0], 1e-12))

9. ML Applications: Softmax, K-FAC Intuition, and EWC

Softmax blocks, Kronecker factors, and diagonal Fisher penalties show how one geometric object gets reused across very different ML subsystems.

Code cell 49

header("Single-example softmax Fisher block")
z = np.array([1.2, -0.4, 0.8])
p = np.exp(z - np.max(z))
p /= p.sum()
fisher_softmax = np.diag(p) - np.outer(p, p)
print("softmax probabilities:", p)
print("Fisher block =")
print(fisher_softmax)
eigvals = np.linalg.eigvalsh(fisher_softmax)
print("eigenvalues:", eigvals)
check_true("Softmax Fisher block is PSD", np.all(eigvals >= -1e-10))

Code cell 50

header("K-FAC-style Kronecker factor intuition for a linear layer")
rng = np.random.default_rng(21)
A = rng.normal(size=(300, 4))
G = rng.normal(size=(300, 3))
A_cov = A.T @ A / len(A)
G_cov = G.T @ G / len(G)
kron_shape = (A_cov.shape[0] * G_cov.shape[0], A_cov.shape[1] * G_cov.shape[1])
print("Activation covariance shape:", A_cov.shape)
print("Gradient covariance shape  :", G_cov.shape)
print("Kronecker block shape      :", kron_shape)
check_true("K-FAC factors are much smaller than the full parameter block", A_cov.size + G_cov.size < kron_shape[0] * kron_shape[1])

Code cell 51

header("Diagonal Fisher importance for EWC")
rng = np.random.default_rng(31)
X1 = rng.normal(size=(300, 2))
w_star = np.array([1.0, -0.8])
p1 = sigmoid(X1 @ w_star)
y1 = rng.binomial(1, p1)
diag_fisher = np.mean((((y1 - p1)[:, None]) * X1) ** 2, axis=0)
print("Diagonal Fisher importance:", diag_fisher)
check_true("Importance weights are nonnegative", np.all(diag_fisher >= 0))

Code cell 52

header("EWC-style penalty grows with movement in important directions")
theta_old = np.array([1.0, -0.8])
theta_new = np.array([1.2, -0.1])
penalty = 0.5 * np.sum(diag_fisher * (theta_new - theta_old) ** 2)
theta_new2 = np.array([1.02, -0.75])
penalty2 = 0.5 * np.sum(diag_fisher * (theta_new2 - theta_old) ** 2)
print("large-move penalty =", penalty)
print("small-move penalty =", penalty2)
check_true("Penalty increases with movement away from old parameters", penalty > penalty2)

Code cell 54

header("Takeaway summary")
print("1. Fisher information is local KL curvature.")
print("2. It accumulates additively under independence.")
print("3. It transforms as a metric under reparameterization.")
print("4. The true Fisher, empirical Fisher, and Hessian are related but distinct.")
print("5. Modern ML uses Fisher geometry in natural gradient, K-FAC, EWC, and score-based modeling.")
check_true("Notebook reached the final recap cell", True)

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue