Theory Notebook
Converted from
theory.ipynbfor web reading.
Number Systems β From Bits to LLM Training
Every number in a computer is a compromise between range, precision, and cost. Understanding these tradeoffs is the foundation of efficient AI engineering.
This notebook is the interactive companion to notes.md. It demonstrates the key concepts from all 17 sections with runnable Python, NumPy, and PyTorch code.
| Section | Topic | What You'll Build |
|---|---|---|
| 1 | Positional Number Systems | Binary/hex converters, two's complement, fixed-point |
| 2 | IEEE 754 Deep Dive | Bit-level float decoder, special values, epsilon demo |
| 3 | Floating-Point Formats for AI | BF16/FP16/FP8/TF32 comparison, range/precision analysis |
| 4 | Integer Formats & Quantization | INT8/INT4 quantization, per-channel vs per-tensor |
| 5 | Non-Uniform Formats | NF4 quantile levels, ternary weight simulation |
| 6 | Floating-Point Arithmetic | Catastrophic cancellation, Kahan summation, FMA |
| 7 | Numerical Stability | Stable softmax, log-sum-exp, RMSNorm vs LayerNorm |
| 8 | Quantization Mathematics | SQNR, group quantization, Lloyd-Max, Hadamard transform |
| 9 | Mixed Precision Training | Full mixed-precision pipeline, BF16 precision limits |
| 10 | Hardware & Memory Analysis | Arithmetic intensity, memory budget calculator |
| 11 | Training Stability | Stochastic rounding, Adam errors, attention logit growth |
| 12 | Practical Guide | Format selector, per-layer sensitivity, error propagation |
Prerequisites: Python, NumPy. PyTorch optional but recommended.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
import numpy as np
import struct
import warnings
from typing import Tuple, List
np.random.seed(42)
np.set_printoptions(precision=8, suppress=True)
# Optional: PyTorch for real implementations
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)
HAS_TORCH = True
print(f'NumPy {np.__version__} | PyTorch {torch.__version__}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name()}')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
print('Device: Apple Silicon (MPS)')
else:
print('Device: CPU')
except ImportError:
HAS_TORCH = False
print(f'NumPy {np.__version__} | PyTorch not installed (NumPy demos still work)')
1. Positional Number Systems
Every number system uses position to determine value. A digit at position in base contributes :
Why this matters for AI
- Binary (base 2): how every value is stored in hardware
- Hexadecimal (base 16): how we inspect memory and weights
- Two's complement: how INT8/INT4 quantization encodes negative values
- Fixed-point: an alternative to floating-point used in some accelerators
Code cell 5
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1.1 Base Conversion Engine
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def decimal_to_base(n: int, base: int) -> str:
"""Convert a non-negative integer to any base (2-16)."""
if n == 0:
return '0'
digits = '0123456789ABCDEF'
result = []
while n > 0:
result.append(digits[n % base])
n //= base
return ''.join(reversed(result))
def show_positional_breakdown(n: int, base: int):
"""Show how positional notation builds a number."""
rep = decimal_to_base(n, base)
print(f' {n} in base {base}: {rep}')
terms = []
for i, d in enumerate(reversed(rep)):
val = int(d, 16) # handles hex digits
if val > 0:
terms.append(f'{d}Γ{base}^{i}={val * base**i}')
print(f' = {" + ".join(terms)} = {n}')
# Demonstrate the universality of positional notation
print('Positional Number System Demonstrations:')
print('=' * 55)
for val in [42, 255, 1024]:
for base in [2, 10, 16]:
show_positional_breakdown(val, base)
print()
Code cell 6
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1.2 Two's Complement β How INT8/INT4 Work
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def twos_complement(n: int, bits: int = 8) -> str:
"""Show two's complement representation of a signed integer."""
if n >= 0:
return f'{n:0{bits}b}'
else:
# Two's complement: flip bits of |n|, add 1
pos_bits = f'{abs(n):0{bits}b}'
flipped = ''.join('1' if b == '0' else '0' for b in pos_bits)
result = int(flipped, 2) + 1
return f'{result:0{bits}b}'
print("Two's Complement β the encoding behind INT8 quantization")
print('=' * 60)
print(f'{"Decimal":>8} {"8-bit Binary":>12} {"Hex":>5} Step-by-Step')
print('-' * 60)
for val in [42, -42, 127, -128, 0, 1, -1]:
bits = twos_complement(val)
hex_val = f'{int(bits, 2):02X}'
if val < 0:
step = f'flip {abs(val):08b} β {".".join("1" if b=="0" else "0" for b in f"{abs(val):08b}")} + 1'
else:
step = 'direct binary'
print(f'{val:>8} {bits:>12} 0x{hex_val:>3} {step}')
# Show the critical ranges for quantization
print(f'\nQuantization ranges:')
for bits in [8, 4, 2, 1]:
signed_min = -(2**(bits-1))
signed_max = 2**(bits-1) - 1
unsigned_max = 2**bits - 1
print(f' INT{bits}: [{signed_min:>5}, {signed_max:>4}] ({2**bits} levels)'
f' UINT{bits}: [0, {unsigned_max:>3}] ({2**bits} levels)')
Code cell 7
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1.3 Fixed-Point Representation (Q-format)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def float_to_fixed(value: float, int_bits: int = 3, frac_bits: int = 4) -> Tuple[int, str]:
"""Convert float to fixed-point Q{int_bits}.{frac_bits} format."""
scale = 2 ** frac_bits
total_bits = 1 + int_bits + frac_bits # sign + integer + fraction
q = int(round(value * scale))
# Clamp to representable range
max_val = 2**(total_bits - 1) - 1
min_val = -(2**(total_bits - 1))
q = max(min_val, min(max_val, q))
# Show binary
if q < 0:
binary = f'{(2**total_bits + q):0{total_bits}b}'
else:
binary = f'{q:0{total_bits}b}'
formatted = f'{binary[0]}|{binary[1:1+int_bits]}.{binary[1+int_bits:]}'
return q, formatted
def fixed_to_float(q: int, frac_bits: int = 4) -> float:
return q / (2 ** frac_bits)
print('Fixed-Point Q3.4 Format (1 sign + 3 int + 4 frac = 8 bits)')
print('=' * 60)
print(f'{"Value":>8} {"Q3.4 Binary":>14} {"Dequantized":>12} {"Error":>8}')
print('-' * 60)
for val in [2.75, -3.5, 0.0625, 7.9375, -8.0, 0.1, np.pi]:
q, binary = float_to_fixed(val)
recon = fixed_to_float(q)
error = abs(val - recon)
print(f'{val:>8.4f} {binary:>14} {recon:>12.4f} {error:>8.4f}')
print(f'\nQ3.4 properties:')
print(f' Range: [{fixed_to_float(-128)}, {fixed_to_float(127)}]')
print(f' Resolution: {1/16} = 2^(-4)')
print(f' Compare: FP32 has variable resolution (higher near 0, lower far from 0)')
2. IEEE 754 Floating-Point β Deep Dive
Every float is:
FP32: [S|EEEEEEEE|MMMMMMMMMMMMMMMMMMMMMMM] 1+8+23 = 32 bits
FP16: [S|EEEEE|MMMMMMMMMM] 1+5+10 = 16 bits
BF16: [S|EEEEEEEE|MMMMMMM] 1+8+7 = 16 bits
FP8: [S|EEEE|MMM] (E4M3) or [S|EEEEE|MM] (E5M2) = 8 bits
Code cell 9
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2.1 IEEE 754 Bit-Level Decoder
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def decode_fp32(value: float) -> dict:
"""Fully decode an FP32 value into its IEEE 754 components."""
packed = struct.pack('>f', value)
bits = ''.join(f'{byte:08b}' for byte in packed)
sign_bit = int(bits[0])
exp_bits = bits[1:9]
man_bits = bits[9:]
biased_exp = int(exp_bits, 2)
true_exp = biased_exp - 127
# Compute mantissa value
mantissa_val = sum(int(b) * 2**(-i-1) for i, b in enumerate(man_bits))
# Determine special cases
if biased_exp == 0:
if mantissa_val == 0:
category = 'Zero'
else:
category = 'Subnormal'
elif biased_exp == 255:
if mantissa_val == 0:
category = 'Infinity'
else:
category = 'NaN'
else:
category = 'Normal'
return {
'value': value,
'bits': f'{bits[0]} | {exp_bits} | {man_bits}',
'hex': packed.hex(),
'sign': sign_bit,
'biased_exp': biased_exp,
'true_exp': true_exp,
'mantissa_val': mantissa_val,
'implicit_1_plus_m': 1.0 + mantissa_val,
'category': category,
'formula': f'(-1)^{sign_bit} Γ {1+mantissa_val:.6f} Γ 2^{true_exp}'
}
# Decode a series of important values
print('IEEE 754 FP32 Decoder')
print('=' * 80)
test_values = [5.0, -13.625, 0.1, 1.0, 0.0, float('inf'), float('-inf'), float('nan')]
for val in test_values:
d = decode_fp32(val)
print(f'\n Value: {val}')
print(f' Bits: {d["bits"]}')
print(f' Hex: 0x{d["hex"]}')
print(f' Type: {d["category"]}')
if d['category'] == 'Normal':
print(f' Sign={d["sign"]}, Exp={d["biased_exp"]}-127={d["true_exp"]}, '
f'1+M={d["implicit_1_plus_m"]:.6f}')
print(f' = {d["formula"]} = {val}')
Code cell 10
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2.2 Machine Epsilon & Precision Limits
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Machine Epsilon β the fundamental precision limit')
print('=' * 70)
print(f'{"Format":<12} {"Ξ΅":>14} {"Decimal digits":>15} {"Relative precision":>20}')
print('-' * 70)
for dtype, name in [(np.float16, 'FP16'), (np.float32, 'FP32'), (np.float64, 'FP64')]:
eps = np.finfo(dtype).eps
digits = int(-np.log10(eps))
print(f'{name:<12} {eps:>14.2e} {digits:>15} {f"1 part in {int(1/eps):,}":>20}')
# BF16 epsilon (manually β not in NumPy)
bf16_eps = 2**-7 # 7 mantissa bits
print(f'{"BF16":<12} {bf16_eps:>14.2e} {"~2":>15} {f"1 part in {int(1/bf16_eps):,}":>20}')
fp8_eps = 2**-3 # E4M3: 3 mantissa bits
print(f'{"FP8 E4M3":<12} {fp8_eps:>14.2e} {"~1":>15} {f"1 part in {int(1/fp8_eps):,}":>20}')
# Demonstrate precision loss in gradient accumulation
print('\n--- Precision Loss Demo: Adding small to large ---')
print(f'{"":<14} {"1.0 + Ξ΅":>14} {"1.0 + Ξ΅/2":>14} {"Ξ΅/2 lost?":>10}')
print('-' * 56)
for dtype in [np.float16, np.float32, np.float64]:
eps = np.finfo(dtype).eps
one = dtype(1.0)
result1 = float(one + dtype(eps))
result2 = float(one + dtype(eps / 2))
lost = result2 == 1.0
print(f'{dtype.__name__:<14} {result1:>14.10f} {result2:>14.10f} {"YES":>10}' if lost
else f'{dtype.__name__:<14} {result1:>14.10f} {result2:>14.10f} {"no":>10}')
print('\nβ In BF16 training: any gradient < 0.78% of the running weight sum is silently lost!')
print('β This is why FP32 master weights are mandatory for stable training.')
Code cell 11
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2.3 Floating-Point Arithmetic is NOT Associative
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# The classic proof from Β§3.4 of notes.md
a = np.float32(1e8)
b = np.float32(1.0)
c = np.float32(-1e8)
left = np.float32(np.float32(a + b) + c) # (a + b) + c
right = np.float32(a + np.float32(b + c)) # a + (b + c)
print('Floating-Point Associativity Failure')
print('=' * 55)
print(f' a = {a:.0f}, b = {b:.0f}, c = {c:.0f}')
print(f' (a + b) + c = ({a:.0f} + {b:.0f}) + {c:.0f}')
print(f' = {np.float32(a+b):.0f} + {c:.0f}') # b is absorbed!
print(f' = {left:.0f} β b was absorbed into a!')
print(f' a + (b + c) = {a:.0f} + ({b:.0f} + {c:.0f})')
print(f' = {a:.0f} + {np.float32(b+c):.0f}')
print(f' = {right:.0f} β correct!')
print(f'\n (a+b)+c = {left} β a+(b+c) = {right}')
print(f' β Associativity FAILS in floating-point!')
print(f' β This is why summation order matters for gradient accumulation.')
Code cell 12
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2.4 Special Values: Subnormals, Infinity, NaN
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('IEEE 754 Special Values')
print('=' * 65)
# Subnormal numbers β the "gradual underflow" zone
print('\n--- Subnormal Numbers (gradual underflow) ---')
normal_min = np.finfo(np.float32).tiny # smallest normal
subnormal = normal_min / 2 # enters subnormal territory
print(f' Smallest normal FP32: {normal_min:.6e}')
print(f' Half of that (subnormal): {subnormal:.6e}')
print(f' Is subnormal? biased_exp == 0: ', end='')
d = decode_fp32(subnormal)
print(f'{d["category"]} (biased_exp = {d["biased_exp"]})')
# Infinity arithmetic
print('\n--- Infinity Arithmetic ---')
inf = float('inf')
for expr, result in [
('inf + 1', inf + 1),
('inf + inf', inf + inf),
('inf * 2', inf * 2),
('1 / inf', 1 / inf),
('inf / inf', inf / inf), # NaN
('inf - inf', inf - inf), # NaN
('0 * inf', 0 * inf), # NaN
]:
print(f' {expr:>12} = {result}')
# NaN propagation β the training killer
print('\n--- NaN Propagation (why one NaN kills training) ---')
nan = float('nan')
print(f' nan + 1 = {nan + 1}')
print(f' nan * 0 = {nan * 0}')
print(f' nan == nan: {nan == nan} β NaN is not equal to itself!')
print(f' np.isnan(nan): {np.isnan(nan)} β use this to detect NaN')
print(f' β If a single weight becomes NaN, ALL subsequent matmuls produce NaN')
print(f' β The entire model is corrupted in one forward pass')
3. Floating-Point Formats for AI
The AI industry uses a zoo of number formats. Each trades precision for speed/memory:
| Format | Bits | Exp | Mantissa | Range | ML Role |
|---|---|---|---|---|---|
| FP64 | 64 | 11 | 52 | Β±10Β³β°βΈ | Scientific computing |
| FP32 | 32 | 8 | 23 | Β±3.4Γ10Β³βΈ | Master weights, optimizer |
| TF32 | 19 | 8 | 10 | Β±3.4Γ10Β³βΈ | Auto matmul on A100+ |
| BF16 | 16 | 8 | 7 | Β±3.4Γ10Β³βΈ | Training default (2020+) |
| FP16 | 16 | 5 | 10 | Β±65504 | Legacy training, inference |
| FP8 E4M3 | 8 | 4 | 3 | Β±448 | Forward pass (H100+) |
| FP8 E5M2 | 8 | 5 | 2 | Β±57344 | Backward pass (H100+) |
Code cell 14
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3.1 Format Range & Precision Comparison
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Define format properties (exp_bits, man_bits, max_exp_bias)
formats = {
'FP64': {'exp': 11, 'man': 52, 'max_val': 1.8e308, 'eps': 2**-52, 'bytes': 8},
'FP32': {'exp': 8, 'man': 23, 'max_val': 3.4e38, 'eps': 2**-23, 'bytes': 4},
'TF32': {'exp': 8, 'man': 10, 'max_val': 3.4e38, 'eps': 2**-10, 'bytes': 2.375},
'BF16': {'exp': 8, 'man': 7, 'max_val': 3.4e38, 'eps': 2**-7, 'bytes': 2},
'FP16': {'exp': 5, 'man': 10, 'max_val': 65504, 'eps': 2**-10, 'bytes': 2},
'FP8 E4M3': {'exp': 4, 'man': 3, 'max_val': 448, 'eps': 2**-3, 'bytes': 1},
'FP8 E5M2': {'exp': 5, 'man': 2, 'max_val': 57344, 'eps': 2**-2, 'bytes': 1},
}
print(f'{"Format":<12} {"Bits":>5} {"Max Value":>12} {"Epsilon":>12} {"Precision":>12} {"70B Model":>10}')
print('=' * 75)
for name, f in formats.items():
total_bits = 1 + f['exp'] + f['man']
mem_70b = 70e9 * f['bytes'] / 1e9
precision = f'~{int(-np.log10(f["eps"]))} digits'
print(f'{name:<12} {total_bits:>5} {f["max_val"]:>12.2e} {f["eps"]:>12.2e} '
f'{precision:>12} {mem_70b:>8.1f} GB')
print(f'\nKey insight: BF16 has the SAME range as FP32 (8-bit exponent) but only ~2 digits precision.')
print(f'This is why BF16 is preferred over FP16 for training β it never overflows on typical gradients.')
Code cell 15
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3.2 BF16 vs FP16: The Critical Difference
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
print('BF16 vs FP16 β Range and Overflow Behaviour')
print('=' * 65)
# Test values near FP16 overflow boundary
test_vals = [100, 1000, 10000, 65504, 65505, 70000, 100000]
print(f'{"Value":>10} {"FP16":>12} {"BF16":>12} {"FP32":>12}')
print('-' * 50)
for v in test_vals:
fp32 = torch.tensor(float(v), dtype=torch.float32)
fp16 = fp32.half()
bf16 = fp32.bfloat16()
print(f'{v:>10} {float(fp16):>12.1f} {float(bf16):>12.1f} {float(fp32):>12.1f}')
# Precision comparison near 1.0
print(f'\n--- Precision near 1.0 ---')
base = torch.tensor(1.0)
for delta_exp in range(-1, -8, -1):
delta = 2.0 ** delta_exp
fp16_result = float((base + delta).half())
bf16_result = float((base + delta).bfloat16())
fp16_ok = fp16_result != 1.0
bf16_ok = bf16_result != 1.0
print(f' 1.0 + 2^{delta_exp} ({delta:.6f}): '
f'FP16={"β" if fp16_ok else "β LOST":<8} '
f'BF16={"β" if bf16_ok else "β LOST":<8}')
print(f'\nβ FP16 has MORE precision (10 vs 7 mantissa bits) but LESS range')
print(f'β BF16 has LESS precision but the SAME range as FP32')
print(f'β For training: range wins (gradient magnitudes vary hugely)')
else:
print('PyTorch required for BF16/FP16 comparison')
Code cell 16
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3.3 FP8 Format Deep Dive
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def enumerate_fp8_values(exp_bits: int, man_bits: int) -> list:
"""Enumerate all representable positive normal values in an FP8-like format."""
bias = 2**(exp_bits - 1) - 1
values = []
for e in range(1, 2**exp_bits - 1): # skip 0 (subnormal) and all-1s (special)
for m in range(2**man_bits):
mantissa = 1.0 + m / (2**man_bits)
val = mantissa * (2 ** (e - bias))
values.append(val)
return sorted(values)
e4m3_vals = enumerate_fp8_values(4, 3)
e5m2_vals = enumerate_fp8_values(5, 2)
print('FP8 Representable Values β E4M3 vs E5M2')
print('=' * 60)
print(f' E4M3: {len(e4m3_vals)} positive normal values, range [{e4m3_vals[0]}, {e4m3_vals[-1]}]')
print(f' E5M2: {len(e5m2_vals)} positive normal values, range [{e5m2_vals[0]}, {e5m2_vals[-1]}]')
# Show values between 1.0 and 2.0 to compare precision
print(f'\nValues between 1.0 and 2.0 (precision comparison):')
e4m3_1to2 = [v for v in e4m3_vals if 1.0 <= v < 2.0]
e5m2_1to2 = [v for v in e5m2_vals if 1.0 <= v < 2.0]
print(f' E4M3 ({len(e4m3_1to2)} levels): {e4m3_1to2}')
print(f' E5M2 ({len(e5m2_1to2)} levels): {e5m2_1to2}')
print(f' E4M3 spacing: {e4m3_1to2[1]-e4m3_1to2[0]:.3f}')
print(f' E5M2 spacing: {e5m2_1to2[1]-e5m2_1to2[0]:.3f}')
print(f'\nβ E4M3 has 2Γ more levels (precision) β better for weights/activations')
print(f'β E5M2 has wider range β better for gradients (which vary more in magnitude)')
4. Integer Formats for AI β Quantization in Practice
Quantization maps floating-point values to integers:
Code cell 18
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4.1 Complete INT8 Quantization with Error Analysis
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def quantize_symmetric(x: np.ndarray, bits: int = 8) -> Tuple[np.ndarray, float]:
"""Symmetric quantization: zero maps to zero, scale based on max abs value."""
alpha = np.max(np.abs(x))
qmax = 2**(bits - 1) - 1 # 127 for INT8
scale = alpha / qmax if alpha > 0 else 1.0
x_q = np.clip(np.round(x / scale), -qmax - 1, qmax).astype(np.int8 if bits == 8 else np.int32)
return x_q, scale
def dequantize(x_q: np.ndarray, scale: float) -> np.ndarray:
"""Dequantize: INT β float approximation."""
return x_q.astype(np.float32) * scale
# Worked example from Β§9.1 of notes.md
weights = np.array([-0.45, 0.12, -0.03, 0.67, -0.89, 0.34], dtype=np.float32)
x_q, scale = quantize_symmetric(weights, bits=8)
reconstructed = dequantize(x_q, scale)
errors = np.abs(weights - reconstructed)
print('INT8 Symmetric Quantization β Worked Example')
print('=' * 70)
print(f'Scale factor s = max(|w|) / 127 = {np.max(np.abs(weights)):.2f} / 127 = {scale:.6f}')
print(f'\n{"Original":>10} {"w/s":>10} {"Rounded":>8} {"INT8":>6} {"Deq":>10} {"Error":>10}')
print('-' * 60)
for w, q, r, e in zip(weights, weights/scale, x_q, errors):
print(f'{w:>10.4f} {q:>10.2f} {np.round(q):>8.0f} {int(r):>6} {dequantize(np.array([r]), scale)[0]:>10.5f} {e:>10.5f}')
print(f'\nMax error: {np.max(errors):.5f} β s/2 = {scale/2:.5f}')
print(f'MSE: {np.mean(errors**2):.8f}')
print(f'Memory: {weights.nbytes} bytes (FP32) β {x_q.nbytes} bytes (INT8) = {weights.nbytes/x_q.nbytes:.0f}Γ compression')
Code cell 19
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4.2 Per-Tensor vs Per-Channel vs Per-Group Quantization
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
np.random.seed(42)
# Simulate a weight matrix with channels of VERY different magnitudes
# (this is realistic β transformer weights often have outlier channels)
W = np.random.randn(8, 256).astype(np.float32)
W[0] *= 10 # Channel 0: 10Γ larger (outlier)
W[1] *= 0.01 # Channel 1: 100Γ smaller
# Per-tensor: single scale for entire matrix
q_tensor, s_tensor = quantize_symmetric(W.flatten())
recon_tensor = dequantize(q_tensor, s_tensor).reshape(W.shape)
# Per-channel: one scale per output channel (row)
recon_channel = np.zeros_like(W)
for c in range(W.shape[0]):
q_c, s_c = quantize_symmetric(W[c])
recon_channel[c] = dequantize(q_c, s_c)
# Per-group: one scale per group of G elements
G = 64 # group size
recon_group = np.zeros_like(W)
for c in range(W.shape[0]):
for start in range(0, W.shape[1], G):
end = min(start + G, W.shape[1])
q_g, s_g = quantize_symmetric(W[c, start:end])
recon_group[c, start:end] = dequantize(q_g, s_g)
print('Quantization Granularity Comparison')
print('=' * 70)
print(f'{"Channel":>4} {"Weight Ο":>9} {"Per-Tensor MSE":>16} {"Per-Channel MSE":>16} {"Per-Group MSE":>14}')
print('-' * 65)
for c in range(W.shape[0]):
mse_t = np.mean((W[c] - recon_tensor[c])**2)
mse_c = np.mean((W[c] - recon_channel[c])**2)
mse_g = np.mean((W[c] - recon_group[c])**2)
flag = ' β outlier!' if c in [0, 1] else ''
print(f'{c:>4} {np.std(W[c]):>9.4f} {mse_t:>16.8f} {mse_c:>16.8f} {mse_g:>14.8f}{flag}')
total_t = np.mean((W - recon_tensor)**2)
total_c = np.mean((W - recon_channel)**2)
total_g = np.mean((W - recon_group)**2)
print(f'\nTotal MSE: Per-tensor={total_t:.6f} Per-channel={total_c:.6f} Per-group(64)={total_g:.6f}')
print(f'Per-channel is {total_t/total_c:.1f}Γ better than per-tensor')
print(f'Per-group is {total_t/total_g:.1f}Γ better than per-tensor')
print(f'\nβ Channel 1 (tiny weights) is crushed by per-tensor: the global scale is too large')
print(f'β Per-group gives the best results with modest overhead ({W.size // G} extra scale factors)')
Code cell 20
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4.3 INT4 Quantization β W4A16 Pipeline
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def quantize_int4_symmetric(x: np.ndarray) -> Tuple[np.ndarray, float]:
"""Symmetric INT4 quantization: range [-8, 7]."""
alpha = np.max(np.abs(x))
scale = alpha / 7.0 if alpha > 0 else 1.0
x_q = np.clip(np.round(x / scale), -8, 7).astype(np.int8) # store in int8 container
return x_q, scale
# Simulate INT4 weight quantization (W4A16 = 4-bit weights, 16-bit activations)
np.random.seed(42)
W = np.random.randn(128, 128).astype(np.float32) * 0.02 # typical weight scale
x = np.random.randn(1, 128).astype(np.float32) # input activation
# Full precision matmul
y_fp32 = x @ W.T
# INT4 quantized matmul (per-group, G=32)
G = 32
y_int4 = np.zeros((1, 128), dtype=np.float32)
for col_start in range(0, 128, G):
col_end = col_start + G
W_group = W[:, col_start:col_end]
x_group = x[:, col_start:col_end]
# Quantize weights to INT4 per group
for row in range(128):
q, s = quantize_int4_symmetric(W_group[row])
W_deq = dequantize(q, s)
y_int4[0, row] += (x_group @ W_deq.reshape(-1, 1)).item()
# Compare
mse = np.mean((y_fp32 - y_int4)**2)
relative_error = np.mean(np.abs(y_fp32 - y_int4) / (np.abs(y_fp32) + 1e-10))
print('W4A16 (INT4 weight, FP16 activation) Pipeline')
print('=' * 55)
print(f'Weight matrix: {W.shape}, stored in INT4 with group_size={G}')
print(f'FP32 memory: {W.nbytes:,} bytes')
print(f'INT4 memory: ~{W.size // 2:,} bytes (+ scale overhead)')
print(f'Compression: ~{W.nbytes / (W.size // 2):.0f}Γ')
print(f'\nOutput MSE: {mse:.8f}')
print(f'Mean relative error: {relative_error:.4%}')
print(f'Max |y_fp32|: {np.max(np.abs(y_fp32)):.4f}')
print(f'Max |error|: {np.max(np.abs(y_fp32 - y_int4)):.6f}')
print(f'\nβ Only 16 quantization levels! Yet relative error is small because')
print(f' the group-wise scale adapts to local weight distribution.')
5. Non-Uniform and Specialised Formats
Neural network weights follow a bell-shaped distribution (roughly Gaussian). Uniform quantization wastes levels on the sparse tails. Non-uniform formats place more levels where the data is dense.
Code cell 22
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5.1 NF4 β Normal Float 4-bit (QLoRA)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from scipy.stats import norm
def compute_nf4_levels(bits: int = 4) -> np.ndarray:
"""Compute NF4 quantisation levels as quantiles of N(0,1).
These are the optimal Lloyd-Max levels for a standard normal distribution."""
n_levels = 2**bits
# Quantiles at evenly spaced probabilities
probs = np.linspace(1/(2*n_levels), 1 - 1/(2*n_levels), n_levels)
levels = norm.ppf(probs) # inverse CDF of N(0,1)
# Normalise to [-1, 1]
levels = levels / np.max(np.abs(levels))
return levels
nf4_levels = compute_nf4_levels(4)
print('NF4 (Normal Float 4-bit) β The 16 Quantisation Levels')
print('=' * 65)
print('These are the optimal quantisation levels for normally-distributed data:')
print(f'\nLevels: {np.array2string(nf4_levels, precision=4, separator=", ")}')
# Visualise the distribution of levels
print(f'\nLevel distribution (density near zero is highest):')
for i, level in enumerate(nf4_levels):
bar = 'β' * int((level + 1) * 25)
print(f' {i:>2}: {level:>7.4f} {bar}')
# Compare NF4 vs uniform INT4 on normally-distributed weights
np.random.seed(42)
weights = np.random.randn(10000).astype(np.float32)
# NF4 quantization
w_norm = weights / np.max(np.abs(weights)) # normalise to [-1, 1]
nf4_q = np.array([nf4_levels[np.argmin(np.abs(nf4_levels - w))] for w in w_norm])
nf4_recon = nf4_q * np.max(np.abs(weights))
nf4_mse = np.mean((weights - nf4_recon)**2)
# Uniform INT4 quantization
int4_q, int4_s = quantize_int4_symmetric(weights)
int4_recon = dequantize(int4_q, int4_s)
int4_mse = np.mean((weights - int4_recon)**2)
print(f'\nQuantisation quality comparison (10K normally-distributed weights):')
print(f' NF4 MSE: {nf4_mse:.6f}')
print(f' INT4 MSE: {int4_mse:.6f}')
print(f' NF4 is {int4_mse/nf4_mse:.1f}Γ better for Gaussian data')
print(f' β NF4 places more levels near zero where most weights live')
Code cell 23
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5.2 Ternary Weights β BitNet b1.58 {-1, 0, +1}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def ternarize_weights(W: np.ndarray) -> Tuple[np.ndarray, float]:
"""Quantise weights to {-1, 0, +1} using mean absolute value as threshold."""
alpha = np.mean(np.abs(W)) # scale factor
W_ternary = np.zeros_like(W, dtype=np.int8)
W_ternary[W > alpha * 0.5] = 1
W_ternary[W < -alpha * 0.5] = -1
return W_ternary, alpha
# Simulate a ternary matmul vs full-precision
np.random.seed(42)
d = 512
W = np.random.randn(d, d).astype(np.float32) * 0.02
x = np.random.randn(1, d).astype(np.float32)
# Full precision
y_full = x @ W.T
# Ternary
W_tern, alpha = ternarize_weights(W)
y_tern = x @ (W_tern.astype(np.float32) * alpha).T
# Statistics
n_zero = np.sum(W_tern == 0)
n_pos = np.sum(W_tern == 1)
n_neg = np.sum(W_tern == -1)
total = W_tern.size
relative_error = np.mean(np.abs(y_full - y_tern) / (np.abs(y_full) + 1e-10))
print('Ternary Weights β BitNet b1.58 Simulation')
print('=' * 55)
print(f'Weight matrix: {W.shape}')
print(f'Scale factor Ξ± = mean(|W|) = {alpha:.6f}')
print(f'\nWeight distribution:')
print(f' -1: {n_neg:>6} ({n_neg/total:>6.1%})')
print(f' 0: {n_zero:>6} ({n_zero/total:>6.1%})')
print(f' +1: {n_pos:>6} ({n_pos/total:>6.1%})')
print(f'\nAverage entropy per weight: {-sum(p*np.log2(p+1e-10) for p in [n_neg/total, n_zero/total, n_pos/total]):.2f} bits')
print(f' β logβ(3) = {np.log2(3):.2f} bits (1.58-bit, hence the name)')
print(f'\nOutput relative error: {relative_error:.4%}')
print(f'\nMatmul cost comparison:')
print(f' FP32: {d*d} multiply-accumulate operations')
print(f' Ternary: {n_pos + n_neg} additions only (no multiplications!)')
print(f' FLOPs reduction: {(total) / (n_pos + n_neg):.1f}Γ (only additions, no multiplies)')
6. Floating-Point Arithmetic Deep Dive
Understanding the mechanics of FP addition and multiplication explains why catastrophic cancellation occurs and why FMA (fused multiply-add) matters.
Code cell 25
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6.1 Catastrophic Cancellation
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Catastrophic Cancellation β when subtraction destroys precision')
print('=' * 65)
# Example: computing variance with naive formula vs stable formula
# Naive: var = E[xΒ²] - (E[x])Β² β catastrophic cancellation when E[x] >> std(x)
# Stable: var = E[(x - mean(x))Β²] β no cancellation
np.random.seed(42)
# Data with large mean, small variance (extreme case)
data = np.float32(1e6) + np.random.randn(10000).astype(np.float32) * 0.01
true_var = np.float64(np.var(data.astype(np.float64))) # ground truth in FP64
# Naive formula in FP32
mean_sq = np.float32(np.mean(data.astype(np.float32)**2))
sq_mean = np.float32(np.mean(data.astype(np.float32)))**2
naive_var = np.float32(mean_sq - sq_mean)
# Stable formula in FP32
centered = data - np.float32(np.mean(data))
stable_var = np.float32(np.mean(centered**2))
print(f'Data: {len(data)} values ~ N(1000000, 0.01Β²)')
print(f'True variance (FP64): {true_var:.10f}')
print(f'\nNaive var = E[xΒ²] - E[x]Β²:')
print(f' E[xΒ²] = {mean_sq:.6f}')
print(f' E[x]Β² = {sq_mean:.6f}')
print(f' Difference = {naive_var:.6f}') # garbage!
print(f' Relative error: {abs(naive_var - true_var) / true_var:.2%}')
print(f'\nStable var = E[(x - ΞΌ)Β²]:')
print(f' Result = {stable_var:.10f}')
print(f' Relative error: {abs(stable_var - true_var) / true_var:.2%}')
print(f'\nβ The naive formula subtracts two nearly-equal large numbers')
print(f'β Most significant bits cancel, leaving only rounding noise')
print(f'β This is exactly what happens in LayerNorm if not implemented carefully')
Code cell 26
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6.2 Kahan Summation Algorithm
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def naive_sum_fp32(values: np.ndarray) -> float:
"""Simple left-to-right summation in FP32."""
total = np.float32(0.0)
for v in values:
total = np.float32(total + np.float32(v))
return float(total)
def kahan_sum(values: np.ndarray) -> float:
"""Kahan compensated summation β tracks and corrects rounding error."""
total = np.float32(0.0)
comp = np.float32(0.0) # running compensation for lost bits
for v in values:
v = np.float32(v)
y = np.float32(v - comp) # add back last error
temp = np.float32(total + y) # large + small β error here
comp = np.float32(np.float32(temp - total) - y) # capture what was lost
total = temp
return float(total)
def pairwise_sum(values: np.ndarray) -> float:
"""Pairwise summation (what NumPy uses internally)."""
if len(values) <= 2:
return float(np.float32(np.sum(values.astype(np.float32))))
mid = len(values) // 2
return float(np.float32(pairwise_sum(values[:mid]) + pairwise_sum(values[mid:])))
# Test: sum 1,000,000 small values
n = 1_000_000
vals = np.full(n, 1e-5, dtype=np.float32)
true_sum = n * 1e-5 # exact: 10.0
r_naive = naive_sum_fp32(vals)
r_kahan = kahan_sum(vals)
r_numpy = float(np.sum(vals))
print(f'Summation of {n:,} Γ 1e-5 (expected: {true_sum})')
print('=' * 55)
print(f'{"Method":<20} {"Result":>12} {"Error":>12} {"Rel Error":>12}')
print('-' * 58)
for name, result in [('Naive FP32', r_naive), ('Kahan (compensated)', r_kahan), ('NumPy (pairwise)', r_numpy)]:
err = abs(result - true_sum)
rel = err / true_sum
print(f'{name:<20} {result:>12.6f} {err:>12.6f} {rel:>12.2e}')
print(f'\nβ Kahan summation reduces error by {abs(r_naive - true_sum) / max(abs(r_kahan - true_sum), 1e-15):.0f}Γ')
print(f'β This is essential for gradient accumulation across large batches')
Code cell 27
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6.3 BF16 Dot Product Accumulation Error
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
# Demonstrate why FP32 accumulation is critical for BF16 matmul
d = 4096 # typical transformer hidden dimension
torch.manual_seed(42)
a = torch.randn(d, dtype=torch.float32)
b = torch.randn(d, dtype=torch.float32)
# Ground truth in FP32
dot_fp32 = torch.dot(a, b)
# BF16 with BF16 accumulation (BAD)
a_bf16 = a.bfloat16()
b_bf16 = b.bfloat16()
# Simulate BF16 accumulation (Python loop)
acc_bf16 = torch.tensor(0.0, dtype=torch.bfloat16)
for i in range(min(d, 512)): # first 512 for speed
acc_bf16 += a_bf16[i] * b_bf16[i]
# Scale up for full vector
dot_bf16_acc = float(acc_bf16) * (d / 512)
# BF16 with FP32 accumulation (GOOD β what tensor cores do)
acc_fp32 = torch.tensor(0.0, dtype=torch.float32)
for i in range(min(d, 512)):
product = float(a_bf16[i]) * float(b_bf16[i]) # BF16 multiply
acc_fp32 += product # FP32 accumulate
dot_bf16_fp32_acc = float(acc_fp32) * (d / 512)
print('Dot Product Accumulation Precision')
print('=' * 55)
print(f'Vector dimension: {d}')
print(f'FP32 dot product (truth): {float(dot_fp32):.4f}')
print(f'BF16 with BF16 accumulation: {dot_bf16_acc:.4f} '
f'(error: {abs(dot_bf16_acc - float(dot_fp32)) / abs(float(dot_fp32)):.1%})')
print(f'BF16 with FP32 accumulation: {dot_bf16_fp32_acc:.4f} '
f'(error: {abs(dot_bf16_fp32_acc - float(dot_fp32)) / abs(float(dot_fp32)):.1%})')
print(f'\nβ BF16 accumulation has ~{abs(dot_bf16_acc - float(dot_fp32)) / max(abs(dot_bf16_fp32_acc - float(dot_fp32)), 1e-10):.0f}Γ more error than FP32 accumulation')
print(f'β This is why tensor cores always accumulate in FP32')
else:
print('PyTorch required for BF16 dot product demo')
7. Numerical Stability in Neural Networks
The most common sources of training crashes in LLMs are:
- Softmax overflow β attention logits too large
- Log-sum-exp overflow β cross-entropy loss computation
- LayerNorm cancellation β subtracting mean from nearly-identical values
- Gradient vanishing β underflow in FP16 during backward pass
Code cell 29
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7.1 Numerically Stable Softmax β The Max-Subtraction Trick
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def naive_softmax(z: np.ndarray) -> np.ndarray:
"""Naive softmax β overflows for large logits."""
exp_z = np.exp(z)
return exp_z / np.sum(exp_z)
def stable_softmax(z: np.ndarray) -> np.ndarray:
"""Numerically stable softmax β used in all production code."""
m = np.max(z)
exp_z = np.exp(z - m) # max exponent is e^0 = 1
return exp_z / np.sum(exp_z)
# Normal case β both work
z_normal = np.array([1.0, 2.0, 3.0, 4.0])
print('Softmax: Naive vs Stable')
print('=' * 60)
print(f'Normal logits {z_normal}:')
print(f' Naive: {naive_softmax(z_normal)}')
print(f' Stable: {stable_softmax(z_normal)}')
# Large logits β naive breaks!
z_large = np.array([88.5, 88.7, 88.3, 88.6])
print(f'\nLarge logits {z_large} (near FP32 overflow for exp):')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
naive_result = naive_softmax(z_large)
print(f' Naive: {naive_result} {"β overflow!" if np.any(np.isnan(naive_result)) else ""}')
print(f' Stable: {stable_softmax(z_large)}')
# Extreme logits β definitely breaks
z_extreme = np.array([1000.0, 1001.0, 999.0])
print(f'\nExtreme logits {z_extreme}:')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
naive_result = naive_softmax(z_extreme)
print(f' Naive: {naive_result} β NaN from inf/inf!')
print(f' Stable: {stable_softmax(z_extreme)}')
# Mathematical proof
print(f'\n--- Mathematical equivalence proof ---')
print(f' softmax(z-m)_i = exp(z_i - m) / Ξ£ exp(z_j - m)')
print(f' = exp(z_i)Β·e^(-m) / [e^(-m) Β· Ξ£ exp(z_j)]')
print(f' = exp(z_i) / Ξ£ exp(z_j) β e^(-m) cancels!')
print(f' The trick is mathematically exact, numerically essential.')
Code cell 30
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7.2 Log-Sum-Exp Trick β Essential for Cross-Entropy Loss
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def naive_logsumexp(z: np.ndarray) -> float:
"""Naive: log(sum(exp(z))) β overflows."""
return float(np.log(np.sum(np.exp(z))))
def stable_logsumexp(z: np.ndarray) -> float:
"""Stable: m + log(sum(exp(z - m))) β no overflow."""
m = np.max(z)
return float(m + np.log(np.sum(np.exp(z - m))))
# Normal case
z = np.array([1.0, 2.0, 3.0])
print('Log-Sum-Exp Trick')
print('=' * 55)
print(f'z = {z}')
print(f' Naive: {naive_logsumexp(z):.6f}')
print(f' Stable: {stable_logsumexp(z):.6f}')
# Overflow case
z_big = np.array([1000.0, 1001.0, 1002.0])
print(f'\nz = {z_big}')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
print(f' Naive: {naive_logsumexp(z_big)} β overflow!')
print(f' Stable: {stable_logsumexp(z_big):.6f}')
# Cross-entropy loss uses log-sum-exp
print(f'\n--- Cross-Entropy Loss = LSE(z) - z_target ---')
vocab_size = 50000
np.random.seed(42)
logits = np.random.randn(vocab_size).astype(np.float32) * 5 # logits with std=5
target = 42 # target token
lse = stable_logsumexp(logits)
loss = lse - logits[target]
print(f' Vocab size: {vocab_size:,}')
print(f' LSE(logits): {lse:.4f}')
print(f' logits[target]: {logits[target]:.4f}')
print(f' Cross-entropy loss: {loss:.4f}')
print(f' β This computation happens millions of times per training step')
print(f' β NEVER compute log(sum(exp(z))) directly!')
Code cell 31
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7.3 LayerNorm vs RMSNorm β Numerical Comparison
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def layernorm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float = 1e-5) -> np.ndarray:
"""Standard LayerNorm: y = (x - ΞΌ) / β(ΟΒ² + Ξ΅) Β· Ξ³ + Ξ²"""
mu = np.mean(x)
var = np.mean((x - mu)**2)
return gamma * (x - mu) / np.sqrt(var + eps) + beta
def rmsnorm(x: np.ndarray, gamma: np.ndarray, eps: float = 1e-5) -> np.ndarray:
"""RMSNorm: y = x / β(mean(xΒ²) + Ξ΅) Β· Ξ³ (no mean subtraction, no beta)"""
rms = np.sqrt(np.mean(x**2) + eps)
return gamma * x / rms
d = 128
gamma = np.ones(d, dtype=np.float32)
beta = np.zeros(d, dtype=np.float32)
print('LayerNorm vs RMSNorm β Numerical Properties')
print('=' * 65)
# Test 1: Normal input β both work fine
np.random.seed(42)
x_normal = np.random.randn(d).astype(np.float32)
ln_out = layernorm(x_normal, gamma, beta)
rms_out = rmsnorm(x_normal, gamma)
print(f'Normal input (ΞΌβ0, Οβ1):')
print(f' LayerNorm output mean: {np.mean(ln_out):.6f}, std: {np.std(ln_out):.6f}')
print(f' RMSNorm output mean: {np.mean(rms_out):.6f}, std: {np.std(rms_out):.6f}')
# Test 2: Large-offset input β catastrophic cancellation risk
x_offset = x_normal + 1e6 # huge mean, small variance
# FP32: both OK
ln_offset = layernorm(x_offset, gamma, beta)
rms_offset = rmsnorm(x_offset, gamma)
print(f'\nLarge-offset input (ΞΌβ10βΆ, Οβ1) in FP32:')
print(f' LayerNorm mean: {np.mean(ln_offset):.6f}, std: {np.std(ln_offset):.6f}')
print(f' RMSNorm mean: {np.mean(rms_offset):.8f}, std: {np.std(rms_offset):.8f}')
# Test 3: Simulate BF16 precision for mean computation
# BF16 has ~7 mantissa bits β relative precision 2^(-7) β 0.78%
def simulate_bf16_mean(x):
"""Simulate BF16 precision loss in mean computation."""
# Each value rounded to ~2 decimal digits of precision
precision = 2**(-7) # BF16 mantissa precision
x_bf16 = np.round(x / (np.abs(x) * precision + 1e-38)) * (np.abs(x) * precision + 1e-38)
return np.mean(x_bf16)
print(f'\nWhy RMSNorm is numerically superior:')
print(f' LayerNorm computes x - ΞΌ β catastrophic cancellation when x β ΞΌ')
print(f' RMSNorm computes xΒ² β sum of positive values β no cancellation')
print(f' β RMSNorm is used in LLaMA, Mistral, Gemma, and most 2024+ architectures')
Code cell 32
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7.4 Gradient Vanishing/Exploding β Numerical View
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Gradient Flow Through L Layers')
print('=' * 65)
print(f'If each layer\'s Jacobian has dominant eigenvalue Ξ»:')
print(f' Gradient magnitude β Ξ»^L\n')
print(f'{"Ξ»":>6} {"L=10":>12} {"L=50":>12} {"L=100":>12} {"L=200":>12} {"Status":>20}')
print('-' * 80)
for lam in [1.1, 1.01, 1.001, 1.0, 0.999, 0.99, 0.9]:
vals = [lam**L for L in [10, 50, 100, 200]]
if vals[-1] > 1e38:
status = 'π₯ EXPLODES (overflow)'
elif vals[-1] < 1e-10:
status = 'π VANISHES (underflow)'
elif vals[-1] < 0.01:
status = 'β Very small'
else:
status = 'β Stable'
print(f'{lam:>6.3f} {vals[0]:>12.4e} {vals[1]:>12.4e} {vals[2]:>12.4e} {vals[3]:>12.4e} {status:>20}')
print(f'\n--- Residual Connections Fix This ---')
print(f' Without residual: x_l = f(x_{{l-1}}) β Jacobian eigenvalue = Ξ»_f')
print(f' With residual: x_l = x_{{l-1}} + f(x_{{l-1}}) β Jacobian eigenvalue = 1 + Ξ»_f')
print(f' Even if Ξ»_f is small, 1 + Ξ»_f β 1 β gradient flows stably')
# FP16 underflow demonstration
print(f'\n--- FP16 Underflow Zone ---')
fp16_min = np.finfo(np.float16).tiny # smallest positive normal
print(f' FP16 smallest positive normal: {fp16_min:.2e}')
print(f' BF16 smallest positive normal: ~1.2e-38')
print(f' Gradient of 1e-5 in FP16: {np.float16(1e-5)} β {"OK" if np.float16(1e-5) != 0 else "LOST!"}')
print(f' Gradient of 1e-6 in FP16: {np.float16(1e-6)} β {"OK" if np.float16(1e-6) != 0 else "LOST!"}')
print(f' β FP16 silently kills gradients < 6e-5 β training stalls without any error message')
8. Quantization Mathematics
Beyond the mechanics of quantization (Β§4), this section covers the mathematical theory:
- Signal-to-quantization-noise ratio (SQNR)
- Lloyd-Max optimal quantization
- Hadamard transform for outlier suppression
- Error propagation through multiple layers
Code cell 34
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 8.1 Signal-to-Quantization-Noise Ratio (SQNR)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def compute_sqnr(x: np.ndarray, x_hat: np.ndarray) -> float:
"""Compute SQNR in dB."""
signal_power = np.mean(x**2)
noise_power = np.mean((x - x_hat)**2)
if noise_power == 0:
return float('inf')
return 10 * np.log10(signal_power / noise_power)
print('SQNR vs Bit Width β The 6 dB/bit Rule')
print('=' * 60)
np.random.seed(42)
# Uniformly distributed signal (theoretical case)
x_uniform = np.random.uniform(-1, 1, 100000).astype(np.float32)
# Normally distributed signal (realistic weights)
x_normal = np.random.randn(100000).astype(np.float32)
print(f'{"Bits":>5} {"Theory (dB)":>12} {"Uniform (dB)":>13} {"Normal (dB)":>12} {"Noise ratio":>12}')
print('-' * 60)
for bits in [1, 2, 3, 4, 6, 8, 16]:
theory = 6.02 * bits
# Quantize uniform signal
q_u, s_u = quantize_symmetric(x_uniform, bits=min(bits, 8))
if bits > 8:
# For >8 bits, simulate with scaled INT16
alpha = np.max(np.abs(x_uniform))
qmax = 2**(bits-1) - 1
scale = alpha / qmax
q_u = np.clip(np.round(x_uniform / scale), -qmax-1, qmax)
r_u = q_u * scale
else:
r_u = dequantize(q_u, s_u)
sqnr_u = compute_sqnr(x_uniform, r_u)
# Quantize normal signal
if bits <= 8:
q_n, s_n = quantize_symmetric(x_normal, bits=bits)
r_n = dequantize(q_n, s_n)
else:
alpha = np.max(np.abs(x_normal))
qmax = 2**(bits-1) - 1
scale = alpha / qmax
q_n = np.clip(np.round(x_normal / scale), -qmax-1, qmax)
r_n = q_n * scale
sqnr_n = compute_sqnr(x_normal, r_n)
noise_ratio = 10**(-sqnr_u / 10) * 100
print(f'{bits:>5} {theory:>12.1f} {sqnr_u:>13.1f} {sqnr_n:>12.1f} {noise_ratio:>11.4f}%')
print(f'\nβ Each additional bit adds ~6 dB of SQNR (= 4Γ less noise power)')
print(f'β Normal distribution SQNR is slightly worse because outliers waste quantisation levels')
Code cell 35
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 8.2 Lloyd-Max Optimal Quantisation
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def lloyd_max(data: np.ndarray, n_levels: int, max_iter: int = 100) -> Tuple[np.ndarray, np.ndarray]:
"""Lloyd-Max algorithm: find optimal quantisation levels for given data.
Returns (levels, boundaries)."""
# Initialise levels uniformly
levels = np.linspace(np.min(data), np.max(data), n_levels)
for iteration in range(max_iter):
# Step 1: boundaries = midpoints between adjacent levels
boundaries = (levels[:-1] + levels[1:]) / 2
# Step 2: levels = centroids of data in each bin
new_levels = np.zeros_like(levels)
all_boundaries = np.concatenate([[-np.inf], boundaries, [np.inf]])
for i in range(n_levels):
mask = (data >= all_boundaries[i]) & (data < all_boundaries[i+1])
if np.sum(mask) > 0:
new_levels[i] = np.mean(data[mask])
else:
new_levels[i] = levels[i]
if np.allclose(levels, new_levels, atol=1e-8):
break
levels = new_levels
return levels, boundaries
# Optimal quantisation for normal distribution
np.random.seed(42)
data = np.random.randn(100000).astype(np.float64)
print('Lloyd-Max Optimal Quantisation for N(0,1)')
print('=' * 65)
for bits in [2, 3, 4]:
n_levels = 2**bits
levels, boundaries = lloyd_max(data, n_levels)
# Quantise and compute MSE
indices = np.digitize(data, boundaries)
quantized = levels[indices]
mse_lm = np.mean((data - quantized)**2)
# Compare with uniform quantisation
alpha = np.max(np.abs(data))
uniform_levels = np.linspace(-alpha, alpha, n_levels)
uniform_boundaries = (uniform_levels[:-1] + uniform_levels[1:]) / 2
u_indices = np.digitize(data, uniform_boundaries)
u_quantized = uniform_levels[u_indices]
mse_uniform = np.mean((data - u_quantized)**2)
print(f'\n{bits}-bit ({n_levels} levels):')
print(f' Lloyd-Max levels: {np.array2string(levels, precision=3, separator=", ")}')
print(f' Lloyd-Max MSE: {mse_lm:.6f}')
print(f' Uniform MSE: {mse_uniform:.6f}')
print(f' Improvement: {mse_uniform/mse_lm:.2f}Γ better')
print(f'\nβ At 4-bit, Lloyd-Max levels closely match the NF4 levels used in QLoRA!')
print(f'β Non-uniform quantisation significantly outperforms uniform for bell-shaped distributions')
Code cell 36
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 8.3 Hadamard Transform for Outlier Suppression (QuIP/QuaRot)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def hadamard_matrix(n: int) -> np.ndarray:
"""Construct normalised Hadamard matrix of size n (n must be power of 2)."""
if n == 1:
return np.array([[1.0]])
H_half = hadamard_matrix(n // 2)
H = np.block([[H_half, H_half], [H_half, -H_half]]) / np.sqrt(2)
return H
# Demonstrate outlier suppression
np.random.seed(42)
d = 64 # dimension
# Create a weight vector with outliers (realistic: some channels are 10-100Γ larger)
w = np.random.randn(d).astype(np.float64) * 0.1
w[0] = 10.0 # massive outlier
w[1] = -8.0 # another outlier
w[2] = 5.0 # moderate outlier
H = hadamard_matrix(d)
# Rotate weights
w_rotated = H @ w
print('Hadamard Transform for Outlier Suppression')
print('=' * 65)
print(f'Original weights:')
print(f' Max |w|: {np.max(np.abs(w)):.4f}')
print(f' Min |w|: {np.min(np.abs(w)):.4f}')
print(f' Max/Min ratio: {np.max(np.abs(w)) / np.min(np.abs(w[w != 0])):.1f}Γ')
print(f' Std of |w|: {np.std(np.abs(w)):.4f}')
print(f'\nAfter Hadamard rotation (w\' = Hw):')
print(f' Max |w\'|: {np.max(np.abs(w_rotated)):.4f}')
print(f' Min |w\'|: {np.min(np.abs(w_rotated)):.4f}')
print(f' Max/Min ratio: {np.max(np.abs(w_rotated)) / np.min(np.abs(w_rotated[w_rotated != 0])):.1f}Γ')
print(f' Std of |w\'|: {np.std(np.abs(w_rotated)):.4f}')
# Verify orthogonality: H @ H^T = I
identity_check = H @ H.T
print(f'\nOrthogonality check: ||HH^T - I|| = {np.linalg.norm(identity_check - np.eye(d)):.2e} (should be ~0)')
# Quantise both and compare
q_orig, s_orig = quantize_int4_symmetric(w.astype(np.float32))
r_orig = dequantize(q_orig, s_orig)
mse_orig = np.mean((w.astype(np.float32) - r_orig)**2)
q_rot, s_rot = quantize_int4_symmetric(w_rotated.astype(np.float32))
r_rot = dequantize(q_rot, s_rot)
w_recon = (H.T @ r_rot).astype(np.float32) # rotate back
mse_hadamard = np.mean((w.astype(np.float32) - w_recon)**2)
print(f'\nINT4 quantisation comparison:')
print(f' Direct quantisation MSE: {mse_orig:.6f}')
print(f' Hadamard + quantise MSE: {mse_hadamard:.6f}')
print(f' Improvement: {mse_orig / mse_hadamard:.1f}Γ')
print(f'\nβ Hadamard "spreads" outliers across all dimensions β more uniform range β better quantisation')
Code cell 37
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 8.4 Error Propagation Through Layers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
np.random.seed(42)
d = 256
n_layers = 12
# Create a simple multi-layer network: y = W_L ... W_2 W_1 x
# Each weight matrix is quantised to INT8
x = np.random.randn(d).astype(np.float32)
x = x / np.linalg.norm(x) # normalise input
# Create weight matrices (orthogonal for stability)
weights = []
for _ in range(n_layers):
W = np.random.randn(d, d).astype(np.float32)
U, _, Vt = np.linalg.svd(W, full_matrices=False)
weights.append(U @ Vt) # orthogonal matrix (preserves norms)
# Forward pass: exact vs quantised for different bit widths
print('Error Propagation Through Quantised Layers')
print('=' * 70)
print(f'{n_layers} layers, dimension {d}')
print(f'\n{"Bits":>5} {"After L=1":>12} {"After L=4":>12} {"After L=8":>12} {"After L=12":>13}')
print('-' * 60)
for bits in [2, 4, 8, 16]:
# Exact forward pass
y_exact = x.copy()
for W in weights:
y_exact = W @ y_exact
# Quantised forward pass with error tracking
y_quant = x.copy()
errors = []
y_ex = x.copy()
for i, W in enumerate(weights):
# Quantise weight matrix
if bits <= 8:
W_flat_q, W_s = quantize_symmetric(W.flatten(), bits=bits)
W_deq = dequantize(W_flat_q, W_s).reshape(W.shape)
else:
alpha = np.max(np.abs(W))
qmax = 2**(bits-1) - 1
scale = alpha / qmax
W_q = np.clip(np.round(W / scale), -qmax-1, qmax)
W_deq = W_q * scale
y_quant = W_deq @ y_quant
y_ex = W @ y_ex
if (i+1) in [1, 4, 8, 12]:
rel_error = np.linalg.norm(y_quant - y_ex) / (np.linalg.norm(y_ex) + 1e-10)
errors.append(rel_error)
error_strs = [f'{e:>12.4e}' for e in errors]
print(f'{bits:>5} {" ".join(error_strs)}')
print(f'\nβ Error grows roughly linearly with number of layers (first-order approximation)')
print(f'β INT4 after 12 layers: significant error β keep first/last layers in higher precision')
print(f'β INT8 error stays manageable even through 12 layers')
9. Mixed Precision Training β Complete Pipeline
The standard recipe for all large-scale LLM training since 2020:
- FP32 master weights (never lost, never quantised)
- BF16 forward/backward (fast, good range)
- FP32 optimizer states (Adam m and v need high precision)
- FP32 loss (cross-entropy needs log/exp precision)
Code cell 39
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 9.1 Simulating Mixed Precision Training
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
def simulate_training(model, dtype_forward, use_fp32_master=True, steps=200, lr=0.01):
"""Simulate training with different precision configurations."""
torch.manual_seed(42)
# Simple target function: y = sin(x)
x_train = torch.linspace(-3, 3, 100).unsqueeze(1)
y_train = torch.sin(x_train)
if use_fp32_master:
master_params = [p.clone().float() for p in model.parameters()]
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
losses = []
for step in range(steps):
if use_fp32_master:
# Copy FP32 master β working precision
for p, mp in zip(model.parameters(), master_params):
p.data = mp.data.to(dtype_forward)
# Forward pass in specified precision
x_in = x_train.to(dtype_forward)
y_pred = model(x_in)
loss = F.mse_loss(y_pred, y_train.to(dtype_forward))
losses.append(loss.float().item())
# Backward
optimizer.zero_grad()
loss.backward()
if use_fp32_master:
# Update FP32 master weights with FP32 gradients
for mp, p in zip(master_params, model.parameters()):
mp.data -= lr * p.grad.float()
else:
optimizer.step()
return losses
# Compare different precision configurations
configs = [
('FP32 (baseline)', torch.float32, True),
('BF16 + FP32 master', torch.bfloat16, True),
('BF16 without master', torch.bfloat16, False),
('FP16 + FP32 master', torch.float16, True),
]
print('Mixed Precision Training Comparison')
print('=' * 70)
print(f'{"Config":<25} {"Final Loss":>12} {"Min Loss":>12} {"Converged?":>12}')
print('-' * 65)
for name, dtype, use_master in configs:
model = nn.Sequential(nn.Linear(1, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 1))
losses = simulate_training(model, dtype, use_master, steps=200, lr=0.005)
final = losses[-1]
min_loss = min(losses)
converged = final < 0.1
print(f'{name:<25} {final:>12.6f} {min_loss:>12.6f} {"β" if converged else "β":>12}')
print(f'\nβ BF16 + FP32 master matches FP32 baseline')
print(f'β BF16 without FP32 master may diverge or stall on longer training')
print(f'β FP16 + FP32 master also works (with loss scaling for larger models)')
else:
print('PyTorch required for mixed precision training demo')
Code cell 40
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 9.2 Memory Budget Calculator
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def memory_budget(n_params: float, mode: str = 'train_bf16') -> dict:
"""Calculate memory requirements for different configurations."""
configs = {
'train_fp32': {
'weights': 4, # FP32
'adam_m': 4, # FP32
'adam_v': 4, # FP32
'gradients': 4, # FP32
'activations': 4, # FP32 (estimate: ~1Γ params)
},
'train_bf16': {
'weights_master': 4, # FP32 master
'weights_working': 2, # BF16
'adam_m': 4, # FP32
'adam_v': 4, # FP32
'gradients': 2, # BF16
'activations': 2, # BF16
},
'qlora_nf4': {
'base_weights': 0.5, # NF4 (4 bits)
'lora_weights': 0.01, # BF16 but << 1% of params
'adam_m': 0.01, # only for LoRA params
'adam_v': 0.01, # only for LoRA params
'gradients': 0.01,
'activations': 2, # BF16
},
'inference_bf16': {'weights': 2},
'inference_int8': {'weights': 1},
'inference_int4': {'weights': 0.5},
}
config = configs[mode]
total_bytes = sum(v * n_params for v in config.values())
return {
'components': {k: v * n_params / 1e9 for k, v in config.items()},
'total_gb': total_bytes / 1e9
}
# Calculate for common model sizes
print('Memory Budget Calculator')
print('=' * 80)
for model_name, n_params in [('7B', 7e9), ('13B', 13e9), ('70B', 70e9), ('405B', 405e9)]:
print(f'\n--- {model_name} Parameters ---')
print(f'{"Mode":<22} {"Total (GB)":>10} {"Fits in 80GB?":>14} {"GPUs needed":>12}')
print('-' * 62)
for mode in ['train_fp32', 'train_bf16', 'qlora_nf4', 'inference_bf16', 'inference_int8', 'inference_int4']:
budget = memory_budget(n_params, mode)
total = budget['total_gb']
fits = total <= 80
gpus = max(1, int(np.ceil(total / 80)))
print(f'{mode:<22} {total:>10.1f} {"β" if fits else "β":>14} {gpus:>12}')
print(f'\nβ QLoRA enables 70B fine-tuning on a single 80GB GPU!')
print(f'β INT4 inference enables 70B serving on a single GPU')
print(f'β Full FP32 training of 70B requires 14+ GPUs just for weight storage')
10. Hardware & Memory Analysis
For LLM inference, memory bandwidth (not compute) is the bottleneck. Smaller number formats directly reduce memory traffic β direct speedup.
Code cell 42
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 10.1 Arithmetic Intensity Analysis
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Arithmetic Intensity β Why Quantisation Speeds Up Inference')
print('=' * 70)
# For autoregressive generation (batch_size=1), each token requires:
# - Loading full weight matrix: dΒ² Γ bytes_per_weight
# - Computing matrix-vector multiply: 2 Γ dΒ² FLOPs
# Arithmetic intensity = FLOPs / bytes = 2 / bytes_per_weight
# H100 specs
h100_bandwidth = 3.35e12 # bytes/sec (HBM3)
h100_fp16_flops = 989.5e12 # TFLOPS
h100_int8_ops = 1979e12 # TOPS
# Compute roofline
print(f'H100 SXM5 Specifications:')
print(f' HBM3 Bandwidth: {h100_bandwidth/1e12:.2f} TB/s')
print(f' FP16 Compute: {h100_fp16_flops/1e12:.1f} TFLOPS')
print(f' INT8 Compute: {h100_int8_ops/1e12:.0f} TOPS')
ridge_point = h100_fp16_flops / h100_bandwidth
print(f' Ridge point: {ridge_point:.1f} FLOPs/byte')
print(f'\n{"Format":>10} {"Bytes":>6} {"Arith Intensity":>16} {"Bound":>10} {"Theoretical speedup":>20}')
print('-' * 68)
for name, bpw in [('FP32', 4), ('BF16', 2), ('INT8', 1), ('INT4', 0.5), ('INT2', 0.25)]:
ai = 2 / bpw
bound = 'Compute' if ai >= ridge_point else 'Memory'
# Speedup relative to FP32 (memory-bound)
speedup = 4 / bpw # linear with compression ratio when memory-bound
print(f'{name:>10} {bpw:>6.1f} {ai:>16.1f} {bound:>10} {speedup:>18.0f}Γ')
print(f'\nβ For single-token generation, ALL formats are memory-bound on H100')
print(f'β This means reducing weight precision from BF16βINT4 gives ~4Γ real speedup')
print(f'β The speedup comes from less data to load, not faster compute')
Code cell 43
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 10.2 Energy Cost Analysis
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Approximate energy per operation (7nm process, picojoules)
energy_pj = {
'INT1 XNOR': 0.02,
'INT4 MAC': 0.05,
'INT8 MAC': 0.2,
'FP8 FMA': 0.4,
'BF16 FMA': 0.8,
'FP16 FMA': 1.0,
'FP32 FMA': 3.7,
'DRAM read (64B)': 12.5,
}
print('Energy Cost per Operation (7nm process)')
print('=' * 55)
print(f'{"Operation":<18} {"Energy (pJ)":>12} {"Relative to INT8":>18}')
print('-' * 50)
ref = energy_pj['INT8 MAC']
for op, pj in energy_pj.items():
print(f'{op:<18} {pj:>12.2f} {pj/ref:>18.1f}Γ')
print(f'\nβ‘ Key insight: DRAM access costs {energy_pj["DRAM read (64B)"]/energy_pj["INT8 MAC"]:.0f}Γ more '
f'energy than an INT8 MAC!')
print(f'β Most inference energy is spent moving data, not computing')
print(f'β Smaller formats reduce BOTH compute AND memory energy')
# Estimate energy for one forward pass of a 70B model
n_params = 70e9
n_ops = 2 * n_params # ~2 FLOPs per parameter for a forward pass
print(f'\nEstimated energy for one 70B forward pass (single token):')
print(f'{"Format":<10} {"Compute (J)":>12} {"Memory (J)":>12} {"Total (J)":>12} {"Relative":>10}')
print('-' * 60)
for name, compute_pj, bpw in [
('FP32', 3.7, 4), ('BF16', 0.8, 2), ('INT8', 0.2, 1), ('INT4', 0.05, 0.5)]:
compute_j = n_ops * compute_pj * 1e-12
memory_bytes = n_params * bpw
memory_j = (memory_bytes / 64) * energy_pj['DRAM read (64B)'] * 1e-12
total_j = compute_j + memory_j
print(f'{name:<10} {compute_j:>12.1f} {memory_j:>12.1f} {total_j:>12.1f} {total_j/(2*70e9*3.7e-12 + (70e9*4/64)*12.5e-12):>10.2f}Γ')
11. Training Stability & Precision
Precision failures manifest as training failures: loss spikes, NaN crashes, or silent divergence. Understanding the numerical mechanisms enables prevention.
Code cell 45
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 11.1 Stochastic Rounding β Unbiased Low-Precision Updates
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def round_to_nearest(x: float, resolution: float) -> float:
"""Deterministic round-to-nearest."""
return round(x / resolution) * resolution
def stochastic_round(x: float, resolution: float) -> float:
"""Stochastic rounding: E[SR(x)] = x."""
lower = np.floor(x / resolution) * resolution
upper = lower + resolution
prob_up = (x - lower) / resolution
return upper if np.random.random() < prob_up else lower
# Simulate gradient accumulation over many steps
resolution = 0.0078125 # BF16 ULP near 1.0 (= 2^(-7))
gradient = 0.001 # small gradient (< resolution/2)
n_steps = 10000
# Deterministic accumulation
weight_det = 1.0
for _ in range(n_steps):
weight_det = round_to_nearest(weight_det + gradient, resolution)
# Stochastic accumulation
np.random.seed(42)
weight_stoch = 1.0
for _ in range(n_steps):
weight_stoch = stochastic_round(weight_stoch + gradient, resolution)
expected = 1.0 + n_steps * gradient
print('Stochastic Rounding vs Deterministic (RNE)')
print('=' * 60)
print(f'BF16 resolution (ULP near 1.0): {resolution}')
print(f'Gradient per step: {gradient}')
print(f'Gradient / resolution = {gradient/resolution:.3f} (< 0.5 β always rounds to same value!)')
print(f'Number of steps: {n_steps:,}')
print(f'\nExpected final weight: {expected:.1f}')
print(f'Deterministic (RNE): {weight_det:.4f} (error: {abs(weight_det - expected):.4f})')
print(f'Stochastic rounding: {weight_stoch:.4f} (error: {abs(weight_stoch - expected):.4f})')
print(f'\nβ Deterministic: gradient is ALWAYS rounded to zero β weight NEVER updates!')
print(f'β Stochastic: gradient contributes probabilistically β correct on average')
print(f'\nβ This is why FP8 training on H100 uses stochastic rounding β without it,')
print(f' many gradient updates would be silently lost.')
# Statistical verification: E[SR(x)] = x
print(f'\n--- Bias verification ---')
np.random.seed(0)
x_test = 1.003 # between two BF16-representable values
sr_results = [stochastic_round(x_test, resolution) for _ in range(100000)]
print(f' x = {x_test}')
print(f' E[SR(x)] over 100K trials: {np.mean(sr_results):.6f} (should be β {x_test})')
print(f' RNE(x) = {round_to_nearest(x_test, resolution):.6f} (biased!)')
Code cell 46
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 11.2 Adam Optimizer Numerical Edge Cases
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Adam Optimizer β Numerical Failure Modes')
print('=' * 65)
# Failure mode 1: epsilon too small
print('\n--- Failure Mode 1: Ξ΅ too small ---')
eta = 1e-4 # learning rate
m_t = 1e-4 # momentum (typical small gradient)
for eps in [1e-8, 1e-6, 1e-4]:
for v_t in [1e-8, 1e-4, 1.0]:
update = eta * m_t / (np.sqrt(v_t) + eps)
print(f' Ξ΅={eps:.0e}, v_t={v_t:.0e} β update = {update:.6f}'
f'{" β HUGE!" if update > 0.01 else ""}')
print(f'\n When v_t β 0 and Ξ΅ = 1e-8: update β Ξ·Β·m_t/Ξ΅ = {eta * m_t / 1e-8:.0f}')
print(f' β This can cause catastrophic weight jumps!')
print(f' β Use Ξ΅ = 1e-6 or 1e-4 for BF16 training')
# Failure mode 2: bias correction amplification
print(f'\n--- Failure Mode 2: Bias correction early in training ---')
beta2 = 0.999
print(f' Bias correction factor 1/(1 - Ξ²β^t) for Ξ²β = {beta2}:')
for t in [1, 5, 10, 50, 100, 1000, 10000]:
correction = 1 / (1 - beta2**t)
print(f' t={t:>5}: 1/(1-Ξ²β^t) = {correction:>10.1f}Γ'
f'{" β 1000Γ amplification!" if correction > 100 else ""}')
print(f'\n β At t=1, v_t is amplified 1000Γ! This can cause overflow in low precision.')
print(f' β Learning rate warmup helps by keeping Ξ· small during high-amplification phase.')
Code cell 47
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 11.3 Attention Logit Growth β The Slow Training Killer
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
d_k = 128 # head dimension
print('Attention Logit Growth Simulation')
print('=' * 65)
print(f'Head dimension d_k = {d_k}')
print(f'Standard scaling: QK^T / βd_k = QK^T / {np.sqrt(d_k):.1f}')
# Simulate Q, K with increasing norms (as happens during training)
print(f'\n{"||Q||":>8} {"Max logit":>10} {"Softmax max":>13} {"Grad magnitude":>15} {"Risk":>12}')
print('-' * 65)
for qk_scale in [1.0, 2.0, 5.0, 10.0, 15.0, 20.0, 25.0]:
torch.manual_seed(42)
seq_len = 64
Q = torch.randn(1, seq_len, d_k) * qk_scale
K = torch.randn(1, seq_len, d_k) * qk_scale
# Attention logits
logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
max_logit = logits.max().item()
# Softmax
attn = F.softmax(logits, dim=-1)
max_attn = attn.max().item()
# Gradient magnitude (proxy: entropy of attention distribution)
entropy = -(attn * torch.log(attn + 1e-10)).sum(-1).mean().item()
grad_proxy = entropy / np.log(seq_len) # normalised entropy
# Risk assessment
if max_logit > 88:
risk = 'π₯ OVERFLOW'
elif max_attn > 0.999:
risk = 'β Near-1hot'
elif max_attn > 0.99:
risk = 'β Spiky'
else:
risk = 'β Safe'
print(f'{qk_scale:>8.1f} {max_logit:>10.1f} {max_attn:>13.6f} {grad_proxy:>15.6f} {risk:>12}')
print(f'\nβ As Q, K norms grow during training, attention becomes "spiky"')
print(f'β Spiky attention β vanishing gradients β learning stops for that head')
print(f'β Eventually max logit > 88 β exp() overflow β NaN β training crash')
print(f'\nMitigations:')
print(f' 1. QK-Norm: normalise Q and K before computing attention')
print(f' 2. Logit capping: clamp logits to [-50, 50] before softmax')
print(f' 3. Gradient clipping: clip global norm to prevent explosive updates')
else:
print('PyTorch required for attention logit demo')
12. Practical Guide & Real-World Applications
Code cell 49
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 12.1 Real PyTorch Quantisation Demo
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
print('PyTorch Quantisation β Production Code')
print('=' * 55)
# Build a small model
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Size before quantisation
orig_size = sum(p.numel() * p.element_size() for p in model.parameters())
# Dynamic quantisation (real PyTorch API)
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
# Compare outputs
x = torch.randn(1, 512)
with torch.no_grad():
y_orig = model(x)
y_quant = quantized_model(x)
diff = (y_orig - y_quant).abs().mean().item()
print(f'Original model: {orig_size:,} bytes')
print(f'Output difference (FP32 vs INT8): {diff:.6f}')
# Show quantised model structure
print(f'\nQuantised model layers:')
for name, module in quantized_model.named_modules():
if name:
print(f' {name}: {module.__class__.__name__}')
# Show production LLM quantisation code
print(f'\n--- LLM Quantisation in Practice ---')
print('''
# Method 1: bitsandbytes (easiest)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8b",
load_in_4bit=True, # β NF4 quantisation
bnb_4bit_compute_dtype=torch.bfloat16 # β BF16 compute
)
# Method 2: GPTQ (highest quality INT4)
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-3-8B-GPTQ",
device_map="auto"
)
# Method 3: AWQ (fastest INT4 inference)
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"TheBloke/Llama-3-8B-AWQ"
)
# Method 4: QLoRA fine-tuning
from peft import prepare_model_for_kbit_training, LoraConfig
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
''')
else:
print('PyTorch required for quantisation demo')
Code cell 50
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 12.2 KV Cache Memory Calculator
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def kv_cache_memory(n_layers: int, n_kv_heads: int, d_head: int,
seq_len: int, batch_size: int, bytes_per_element: float) -> float:
"""Calculate KV cache memory in GB."""
# 2 for K and V, Γ layers Γ heads Γ head_dim Γ seq_len Γ batch Γ bytes
return 2 * n_layers * n_kv_heads * d_head * seq_len * batch_size * bytes_per_element / 1e9
print('KV Cache Memory Analysis')
print('=' * 70)
# Common LLM configurations
models = {
'LLaMA-3 8B': {'n_layers': 32, 'n_kv_heads': 8, 'd_head': 128},
'LLaMA-3 70B': {'n_layers': 80, 'n_kv_heads': 8, 'd_head': 128},
'GPT-4 (est)': {'n_layers': 120, 'n_kv_heads': 16, 'd_head': 128},
}
for model_name, config in models.items():
print(f'\n--- {model_name} ---')
print(f' {config["n_layers"]}L Γ {config["n_kv_heads"]} KV heads Γ {config["d_head"]}d')
print(f' {"Seq Len":>10} {"BF16 (GB)":>10} {"INT8 (GB)":>10} {"INT4 (GB)":>10}')
print(f' {"":>10} {"-"*10} {"-"*10} {"-"*10}')
for seq_len in [2048, 8192, 32768, 131072]:
bf16 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=2)
int8 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=1)
int4 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=0.5)
print(f' {seq_len:>10,} {bf16:>10.2f} {int8:>10.2f} {int4:>10.2f}')
print(f'\nβ 70B model with 128K context: KV cache alone is 34 GB in BF16!')
print(f'β INT4 KV cache reduces this to 8.6 GB β critical for long-context serving')
print(f'β Per-channel INT8 KV quantisation is nearly lossless and should always be used')
Code cell 51
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 12.3 Per-Layer Quantisation Sensitivity Analysis
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_TORCH:
# Build a mini transformer-like model to test per-layer sensitivity
torch.manual_seed(42)
class MiniTransformerBlock(nn.Module):
def __init__(self, d=256):
super().__init__()
self.q_proj = nn.Linear(d, d)
self.k_proj = nn.Linear(d, d)
self.v_proj = nn.Linear(d, d)
self.out_proj = nn.Linear(d, d)
self.ffn_gate = nn.Linear(d, d * 4)
self.ffn_down = nn.Linear(d * 4, d)
self.norm = nn.LayerNorm(d)
def forward(self, x):
# Simplified attention (no actual attention pattern)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
attn_out = self.out_proj(v) # simplified
x = self.norm(x + attn_out)
# FFN
ffn = F.silu(self.ffn_gate(x))
ffn = self.ffn_down(ffn)
return self.norm(x + ffn)
model = MiniTransformerBlock(256)
x = torch.randn(1, 32, 256)
with torch.no_grad():
y_baseline = model(x)
# Test sensitivity: quantise each layer individually
print('Per-Layer Quantisation Sensitivity (INT4)')
print('=' * 60)
print(f'{"Layer":<15} {"Output Ξ (L2)":>14} {"Relative Ξ":>12} {"Sensitivity":>12}')
print('-' * 55)
for layer_name in ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'ffn_gate', 'ffn_down']:
# Clone model and quantise only this layer
test_model = MiniTransformerBlock(256)
test_model.load_state_dict(model.state_dict())
layer = getattr(test_model, layer_name)
with torch.no_grad():
w = layer.weight.data.numpy()
q, s = quantize_int4_symmetric(w.flatten())
w_deq = dequantize(q, s).reshape(w.shape)
layer.weight.data = torch.tensor(w_deq)
with torch.no_grad():
y_quant = test_model(x)
delta = torch.norm(y_baseline - y_quant).item()
relative = delta / torch.norm(y_baseline).item()
sensitivity = 'π΄ HIGH' if relative > 0.1 else ('π‘ MED' if relative > 0.01 else 'π’ LOW')
print(f'{layer_name:<15} {delta:>14.6f} {relative:>12.4%} {sensitivity:>12}')
print(f'\nβ ffn_down is the most sensitive layer (directly affects residual stream)')
print(f'β Strategy: keep ffn_down in INT8, quantise others to INT4')
else:
print('PyTorch required for per-layer sensitivity demo')
Summary: Key Takeaways
| Concept | Why It Matters | Action |
|---|---|---|
| IEEE 754 layout | Debug NaN, precision bugs | Inspect bits with struct.pack |
| Machine epsilon | Minimum useful learning rate | Check np.finfo() for your dtype |
| BF16 > FP16 | Same range as FP32, no overflow | Always use BF16 for training |
| Softmax stability | Prevents NaN in every attention layer | Always subtract max first |
| Log-sum-exp | Cross-entropy loss without overflow | Use torch.logsumexp() |
| RMSNorm > LayerNorm | No catastrophic cancellation | Preferred in 2024+ architectures |
| Per-group quantisation | Much better than per-tensor | Use group_size=64 or 128 |
| NF4 (QLoRA) | Optimal for Gaussian-distributed weights | load_in_4bit=True |
| SQNR: 6 dB/bit | Each bit doubles precision | Choose bit width by quality need |
| Hadamard rotation | Suppresses outliers for better quant | Used in QuIP, QuaRot |
| FP32 master weights | Prevent training divergence | NEVER skip this |
| Stochastic rounding | Enables sub-8-bit training | Used in FP8 on H100 |
| Memory bandwidth | The true inference bottleneck | Smaller = faster (linear!) |
| Error propagation | Quantisation error grows through layers | Keep first/last layers higher precision |
| KV cache quantisation | Essential for long-context serving | INT8 per-channel is nearly lossless |
Best Practices
- Training: BF16 forward/backward + FP32 master weights + FP32 Adam states
- Inference (balanced): INT8 weights with SmoothQuant β nearly lossless
- Inference (compressed): INT4 with GPTQ or AWQ + per-group quantisation
- Fine-tuning (budget): QLoRA with NF4 base + BF16 LoRA adapters
- Long context: INT8 KV cache per-channel quantisation
- Never compare floats with
==β usenp.isclose()ortorch.allclose() - Never compute
log(sum(exp(x)))β usetorch.logsumexp()
Practice Questions
- IEEE 754: Encode -13.625 in FP32 binary. Show sign, exponent, mantissa.
- Precision: Why does
1.0 + 1e-8 == 1.0in FP32 but not FP64? - BF16 vs FP16: A gradient of 70000.0 β which format handles it without overflow?
- Softmax: Given logits [89.0, 89.1, 89.2], what happens with naive softmax in FP32?
- SQNR: Calculate the theoretical SQNR for 3-bit uniform quantisation.
- Memory: How much VRAM for a 13B model in INT4 inference? In BF16 training?
- Stochastic rounding: If gradient=0.003 and BF16 ULP=0.0078, what does RNE do? SR?
- KV Cache: Calculate the KV cache size for a 32-layer, 32-head, d=128 model at 32K context.
- Hadamard: Why does rotating weights before quantisation reduce error?
- Error propagation: Why are the first and last layers of an LLM more sensitive to quantisation?
Next Steps
- exercises.ipynb: Extended practice problems with detailed solutions
- notes.md: Complete 17-section mathematical reference
- Continue to: 02-Sets-and-Logic