Theory NotebookMath for LLMs

Graph Neural Networks

Graph Theory / Graph Neural Networks

Run notebook
Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Graph Neural Networks — Theory Notebook

"A graph neural network is a machine that reads a graph and learns by listening to its neighbors."

Interactive derivations covering: GCN propagation, over-smoothing dynamics, WL color refinement, GAT attention, GIN expressiveness, graph pooling, positional encodings, and training at scale.

Companion: notes.md | exercises.ipynb

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import scipy.linalg as la
import scipy.sparse as sp

try:
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    try:
        import seaborn as sns
        sns.set_theme(style='whitegrid', palette='colorblind')
        HAS_SNS = True
    except ImportError:
        plt.style.use('seaborn-v0_8-whitegrid')
        HAS_SNS = False
    mpl.rcParams.update({
        'figure.figsize': (10, 6), 'figure.dpi': 120,
        'font.size': 13, 'axes.titlesize': 15, 'axes.labelsize': 13,
        'xtick.labelsize': 11, 'ytick.labelsize': 11,
        'legend.fontsize': 11, 'lines.linewidth': 2.0,
        'axes.spines.top': False, 'axes.spines.right': False,
    })
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

COLORS = {
    'primary':   '#0077BB',
    'secondary': '#EE7733',
    'tertiary':  '#009988',
    'error':     '#CC3311',
    'neutral':   '#555555',
    'highlight': '#EE3377',
}

np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)
print('Setup complete.')

1. Graph Data Structures

We build graphs as adjacency matrices and edge lists, then verify the key representations used in GNNs.

Code cell 5

# === 1. Graph Data Structures ===

def make_graph(n, edges):
    """Build adjacency matrix from edge list (undirected)."""
    A = np.zeros((n, n))
    for u, v in edges:
        A[u, v] = 1
        A[v, u] = 1
    return A

# Karate-club-like small graph: 8 nodes
n = 8
edges = [(0,1),(0,2),(0,3),(1,4),(2,4),(3,5),(4,6),(5,6),(6,7),(1,2)]
A = make_graph(n, edges)
degrees = A.sum(axis=1)

print('Adjacency matrix A:')
print(A.astype(int))
print(f'\nDegrees: {degrees.astype(int)}')
print(f'Edges: {int(A.sum()//2)}')

# Node feature matrix: one-hot encoding
X = np.eye(n)
print(f'\nNode features X shape: {X.shape}')

2. GCN Propagation Matrix

Derive A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} and study its spectral properties.

Code cell 7

# === 2. GCN Propagation Matrix ===

def gcn_propagation_matrix(A):
    """Compute A_hat = D_tilde^{-1/2} A_tilde D_tilde^{-1/2}."""
    n = A.shape[0]
    A_tilde = A + np.eye(n)          # Add self-loops
    D_tilde = np.diag(A_tilde.sum(axis=1))  # Degree matrix
    D_inv_sqrt = np.diag(1.0 / np.sqrt(np.diag(D_tilde)))
    A_hat = D_inv_sqrt @ A_tilde @ D_inv_sqrt
    return A_hat, A_tilde, D_tilde

A_hat, A_tilde, D_tilde = gcn_propagation_matrix(A)

# Verify: eigenvalues should be in (-1, 1]
eigvals = np.linalg.eigvalsh(A_hat)
print('Eigenvalues of A_hat:')
print(np.sort(eigvals)[::-1].round(4))
print(f'\nMax eigenvalue: {eigvals.max():.6f} (should be <= 1.0)')
print(f'Min eigenvalue: {eigvals.min():.6f} (should be > -1.0)')

ok_max = eigvals.max() <= 1.0 + 1e-10
ok_min = eigvals.min() > -1.0 - 1e-10
print(f'\nPASS: eigenvalues in (-1,1]' if (ok_max and ok_min) else 'FAIL: eigenvalue out of range')

Code cell 8

# === 2.1 Two-Layer GCN Forward Pass ===

def relu(x):
    return np.maximum(0, x)

def softmax(x):
    e = np.exp(x - x.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

np.random.seed(42)
d0, d1, d2 = n, 4, 3  # input_dim, hidden_dim, output_dim
W0 = np.random.randn(d0, d1) * 0.5
W1 = np.random.randn(d1, d2) * 0.5

# Two-layer GCN
H1 = relu(A_hat @ X @ W0)      # Shape: (n, d1)
H2 = softmax(A_hat @ H1 @ W1)  # Shape: (n, d2)

print('H1 (hidden layer):')
print(H1.round(4))
print(f'\nH2 (output, softmax): shape {H2.shape}')
print(H2.round(4))
print(f'\nRow sums of H2 (should be 1.0): {H2.sum(axis=1).round(6)}')

ok = np.allclose(H2.sum(axis=1), 1.0)
print(f"{'PASS' if ok else 'FAIL'} - softmax rows sum to 1")

2.2 Permutation Equivariance Verification

Verify that GCN satisfies f(PX,PAP)=Pf(X,A)f(PX, PAP^\top) = P \cdot f(X, A) for any permutation PP.

Code cell 10

# === 2.2 Permutation Equivariance ===

# Create a random permutation
perm = np.random.permutation(n)
P = np.eye(n)[perm]  # Permutation matrix

# Permuted graph
A_perm = P @ A @ P.T
X_perm = P @ X

# GCN on permuted graph
A_hat_perm, _, _ = gcn_propagation_matrix(A_perm)
H1_perm = relu(A_hat_perm @ X_perm @ W0)
H2_perm = softmax(A_hat_perm @ H1_perm @ W1)

# Check: H2_perm should equal P @ H2
H2_expected = P @ H2
diff = np.abs(H2_perm - H2_expected).max()
print(f'Max difference |H2_perm - P@H2|: {diff:.2e}')

ok = np.allclose(H2_perm, H2_expected, atol=1e-10)
print(f"{'PASS' if ok else 'FAIL'} - GCN is permutation equivariant")

3. Over-Smoothing: Dirichlet Energy Decay

Visualize how the Dirichlet energy E(H)=tr(HLH)E(H) = \operatorname{tr}(H^\top L H) decays to zero as GCN depth increases.

Code cell 12

# === 3. Over-Smoothing Analysis ===

# Unnormalized Laplacian L = D - A
L = np.diag(A.sum(axis=1)) - A

def dirichlet_energy(H, L):
    return np.trace(H.T @ L @ H)

# Apply pure smoothing (A_hat, no weight matrix) for increasing depths
H_current = X.copy()  # Start from identity features
depths = list(range(0, 33))
energies = []

H_iter = X.copy()
for d in range(depths[-1] + 1):
    energies.append(dirichlet_energy(H_iter, L))
    H_iter = A_hat @ H_iter  # One smoothing step

print('Dirichlet energy at selected depths:')
for d in [0, 1, 2, 4, 8, 16, 32]:
    print(f'  Depth {d:2d}: E = {energies[d]:.6f}')

ok = energies[32] < energies[0] * 0.01
print(f"\n{'PASS' if ok else 'FAIL'} - energy decays by 99%+ at depth 32")

Code cell 13

# === 3.1 Over-Smoothing Visualization ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: Dirichlet energy vs depth
    ax = axes[0]
    ax.semilogy(depths, energies, color=COLORS['primary'], linewidth=2)
    ax.axvline(2, color=COLORS['neutral'], linestyle='--', alpha=0.7, label='Depth 2 (typical GCN)')
    ax.set_title('Over-smoothing: Dirichlet energy vs depth')
    ax.set_xlabel('Number of propagation steps')
    ax.set_ylabel('Dirichlet energy $E(H)$ (log scale)')
    ax.legend()

    # Right: Node representations at depth 0 vs 16
    ax = axes[1]
    H_shallow = np.linalg.matrix_power(A_hat, 2) @ X
    H_deep = np.linalg.matrix_power(A_hat, 16) @ X
    # Project to 2D via first two columns
    ax.scatter(H_shallow[:, 0], H_shallow[:, 1],
               color=COLORS['primary'], s=120, label='Depth 2', zorder=5)
    ax.scatter(H_deep[:, 0], H_deep[:, 1],
               color=COLORS['error'], s=120, marker='x', linewidths=2,
               label='Depth 16 (over-smoothed)', zorder=5)
    ax.set_title('Node representations: shallow vs deep')
    ax.set_xlabel('Feature dim 1')
    ax.set_ylabel('Feature dim 2')
    ax.legend()

    fig.tight_layout()
    plt.show()
    print('Plot displayed.')

Code cell 14

# === 3.2 Convergence to Stationary Distribution ===

# Compute stationary distribution: pi_v = d_tilde_v / sum(d_tilde)
d_tilde = A_hat.sum(axis=1)  # row sums of A_hat (should be ~1 for normalized)
# For random walk: stationary = degree / vol(G)
A_rw = (A_tilde.T / A_tilde.sum(axis=1)).T  # row-normalized (random walk)
pi = A_tilde.sum(axis=1) / A_tilde.sum()

# After many steps, each row of A_rw^k should converge to pi
A_rw_k = np.linalg.matrix_power(A_rw, 50)
print('First 3 rows of A_rw^50 (should all be approx. pi):')
print(A_rw_k[:3].round(6))
print(f'\nStationary distribution pi:')
print(pi.round(6))

ok = np.allclose(A_rw_k[:3], pi[np.newaxis, :], atol=1e-4)
print(f"\n{'PASS' if ok else 'FAIL'} - rows converge to stationary distribution")

4. Weisfeiler-Leman Color Refinement

Implement 1-WL and test it on pairs of graphs to understand GNN expressiveness limits.

Code cell 16

# === 4. Weisfeiler-Leman Color Refinement ===

from collections import Counter

def wl_refinement(adj, max_iter=10):
    """Run 1-WL color refinement. adj: dict {node: set_of_neighbors}."""
    n = len(adj)
    # Initial: all same color
    colors = {v: 0 for v in adj}
    color_history = [dict(colors)]

    for t in range(max_iter):
        new_colors = {}
        color_map = {}  # (color, sorted nbr colors) -> new color
        counter = [0]

        def get_color(key):
            if key not in color_map:
                color_map[key] = counter[0]
                counter[0] += 1
            return color_map[key]

        for v in adj:
            nbr_colors = tuple(sorted(colors[u] for u in adj[v]))
            key = (colors[v], nbr_colors)
            new_colors[v] = get_color(key)

        if new_colors == colors:
            break
        colors = new_colors
        color_history.append(dict(colors))

    return colors, color_history

def adj_from_edges(n, edges):
    adj = {i: set() for i in range(n)}
    for u, v in edges:
        adj[u].add(v)
        adj[v].add(u)
    return adj

# C6 (6-cycle) vs C3+C3 (two disjoint triangles)
C6_adj = adj_from_edges(6, [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0)])
C3C3_adj = adj_from_edges(6, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3)])

c_C6, hist_C6 = wl_refinement(C6_adj)
c_C3C3, hist_C3C3 = wl_refinement(C3C3_adj)

print('C6 final colors:', c_C6)
print('C3+C3 final colors:', c_C3C3)
print(f'C6 color histogram: {dict(Counter(c_C6.values()))}')
print(f'C3+C3 color histogram: {dict(Counter(c_C3C3.values()))}')

# WL says they are isomorphic if histograms match
wl_distinguishes = Counter(c_C6.values()) != Counter(c_C3C3.values())
print(f'\n1-WL distinguishes C6 vs C3+C3: {wl_distinguishes}')
print('(Expected: False — 1-WL cannot distinguish them)')

Code cell 17

# === 4.1 WL on Distinguishable Graphs ===

# K_{1,3} (star with 3 leaves) vs P4 (path of 4 nodes)
K13_adj = adj_from_edges(4, [(0,1),(0,2),(0,3)])  # node 0 is hub
P4_adj = adj_from_edges(4, [(0,1),(1,2),(2,3)])   # path

c_K13, _ = wl_refinement(K13_adj)
c_P4, _ = wl_refinement(P4_adj)

print('K_{1,3} colors:', c_K13)
print('P4 colors:', c_P4)

hist_K13 = Counter(c_K13.values())
hist_P4 = Counter(c_P4.values())
distinguishes = hist_K13 != hist_P4
print(f'\n1-WL distinguishes K_{{1,3}} vs P4: {distinguishes}')
print(f'K_{{1,3}} histogram: {dict(hist_K13)}')
print(f'P4 histogram: {dict(hist_P4)}')

ok = distinguishes == True
print(f"\n{'PASS' if ok else 'FAIL'} - WL correctly distinguishes K_{{1,3}} vs P4")

5. Graph Attention Network (GAT)

Implement GAT and GATv2 attention from scratch and visualize attention patterns.

Code cell 19

# === 5. GAT Attention ===

np.random.seed(42)

def gat_layer(H, A, W, a, leaky_slope=0.2):
    """
    GAT layer (single head).
    H: (n, d) node features
    A: (n, n) adjacency (with self-loops)
    W: (d, d') linear transform
    a: (2*d',) attention vector
    Returns: H_new (n, d'), alpha (n, n)
    """
    n, d = H.shape
    Z = H @ W  # (n, d')
    d_prime = Z.shape[1]

    # Compute attention scores for all edges
    a_left = a[:d_prime]    # (d',)
    a_right = a[d_prime:]   # (d',)

    # e_ij = LeakyReLU(a^T [z_i || z_j])
    score_left = Z @ a_left     # (n,)
    score_right = Z @ a_right   # (n,)

    # e[i,j] = score_left[i] + score_right[j] (for neighbors)
    E = score_left[:, np.newaxis] + score_right[np.newaxis, :]  # (n, n)

    # LeakyReLU
    E = np.where(E >= 0, E, leaky_slope * E)

    # Mask: only attend to neighbors (where A > 0)
    mask = (A > 0).astype(float)
    E_masked = np.where(mask > 0, E, -1e9)

    # Softmax over neighbors
    E_exp = np.exp(E_masked - E_masked.max(axis=1, keepdims=True))
    alpha = E_exp * mask
    alpha = alpha / (alpha.sum(axis=1, keepdims=True) + 1e-10)

    # Aggregate
    H_new = np.tanh(alpha @ Z)  # (n, d')
    return H_new, alpha

# Setup
n_small = 6
edges_small = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0),(0,2),(1,4)]
A_small = make_graph(n_small, edges_small)
A_small_sl = A_small + np.eye(n_small)  # with self-loops

d_in, d_out = 4, 3
H0 = np.random.randn(n_small, d_in)
W_gat = np.random.randn(d_in, d_out) * 0.5
a_gat = np.random.randn(2 * d_out) * 0.5

H_gat, alpha = gat_layer(H0, A_small_sl, W_gat, a_gat)
print('GAT output shape:', H_gat.shape)
print('\nAttention matrix alpha (rows = target nodes):')
print(alpha.round(3))
print('\nRow sums (should be 1 for each node):', alpha.sum(axis=1).round(4))

Code cell 20

# === 5.1 GATv2 vs GAT: Static vs Dynamic Attention ===

def gatv2_layer(H, A, W, a, leaky_slope=0.2):
    """GATv2: dynamic attention e_ij = a^T LeakyReLU(W[h_i || h_j])."""
    n, d = H.shape
    n_out = W.shape[1]

    # Concatenate all pairs (broadcast)
    # h_i repeated across columns, h_j repeated across rows
    H_i = np.repeat(H[:, np.newaxis, :], n, axis=1)  # (n, n, d)
    H_j = np.repeat(H[np.newaxis, :, :], n, axis=0)  # (n, n, d)
    H_cat = np.concatenate([H_i, H_j], axis=2)       # (n, n, 2d)

    # W in R^{2d x d_out} maps concat to features
    # For simplicity, use W and a directly on concatenated
    scores = H_cat @ W  # (n, n, n_out) -- W should be (2d, n_out)
    # Apply LeakyReLU
    scores = np.where(scores >= 0, scores, leaky_slope * scores)
    # Dot with a: (n, n, n_out) . (n_out,) -> (n, n)
    E = scores @ a[:n_out]  # (n, n)

    mask = (A > 0).astype(float)
    E_masked = np.where(mask > 0, E, -1e9)
    E_exp = np.exp(E_masked - E_masked.max(axis=1, keepdims=True))
    alpha_v2 = E_exp * mask
    alpha_v2 = alpha_v2 / (alpha_v2.sum(axis=1, keepdims=True) + 1e-10)

    W_feat = np.random.randn(d, n_out) * 0.3
    H_new = np.tanh(alpha_v2 @ (H @ W_feat))
    return H_new, alpha_v2

W_v2 = np.random.randn(2 * d_in, d_out) * 0.3
a_v2 = np.random.randn(d_out) * 0.3
H_v2, alpha_v2 = gatv2_layer(H0, A_small_sl, W_v2, a_v2)

# Compare attention patterns for node 0
print('Node 0 neighbor attention weights:')
nbrs_0 = [i for i in range(n_small) if A_small_sl[0, i] > 0]
print(f'  GAT  neighbors {nbrs_0}: {alpha[0, nbrs_0].round(4)}')
print(f'  GATv2 neighbors {nbrs_0}: {alpha_v2[0, nbrs_0].round(4)}')
print('\n(Different weights = dynamic attention working in GATv2)')

Code cell 21

# === 5.2 Visualize Attention Matrix ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    for ax, mat, title in zip(axes,
                               [alpha, alpha_v2],
                               ['GAT (static attention)', 'GATv2 (dynamic attention)']):
        im = ax.imshow(mat, cmap='viridis', vmin=0, vmax=mat.max())
        fig.colorbar(im, ax=ax, label='Attention weight $\\alpha_{uv}$')
        ax.set_title(title)
        ax.set_xlabel('Source node $u$')
        ax.set_ylabel('Target node $v$')
        ax.set_xticks(range(n_small))
        ax.set_yticks(range(n_small))

    fig.suptitle('Attention patterns: GAT vs GATv2', fontsize=15)
    fig.tight_layout()
    plt.show()
    print('Attention heatmaps displayed.')

6. GIN: Graph Isomorphism Network

Demonstrate that sum aggregation (GIN) distinguishes graphs that mean aggregation (GCN) cannot.

Code cell 23

# === 6. Sum vs Mean Aggregation Expressiveness ===

# Multiset counterexample: {1,1} vs {1,1,1}
M1 = np.array([1.0, 1.0])
M2 = np.array([1.0, 1.0, 1.0])

print('Multiset expressiveness comparison:')
print(f'  M1 = {M1}, M2 = {M2}')
print(f'  sum(M1) = {M1.sum():.1f}, sum(M2) = {M2.sum():.1f}  -> distinguishable')
print(f'  mean(M1)= {M1.mean():.1f}, mean(M2)= {M2.mean():.1f} -> INDISTINGUISHABLE')
print(f'  max(M1) = {M1.max():.1f},  max(M2) = {M2.max():.1f}  -> INDISTINGUISHABLE')

# Graph-level: C6 vs C3+C3 with degree as feature
# C6: all nodes degree 2
# C3+C3: all nodes degree 2
# With one-hot initial features based on degree:
print('\nGIN sum-readout on C6 vs C3+C3 (with degree feature):')

def gin_layer(H, adj_dict, eps=0.0):
    """One GIN layer: MLP((1+eps)*h_v + sum_{u in N(v)} h_u)."""
    H_new = np.zeros_like(H)
    for v in adj_dict:
        nbr_sum = sum(H[u] for u in adj_dict[v])
        H_new[v] = (1 + eps) * H[v] + nbr_sum
    return np.tanh(H_new)  # MLP approximated by tanh

# Initialize with degree as feature
C6_feats = np.array([[d] for d in [2,2,2,2,2,2]], dtype=float)
C3C3_feats = np.array([[d] for d in [2,2,2,2,2,2]], dtype=float)

H_C6 = gin_layer(C6_feats, C6_adj)
H_C3C3 = gin_layer(C3C3_feats, C3C3_adj)

print(f'  C6 after 1 GIN layer (sum readout): {H_C6.sum():.6f}')
print(f'  C3C3 after 1 GIN layer (sum readout): {H_C3C3.sum():.6f}')

# Add second layer
H_C6_2 = gin_layer(H_C6, C6_adj)
H_C3C3_2 = gin_layer(H_C3C3, C3C3_adj)
print(f'\n  C6 after 2 GIN layers (sum readout): {H_C6_2.sum():.6f}')
print(f'  C3C3 after 2 GIN layers (sum readout): {H_C3C3_2.sum():.6f}')
print('\n(With uniform initial features, GIN still cannot distinguish them!')
print('WL bound: both have identical 1-WL color multisets at each iteration.')

Code cell 24

# === 6.1 Adding Structural Features Breaks the Tie ===

# Add RWSE-like feature: self-loop probability at step 2
def rwse(adj_matrix, p=2):
    """Random walk structural encoding: return prob at step p."""
    n = adj_matrix.shape[0]
    D_inv = np.diag(1.0 / (adj_matrix.sum(axis=1) + 1e-10))
    P = D_inv @ adj_matrix  # Row-stochastic
    Pp = np.linalg.matrix_power(P, p)
    return np.diag(Pp)

# Build adjacency matrices for C6 and C3+C3
C6_A = make_graph(6, [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0)])
C3C3_A = make_graph(6, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3)])

rwse_C6 = rwse(C6_A, p=3)
rwse_C3C3 = rwse(C3C3_A, p=3)

print('RWSE (3-step return probability) for C6 and C3+C3:')
print(f'  C6 (all values): {rwse_C6.round(4)}')
print(f'  C3+C3 (all values): {rwse_C3C3.round(4)}')
print(f'\n  C6 unique RWSE: {set(rwse_C6.round(4))}')
print(f'  C3+C3 unique RWSE: {set(rwse_C3C3.round(4))}')

distinguishable = not np.allclose(sorted(rwse_C6), sorted(rwse_C3C3))
print(f'\nRWSE distinguishes C6 vs C3+C3: {distinguishable}')
print('PASS: structural features (RWSE) break the 1-WL tie' if distinguishable
      else 'FAIL: RWSE cannot distinguish them')

7. GraphSAGE: Inductive Learning and Neighbor Sampling

Implement neighbor sampling and the GraphSAGE aggregation step.

Code cell 26

# === 7. GraphSAGE Neighbor Sampling ===

np.random.seed(42)

def build_random_graph(n, p_edge=0.15):
    """Erdos-Renyi G(n,p) graph."""
    adj = {i: set() for i in range(n)}
    for i in range(n):
        for j in range(i+1, n):
            if np.random.rand() < p_edge:
                adj[i].add(j)
                adj[j].add(i)
    return adj

def sample_neighbors(adj, node, k):
    """Sample k neighbors of node uniformly (with replacement if needed)."""
    nbrs = list(adj[node])
    if len(nbrs) == 0:
        return []
    if len(nbrs) <= k:
        return nbrs
    return list(np.random.choice(nbrs, k, replace=False))

def sage_mean_layer(H, adj, target_nodes, S=10):
    """
    GraphSAGE mean aggregation layer.
    Returns embeddings only for target_nodes.
    """
    d = H.shape[1]
    W = np.random.randn(2*d, d) * 0.3
    H_new = np.zeros((len(target_nodes), d))

    for idx, v in enumerate(target_nodes):
        sampled = sample_neighbors(adj, v, S)
        if sampled:
            nbr_mean = H[sampled].mean(axis=0)
        else:
            nbr_mean = np.zeros(d)
        concat = np.concatenate([H[v], nbr_mean])
        H_new[idx] = np.tanh(W.T @ concat)

    return H_new

n_large = 100
adj_large = build_random_graph(n_large, p_edge=0.08)
X_large = np.random.randn(n_large, 8)

# Compute embeddings for a mini-batch of 10 target nodes
batch = list(range(10))
H_sage = sage_mean_layer(X_large, adj_large, batch, S=5)

degrees = [len(adj_large[v]) for v in range(n_large)]
print(f'Graph: {n_large} nodes, avg degree {np.mean(degrees):.1f}')
print(f'Mini-batch size: {len(batch)} nodes')
print(f'GraphSAGE output shape: {H_sage.shape}')
print(f'Output norms: {np.linalg.norm(H_sage, axis=1).round(4)}')

ok = H_sage.shape == (len(batch), 8)
print(f"\n{'PASS' if ok else 'FAIL'} - GraphSAGE output has correct shape")

Code cell 27

# === 7.1 Inductive Inference: New Node ===

# Simulate adding a new node to the graph
# New node features + connections to some existing nodes
new_node_features = np.random.randn(8)
new_node_neighbors = [0, 5, 12, 23]  # Indices of existing nodes

# Compute embedding for new node using LEARNED aggregation function
# (In practice, W would be trained — here we just demonstrate the pipeline)
d = 8
W_inductive = np.random.randn(2*d, d) * 0.3

# Aggregate neighbor features
nbr_feats = X_large[new_node_neighbors]
nbr_mean = nbr_feats.mean(axis=0)
concat = np.concatenate([new_node_features, nbr_mean])
new_node_embedding = np.tanh(W_inductive.T @ concat)

print('Inductive embedding for new node:')
print(f'  Features: {new_node_features[:3].round(3)}...')
print(f'  Neighbors: {new_node_neighbors}')
print(f'  Embedding: {new_node_embedding[:4].round(4)}...')
print(f'  Embedding norm: {np.linalg.norm(new_node_embedding):.4f}')
print('\nKey insight: embedding computed WITHOUT retraining the model!')

8. Graph Pooling Methods

Compare global pooling strategies and implement DiffPool soft cluster assignment.

Code cell 29

# === 8. Global Pooling Methods ===

np.random.seed(42)

# Simulate node embeddings for a small graph
n_mol = 12  # 12-atom molecule
d_emb = 8
H_mol = np.random.randn(n_mol, d_emb)

# Global sum pooling
h_sum = H_mol.sum(axis=0)
# Global mean pooling
h_mean = H_mol.mean(axis=0)
# Global max pooling
h_max = H_mol.max(axis=0)

print(f'Node embeddings shape: {H_mol.shape}')
print(f'\nSum pooling  -> norm: {np.linalg.norm(h_sum):.4f}')
print(f'Mean pooling -> norm: {np.linalg.norm(h_mean):.4f}')
print(f'Max pooling  -> norm: {np.linalg.norm(h_max):.4f}')
print(f'\nSum vs Mean ratio: {np.linalg.norm(h_sum) / np.linalg.norm(h_mean):.4f}')
print(f'(Expected: ~sqrt(n) = {np.sqrt(n_mol):.4f} for random embeddings)')

# Expressiveness: two graphs with same avg but different sum
G1_nodes = np.ones((4, 2))         # 4 nodes, features all 1
G2_nodes = np.ones((8, 2))         # 8 nodes, features all 1
print(f'\nExpressiveness demo (all-ones graphs):')
print(f'  G1 (4 nodes) mean: {G1_nodes.mean(axis=0)}, sum: {G1_nodes.sum(axis=0)}')
print(f'  G2 (8 nodes) mean: {G2_nodes.mean(axis=0)}, sum: {G2_nodes.sum(axis=0)}')
print('Mean cannot distinguish G1 from G2; Sum can.')

Code cell 30

# === 8.1 DiffPool Soft Assignment ===

np.random.seed(42)

def diffpool_step(A, H, k):
    """
    Simplified DiffPool: compute soft cluster assignments.
    A: (n, n) adjacency
    H: (n, d) node embeddings
    k: target number of clusters
    Returns: S (n, k), H_pooled (k, d), A_pooled (k, k)
    """
    n, d = H.shape
    # Soft assignment: S = softmax(H @ W_pool)
    W_pool = np.random.randn(d, k) * 0.5
    S_logits = H @ W_pool     # (n, k)
    S = np.exp(S_logits - S_logits.max(axis=1, keepdims=True))
    S = S / S.sum(axis=1, keepdims=True)  # (n, k) - soft assignments

    # Coarsened embeddings: H_pooled = S^T @ H
    H_pooled = S.T @ H   # (k, d)

    # Coarsened adjacency: A_pooled = S^T @ A @ S
    A_pooled = S.T @ A @ S  # (k, k)

    # Auxiliary losses
    lp_loss = np.linalg.norm(A - S @ S.T, 'fro')  # Link prediction
    ent_loss = -(S * np.log(S + 1e-10)).sum() / n  # Entropy (min for crisp)

    return S, H_pooled, A_pooled, lp_loss, ent_loss

# Run on our small graph
A_hat_dense = A_hat.copy()
H_demo = np.random.randn(n, 5)  # n=8 nodes from earlier
k_clusters = 3

S, H_pool, A_pool, lp, ent = diffpool_step(A_hat_dense, H_demo, k_clusters)

print(f'DiffPool: {n} nodes -> {k_clusters} clusters')
print(f'Soft assignment matrix S shape: {S.shape}')
print(f'S row sums (should be 1): {S.sum(axis=1).round(4)}')
print(f'H_pooled shape: {H_pool.shape}')
print(f'A_pooled shape: {A_pool.shape}')
print(f'\nAuxiliary losses:')
print(f'  Link prediction loss: {lp:.4f}')
print(f'  Entropy loss: {ent:.4f} (lower = crisper assignments)')

9. Positional Encodings for Graph Transformers

Compute Laplacian PE (LapPE) and Random Walk SE (RWSE) and visualize them.

Code cell 32

# === 9. Laplacian Positional Encodings ===

np.random.seed(42)

def compute_lapPE(A, k=4):
    """
    Compute first k non-trivial eigenvectors of normalized Laplacian.
    Returns: eigvecs (n, k), eigvals (k,)
    """
    n = A.shape[0]
    D = np.diag(A.sum(axis=1))
    D_inv_sqrt = np.diag(1.0 / (np.sqrt(A.sum(axis=1)) + 1e-10))
    L_sym = np.eye(n) - D_inv_sqrt @ A @ D_inv_sqrt

    eigvals, eigvecs = np.linalg.eigh(L_sym)
    # Skip first eigvec (constant, eigenvalue ~0)
    # Take next k eigvecs
    return eigvecs[:, 1:k+1], eigvals[1:k+1]

def compute_rwse(A, max_p=5):
    """RWSE: landing probability at each step p."""
    n = A.shape[0]
    D_inv = np.diag(1.0 / (A.sum(axis=1) + 1e-10))
    P = D_inv @ A  # Row-stochastic
    rwse = np.zeros((n, max_p))
    Pp = np.eye(n)
    for p in range(1, max_p+1):
        Pp = Pp @ P
        rwse[:, p-1] = np.diag(Pp)
    return rwse

# Use larger graph for interesting PEs
# Build a ring graph (C_12)
n_ring = 12
ring_edges = [(i, (i+1) % n_ring) for i in range(n_ring)]
A_ring = make_graph(n_ring, ring_edges)

lapPE, lapEigvals = compute_lapPE(A_ring, k=4)
rwse_ring = compute_rwse(A_ring, max_p=5)

print(f'Graph: C_{n_ring} (ring with {n_ring} nodes)')
print(f'LapPE shape: {lapPE.shape} (n x k)')
print(f'First 4 non-trivial Laplacian eigenvalues: {lapEigvals.round(4)}')
print(f'\nLapPE for first 4 nodes (each row = position encoding):')
print(lapPE[:4].round(4))
print(f'\nRWSE for first 4 nodes (5-step return probs):')
print(rwse_ring[:4].round(4))

Code cell 33

# === 9.1 Visualize LapPE as 'Coordinates' ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # Left: LapPE eigvec 1 vs eigvec 2 for ring graph
    ax = axes[0]
    x_coord = lapPE[:, 0]
    y_coord = lapPE[:, 1]
    sc = ax.scatter(x_coord, y_coord,
                    c=np.arange(n_ring), cmap='plasma', s=120, zorder=5)
    for i in range(n_ring):
        ax.annotate(str(i), (x_coord[i]+0.01, y_coord[i]+0.01), fontsize=10)
    plt.colorbar(sc, ax=ax, label='Node index')
    ax.set_title(f'LapPE: $C_{{12}}$ ring — eigvec 1 vs 2')
    ax.set_xlabel('LapPE dim 1 ($\\mathbf{u}_2$)')
    ax.set_ylabel('LapPE dim 2 ($\\mathbf{u}_3$)')
    ax.set_aspect('equal')

    # Right: RWSE for each node
    ax = axes[1]
    im = ax.imshow(rwse_ring.T, cmap='viridis', aspect='auto')
    plt.colorbar(im, ax=ax, label='Return probability')
    ax.set_title('RWSE for all nodes (C_{12} ring)')
    ax.set_xlabel('Node index')
    ax.set_ylabel('Walk length $p$')
    ax.set_yticks(range(5))
    ax.set_yticklabels([f'p={p+1}' for p in range(5)])

    fig.tight_layout()
    plt.show()
    print('LapPE and RWSE visualizations displayed.')

10. Over-Squashing: Jacobian Analysis

Visualize how the GCN Jacobian hv[k]/xu\partial \mathbf{h}_v^{[k]} / \partial \mathbf{x}_u decays with distance for bottleneck graphs.

Code cell 35

# === 10. Over-Squashing: Jacobian via A_hat powers ===

np.random.seed(42)

# Build a 'dumbbell' graph: two cliques connected by a single bridge edge
def make_dumbbell(k1=5, k2=5):
    """Two k-cliques connected by a bridge (bottleneck)."""
    n = k1 + k2
    A = np.zeros((n, n))
    # Clique 1: nodes 0..k1-1
    for i in range(k1):
        for j in range(i+1, k1):
            A[i,j] = A[j,i] = 1
    # Clique 2: nodes k1..n-1
    for i in range(k1, n):
        for j in range(i+1, n):
            A[i,j] = A[j,i] = 1
    # Bridge: last node of clique 1 to first of clique 2
    A[k1-1, k1] = A[k1, k1-1] = 1
    return A

A_db = make_dumbbell(k1=5, k2=5)
A_hat_db, _, _ = gcn_propagation_matrix(A_db)
n_db = A_db.shape[0]

# Jacobian proxy: (A_hat^k)[v, u] = effective influence of u on v at depth k
depths_sq = [1, 2, 3, 4, 5]

# Nodes of interest: node 0 (clique 1) receiving info from node 9 (clique 2)
v_target = 0
u_source = n_db - 1
print(f'Dumbbell graph: {n_db} nodes, bridge at ({4},{5})')
print(f'Tracking influence of node {u_source} on node {v_target}:')
print()

A_hat_power = np.eye(n_db)
for k in depths_sq:
    A_hat_power = A_hat_power @ A_hat_db
    influence = A_hat_power[v_target, u_source]
    print(f'  Depth {k}: (A_hat^{k})[{v_target},{u_source}] = {influence:.6f}')

print()
print('Over-squashing: cross-bottleneck influence exponentially small.')
ok = A_hat_power[v_target, u_source] < 0.01
print(f"{'PASS' if ok else 'FAIL'} - influence < 0.01 after 5 steps")

Code cell 36

# === 10.1 Visualize Influence Matrix at Different Depths ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle('GCN influence matrix $(\\hat{A}^k)_{vu}$: dumbbell graph', fontsize=14)

    for ax, k in zip(axes, [1, 2, 5]):
        Ak = np.linalg.matrix_power(A_hat_db, k)
        im = ax.imshow(Ak, cmap='plasma', vmin=0)
        plt.colorbar(im, ax=ax)
        ax.set_title(f'Depth $k={k}$')
        ax.set_xlabel('Source node $u$')
        ax.set_ylabel('Target node $v$')
        # Mark the bottleneck
        ax.axhline(4.5, color='white', linewidth=1.5, linestyle='--')
        ax.axvline(4.5, color='white', linewidth=1.5, linestyle='--')

    fig.tight_layout()
    plt.show()
    print('Influence matrices displayed. Note the dark cross-cluster region.')

11. Community Detection with Graph Spectral Clustering vs GCN

Compare spectral (Fiedler vector) and GCN-based node classification on a planted community graph.

Code cell 38

# === 11. Stochastic Block Model Experiment ===

np.random.seed(42)

def stochastic_block_model(n_per_block, n_blocks, p_in, p_out):
    """Generate SBM adjacency matrix."""
    n = n_per_block * n_blocks
    A = np.zeros((n, n))
    labels = np.repeat(np.arange(n_blocks), n_per_block)
    for i in range(n):
        for j in range(i+1, n):
            p = p_in if labels[i] == labels[j] else p_out
            if np.random.rand() < p:
                A[i,j] = A[j,i] = 1
    return A, labels

n_per = 20
n_bl = 3
A_sbm, true_labels = stochastic_block_model(n_per, n_bl, p_in=0.4, p_out=0.02)
n_sbm = A_sbm.shape[0]

# Spectral clustering via Fiedler vector
D_sbm = np.diag(A_sbm.sum(axis=1))
L_sbm = D_sbm - A_sbm
D_inv_sqrt = np.diag(1.0 / (np.sqrt(A_sbm.sum(axis=1)) + 1e-10))
L_sym_sbm = np.eye(n_sbm) - D_inv_sqrt @ A_sbm @ D_inv_sqrt

eigvals_sbm, eigvecs_sbm = np.linalg.eigh(L_sym_sbm)
# Use first n_blocks non-trivial eigenvectors for clustering
spectral_coords = eigvecs_sbm[:, 1:n_bl+1]

# k-means on spectral coords (manual)
from scipy.spatial.distance import cdist

def kmeans_simple(X, k, max_iter=100):
    np.random.seed(42)
    centers = X[np.random.choice(len(X), k, replace=False)]
    for _ in range(max_iter):
        dists = cdist(X, centers)
        assignments = dists.argmin(axis=1)
        new_centers = np.array([X[assignments == c].mean(axis=0) for c in range(k)])
        if np.allclose(centers, new_centers):
            break
        centers = new_centers
    return assignments

spectral_pred = kmeans_simple(spectral_coords, n_bl)

# Accuracy (Hungarian matching)
from itertools import permutations
def clustering_accuracy(true, pred, k):
    best_acc = 0
    for perm in permutations(range(k)):
        mapped = np.array([perm[p] for p in pred])
        acc = (mapped == true).mean()
        best_acc = max(best_acc, acc)
    return best_acc

acc_spectral = clustering_accuracy(true_labels, spectral_pred, n_bl)
print(f'SBM: {n_sbm} nodes, {n_bl} blocks ({n_per} nodes each)')
print(f'Spectral clustering accuracy: {acc_spectral:.3f}')
print(f'\nTrue labels: {true_labels}')
print(f'Predicted:   {spectral_pred}')

ok = acc_spectral > 0.85
print(f"\n{'PASS' if ok else 'FAIL'} - spectral clustering >85% accuracy")

Code cell 39

# === 11.1 GCN Propagation on SBM ===

# Initialize with class one-hot features and propagate
H_sbm = np.eye(n_sbm)[:, :n_bl]  # n x 3 one-hot-like

A_hat_sbm, _, _ = gcn_propagation_matrix(A_sbm)

# Compute Dirichlet energy for label signal
L_sbm_unnorm = D_sbm - A_sbm
H_onehot = np.zeros((n_sbm, n_bl))
H_onehot[np.arange(n_sbm), true_labels] = 1.0

print('Dirichlet energy of label signal at increasing GCN depth:')
H_iter = H_onehot.copy()
for depth in [0, 1, 2, 3, 5, 10]:
    H_iter_d = np.linalg.matrix_power(A_hat_sbm, depth) @ H_onehot
    e = np.trace(H_iter_d.T @ L_sbm_unnorm @ H_iter_d)
    print(f'  Depth {depth:2d}: E = {e:.4f}')

print()
print('Key: at depth 2-3, label signal is smooth but discriminative.')
print('At depth 10, over-smoothing destroys cluster boundaries.')

12. Training Dynamics and Learning Curves

Simulate GNN training on a synthetic node classification task and visualize convergence.

Code cell 41

# === 12. Synthetic GCN Training Simulation ===

np.random.seed(42)

# Generate a simple 2-class node classification problem on SBM
n2 = 40
A2, labels2 = stochastic_block_model(n2//2, 2, p_in=0.4, p_out=0.04)
A_hat2, _, _ = gcn_propagation_matrix(A2)

# Features: class-correlated with noise
X2 = np.zeros((n2, 4))
X2[:n2//2, 0] = 1.0  # Class 0: feature 0 = 1
X2[n2//2:, 1] = 1.0  # Class 1: feature 1 = 1
X2 += np.random.randn(n2, 4) * 0.3  # Add noise

# Labels: 20% labeled (4 from each class)
labeled_idx = list(range(0, 4)) + list(range(n2//2, n2//2+4))
Y = labels2[labeled_idx]

# Two-layer GCN with SGD (manual backprop)
W1_train = np.random.randn(4, 4) * 0.5
W2_train = np.random.randn(4, 2) * 0.5
lr = 0.05

def forward(X, A_hat, W1, W2):
    H1 = np.maximum(0, A_hat @ X @ W1)  # ReLU
    logits = A_hat @ H1 @ W2            # No activation
    # Softmax
    e = np.exp(logits - logits.max(axis=1, keepdims=True))
    probs = e / e.sum(axis=1, keepdims=True)
    return H1, probs

losses = []
for step in range(300):
    H1, probs = forward(X2, A_hat2, W1_train, W2_train)

    # Cross-entropy loss on labeled nodes
    loss = -np.log(probs[labeled_idx, Y] + 1e-10).mean()
    losses.append(loss)

    # Gradient for W2 (simplified)
    dL = probs.copy()
    dL[labeled_idx, Y] -= 1
    dL /= len(labeled_idx)
    dW2 = (A_hat2 @ H1).T @ dL
    dH1 = dL @ W2_train.T
    dH1_relu = dH1 * (H1 > 0).astype(float)
    dW1 = X2.T @ (A_hat2 @ dH1_relu)

    W2_train -= lr * dW2
    W1_train -= lr * dW1

_, final_probs = forward(X2, A_hat2, W1_train, W2_train)
pred_all = final_probs.argmax(axis=1)
acc = (pred_all == labels2).mean()
print(f'Training loss after 300 steps: {losses[-1]:.4f}')
print(f'Final accuracy (all nodes): {acc:.3f}')

ok = acc > 0.7
print(f"{'PASS' if ok else 'FAIL'} - GCN achieves >70% accuracy on SBM")

Code cell 42

# === 12.1 Training Curve Visualization ===

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # Left: Loss curve
    ax = axes[0]
    ax.plot(losses, color=COLORS['primary'], linewidth=2)
    ax.set_title('GCN training loss (semi-supervised)')
    ax.set_xlabel('Training step')
    ax.set_ylabel('Cross-entropy loss')
    ax.axhline(losses[-1], color=COLORS['neutral'], linestyle='--',
               alpha=0.6, label=f'Final: {losses[-1]:.4f}')
    ax.legend()

    # Right: Final node embeddings colored by predicted label
    ax = axes[1]
    # Use GCN output as 2D coordinates (take first 2 logit dims)
    _, probs_viz = forward(X2, A_hat2, W1_train, W2_train)
    # Color by true label, shape by prediction
    for cls, color, label in [(0, COLORS['primary'], 'Class 0'),
                               (1, COLORS['secondary'], 'Class 1')]:
        mask = labels2 == cls
        ax.scatter(probs_viz[mask, 0], probs_viz[mask, 1],
                   color=color, s=60, alpha=0.8, label=label)
    # Mark labeled nodes with larger markers
    ax.scatter(probs_viz[labeled_idx, 0], probs_viz[labeled_idx, 1],
               c='none', edgecolors='black', s=150, linewidths=2,
               label='Labeled (train)')
    ax.set_title('Node prediction probabilities after training')
    ax.set_xlabel('$P$(class 0)')
    ax.set_ylabel('$P$(class 1)')
    ax.legend(markerscale=1.2)

    fig.tight_layout()
    plt.show()
    print('Training curves and classification plot displayed.')

13. Architecture Comparison Summary

Tabulate key properties of the major GNN architectures.

Code cell 44

# === 13. Architecture Comparison ===

architectures = [
    ('GCN',        'Fixed (degree-norm)',  'Sum/Mean', 'MLP', 'O(m*d)', '< 1-WL', 'Transductive'),
    ('GraphSAGE',  'Fixed (mean/max)',     'Mean/Max', 'MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
    ('GAT',        'Learned (static)',     'Attention','MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
    ('GATv2',      'Learned (dynamic)',    'Attention','MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
    ('GIN',        'Fixed (none)',         'Sum',      'MLP', 'O(m*d)', '= 1-WL', 'Inductive'),
    ('GPS',        'Local MPNN + Global Attn','Hybrid','MLP', 'O(m+n^2)*d','> 1-WL w/ PE','Inductive'),
    ('Graphormer', 'Full attention + bias','Attention','MLP', 'O(n^2*d)','> 1-WL w/ PE','Inductive'),
]

header = ['Architecture', 'Aggregation', 'Agg. fn', 'Update', 'Complexity', 'Expressiveness', 'Setting']
col_widths = [13, 22, 10, 5, 15, 16, 12]

def row_str(row):
    return ' | '.join(str(v).ljust(w) for v, w in zip(row, col_widths))

print(row_str(header))
print('-+-'.join('-'*w for w in col_widths))
for arch in architectures:
    print(row_str(arch))

print()
print('Key takeaways:')
print('  1. GIN is the only MPNN matching 1-WL expressiveness (sum agg + MLP)')
print('  2. Graph Transformers exceed 1-WL through structural PEs')
print('  3. GraphSAGE+Cluster-GCN scale to billion-node graphs; GPS does not')
print('  4. GAT vs GATv2: same complexity, GATv2 has dynamic (stronger) attention')

Summary

This notebook has covered:

  1. GCN propagation matrix A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} and its spectral properties (eigenvalues in (1,1](-1,1], permutation equivariance)
  2. Over-smoothing as Dirichlet energy decay: E(H[L])0E(H^{[L]}) \to 0 geometrically in LL
  3. Weisfeiler-Leman test: 1-WL cannot distinguish C6C_6 from C3C3C_3 \cup C_3 — any MPNN without structural features faces the same limit
  4. GAT vs GATv2: static vs dynamic attention — GATv2's LeakyReLU(W[hihj])\operatorname{LeakyReLU}(W[h_i\|h_j]) creates genuine query-dependent attention
  5. GIN sum aggregation: distinguishes multisets that mean/max cannot
  6. RWSE breaks 1-WL symmetry by encoding local loop structure
  7. GraphSAGE: inductive inference on unseen nodes using learned aggregation functions
  8. DiffPool: soft cluster assignment S=softmax(GNN(A,H))S = \operatorname{softmax}(\operatorname{GNN}(A,H)), coarsened graph A=SASA' = S^\top A S
  9. LapPE and RWSE: positional encodings for graph transformers that capture relative structure
  10. Over-squashing: dumbbell graph shows cross-bottleneck influence decays exponentially

Companion materials:


14. Graph Rewiring Effect on Information Flow

Compare the Fiedler value and cross-cluster influence before and after adding a rewiring edge.

Code cell 47

# === 14. Graph Rewiring ===

import numpy as np
np.random.seed(42)

def make_graph_local(n, edges):
    A = np.zeros((n, n))
    for u, v in edges:
        A[u,v]=1; A[v,u]=1
    return A

def gcn_prop(A):
    n = A.shape[0]
    At = A + np.eye(n)
    d = At.sum(axis=1)
    D_inv_sqrt = np.diag(1.0/np.sqrt(d))
    return D_inv_sqrt @ At @ D_inv_sqrt

def fiedler(A):
    D = np.diag(A.sum(axis=1))
    L = D - A
    return np.sort(np.linalg.eigvalsh(L))[1]

# Dumbbell: two 5-cliques connected by one bridge edge
k = 5
n_db = 2*k
A_db = np.zeros((n_db, n_db))
for i in range(k):
    for j in range(i+1,k):
        A_db[i,j]=A_db[j,i]=1
for i in range(k, n_db):
    for j in range(i+1,n_db):
        A_db[i,j]=A_db[j,i]=1
A_db[k-1, k]=A_db[k, k-1]=1  # bridge

A_rew = A_db.copy()
A_rew[0, n_db-1]=A_rew[n_db-1,0]=1  # extra rewiring edge

lam2_orig = fiedler(A_db)
lam2_rew  = fiedler(A_rew)
inf_orig = np.linalg.matrix_power(gcn_prop(A_db), 3)[0, n_db-1]
inf_rew  = np.linalg.matrix_power(gcn_prop(A_rew), 3)[0, n_db-1]

print('Rewiring effect on dumbbell graph:')
print(f'  lambda_2 original: {lam2_orig:.6f}')
print(f'  lambda_2 rewired:  {lam2_rew:.6f}')
print(f'  Cross-cluster influence (A_hat^3)[0,9]: orig={inf_orig:.6f}, rewired={inf_rew:.6f}')
ok = inf_rew > inf_orig * 1.5
print(f"\n{'PASS' if ok else 'FAIL'} - rewiring improves cross-cluster influence")

15. MPNN Unification: GCN, SAGE, GIN Side-by-Side

Verify that all three fit the template hv[l+1]=UPDATE(hv[l],AGGREGATE({hu:uN(v)}))\mathbf{h}_v^{[l+1]} = \operatorname{UPDATE}(\mathbf{h}_v^{[l]}, \operatorname{AGGREGATE}(\{\mathbf{h}_u : u \in \mathcal{N}(v)\})) and differ only in aggregation and update design.

Code cell 49

# === 15. MPNN Unification ===

import numpy as np
np.random.seed(42)

def make_adj(n, edges):
    adj = {i: set() for i in range(n)}
    for u,v in edges:
        adj[u].add(v); adj[v].add(u)
    return adj

n_u = 5
edges_u = [(0,1),(1,2),(2,3),(3,4),(0,3),(1,4)]
adj_u = make_adj(n_u, edges_u)
H_u = np.random.randn(n_u, 3)
W1u = np.random.randn(3,3)*0.5
W2u = np.random.randn(6,3)*0.3

# GCN-style: mean(normalized) + linear update
def layer_gcn(H, adj):
    H_new = np.zeros_like(H)
    for v in adj:
        all_nodes = list(adj[v]) + [v]  # include self
        norms = np.array([np.sqrt(len(adj[u])+1) for u in all_nodes])
        dv = np.sqrt(len(adj[v])+1)
        m = sum(H[u]/norms[i]/dv for i,u in enumerate(all_nodes))
        H_new[v] = np.tanh(W1u.T @ m)
    return H_new

# SAGE: mean + concat update
def layer_sage(H, adj):
    H_new = np.zeros_like(H)
    for v in adj:
        nbrs = list(adj[v])
        m = H[nbrs].mean(axis=0) if nbrs else np.zeros(H.shape[1])
        H_new[v] = np.tanh(W2u.T @ np.concatenate([H[v], m]))
    return H_new

# GIN: sum + MLP update
W3u = np.random.randn(3,3)*0.3
def layer_gin(H, adj, eps=0.0):
    H_new = np.zeros_like(H)
    for v in adj:
        s = sum(H[u] for u in adj[v]) if adj[v] else np.zeros(H.shape[1])
        H_new[v] = np.tanh(W3u.T @ np.tanh(W1u.T @ ((1+eps)*H[v] + s)))
    return H_new

H_gcn2 = layer_gcn(H_u, adj_u)
H_sage2 = layer_sage(H_u, adj_u)
H_gin2  = layer_gin(H_u, adj_u)

print('One MPNN layer on 5-node graph:')
for name, H_out in [('GCN ', H_gcn2), ('SAGE', H_sage2), ('GIN ', H_gin2)]:
    print(f'  {name}: shape={H_out.shape}, norms={np.linalg.norm(H_out,axis=1).round(4)}')

print('\nAll fit MPNN template. Key differences:')
print('  GCN:  mean-normalized agg, linear update (fixed by spectral derivation)')
print('  SAGE: mean agg, concat update (inductive, scalable)')
print('  GIN:  sum agg, 2-layer MLP update (maximally expressive, = 1-WL)')

Code cell 50

# === 15.1 Expressiveness: Which Aggregator Distinguishes Most Graphs? ===

import numpy as np
np.random.seed(42)

# Three informative multisets
tests = [
    ('M={1,1} vs M={1,1,1}',   [1.0,1.0], [1.0,1.0,1.0]),
    ('M={1,2} vs M={2}',       [1.0,2.0], [2.0]),
    ('M={1,1,2} vs M={1,2,2}', [1.0,1.0,2.0], [1.0,2.0,2.0]),
]

print('Aggregation expressiveness on multisets:')
print(f'{"Test":<30} {"Sum?":<6} {"Mean?":<6} {"Max?"}')
print('-'*55)
for name, m1, m2 in tests:
    m1, m2 = np.array(m1), np.array(m2)
    s = not np.isclose(m1.sum(), m2.sum())
    me = not np.isclose(m1.mean(), m2.mean())
    mx = not np.isclose(m1.max(), m2.max())
    print(f'{name:<30} {str(s):<6} {str(me):<6} {str(mx)}')

print()
print('PASS: sum is strictly more powerful than mean and max')

16. GPS Layer: Local MPNN + Global Attention

Implement a simplified GPS layer combining neighborhood aggregation with full pairwise attention.

Code cell 52

# === 16. Simplified GPS Layer ===

import numpy as np
np.random.seed(42)

def softmax_rows(X):
    e = np.exp(X - X.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

def scaled_dot_product_attention(H, W_Q, W_K, W_V):
    Q = H @ W_Q; K = H @ W_K; V = H @ W_V
    dk = Q.shape[1]
    scores = Q @ K.T / np.sqrt(dk)
    alpha = softmax_rows(scores)
    return alpha @ V, alpha

def local_mpnn(H, adj):
    """Simple mean aggregation."""
    H_new = np.zeros_like(H)
    for v in adj:
        nbrs = list(adj[v])
        m = H[nbrs].mean(axis=0) if nbrs else np.zeros(H.shape[1])
        H_new[v] = np.tanh(0.5*(H[v] + m))
    return H_new

def layer_norm(H, eps=1e-6):
    mu = H.mean(axis=1, keepdims=True)
    sigma = H.std(axis=1, keepdims=True) + eps
    return (H - mu) / sigma

n_gps = 8; d_gps = 6
edges_gps = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(0,4),(2,6)]
adj_gps = {i: set() for i in range(n_gps)}
for u,v in edges_gps:
    adj_gps[u].add(v); adj_gps[v].add(u)

H_gps = np.random.randn(n_gps, d_gps)
W_Q = np.random.randn(d_gps, d_gps) * 0.3
W_K = np.random.randn(d_gps, d_gps) * 0.3
W_V = np.random.randn(d_gps, d_gps) * 0.3

# GPS layer: H' = LayerNorm(H + MPNN(H,A) + Attention(H))
mpnn_out = local_mpnn(H_gps, adj_gps)
attn_out, attn_weights = scaled_dot_product_attention(H_gps, W_Q, W_K, W_V)
H_gps_out = layer_norm(H_gps + mpnn_out + attn_out)

print(f'GPS layer: {n_gps} nodes, d={d_gps}')
print(f'MPNN output norm: {np.linalg.norm(mpnn_out, axis=1).round(4)}')
print(f'Attention output norm: {np.linalg.norm(attn_out, axis=1).round(4)}')
print(f'GPS output norm: {np.linalg.norm(H_gps_out, axis=1).round(4)}')
print(f'\nAttention matrix shape: {attn_weights.shape}')
print(f'Attention row sums (should be 1): {attn_weights.sum(axis=1).round(4)}')

ok = np.allclose(attn_weights.sum(axis=1), 1.0, atol=1e-6)
print(f"\n{'PASS' if ok else 'FAIL'} - GPS attention rows sum to 1")
print('\nGPS combines: local structure (MPNN) + global context (Attention)')

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue