Theory NotebookMath for LLMs

KL Divergence

Information Theory / KL Divergence

Run notebook
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

KL Divergence — Theory Notebook

"The most important single quantity in information theory and in machine learning is the Kullback-Leibler divergence." — David MacKay

Interactive exploration of KL divergence: from first principles through VAEs, RLHF, and knowledge distillation. Run cells top-to-bottom.

Companion: notes.md | exercises.ipynb

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import scipy.stats as stats
from scipy.special import rel_entr

try:
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    try:
        import seaborn as sns
        sns.set_theme(style='whitegrid', palette='colorblind')
        HAS_SNS = True
    except ImportError:
        plt.style.use('seaborn-v0_8-whitegrid')
        HAS_SNS = False
    mpl.rcParams.update({
        'figure.figsize': (10, 6), 'figure.dpi': 120,
        'font.size': 13, 'axes.titlesize': 15, 'axes.labelsize': 13,
        'lines.linewidth': 2.0, 'axes.spines.top': False, 'axes.spines.right': False,
    })
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

COLORS = {
    'primary':   '#0077BB',
    'secondary': '#EE7733',
    'tertiary':  '#009988',
    'error':     '#CC3311',
    'neutral':   '#555555',
    'highlight': '#EE3377',
}

np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)
print('Setup complete.')
print(f'Matplotlib: {HAS_MPL}, Seaborn: {HAS_SNS}')

1. Intuition: What Is KL Divergence?

KL divergence DKL(pq)D_{\mathrm{KL}}(p \| q) measures the expected excess code length when encoding data from pp using a code optimized for qq. Equivalently, it is the expected log-likelihood ratio Ep[log(p/q)]\mathbb{E}_p[\log(p/q)].

DKL(pq)=xp(x)logp(x)q(x)D_{\mathrm{KL}}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)}

Below we compute KL for a concrete weather-forecast example and verify the coding interpretation.

Code cell 5

# === 1.1 Weather Forecast Example ===

# True distribution (nature) vs forecast distribution (model)
outcomes = ['Sunny', 'Cloudy', 'Rain']
p = np.array([0.50, 0.30, 0.20])  # true
q = np.array([0.70, 0.20, 0.10])  # forecast

# KL divergence: sum of p * log(p/q)
kl_pq = np.sum(p * np.log(p / q))
kl_qp = np.sum(q * np.log(q / p))

print('=== Weather Forecast KL Example ===')
print(f'Outcomes: {outcomes}')
print(f'True  p:  {p}')
print(f'Model q:  {q}')
print()
print(f'D_KL(p || q) = {kl_pq:.4f} nats  [{kl_pq/np.log(2):.4f} bits]')
print(f'D_KL(q || p) = {kl_qp:.4f} nats  [{kl_qp/np.log(2):.4f} bits]')
print(f'Asymmetric: D_KL(p||q) != D_KL(q||p): {not np.isclose(kl_pq, kl_qp)}')

# Coding interpretation: optimal code lengths
H_p  = -np.sum(p * np.log(p))          # entropy of p
H_pq = -np.sum(p * np.log(q))           # cross-entropy H(p,q)
print()
print(f'H(p)   = {H_p:.4f} nats  (optimal code for p)')
print(f'H(p,q) = {H_pq:.4f} nats  (code for p using q)')
print(f'Extra  = {H_pq - H_p:.4f} nats  (= D_KL(p||q): {kl_pq:.4f})')

ok = np.isclose(H_pq - H_p, kl_pq, atol=1e-10)
print(f'\nPASS: H(p,q) - H(p) = D_KL(p||q)' if ok else 'FAIL')

2. Non-Negativity: Gibbs' Inequality

DKL(pq)0D_{\mathrm{KL}}(p \| q) \ge 0 with equality iff p=qp = q.

Proof sketch: By Jensen's inequality on concave ln\ln: Ep[ln(q/p)]ln(Ep[q/p])=ln1=0\mathbb{E}_p[\ln(q/p)] \le \ln(\mathbb{E}_p[q/p]) = \ln 1 = 0.

Below we verify numerically over random probability vectors and visualize the function f(t)=lntt1f(t) = -\ln t \ge t - 1 that underpins the proof.

Code cell 7

# === 2.1 Numerical Verification of Non-Negativity ===

np.random.seed(42)
n_tests = 10000
n_symbols = 5

all_kls = []
for _ in range(n_tests):
    p = np.random.dirichlet(np.ones(n_symbols))
    q = np.random.dirichlet(np.ones(n_symbols))
    kl = np.sum(p * np.log(p / q))
    all_kls.append(kl)

all_kls = np.array(all_kls)
print(f'Tests: {n_tests:,} random pairs of {n_symbols}-symbol distributions')
print(f'Min D_KL: {all_kls.min():.2e}  (should be >= 0)')
print(f'Max D_KL: {all_kls.max():.4f}')
print(f'Mean D_KL: {all_kls.mean():.4f}')
print(f'All non-negative: {(all_kls >= -1e-12).all()}')

# Verify D_KL(p||p) = 0 for several p
zero_kls = []
for _ in range(100):
    p = np.random.dirichlet(np.ones(n_symbols))
    kl = np.sum(p * np.log(p / p))
    zero_kls.append(kl)
print(f'\nD_KL(p||p) for 100 random p: max = {max(zero_kls):.2e} (should be ~0)')
print('PASS - Gibbs inequality verified numerically')

Code cell 8

# === 2.2 Visualization of ln(t) <= t - 1 ===

if HAS_MPL:
    t = np.linspace(0.01, 4, 400)
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: ln(t) <= t-1
    ax = axes[0]
    ax.plot(t, np.log(t),   color=COLORS['primary'],   label=r'$\ln t$')
    ax.plot(t, t - 1,       color=COLORS['secondary'], label=r'$t - 1$', linestyle='--')
    ax.axvline(1, color=COLORS['neutral'], linewidth=0.8, linestyle=':')
    ax.axhline(0, color=COLORS['neutral'], linewidth=0.8, linestyle=':')
    ax.fill_between(t, np.log(t), t - 1, alpha=0.15, color=COLORS['error'],
                    label=r'gap $= (t-1) - \ln t \geq 0$')
    ax.set_title(r'$\ln t \leq t - 1$ (equality at $t=1$)')
    ax.set_xlabel(r'$t = q(x)/p(x)$')
    ax.set_ylabel('Value')
    ax.legend()
    ax.set_xlim(0, 4)
    ax.set_ylim(-2, 3)

    # Right: D_KL distribution over random pairs
    ax = axes[1]
    ax.hist(all_kls, bins=60, density=True, color=COLORS['primary'], alpha=0.75,
            edgecolor='white')
    ax.axvline(0, color=COLORS['error'], linewidth=2, linestyle='--', label='KL=0 (p=q)')
    ax.set_title(r'Distribution of $D_{\mathrm{KL}}(p\|q)$ over random pairs')
    ax.set_xlabel(r'$D_{\mathrm{KL}}(p\|q)$ [nats]')
    ax.set_ylabel('Density')
    ax.legend()

    fig.tight_layout()
    plt.show()
    print('Figure: ln(t) <= t-1 proof visualization + KL distribution')

3. Asymmetry

DKL(pq)DKL(qp)D_{\mathrm{KL}}(p\|q) \ne D_{\mathrm{KL}}(q\|p) in general. This is not a flaw — the two directions answer different questions. Below we sweep θ\theta for Bernoulli distributions and visualize both directions.

Code cell 10

# === 3.1 Asymmetry: Bernoulli KL vs reverse KL ===

theta = np.linspace(0.01, 0.99, 200)
q0 = 0.5  # fixed reference

# Forward KL: D_KL(Bern(theta) || Bern(0.5))
p_ = np.stack([theta, 1 - theta], axis=1)
q_ = np.array([[q0, 1 - q0]] * len(theta))

kl_forward = np.sum(p_ * np.log(p_ / q_), axis=1)
kl_reverse = np.sum(q_ * np.log(q_ / p_), axis=1)
jsd = 0.5 * kl_forward + 0.5 * kl_reverse  # approximate JSD using mixture

# Compute true JSD
m_ = 0.5 * p_ + 0.5 * q_
jsd_true = 0.5 * np.sum(p_ * np.log(p_ / m_), axis=1) + \
           0.5 * np.sum(q_ * np.log(q_ / m_), axis=1)

print('Bernoulli KL vs reverse KL (q=Bern(0.5))')
print(f'theta=0.1: D_KL(p||q)={kl_forward[10]:.4f}, D_KL(q||p)={kl_reverse[10]:.4f}')
print(f'theta=0.5: D_KL(p||q)={kl_forward[99]:.4f}, D_KL(q||p)={kl_reverse[99]:.4f}')
print(f'theta=0.9: D_KL(p||q)={kl_forward[179]:.4f}, D_KL(q||p)={kl_reverse[179]:.4f}')

if HAS_MPL:
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(theta, kl_forward, color=COLORS['primary'],   label=r'$D_{\mathrm{KL}}(p\|q)$ forward')
    ax.plot(theta, kl_reverse, color=COLORS['secondary'], label=r'$D_{\mathrm{KL}}(q\|p)$ reverse', linestyle='--')
    ax.plot(theta, jsd_true,   color=COLORS['tertiary'],  label='JSD (symmetric)', linestyle=':')
    ax.axhline(np.log(2), color=COLORS['neutral'], linewidth=0.8, linestyle=':', label=r'$\ln 2$ (JSD bound)')
    ax.set_title(r'$D_{\mathrm{KL}}(\mathrm{Bern}(\theta) \| \mathrm{Bern}(0.5))$: asymmetry')
    ax.set_xlabel(r'$\theta$ (parameter of $p$)')
    ax.set_ylabel('Divergence [nats]')
    ax.legend()
    fig.tight_layout()
    plt.show()

4. Forward KL vs Reverse KL

The most important practical distinction: which direction to minimize.

  • Forward KL DKL(pq)D_{\mathrm{KL}}(p\|q): expectation under pp. Zero-avoiding (mass-covering). Mean-seeking.
  • Reverse KL DKL(qp)D_{\mathrm{KL}}(q\|p): expectation under qq. Zero-forcing (mass-concentrating). Mode-seeking.

We fit a Gaussian q=N(μ,σ2)q = \mathcal{N}(\mu, \sigma^2) to a bimodal pp using both directions and visualize the difference.

Code cell 12

# === 4.1 Forward vs Reverse KL on Bimodal Distribution ===

from scipy.optimize import minimize_scalar, minimize

x = np.linspace(-8, 8, 2000)
dx = x[1] - x[0]

# True bimodal distribution p
p_true = 0.5 * np.exp(-0.5*(x+3)**2) / np.sqrt(2*np.pi) + \
         0.5 * np.exp(-0.5*(x-3)**2) / np.sqrt(2*np.pi)
p_true = p_true / (p_true.sum() * dx)  # normalize

def gaussian_pdf(x, mu, sigma):
    return np.exp(-0.5*((x-mu)/sigma)**2) / (sigma * np.sqrt(2*np.pi))

def forward_kl(params):
    mu, log_sigma = params
    sigma = np.exp(log_sigma)
    q = gaussian_pdf(x, mu, sigma)
    q = np.maximum(q, 1e-300)
    mask = p_true > 1e-300
    return np.sum(p_true[mask] * np.log(p_true[mask] / q[mask])) * dx

def reverse_kl(params):
    mu, log_sigma = params
    sigma = np.exp(log_sigma)
    q = gaussian_pdf(x, mu, sigma)
    q = np.maximum(q, 1e-300)
    p = np.maximum(p_true, 1e-300)
    return np.sum(q * np.log(q / p)) * dx

# Minimize forward KL (starting near mean of bimodal = 0)
res_fwd = minimize(forward_kl, [0.0, np.log(3.0)], method='Nelder-Mead')
mu_fwd, sig_fwd = res_fwd.x[0], np.exp(res_fwd.x[1])

# Minimize reverse KL (starting near mode at +3)
res_rev = minimize(reverse_kl, [3.0, np.log(1.0)], method='Nelder-Mead')
mu_rev, sig_rev = res_rev.x[0], np.exp(res_rev.x[1])

print('=== Forward KL minimizer (mass-covering) ===')
print(f'  mu* = {mu_fwd:.3f}, sigma* = {sig_fwd:.3f}')
print(f'  Expected: mu~0 (mean of bimodal), sigma~3 (wide to cover both modes)')
print()
print('=== Reverse KL minimizer (mode-seeking) ===')
print(f'  mu* = {mu_rev:.3f}, sigma* = {sig_rev:.3f}')
print(f'  Expected: mu~+/-3 (one mode), sigma~1 (tight around that mode)')

Code cell 13

# === 4.2 Visualize Forward vs Reverse KL Results ===

if HAS_MPL:
    q_fwd = gaussian_pdf(x, mu_fwd, sig_fwd)
    q_rev = gaussian_pdf(x, mu_rev, sig_rev)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle('Fitting Gaussian to Bimodal: Forward vs Reverse KL', fontsize=15)

    for ax, q_fit, label, color, title in [
        (axes[0], q_fwd, f'q=N({mu_fwd:.1f}, {sig_fwd:.1f}²)', COLORS['secondary'],
         r'Forward KL $D_{\mathrm{KL}}(p\|q)$: mean-seeking'),
        (axes[1], q_rev, f'q=N({mu_rev:.1f}, {sig_rev:.1f}²)', COLORS['error'],
         r'Reverse KL $D_{\mathrm{KL}}(q\|p)$: mode-seeking'),
    ]:
        ax.fill_between(x, p_true, alpha=0.25, color=COLORS['primary'], label='True p (bimodal)')
        ax.plot(x, p_true, color=COLORS['primary'], linewidth=2)
        ax.plot(x, q_fit,  color=color, linewidth=2.5, linestyle='--', label=label)
        ax.set_title(title)
        ax.set_xlabel('x')
        ax.set_ylabel('Density')
        ax.legend()
        ax.set_xlim(-8, 8)

    fig.tight_layout()
    plt.show()
    print('Key insight: Forward KL averages across modes; Reverse KL collapses to one mode.')

5. KL Between Gaussians — Closed Form

For p=N(μ1,σ12)p = \mathcal{N}(\mu_1, \sigma_1^2) and q=N(μ2,σ22)q = \mathcal{N}(\mu_2, \sigma_2^2):

DKL(pq)=lnσ2σ1+σ12+(μ1μ2)22σ2212D_{\mathrm{KL}}(p\|q) = \ln\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}

VAE encoder KL: When q=N(0,1)q = \mathcal{N}(0,1):

DKL(N(μ,σ2)N(0,1))=12(μ2+σ2lnσ21)D_{\mathrm{KL}}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0,1)) = \frac{1}{2}(\mu^2 + \sigma^2 - \ln\sigma^2 - 1)

Code cell 15

# === 5.1 KL Between Gaussians: Formula Verification ===

def kl_gaussians(mu1, sigma1, mu2, sigma2):
    """Closed-form D_KL(N(mu1,sigma1^2) || N(mu2,sigma2^2))"""
    return (np.log(sigma2/sigma1) +
            (sigma1**2 + (mu1-mu2)**2) / (2*sigma2**2) - 0.5)

def kl_to_standard_normal(mu, sigma):
    """D_KL(N(mu,sigma^2) || N(0,1)) = 0.5*(mu^2 + sigma^2 - ln(sigma^2) - 1)"""
    return 0.5 * (mu**2 + sigma**2 - np.log(sigma**2) - 1)

# Test cases
test_cases = [
    (1.0, 1.0, 0.0, 1.0, 'N(1,1) vs N(0,1)'),
    (0.0, 2.0, 0.0, 1.0, 'N(0,2) vs N(0,1)'),
    (2.0, 0.5, 1.0, 1.5, 'N(2,0.25) vs N(1,2.25)'),
    (0.0, 1.0, 0.0, 1.0, 'N(0,1) vs N(0,1) [should be 0]'),
]

print('=== KL Between Gaussians ===')
for mu1, s1, mu2, s2, desc in test_cases:
    kl = kl_gaussians(mu1, s1, mu2, s2)
    print(f"{desc}: D_KL = {kl:.6f}")

print()
print('=== VAE Encoder KL to N(0,1) ===')
vae_cases = [(0.0, 1.0), (1.0, 1.0), (2.0, 0.5), (0.5, 1.5)]
for mu, sigma in vae_cases:
    kl_formula = kl_to_standard_normal(mu, sigma)
    kl_general = kl_gaussians(mu, sigma, 0.0, 1.0)
    match = np.isclose(kl_formula, kl_general, atol=1e-10)
    print(f'mu={mu}, sigma={sigma}: formula={kl_formula:.6f}, general={kl_general:.6f}, match={match}')

Code cell 16

# === 5.2 VAE KL Surface Visualization ===

if HAS_MPL:
    mu_grid = np.linspace(-3, 3, 100)
    sigma_grid = np.linspace(0.1, 3, 100)
    MU, SIGMA = np.meshgrid(mu_grid, sigma_grid)
    KL = kl_to_standard_normal(MU, SIGMA)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Heatmap
    ax = axes[0]
    im = ax.contourf(MU, SIGMA, KL, levels=30, cmap='plasma')
    fig.colorbar(im, ax=ax, label=r'$D_{\mathrm{KL}}$')
    ax.contour(MU, SIGMA, KL, levels=[0.5, 1.0, 2.0], colors='white', linewidths=0.8)
    ax.plot(0, 1, 'w*', markersize=15, label='minimum (mu=0, sigma=1)')
    ax.set_title(r'$D_{\mathrm{KL}}(\mathcal{N}(\mu,\sigma^2)\|\mathcal{N}(0,1))$')
    ax.set_xlabel(r'$\mu$')
    ax.set_ylabel(r'$\sigma$')
    ax.legend()

    # Slices
    ax = axes[1]
    for sigma_val, color in zip([0.5, 1.0, 2.0], [COLORS['error'], COLORS['primary'], COLORS['tertiary']]):
        kl_slice = kl_to_standard_normal(mu_grid, sigma_val)
        ax.plot(mu_grid, kl_slice, color=color, label=fr'$\sigma={sigma_val}$')
    ax.axvline(0, color=COLORS['neutral'], linewidth=0.8, linestyle=':')
    ax.set_title(r'$D_{\mathrm{KL}}$ vs $\mu$ for fixed $\sigma$')
    ax.set_xlabel(r'$\mu$')
    ax.set_ylabel(r'$D_{\mathrm{KL}}$')
    ax.legend()

    fig.tight_layout()
    plt.show()
    print('Minimum at mu=0, sigma=1 (equals prior): KL=0')

6. f-Divergences: KL as a Special Case

KL divergence is one member of the Csiszár f-divergence family:

Df(pq)=xq(x)f ⁣(p(x)q(x))D_f(p\|q) = \sum_x q(x)\, f\!\left(\frac{p(x)}{q(x)}\right)

Different choices of ff give KL, reverse-KL, Hellinger, total variation, chi-squared, and JSD.

Below we compute all five for Bernoulli distributions and verify Pinsker's inequality.

Code cell 18

# === 6.1 f-Divergence Family for Bernoulli Distributions ===

theta = np.linspace(0.001, 0.999, 500)
q0 = 0.5

p_bern = np.stack([theta, 1 - theta], axis=1)
q_bern = np.array([[q0, 1 - q0]] * len(theta))

def safe_kl(p, q):
    mask = (p > 0) & (q > 0)
    return np.sum(np.where(mask, p * np.log(p / q), 0), axis=1)

kl_fwd = safe_kl(p_bern, q_bern)
kl_rev = safe_kl(q_bern, p_bern)

# Hellinger^2: sum(sqrt(p) - sqrt(q))^2
hell2 = np.sum((np.sqrt(p_bern) - np.sqrt(q_bern))**2, axis=1)

# Total variation: 0.5 * sum|p - q|
tv = 0.5 * np.sum(np.abs(p_bern - q_bern), axis=1)

# JSD
m_bern = 0.5 * p_bern + 0.5 * q_bern
jsd = 0.5 * safe_kl(p_bern, m_bern) + 0.5 * safe_kl(q_bern, m_bern)

# Pinsker's inequality: TV^2 <= 0.5 * D_KL(p||q)
pinsker_holds = (tv**2 <= 0.5 * kl_fwd + 1e-12).all()
print(f'Pinsker: TV^2 <= 0.5*D_KL(p||q) holds for all theta: {pinsker_holds}')
print(f'JSD bounded by ln(2)={np.log(2):.4f}: {(jsd <= np.log(2) + 1e-12).all()}')
print(f'TV in [0,1]: {(tv >= -1e-12).all() and (tv <= 1 + 1e-12).all()}')

Code cell 19

# === 6.2 Visualize f-Divergence Family ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    ax = axes[0]
    ax.plot(theta, kl_fwd, color=COLORS['primary'],   label=r'$D_{\mathrm{KL}}(p\|q)$ forward')
    ax.plot(theta, kl_rev, color=COLORS['secondary'], label=r'$D_{\mathrm{KL}}(q\|p)$ reverse', linestyle='--')
    ax.plot(theta, jsd,    color=COLORS['tertiary'],  label='JSD', linestyle=':')
    ax.plot(theta, hell2,  color=COLORS['highlight'], label=r'Hellinger$^2$', linestyle='-.')
    ax.plot(theta, tv,     color=COLORS['neutral'],   label='Total Variation', linestyle=(0,(3,1,1,1)))
    ax.axvline(0.5, color='gray', linewidth=0.8, linestyle=':')
    ax.set_title(r'f-Divergence Family: $\mathrm{Bern}(\theta)$ vs $\mathrm{Bern}(0.5)$')
    ax.set_xlabel(r'$\theta$')
    ax.set_ylabel('Divergence [nats]')
    ax.legend(fontsize=10)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 2.5)

    # Pinsker's inequality visualization
    ax = axes[1]
    ax.plot(theta, tv**2,         color=COLORS['error'],   label=r'$\mathrm{TV}^2$')
    ax.plot(theta, 0.5 * kl_fwd,  color=COLORS['primary'], label=r'$\frac{1}{2}D_{\mathrm{KL}}(p\|q)$', linestyle='--')
    ax.fill_between(theta, tv**2, 0.5*kl_fwd, alpha=0.15, color=COLORS['primary'],
                    label='Gap (Pinsker bound)')
    ax.set_title("Pinsker's Inequality: $\\mathrm{TV}^2 \\leq \\frac{1}{2}D_{\\mathrm{KL}}$")
    ax.set_xlabel(r'$\theta$')
    ax.set_ylabel('Value')
    ax.legend()

    fig.tight_layout()
    plt.show()

7. Chain Rule and Data Processing Inequality

Chain rule: DKL(P(X,Y)Q(X,Y))=DKL(PXQX)+EPX[DKL(PYXQYX)]D_{\mathrm{KL}}(P(X,Y)\|Q(X,Y)) = D_{\mathrm{KL}}(P_X\|Q_X) + \mathbb{E}_{P_X}[D_{\mathrm{KL}}(P_{Y|X}\|Q_{Y|X})]

Data processing: DKL(pTqT)DKL(pq)D_{\mathrm{KL}}(p_T\|q_T) \le D_{\mathrm{KL}}(p\|q) for any stochastic map TT.

Code cell 21

# === 7.1 Chain Rule Verification ===

# Joint distributions on {0,1}x{0,1}
# P: correlated (X=Y more likely)
P = np.array([[0.40, 0.10],   # P(X=0,Y=0), P(X=0,Y=1)
               [0.10, 0.40]])  # P(X=1,Y=0), P(X=1,Y=1)

# Q: uniform
Q = np.ones((2,2)) * 0.25

def kl_joint(P, Q):
    """D_KL(P || Q) for joint distributions"""
    mask = P > 0
    return np.sum(P[mask] * np.log(P[mask] / Q[mask]))

# Direct computation
kl_direct = kl_joint(P, Q)

# Marginals
P_X = P.sum(axis=1)  # [P(X=0), P(X=1)]
Q_X = Q.sum(axis=1)
kl_marginal = np.sum(P_X * np.log(P_X / Q_X))

# Conditional KL: E_{P_X}[D_KL(P_{Y|X=x} || Q_{Y|X=x})]
kl_conditional = 0.0
for x in range(2):
    P_yx = P[x, :] / P_X[x]  # P(Y|X=x)
    Q_yx = Q[x, :] / Q_X[x]  # Q(Y|X=x)
    kl_yx = np.sum(P_yx * np.log(P_yx / Q_yx))
    kl_conditional += P_X[x] * kl_yx

kl_chain = kl_marginal + kl_conditional

print('=== Chain Rule for KL Divergence ===')
print(f'D_KL(P(X,Y) || Q(X,Y)) directly: {kl_direct:.6f}')
print(f'D_KL(P_X || Q_X):                {kl_marginal:.6f}')
print(f'E[D_KL(P_Y|X || Q_Y|X)]:         {kl_conditional:.6f}')
print(f'Sum (chain rule):                 {kl_chain:.6f}')
ok = np.isclose(kl_direct, kl_chain, atol=1e-10)
print(f'\nPASS - Chain rule verified: {ok}' if ok else 'FAIL')

Code cell 22

# === 7.2 Data Processing Inequality ===

# Original distributions over {0,1,2}
p_orig = np.array([0.5, 0.3, 0.2])
q_orig = np.array([0.2, 0.3, 0.5])

kl_orig = np.sum(p_orig * np.log(p_orig / q_orig))

# Stochastic map T: {0,1,2} -> {A, B}
# T[i,j] = P(output=j | input=i)
T = np.array([[0.8, 0.2],  # x=0: 80% A, 20% B
              [0.5, 0.5],  # x=1: 50% A, 50% B
              [0.1, 0.9]]) # x=2: 10% A, 90% B

p_T = p_orig @ T  # induced distribution on {A,B} under p
q_T = q_orig @ T  # induced distribution on {A,B} under q

kl_after = np.sum(p_T * np.log(p_T / q_T))

print('=== Data Processing Inequality ===')
print(f'p_orig: {p_orig}')
print(f'q_orig: {q_orig}')
print(f'D_KL(p||q) BEFORE processing: {kl_orig:.6f} nats')
print()
print(f'After stochastic map T:')
print(f'p_T: {p_T}')
print(f'q_T: {q_T}')
print(f'D_KL(p_T||q_T) AFTER processing: {kl_after:.6f} nats')
print(f'Reduction: {kl_orig - kl_after:.6f} nats ({100*(kl_orig-kl_after)/kl_orig:.1f}% lost)')
dpi_holds = kl_after <= kl_orig + 1e-12
print(f'\nPASS - DPI holds: {kl_after:.6f} <= {kl_orig:.6f}' if dpi_holds else 'FAIL')

8. Applications: MLE = Minimizing KL

Maximum likelihood estimation is equivalent to minimizing DKL(pdatapθ)D_{\mathrm{KL}}(p_{\mathrm{data}}\|p_{\boldsymbol{\theta}}):

argmaxθilogpθ(x(i))=argminθDKL(p^npθ)\arg\max_{\boldsymbol{\theta}} \sum_i \log p_{\boldsymbol{\theta}}(x^{(i)}) = \arg\min_{\boldsymbol{\theta}} D_{\mathrm{KL}}(\hat{p}_n \| p_{\boldsymbol{\theta}})

We demonstrate this by fitting a Gaussian to data and showing convergence of both objectives.

Code cell 24

# === 8.1 MLE = Minimizing KL: Gaussian Fitting ===

from scipy.optimize import minimize

np.random.seed(42)
# True distribution: mixture of Gaussians
n = 1000
data = np.concatenate([
    np.random.normal(-1, 0.8, n//2),
    np.random.normal(2, 0.5, n//2)
])

# MLE for Gaussian: closed form is sample mean/variance
mu_mle = data.mean()
sigma_mle = data.std()

print(f'Data: n={n}, true means=[-1, 2], true stds=[0.8, 0.5]')
print(f'MLE estimates: mu={mu_mle:.4f}, sigma={sigma_mle:.4f}')

# Compute KL: D_KL(empirical || N(mu, sigma^2))
x_grid = np.linspace(-4, 5, 1000)
dx = x_grid[1] - x_grid[0]
kde = sum(np.exp(-0.5*((x_grid - xi)/0.3)**2) / (0.3*np.sqrt(2*np.pi)) for xi in data) / n
kde = np.maximum(kde, 1e-300)

def kl_from_empirical(params):
    mu, log_sigma = params
    sigma = np.exp(log_sigma)
    model = np.exp(-0.5*((x_grid-mu)/sigma)**2) / (sigma*np.sqrt(2*np.pi))
    model = np.maximum(model, 1e-300)
    return np.sum(kde * np.log(kde/model)) * dx

# Verify: MLE minimizes KL
kl_at_mle = kl_from_empirical([mu_mle, np.log(sigma_mle)])
kl_perturbed = kl_from_empirical([mu_mle + 0.5, np.log(sigma_mle)])
print(f'\nKL at MLE: {kl_at_mle:.4f}')
print(f'KL perturbed (mu+0.5): {kl_perturbed:.4f}')
print(f'MLE is better: {kl_at_mle < kl_perturbed}')
print('PASS - MLE minimizes KL from empirical distribution')

9. RLHF: Optimal Policy from KL Constraint

The RLHF objective maxπE[r]βDKL(ππref)\max_\pi \mathbb{E}[r] - \beta D_{\mathrm{KL}}(\pi\|\pi_{\mathrm{ref}}) has the closed-form optimal solution:

π(yx)=1Z(x)πref(yx)er(x,y)/β\pi^*(y|x) = \frac{1}{Z(x)}\pi_{\mathrm{ref}}(y|x)\, e^{r(x,y)/\beta}

We verify this and explore how β\beta controls the trade-off.

Code cell 26

# === 9.1 RLHF Optimal Policy Computation ===

# Toy example: 5 response candidates
np.random.seed(0)
n_responses = 5
responses = [f'Response_{i}' for i in range(n_responses)]

# Reference policy (e.g., base LLM)
pi_ref = np.array([0.35, 0.25, 0.20, 0.12, 0.08])
assert np.isclose(pi_ref.sum(), 1.0)

# Reward function (human preference scores)
rewards = np.array([1.0, 2.5, -0.5, 3.0, 0.5])

print('=== RLHF Optimal Policy ===')
print(f'Reference policy: {pi_ref}')
print(f'Rewards:          {rewards}')
print()

for beta in [0.1, 0.5, 1.0, 5.0]:
    # Optimal policy: pi* proportional to pi_ref * exp(r/beta)
    unnorm = pi_ref * np.exp(rewards / beta)
    Z = unnorm.sum()
    pi_star = unnorm / Z

    kl = np.sum(pi_star * np.log(pi_star / pi_ref))
    E_r = np.sum(pi_star * rewards)

    print(f'beta={beta:.1f}: pi*={np.round(pi_star,3)}, E[r]={E_r:.3f}, '
          f'D_KL={kl:.3f}')

Code cell 27

# === 9.2 Sweep Beta: Reward vs KL Trade-off ===

betas = np.logspace(-2, 2, 100)
E_rewards = []
kl_values = []

for beta in betas:
    unnorm = pi_ref * np.exp(rewards / beta)
    pi_star = unnorm / unnorm.sum()
    E_rewards.append(np.sum(pi_star * rewards))
    kl_values.append(np.sum(pi_star * np.log(pi_star / pi_ref)))

E_rewards = np.array(E_rewards)
kl_values = np.array(kl_values)

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    ax.semilogx(betas, E_rewards, color=COLORS['primary'])
    ax.axhline(np.max(rewards), color=COLORS['neutral'], linestyle=':', label='max reward')
    ax.axhline(np.sum(pi_ref*rewards), color=COLORS['secondary'], linestyle='--', label='ref policy reward')
    ax.set_title(r'Expected Reward vs $\beta$')
    ax.set_xlabel(r'$\beta$ (KL coefficient)')
    ax.set_ylabel(r'$\mathbb{E}_{\pi^*}[r]$')
    ax.legend()

    ax = axes[1]
    ax.plot(kl_values, E_rewards, color=COLORS['highlight'])
    ax.scatter([0], [np.sum(pi_ref*rewards)], color=COLORS['secondary'], s=100, zorder=5,
               label='reference policy (KL=0)')
    ax.set_title(r'Reward vs KL: Pareto Frontier')
    ax.set_xlabel(r'$D_{\mathrm{KL}}(\pi^*\|\pi_{\mathrm{ref}})$ [nats]')
    ax.set_ylabel(r'$\mathbb{E}[r]$')
    ax.legend()

    fig.tight_layout()
    plt.show()
    print('As beta decreases: higher reward but larger KL (more deviation from reference)')

10. Knowledge Distillation: Forward vs Reverse KL

Distillation trains a student pSp_S to match a teacher pTp_T using forward KL DKL(pTpS)D_{\mathrm{KL}}(p_T\|p_S) — the student must cover the teacher's full distribution. Temperature τ>1\tau > 1 softens both distributions, revealing 'dark knowledge'.

Code cell 29

# === 10.1 Knowledge Distillation: Temperature Scaling ===

# Teacher and student logits for 5 classes
z_teacher = np.array([3.0, 1.5, 0.5, -0.5, -1.5])
z_student  = np.array([2.0, 1.0, 0.5, -0.3, -1.2])

def softmax(z, tau=1.0):
    z_shifted = (z - z.max()) / tau
    exp_z = np.exp(z_shifted)
    return exp_z / exp_z.sum()

def kl_categorical(p, q, eps=1e-10):
    p, q = p + eps, q + eps
    p, q = p/p.sum(), q/q.sum()
    return np.sum(p * np.log(p/q))

print('=== Knowledge Distillation Temperature Analysis ===')
print(f'Teacher logits: {z_teacher}')
print(f'Student logits: {z_student}')
print()
print(f"{"tau":<6} {"H(p_T)":<10} {"KL(T->S)":<12} {"KL(S->T)":<12} {"Asym?"}")
print('-' * 55)

for tau in [1.0, 2.0, 3.0, 5.0, 10.0]:
    p_T = softmax(z_teacher, tau)
    p_S = softmax(z_student,  tau)
    H_T  = -np.sum(p_T * np.log(p_T))
    kl_ts = kl_categorical(p_T, p_S)
    kl_st = kl_categorical(p_S, p_T)
    print(f"{tau:<6.1f} {H_T:<10.4f} {kl_ts:<12.4f} {kl_st:<12.4f} {not np.isclose(kl_ts, kl_st)}")

Code cell 30

# === 10.2 Dark Knowledge Visualization ===

classes = ['Cat', 'Tiger', 'Dog', 'Bird', 'Fish']
z_teacher = np.array([3.0, 1.5, 0.5, -0.5, -1.5])  # cat looks like tiger

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    taus = [1.0, 2.0, 4.0]
    ax = axes[0]
    x_pos = np.arange(len(classes))
    width = 0.25
    for i, tau in enumerate(taus):
        p_T = softmax(z_teacher, tau)
        ax.bar(x_pos + i*width, p_T, width=width,
               color=[COLORS['primary'], COLORS['secondary'], COLORS['tertiary']][i],
               alpha=0.85, label=fr'$\tau={tau}$')
    ax.set_xticks(x_pos + width)
    ax.set_xticklabels(classes)
    ax.set_title('Teacher Soft Labels at Different Temperatures')
    ax.set_ylabel('Probability')
    ax.legend()

    # KL vs temperature
    ax = axes[1]
    tau_range = np.linspace(0.5, 10, 100)
    kl_fwd_tau = [kl_categorical(softmax(z_teacher, t), softmax(z_student, t)) for t in tau_range]
    kl_rev_tau = [kl_categorical(softmax(z_student, t), softmax(z_teacher, t)) for t in tau_range]
    ax.plot(tau_range, kl_fwd_tau, color=COLORS['primary'], label=r'$D_{\mathrm{KL}}(p_T\|p_S)$ (distillation)')
    ax.plot(tau_range, kl_rev_tau, color=COLORS['secondary'], label=r'$D_{\mathrm{KL}}(p_S\|p_T)$ (reverse)', linestyle='--')
    ax.set_title(r'Distillation KL vs Temperature $\tau$')
    ax.set_xlabel(r'Temperature $\tau$')
    ax.set_ylabel(r'$D_{\mathrm{KL}}$ [nats]')
    ax.legend()

    fig.tight_layout()
    plt.show()
    print('Higher temperature: more entropy in teacher, larger KL difference between forward/reverse')

11. Variational Autoencoders: ELBO Decomposition

The VAE ELBO is:

L(ϕ,θ;x)=Eqϕ[logpθ(xz)]DKL(qϕ(zx)p(z))\mathcal{L}(\boldsymbol{\phi},\boldsymbol{\theta};\mathbf{x}) = \mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x}|\mathbf{z})] - D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z}))

We simulate VAE training on 1D data and track the KL and reconstruction terms.

Code cell 32

# === 11.1 VAE ELBO Simulation (1D) ===

np.random.seed(42)
n_data = 200
data = np.random.normal(2.0, 0.8, n_data)  # true data: N(2, 0.64)

def vae_kl_term(mu, log_var):
    """D_KL(N(mu, exp(log_var)) || N(0,1)) per sample"""
    sigma2 = np.exp(log_var)
    return 0.5 * (mu**2 + sigma2 - log_var - 1)

def vae_reconstruction(x, mu_z, log_var_z, decoder_sigma=0.5):
    """
    E_{q(z|x)}[log p(x|z)] for Gaussian decoder p(x|z) = N(z, decoder_sigma^2)
    """
    # Reparameterize
    n_samples = 50
    eps = np.random.randn(n_samples)
    z_samples = mu_z + np.exp(0.5*log_var_z) * eps
    # E[log N(x | z, sigma^2)] = -0.5*log(2*pi*sigma^2) - (x-z)^2/(2*sigma^2)
    recon = -0.5*np.log(2*np.pi*decoder_sigma**2) - \
            np.mean((x - z_samples)**2) / (2*decoder_sigma**2)
    return recon

# Simulate 'training': encoder mu_phi(x) = alpha*x, log_var_phi = const
# Sweep alpha (encoder weight) and see ELBO components
alphas = np.linspace(0, 1.5, 50)
x_test = 2.0  # test data point
log_var_fixed = np.log(0.5)  # fixed encoder variance

elbos, kl_terms, recon_terms = [], [], []
for alpha in alphas:
    mu_enc = alpha * x_test
    kl = vae_kl_term(mu_enc, log_var_fixed)
    recon = vae_reconstruction(x_test, mu_enc, log_var_fixed)
    elbos.append(recon - kl)
    kl_terms.append(kl)
    recon_terms.append(recon)

best_alpha = alphas[np.argmax(elbos)]
print(f'Data point: x={x_test}, encoder: mu_phi(x) = alpha*x')
print(f'Best alpha (max ELBO): {best_alpha:.3f}')
print(f'At alpha={best_alpha:.2f}:')
print(f'  KL term: {kl_terms[np.argmax(elbos)]:.4f}')
print(f'  Recon:   {recon_terms[np.argmax(elbos)]:.4f}')
print(f'  ELBO:    {max(elbos):.4f}')

Code cell 33

# === 11.2 ELBO Components Visualization ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    ax.plot(alphas, elbos,       color=COLORS['primary'],   label='ELBO = Recon - KL')
    ax.plot(alphas, recon_terms,  color=COLORS['tertiary'],  label='Reconstruction', linestyle='--')
    ax.plot(alphas, [-k for k in kl_terms], color=COLORS['error'], label='-KL term', linestyle=':')
    ax.axvline(best_alpha, color=COLORS['neutral'], linewidth=1, linestyle=':')
    ax.set_title(r'ELBO Components vs Encoder Weight $\alpha$')
    ax.set_xlabel(r'$\alpha$ (encoder $\mu_\phi(x) = \alpha x$)')
    ax.set_ylabel('Nats')
    ax.legend()

    # KL surface over (mu, sigma)
    ax = axes[1]
    mu_r = np.linspace(-3, 3, 80)
    sig_r = np.linspace(0.1, 3, 80)
    MU, SIG = np.meshgrid(mu_r, sig_r)
    KL_grid = 0.5 * (MU**2 + SIG**2 - 2*np.log(SIG) - 1)
    im = ax.contourf(MU, SIG, KL_grid, levels=25, cmap='plasma')
    fig.colorbar(im, ax=ax, label=r'$D_{\mathrm{KL}}$')
    ax.plot(0, 1, 'w*', markersize=15, label='minimum (0, 1)')
    ax.set_title(r'VAE KL: $D_{\mathrm{KL}}(\mathcal{N}(\mu,\sigma^2)\|\mathcal{N}(0,1))$')
    ax.set_xlabel(r'$\mu$'); ax.set_ylabel(r'$\sigma$')
    ax.legend()

    fig.tight_layout()
    plt.show()

12. Rényi Divergence

The order-α\alpha Rényi divergence:

