Theory LabMath for LLMs

Quantization And Distillation

Math For LLMs / Quantization And Distillation

Run lab
Concept Lesson
Advanced
13 min

Learning Objective

Understand Quantization And Distillation well enough to explain it, recognize it in Math for LLMs, and apply it in a small task.

Why It Matters

Quantization And Distillation gives you the math vocabulary behind model behavior, optimization, and LLM reasoning.

LabAndDistillationAffine QuantizationSymmetric Int4 Quantization
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Lab
1 min read18 headings

Theory Lab

Runnable lab version for web reading.

Quantization and Distillation: Theory Notebook

This notebook makes compression math concrete: uniform quantization, clipping, group-wise scales, error versus bits, distillation temperature, KL loss, and memory accounting.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Affine quantization

Code cell 4

def quantize_affine(x, qmin, qmax):
    x = np.asarray(x, dtype=float)
    scale = (x.max() - x.min()) / (qmax - qmin)
    zero = qmin - np.round(x.min() / scale)
    q = np.round(x / scale + zero)
    q = np.clip(q, qmin, qmax)
    x_hat = scale * (q - zero)
    return q, x_hat, scale, zero

x = np.array([-1.0, -0.3, 0.2, 0.9])
q, x_hat, s, z = quantize_affine(x, 0, 15)
print("q:", q.astype(int))
print("x_hat:", np.round(x_hat, 4))
print("scale:", s, "zero:", z)
print("max error:", np.max(np.abs(x - x_hat)))

2. Symmetric INT4 quantization

Code cell 6

def quantize_symmetric(x, bits):
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)
    scale = np.max(np.abs(x)) / qmax
    q = np.clip(np.round(x / scale), qmin, qmax)
    return q, q * scale, scale

rng = np.random.default_rng(2)
w = rng.normal(size=16)
q, w_hat, scale = quantize_symmetric(w, 4)
print("scale:", scale)
print("MSE:", np.mean((w - w_hat) ** 2))
print("unique integer values:", np.unique(q).astype(int))

3. Error versus bit width

Code cell 8

bits_list = np.arange(2, 9)
errors = []
for bits in bits_list:
    _, w_hat, _ = quantize_symmetric(w, bits)
    errors.append(np.mean((w - w_hat) ** 2))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(bits_list, errors, marker="o")
ax.set_title("Quantization error decreases with bit width")
ax.set_xlabel("bits")
ax.set_ylabel("MSE")
fig.tight_layout()
plt.show()
print("errors:", np.round(errors, 6))

4. Per-channel quantization

Code cell 10

W = np.vstack([
    np.random.normal(scale=0.1, size=8),
    np.random.normal(scale=1.0, size=8),
    np.random.normal(scale=3.0, size=8),
])
_, global_hat, _ = quantize_symmetric(W, 4)
per_hat = np.zeros_like(W)
for i in range(W.shape[0]):
    _, per_hat[i], _ = quantize_symmetric(W[i], 4)
print("global MSE:", np.mean((W - global_hat) ** 2))
print("per-channel MSE:", np.mean((W - per_hat) ** 2))

5. Clipping tradeoff

Code cell 12

x = np.r_[np.random.normal(scale=0.5, size=1000), np.array([5.0, -4.5])]
clips = np.linspace(0.5, 5.0, 20)
mse = []
for c in clips:
    clipped = np.clip(x, -c, c)
    _, x_hat, _ = quantize_symmetric(clipped, 4)
    mse.append(np.mean((x - x_hat) ** 2))
best = int(np.argmin(mse))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(clips, mse, marker="o")
ax.scatter([clips[best]], [mse[best]], color="red")
ax.set_title("Clipping range tradeoff")
ax.set_xlabel("clip range")
ax.set_ylabel("MSE to original")
fig.tight_layout()
plt.show()
print("best clip:", clips[best], "MSE:", mse[best])

6. Weight memory by precision

Code cell 14

params = 7e9
for bits in [16, 8, 4, 3, 2]:
    gb = params * bits / 8 / 1e9
    print(f"{bits:>2}-bit weights: {gb:.2f} GB")

7. SmoothQuant-style scale shifting intuition

Code cell 16

X = np.array([[10.0, 0.2, 0.1], [8.0, -0.1, 0.2]])
W = np.array([[0.1, 2.0, -1.5], [0.2, -1.0, 1.0]])
scale = np.array([4.0, 1.0, 1.0])
Y_original = X @ W.T
X_scaled = X / scale
W_scaled = W * scale
Y_scaled = X_scaled @ W_scaled.T
print("max matmul difference after exact scale shift:", np.max(np.abs(Y_original - Y_scaled)))
print("activation max before:", np.max(np.abs(X), axis=0))
print("activation max after: ", np.max(np.abs(X_scaled), axis=0))

8. Temperature softens teacher probabilities

Code cell 18

def softmax(z):
    z = np.asarray(z, dtype=float)
    z = z - z.max()
    e = np.exp(z)
    return e / e.sum()

teacher_logits = np.array([6.0, 3.0, 1.0, -1.0])
for tau in [1.0, 2.0, 4.0]:
    p = softmax(teacher_logits / tau)
    print(f"tau={tau}: probs={np.round(p, 3)}, entropy={-np.sum(p*np.log(p+1e-12)):.3f}")

9. KL distillation loss

Code cell 20

teacher = softmax(np.array([3.0, 1.0, 0.0]) / 2.0)
student = softmax(np.array([2.2, 1.4, -0.1]) / 2.0)
tau = 2.0
kl = np.sum(teacher * (np.log(teacher + 1e-12) - np.log(student + 1e-12)))
kd_loss = tau**2 * kl
print("teacher:", np.round(teacher, 3))
print("student:", np.round(student, 3))
print("KD loss:", kd_loss)

10. Combine hard and soft losses

Code cell 22

hard_ce = 0.8
kd = 0.35
alpha = 0.4
combined = alpha * hard_ce + (1 - alpha) * kd
print("combined distillation objective:", combined)

11. QLoRA memory intuition

Code cell 24

base_params = 7e9
lora_params = 40e6
base_bits = 4
lora_bits = 16
adam_bits_per_param = 96
base_gb = base_params * base_bits / 8 / 1e9
lora_gb = lora_params * lora_bits / 8 / 1e9
adam_gb = lora_params * adam_bits_per_param / 8 / 1e9
print("quantized base GB:", base_gb)
print("LoRA weights GB:", lora_gb)
print("LoRA Adam states GB:", adam_gb)

12. Compression checklist

Code cell 26

checks = [
    "state what is quantized: weights, activations, or KV cache",
    "state granularity: tensor, channel, or group",
    "calibration data matches deployment prompts",
    "compare logits, held-out loss, task score, calibration, memory, and latency",
    "verify serving kernels support the selected format",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue