"""Measure the full inter-layer geometric relationship between per-head metrics. For each (L, L', h) pair, compute the Frobenius inner product = tr(g_L^h^T g_L'^h) where g^h = W_K^h^T W_Q^h ∈ R^{hidden × hidden} (rank ≤ head_dim). Using the head_dim × head_dim shortcut: = tr(A B^T) with A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T. Output: gram[L, L', h] and fro_sq[L, h]. From these every layer-pair comparison is derivable without saving the full operators. Also saves top-k principal directions per head (as right singular vectors of g, which are the Q-side eigen-directions) so subspace overlap across layers can be computed downstream. """ import argparse import json import os import numpy as np import torch from transformers import AutoModelForCausalLM @torch.no_grad() def measure(model_name: str, out_path: str, topk: int = 8, dtype=torch.bfloat16): print(f"Loading {model_name} ...", flush=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="cuda", trust_remote_code=True, attn_implementation="eager", ) model.eval() cfg = model.config num_layers = cfg.num_hidden_layers num_heads = cfg.num_attention_heads hidden = cfg.hidden_size head_dim = getattr(cfg, "head_dim", hidden // num_heads) num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads) print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim}", flush=True) # Collect W_Q, W_K per layer as (num_heads, head_dim, hidden) on GPU float32. Wq_list = [] Wk_list = [] for L, layer in enumerate(model.model.layers): attn = layer.self_attn Wq = attn.q_proj.weight.detach().to(torch.float32) # (nh*hd, hidden) Wk = attn.k_proj.weight.detach().to(torch.float32) # (nkv*hd, hidden) Wq = Wq.view(num_heads, head_dim, hidden) # Repeat kv heads so every query head has a matching k-row Wk = Wk.view(num_kv_heads, head_dim, hidden) # Broadcast to num_heads via (h // (num_heads // num_kv_heads))? safer: mapping Wk_full = torch.zeros(num_heads, head_dim, hidden, device=Wk.device, dtype=Wk.dtype) for h in range(num_heads): kv_h = (h * num_kv_heads) // num_heads Wk_full[h] = Wk[kv_h] Wq_list.append(Wq) Wk_list.append(Wk_full) print(f" loaded weights: {num_layers} layers", flush=True) # Per-head top-k right singular vectors of g^h = W_K^T W_Q (hidden, hidden). # The non-zero right singular vectors of g lie in row-space(W_Q) ⊂ R^hidden. # For subspace comparison we need vectors in hidden-space. # # We also need SIGNED eigenvalues of the symmetric part (g + g^T)/2 to # determine curvature signs per eigen-direction. Since g has rank ≤ d_h, # (g + g^T) has rank ≤ 2 d_h, and we can compute its signed non-zero # eigenvalues via the Jordan-Wielandt-style trick: # eigs(X^T J X) = eigs(J X X^T) for X = [W_Q; W_K], J = [[0, I], [I, 0]]. # The resulting 2d_h × 2d_h matrix gives us all non-zero eigenvalues of # (g + g^T) cheaply. topk_eff = min(topk, head_dim) eig_dirs = torch.zeros(num_layers, num_heads, topk_eff, hidden, dtype=torch.float32) fro_sq = torch.zeros(num_layers, num_heads, dtype=torch.float64) sym_eigs = torch.zeros(num_layers, num_heads, 2 * head_dim, dtype=torch.float64) # signed for L in range(num_layers): for h in range(num_heads): Wq = Wq_list[L][h] # (hd, hidden) Wk = Wk_list[L][h] # (hd, hidden) small = Wk @ Wq.T # (hd, hd) U, S, Vh = torch.linalg.svd(small, full_matrices=False) dirs = Vh @ Wq # (hd, hidden) dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-12) eig_dirs[L, h] = dirs[:topk_eff].cpu() fro_sq[L, h] = float((S * S).sum()) # Signed eigenvalues of (g + g^T) via 2d_h × 2d_h matrix # J (X X^T) where X = [Wq; Wk] (stacked) XXT = torch.zeros(2 * head_dim, 2 * head_dim, device=Wq.device, dtype=Wq.dtype) XXT[:head_dim, :head_dim] = Wq @ Wq.T XXT[:head_dim, head_dim:] = Wq @ Wk.T XXT[head_dim:, :head_dim] = Wk @ Wq.T XXT[head_dim:, head_dim:] = Wk @ Wk.T # J matrix is off-diagonal block identity J = torch.zeros(2 * head_dim, 2 * head_dim, device=Wq.device, dtype=Wq.dtype) J[:head_dim, head_dim:] = torch.eye(head_dim, device=Wq.device, dtype=Wq.dtype) J[head_dim:, :head_dim] = torch.eye(head_dim, device=Wq.device, dtype=Wq.dtype) M = J @ XXT # M is not symmetric, but its non-zero eigenvalues are those of # (g + g^T)/2 times 2 → real (since (g + g^T) is symmetric). # Use general eigvals; imag parts should be near zero up to # numerical noise. ev = torch.linalg.eigvals(M) ev_real = ev.real.cpu().double() # sort by magnitude descending so top eigenvalues come first order = torch.argsort(ev_real.abs(), descending=True) sym_eigs[L, h] = ev_real[order] if L % 8 == 0: print(f" eigdecomp L={L}", flush=True) # Gram matrix: gram[L, L', h] = . # Using A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T, = tr(A B^T) = sum(A * B). gram = torch.zeros(num_layers, num_layers, num_heads, dtype=torch.float64) for L in range(num_layers): for Lp in range(L, num_layers): for h in range(num_heads): Wq_L = Wq_list[L][h] Wk_L = Wk_list[L][h] Wq_Lp = Wq_list[Lp][h] Wk_Lp = Wk_list[Lp][h] A = Wk_L @ Wk_Lp.T # (hd, hd) B = Wq_L @ Wq_Lp.T # (hd, hd) v = float((A * B).sum()) gram[L, Lp, h] = v gram[Lp, L, h] = v if L % 4 == 0: print(f" gram row L={L}", flush=True) # Save out = { "model": model_name, "num_layers": num_layers, "num_heads": num_heads, "head_dim": head_dim, "hidden_size": hidden, "topk": topk_eff, "gram": gram.tolist(), "fro_sq": fro_sq.tolist(), } with open(out_path, "w") as f: json.dump(out, f) torch.save({"eig_dirs": eig_dirs, "sym_eigs": sym_eigs}, out_path.replace(".json", "-eigdirs.pt")) print(f"Wrote {out_path} and {out_path.replace('.json', '-eigdirs.pt')}", flush=True) def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-4B") ap.add_argument("--out", default="/tmp/sa-grams.json") ap.add_argument("--topk", type=int, default=8) args = ap.parse_args() measure(args.model, args.out, topk=args.topk) if __name__ == "__main__": main()