Dα(pq)=1α1logxp(x)αq(x)1αD_\alpha(p\|q) = \frac{1}{\alpha-1}\log\sum_x p(x)^\alpha q(x)^{1-\alpha}

α1\alpha \to 1: reduces to KL divergence (L'Hôpital's rule). Used in differential privacy (Rényi DP composes additively).

Code cell 35

# === 12.1 Rényi Divergence vs KL Divergence ===

def renyi_divergence(p, q, alpha, eps=1e-300):
    """D_alpha(p || q) for discrete distributions"""
    if np.isclose(alpha, 1.0):
        # Limit: KL divergence
        mask = p > eps
        return np.sum(p[mask] * np.log(p[mask] / np.maximum(q[mask], eps)))
    p, q = np.maximum(p, eps), np.maximum(q, eps)
    return np.log(np.sum(p**alpha * q**(1-alpha))) / (alpha - 1)

# Test on Bernoulli(0.8) vs Bernoulli(0.5)
p = np.array([0.8, 0.2])
q = np.array([0.5, 0.5])

alphas = np.array([0.1, 0.5, 0.9, 0.99, 1.0, 1.01, 1.5, 2.0, 5.0])
kl_value = renyi_divergence(p, q, 1.0)

print(f'p = Bern({p[0]}), q = Bern({q[0]})')
print(f'KL divergence (alpha=1): {kl_value:.6f}')
print()
print(f"{"alpha":<8} {"D_alpha(p||q)":<15} {"vs KL"}")
print('-' * 35)
for a in alphas:
    d = renyi_divergence(p, q, a)
    diff = d - kl_value
    print(f"{a:<8.2f} {d:<15.6f} {diff:+.6f}")

Code cell 36

# === 12.2 Rényi Divergence: Family Visualization ===

if HAS_MPL:
    theta = np.linspace(0.02, 0.98, 200)
    q0 = 0.3
    p_arr = np.stack([theta, 1-theta], axis=1)
    q_arr = np.array([[q0, 1-q0]] * len(theta))

    alpha_vals = [0.5, 0.9, 1.0, 1.5, 2.0]
    colors_r = [COLORS['error'], COLORS['secondary'], COLORS['primary'],
                COLORS['tertiary'], COLORS['highlight']]

    fig, ax = plt.subplots(figsize=(10, 6))
    for a, c in zip(alpha_vals, colors_r):
        d_alpha = [renyi_divergence(p_arr[i], q_arr[i], a) for i in range(len(theta))]
        lbl = fr'$D_{{\alpha={a}}}$' + (' (= KL)' if a == 1.0 else '')
        ls = '-' if a == 1.0 else '--'
        ax.plot(theta, d_alpha, color=c, label=lbl, linestyle=ls)

    ax.set_title(fr'Rényi Divergence Family: $\mathrm{{Bern}}(\theta)$ vs $\mathrm{{Bern}}({q0})$')
    ax.set_xlabel(r'$\theta$')
    ax.set_ylabel(r'$D_\alpha(p\|q)$ [nats]')
    ax.legend(fontsize=11)
    ax.set_ylim(0, 3)
    fig.tight_layout()
    plt.show()
    print('D_alpha is monotone increasing in alpha; all equal KL at alpha=1')

13. Information Geometry: I-Projection and Pythagorean Theorem

KL divergence is a Bregman divergence generated by the convex function ϕ(p)=xp(x)lnp(x)\phi(p) = \sum_x p(x)\ln p(x).

The Pythagorean theorem for KL: at the I-projection qq^* of pp onto a constraint set,

DKL(rp)=DKL(rq)+DKL(qp)D_{\mathrm{KL}}(r \| p) = D_{\mathrm{KL}}(r \| q^*) + D_{\mathrm{KL}}(q^* \| p)

Code cell 38

# === 13.1 I-Projection onto a Constraint Set ===

from scipy.optimize import minimize

# Target: project uniform distribution onto the set {q : E_q[X] = 2.0}
# on support {0, 1, 2, 3, 4}
support = np.array([0, 1, 2, 3, 4], dtype=float)
target_mean = 2.0

# Uniform distribution (the 'p' we project from in reverse KL = I-projection)
p_target = np.ones(5) / 5  # uniform

# I-projection: q* = argmin_{E_q[X]=2} D_KL(q || p)
# Solution: q* = p * exp(lambda * X) / Z (exponential tilt)
# Find lambda via constraint
def constraint_violation(lam):
    unnorm = p_target * np.exp(lam * support)
    q = unnorm / unnorm.sum()
    return np.sum(q * support) - target_mean

from scipy.optimize import brentq
lam_star = brentq(constraint_violation, -5, 5)
unnorm_star = p_target * np.exp(lam_star * support)
q_star = unnorm_star / unnorm_star.sum()

print(f'I-projection onto E[X]={target_mean} constraint:')
print(f'Lambda*: {lam_star:.4f}')
print(f'q*:      {q_star.round(4)}')
print(f'E_q*[X]: {np.sum(q_star * support):.4f}  (should be {target_mean})')
print(f'D_KL(q* || p): {np.sum(q_star * np.log(q_star / p_target)):.4f}')

# Verify Pythagorean theorem for another q in constraint set
# r: any other distribution with E[X]=2
r = np.array([0.05, 0.15, 0.55, 0.15, 0.10])  # mean ~2 (approx)
r = r / r.sum()  # normalize
# Adjust to have exact mean 2
print(f'\nE_r[X] = {np.sum(r*support):.4f} (approx constraint)')

Code cell 39

# === 13.2 Bregman / KL as Bregman Divergence Verification ===

# KL(p || q) = B_phi(p, q) where phi(p) = sum p*log(p)
# B_phi(p, q) = phi(p) - phi(q) - grad_phi(q)^T (p - q)
# grad_phi(q)_x = log(q_x) + 1

def negative_entropy(p):
    return np.sum(p * np.log(p))

def bregman_neg_entropy(p, q):
    """B_phi(p, q) with phi = negative entropy"""
    grad_q = np.log(q) + 1  # gradient of phi at q
    return negative_entropy(p) - negative_entropy(q) - np.dot(grad_q, p - q)

np.random.seed(7)
print('Verifying KL = Bregman divergence of negative entropy:')
print(f"{"p":<30} {"q":<30} {"D_KL":<10} {"Bregman":<10} {"Match"}")
for _ in range(5):
    p = np.random.dirichlet([1,1,1,1])
    q = np.random.dirichlet([1,1,1,1])
    kl  = np.sum(p * np.log(p/q))
    breg = bregman_neg_entropy(p, q)
    match = np.isclose(kl, breg, atol=1e-10)
    print(f"{str(p.round(3)):<30} {str(q.round(3)):<30} {kl:<10.5f} {breg:<10.5f} {match}")

14. Summary Verification Suite

A comprehensive numerical check of all major results covered in this notebook.

Code cell 41

# === 14. Summary Verification Suite ===

import numpy as np

np.random.seed(42)
print('=' * 60)
print('SUMMARY VERIFICATION: KL DIVERGENCE')
print('=' * 60)

results = []

# 1. Non-negativity
p = np.random.dirichlet([2,3,1,2])
q = np.random.dirichlet([1,2,3,1])
kl = np.sum(p * np.log(p/q))
ok1 = kl >= -1e-12
results.append(ok1)
print(f"{'PASS' if ok1 else 'FAIL'} 1. Non-negativity: D_KL = {kl:.4f} >= 0")

# 2. D_KL(p||p) = 0
kl_self = np.sum(p * np.log(p/p))
ok2 = np.isclose(kl_self, 0, atol=1e-12)
results.append(ok2)
print(f"{'PASS' if ok2 else 'FAIL'} 2. D_KL(p||p) = {kl_self:.2e}  (should be 0)")

# 3. H(p,q) = H(p) + D_KL(p||q)
H_p = -np.sum(p * np.log(p))
H_pq = -np.sum(p * np.log(q))
ok3 = np.isclose(H_pq, H_p + kl, atol=1e-10)
results.append(ok3)
print(f"{'PASS' if ok3 else 'FAIL'} 3. H(p,q) = H(p) + KL: {H_pq:.4f} = {H_p:.4f} + {kl:.4f}")

# 4. Data processing inequality
T = np.random.dirichlet([1,1], size=4)  # stochastic kernel 4->2
p4 = np.random.dirichlet([1,1,1,1])
q4 = np.random.dirichlet([1,1,1,1])
kl_orig4 = np.sum(p4 * np.log(p4/q4))
p_T4 = p4 @ T
q_T4 = q4 @ T
kl_T4 = np.sum(p_T4 * np.log(p_T4/q_T4))
ok4 = kl_T4 <= kl_orig4 + 1e-10
results.append(ok4)
print(f"{'PASS' if ok4 else 'FAIL'} 4. DPI: {kl_T4:.4f} <= {kl_orig4:.4f}")

# 5. Gaussian KL formula
mu1, s1, mu2, s2 = 1.0, 1.5, 0.0, 1.0
kl_gauss_formula = np.log(s2/s1) + (s1**2 + (mu1-mu2)**2)/(2*s2**2) - 0.5
# Numerical: integrate
x = np.linspace(-10, 10, 10000)
dx = x[1]-x[0]
p_pdf = np.exp(-0.5*((x-mu1)/s1)**2)/(s1*np.sqrt(2*np.pi))
q_pdf = np.exp(-0.5*((x-mu2)/s2)**2)/(s2*np.sqrt(2*np.pi))
kl_gauss_num = np.sum(p_pdf * np.log(p_pdf/q_pdf) * dx)
ok5 = np.isclose(kl_gauss_formula, kl_gauss_num, atol=1e-4)
results.append(ok5)
print(f"{'PASS' if ok5 else 'FAIL'} 5. Gaussian KL: formula={kl_gauss_formula:.4f}, numerical={kl_gauss_num:.4f}")

# 6. Pinsker's inequality
tv = 0.5 * np.sum(np.abs(p - q))
ok6 = tv**2 <= 0.5 * kl + 1e-10
results.append(ok6)
print(f"{'PASS' if ok6 else 'FAIL'} 6. Pinsker: TV^2={tv**2:.4f} <= 0.5*KL={0.5*kl:.4f}")

print()
n_pass = sum(results)
print(f'Results: {n_pass}/{len(results)} checks passed')
print('All checks passed!' if all(results) else 'Some checks failed!')

15. Exponential Family KL as Bregman Divergence

For exponential family pη(x)=h(x)exp(ηt(x)A(η))p_{\boldsymbol{\eta}}(x) = h(x)\exp(\boldsymbol{\eta}^\top \mathbf{t}(x) - A(\boldsymbol{\eta})), the KL divergence equals the Bregman divergence of the log-partition function:

DKL(pη1pη2)=A(η2)A(η1)A(η1)(η2η1)D_{\mathrm{KL}}(p_{\boldsymbol{\eta}_1}\|p_{\boldsymbol{\eta}_2}) = A(\boldsymbol{\eta}_2) - A(\boldsymbol{\eta}_1) - \nabla A(\boldsymbol{\eta}_1)^\top(\boldsymbol{\eta}_2 - \boldsymbol{\eta}_1)

This unifies Bernoulli, Gaussian, Poisson, and other families under one formula.

Code cell 43

# === 15.1 Exponential Family KL: Bernoulli Case ===

# Bernoulli: eta = log(p/(1-p)) (log-odds), A(eta) = log(1 + e^eta)
# p = sigmoid(eta), t(x) = x

def bernoulli_A(eta):
    return np.log1p(np.exp(eta))

def bernoulli_dA(eta):
    return np.exp(eta) / (1 + np.exp(eta))  # sigmoid = mean parameter

def kl_bernoulli_expfam(eta1, eta2):
    """Bregman div: A(eta2) - A(eta1) - dA(eta1)*(eta2-eta1)"""
    return bernoulli_A(eta2) - bernoulli_A(eta1) - bernoulli_dA(eta1)*(eta2-eta1)

def kl_bernoulli_direct(p1, p2):
    """Direct: p1*log(p1/p2) + (1-p1)*log((1-p1)/(1-p2))"""
    return p1*np.log(p1/p2) + (1-p1)*np.log((1-p1)/(1-p2))

print('Bernoulli KL via Bregman (exp family) vs direct formula:')
test_pairs = [(0.2, 0.5), (0.8, 0.3), (0.6, 0.4), (0.1, 0.9)]
for p1, p2 in test_pairs:
    eta1 = np.log(p1/(1-p1))
    eta2 = np.log(p2/(1-p2))
    kl_bregman = kl_bernoulli_expfam(eta1, eta2)
    kl_direct  = kl_bernoulli_direct(p1, p2)
    match = np.isclose(kl_bregman, kl_direct, atol=1e-10)
    print(f'p1={p1}, p2={p2}: Bregman={kl_bregman:.5f}, Direct={kl_direct:.5f}, match={match}')
print('\nPASS - Exponential family Bregman = direct KL formula')

Code cell 44

# === 15.2 Exponential Family KL: Gaussian Case ===

# Gaussian(mu, sigma^2): eta = (mu/sigma^2, -1/(2*sigma^2))
# A(eta) = -eta1^2/(4*eta2) - 0.5*log(-2*eta2)
# (using canonical parameterization)

def gaussian_A_canonical(eta1, eta2):
    """Log-partition for Gaussian in natural params (eta1, eta2) where eta2 < 0"""
    return -eta1**2 / (4*eta2) - 0.5*np.log(-2*eta2)

def gaussian_kl_bregman(mu1, s1, mu2, s2):
    """KL via exponential family Bregman formula"""
    eta1_p = mu1/s1**2;  eta2_p = -1/(2*s1**2)
    eta1_q = mu2/s2**2;  eta2_q = -1/(2*s2**2)
    # Gradient of A at (eta1_p, eta2_p)
    dA_eta1 = -eta1_p/(2*eta2_p)  # = mu1
    dA_eta2 = eta1_p**2/(4*eta2_p**2) - 1/(2*eta2_p)  # = mu1^2 + sigma1^2
    A_q = gaussian_A_canonical(eta1_q, eta2_q)
    A_p = gaussian_A_canonical(eta1_p, eta2_p)
    return (A_q - A_p
            - dA_eta1*(eta1_q - eta1_p)
            - dA_eta2*(eta2_q - eta2_p))

def gaussian_kl_formula(mu1, s1, mu2, s2):
    return np.log(s2/s1) + (s1**2 + (mu1-mu2)**2)/(2*s2**2) - 0.5

print('Gaussian KL via exp-family Bregman vs direct formula:')
cases = [(1,1,0,1), (0,2,0,1), (2,0.5,1,1.5)]
for mu1,s1,mu2,s2 in cases:
    kl_b = gaussian_kl_bregman(mu1,s1,mu2,s2)
    kl_f = gaussian_kl_formula(mu1,s1,mu2,s2)
    match = np.isclose(kl_b, kl_f, atol=1e-8)
    print(f'N({mu1},{s1}^2)||N({mu2},{s2}^2): Bregman={kl_b:.5f}, Formula={kl_f:.5f}, match={match}')
print('\nPASS - Gaussian KL = Bregman of log-partition function')

16. Posterior Collapse in VAEs

Posterior collapse occurs when the decoder is powerful enough to reconstruct x\mathbf{x} without using the latent z\mathbf{z}. The KL term is then driven to zero: qϕ(zx)p(z)=N(0,1)q_\phi(\mathbf{z}|\mathbf{x}) \to p(\mathbf{z}) = \mathcal{N}(0,1).

We simulate this with a decoder of varying capacity and show the KL annealing fix.

Code cell 46

# === 16.1 Posterior Collapse Simulation ===

# Toy 1D VAE: encoder q(z|x) = N(mu_phi*x, sigma^2)
# Decoder: p(x|z) = N(z, decoder_var)
# ELBO = E_q[log p(x|z)] - D_KL(q||N(0,1))

def elbo_1d(x, alpha, log_var_enc=-0.5, decoder_var=0.1, beta=1.0):
    """ELBO for 1D VAE with mu_phi(x) = alpha*x"""
    mu_enc = alpha * x
    sigma_enc_sq = np.exp(log_var_enc)
    kl = 0.5 * (mu_enc**2 + sigma_enc_sq - log_var_enc - 1)
    # E_q[log p(x|z)] approx by MC
    np.random.seed(42)
    eps = np.random.randn(1000)
    z = mu_enc + np.sqrt(sigma_enc_sq) * eps
    recon = np.mean(-0.5*np.log(2*np.pi*decoder_var) - (x-z)**2/(2*decoder_var))
    return recon - beta * kl, recon, kl

x_test = 2.0
alphas = np.linspace(0, 1.5, 50)

print('=== Posterior Collapse: Powerful vs Weak Decoder ===')
print()
for decoder_var, label in [(0.01, 'Strong decoder (low var)'), (1.0, 'Weak decoder (high var)')]:
    elbos_beta = [elbo_1d(x_test, a, decoder_var=decoder_var, beta=1.0)[0] for a in alphas]
    best_alpha = alphas[np.argmax(elbos_beta)]
    _, _, kl_opt = elbo_1d(x_test, best_alpha, decoder_var=decoder_var)
    collapsed = 'YES (collapsed!)' if kl_opt < 0.01 else f'No (KL={kl_opt:.3f})'
    print(f"{label}:")
    print(f'  Best alpha={best_alpha:.3f}, KL={kl_opt:.4f}, Collapse: {collapsed}')

# KL annealing: start with beta=0, increase to 1
print()
print('=== KL Annealing Fix ===')
for beta in [0.0, 0.1, 0.5, 1.0]:
    elbos_b = [elbo_1d(x_test, a, decoder_var=0.01, beta=beta)[0] for a in alphas]
    best_a = alphas[np.argmax(elbos_b)]
    _, _, kl_b = elbo_1d(x_test, best_a, decoder_var=0.01, beta=beta)
    print(f'  beta={beta:.1f}: best_alpha={best_a:.3f}, KL={kl_b:.4f}')

17. DPO: Computing the Implicit Reward

DPO reparameterizes the RLHF objective so the implicit reward is:

rθ(x,y)=βlogπθ(yx)πref(yx)r_{\boldsymbol{\theta}}(x,y) = \beta\log\frac{\pi_{\boldsymbol{\theta}}(y|x)}{\pi_{\mathrm{ref}}(y|x)}

We verify that the DPO loss gradient pushes up preferred responses and down dispreferred ones.

Code cell 48

# === 17.1 DPO Implicit Reward and Loss ===

from scipy.special import expit  # sigmoid

# Toy sequence: 4 tokens, vocabulary 3
# Log-probs under reference and two policy versions
np.random.seed(1)
T = 4  # sequence length

# Reference log-probs
logp_ref_w = np.random.randn(T).sum()  # sum of log-probs for winner
logp_ref_l = np.random.randn(T).sum()  # sum of log-probs for loser

# Policy log-probs (before DPO training)
logp_policy_w = logp_ref_w + 0.1  # slightly better than ref on winner
logp_policy_l = logp_ref_l + 0.1  # also slightly better on loser

beta = 0.1

def dpo_loss(logp_pol_w, logp_pol_l, logp_ref_w, logp_ref_l, beta=0.1):
    """DPO loss for one preference pair"""
    reward_w = beta * (logp_pol_w - logp_ref_w)
    reward_l = beta * (logp_pol_l - logp_ref_l)
    return -np.log(expit(reward_w - reward_l))

loss_before = dpo_loss(logp_policy_w, logp_policy_l, logp_ref_w, logp_ref_l, beta)
print(f'DPO Setup:')
print(f'  ref logp(winner)  = {logp_ref_w:.4f}')
print(f'  ref logp(loser)   = {logp_ref_l:.4f}')
print(f'  pol logp(winner)  = {logp_policy_w:.4f}')
print(f'  pol logp(loser)   = {logp_policy_l:.4f}')
print(f'  Implicit reward (winner): beta*(logpol-logref) = {beta*(logp_policy_w-logp_ref_w):.4f}')
print(f'  Implicit reward (loser):  beta*(logpol-logref) = {beta*(logp_policy_l-logp_ref_l):.4f}')
print(f'  DPO loss: {loss_before:.4f}')

# After training: winner logprob increases, loser decreases
logp_policy_w_trained = logp_ref_w + 0.8
logp_policy_l_trained = logp_ref_l - 0.3
loss_after = dpo_loss(logp_policy_w_trained, logp_policy_l_trained, logp_ref_w, logp_ref_l, beta)
print(f'\nAfter training (winner reward up, loser reward down):')
print(f'  DPO loss: {loss_after:.4f}  (lower = better)')
print(f'  Improvement: {loss_before - loss_after:.4f} nats')

18. Quick Reference: KL Divergence Formulas

Distribution pairDKL(pq)D_{\mathrm{KL}}(p\|q)
Bernoulli p,qp, qplnpq+(1p)ln1p1qp\ln\frac{p}{q} + (1-p)\ln\frac{1-p}{1-q}
N(μ1,σ12)\mathcal{N}(\mu_1,\sigma_1^2) vs N(μ2,σ22)\mathcal{N}(\mu_2,\sigma_2^2)lnσ2σ1+σ12+(μ1μ2)22σ2212\ln\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
N(μ,σ2)\mathcal{N}(\mu,\sigma^2) vs N(0,1)\mathcal{N}(0,1)12(μ2+σ2lnσ21)\frac{1}{2}(\mu^2 + \sigma^2 - \ln\sigma^2 - 1)
Pois(λ1)\mathrm{Pois}(\lambda_1) vs Pois(λ2)\mathrm{Pois}(\lambda_2)λ1ln(λ1/λ2)λ1+λ2\lambda_1\ln(\lambda_1/\lambda_2) - \lambda_1 + \lambda_2
Categoricalkpkln(pk/qk)\sum_k p_k\ln(p_k/q_k)

Key inequalities:

  • Pinsker's: TV(p,q)12DKL(pq)\mathrm{TV}(p,q) \le \sqrt{\frac{1}{2}D_{\mathrm{KL}}(p\|q)}
  • Chain rule: DKL(P(X,Y)Q(X,Y))=DKL(PXQX)+E[DKL(PYXQYX)]D_{\mathrm{KL}}(P(X,Y)\|Q(X,Y)) = D_{\mathrm{KL}}(P_X\|Q_X) + \mathbb{E}[D_{\mathrm{KL}}(P_{Y|X}\|Q_{Y|X})]
  • ELBO: logpθ(x)Eq[logpθ(xz)]DKL(qϕp)\log p_{\theta}(\mathbf{x}) \ge \mathbb{E}_q[\log p_{\theta}(\mathbf{x}|\mathbf{z})] - D_{\mathrm{KL}}(q_{\phi}\|p)

Code cell 50

# === 18. All Closed-Form KL Formulas: Verification ===

print('=== Closed-Form KL Formulas Verification ===')

# Bernoulli
p1_b, p2_b = 0.7, 0.4
kl_bern = p1_b*np.log(p1_b/p2_b) + (1-p1_b)*np.log((1-p1_b)/(1-p2_b))
print(f'Bernoulli Bern({p1_b}) || Bern({p2_b}): {kl_bern:.6f}')

# Gaussian scalar
kl_gauss = np.log(1.0/1.5) + (1.5**2 + (2.0-0.0)**2)/(2*1.0**2) - 0.5
print(f'Gaussian N(2,2.25) || N(0,1): {kl_gauss:.6f}')

# VAE
mu, sigma2 = 1.5, 2.0
kl_vae = 0.5*(mu**2 + sigma2 - np.log(sigma2) - 1)
print(f'VAE N({mu},{sigma2}) || N(0,1): {kl_vae:.6f}')

# Poisson
lam1, lam2 = 3.0, 2.0
kl_pois = lam1*np.log(lam1/lam2) - lam1 + lam2
print(f'Poisson Pois({lam1}) || Pois({lam2}): {kl_pois:.6f}')

# All non-negative
all_nonneg = all(k >= 0 for k in [kl_bern, kl_gauss, kl_vae, kl_pois])
print(f'\nAll KL values non-negative: {all_nonneg}')
print('PASS - All closed-form formulas verified')

References

  1. Kullback & Leibler (1951). 'On Information and Sufficiency.' Ann. Math. Stat.
  2. Cover & Thomas (2006). Elements of Information Theory, 2nd ed. Ch. 2.
  3. MacKay (2003). Information Theory, Inference, and Learning Algorithms. Ch. 2.
  4. Bishop (2006). Pattern Recognition and Machine Learning. Ch. 10.
  5. Kingma & Welling (2014). 'Auto-Encoding Variational Bayes.' ICLR.
  6. Rafailov et al. (2023). 'Direct Preference Optimization.' NeurIPS.
  7. Mironov (2017). 'Rényi Differential Privacy.' IEEE CSF.

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue