forked from kent/consciousness
168 lines
7.1 KiB
Python
168 lines
7.1 KiB
Python
|
|
"""Measure the full inter-layer geometric relationship between per-head metrics.
|
|||
|
|
|
|||
|
|
For each (L, L', h) pair, compute the Frobenius inner product
|
|||
|
|
<g_L^h, g_L'^h> = tr(g_L^h^T g_L'^h)
|
|||
|
|
where g^h = W_K^h^T W_Q^h ∈ R^{hidden × hidden} (rank ≤ head_dim).
|
|||
|
|
|
|||
|
|
Using the head_dim × head_dim shortcut:
|
|||
|
|
<g_L, g_L'> = tr(A B^T) with A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T.
|
|||
|
|
|
|||
|
|
Output: gram[L, L', h] and fro_sq[L, h]. From these every layer-pair comparison
|
|||
|
|
is derivable without saving the full operators.
|
|||
|
|
|
|||
|
|
Also saves top-k principal directions per head (as right singular vectors of g,
|
|||
|
|
which are the Q-side eigen-directions) so subspace overlap across layers can be
|
|||
|
|
computed downstream.
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from transformers import AutoModelForCausalLM
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def measure(model_name: str, out_path: str, topk: int = 8,
|
|||
|
|
dtype=torch.bfloat16):
|
|||
|
|
print(f"Loading {model_name} ...", flush=True)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
model_name,
|
|||
|
|
torch_dtype=dtype,
|
|||
|
|
device_map="cuda",
|
|||
|
|
trust_remote_code=True,
|
|||
|
|
attn_implementation="eager",
|
|||
|
|
)
|
|||
|
|
model.eval()
|
|||
|
|
cfg = model.config
|
|||
|
|
num_layers = cfg.num_hidden_layers
|
|||
|
|
num_heads = cfg.num_attention_heads
|
|||
|
|
hidden = cfg.hidden_size
|
|||
|
|
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
|||
|
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
|||
|
|
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim}", flush=True)
|
|||
|
|
|
|||
|
|
# Collect W_Q, W_K per layer as (num_heads, head_dim, hidden) on GPU float32.
|
|||
|
|
Wq_list = []
|
|||
|
|
Wk_list = []
|
|||
|
|
for L, layer in enumerate(model.model.layers):
|
|||
|
|
attn = layer.self_attn
|
|||
|
|
Wq = attn.q_proj.weight.detach().to(torch.float32) # (nh*hd, hidden)
|
|||
|
|
Wk = attn.k_proj.weight.detach().to(torch.float32) # (nkv*hd, hidden)
|
|||
|
|
Wq = Wq.view(num_heads, head_dim, hidden)
|
|||
|
|
# Repeat kv heads so every query head has a matching k-row
|
|||
|
|
Wk = Wk.view(num_kv_heads, head_dim, hidden)
|
|||
|
|
# Broadcast to num_heads via (h // (num_heads // num_kv_heads))? safer: mapping
|
|||
|
|
Wk_full = torch.zeros(num_heads, head_dim, hidden,
|
|||
|
|
device=Wk.device, dtype=Wk.dtype)
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
kv_h = (h * num_kv_heads) // num_heads
|
|||
|
|
Wk_full[h] = Wk[kv_h]
|
|||
|
|
Wq_list.append(Wq)
|
|||
|
|
Wk_list.append(Wk_full)
|
|||
|
|
print(f" loaded weights: {num_layers} layers", flush=True)
|
|||
|
|
|
|||
|
|
# Per-head top-k right singular vectors of g^h = W_K^T W_Q (hidden, hidden).
|
|||
|
|
# The non-zero right singular vectors of g lie in row-space(W_Q) ⊂ R^hidden.
|
|||
|
|
# For subspace comparison we need vectors in hidden-space.
|
|||
|
|
#
|
|||
|
|
# We also need SIGNED eigenvalues of the symmetric part (g + g^T)/2 to
|
|||
|
|
# determine curvature signs per eigen-direction. Since g has rank ≤ d_h,
|
|||
|
|
# (g + g^T) has rank ≤ 2 d_h, and we can compute its signed non-zero
|
|||
|
|
# eigenvalues via the Jordan-Wielandt-style trick:
|
|||
|
|
# eigs(X^T J X) = eigs(J X X^T) for X = [W_Q; W_K], J = [[0, I], [I, 0]].
|
|||
|
|
# The resulting 2d_h × 2d_h matrix gives us all non-zero eigenvalues of
|
|||
|
|
# (g + g^T) cheaply.
|
|||
|
|
topk_eff = min(topk, head_dim)
|
|||
|
|
eig_dirs = torch.zeros(num_layers, num_heads, topk_eff, hidden,
|
|||
|
|
dtype=torch.float32)
|
|||
|
|
fro_sq = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
|||
|
|
sym_eigs = torch.zeros(num_layers, num_heads, 2 * head_dim,
|
|||
|
|
dtype=torch.float64) # signed
|
|||
|
|
for L in range(num_layers):
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
Wq = Wq_list[L][h] # (hd, hidden)
|
|||
|
|
Wk = Wk_list[L][h] # (hd, hidden)
|
|||
|
|
small = Wk @ Wq.T # (hd, hd)
|
|||
|
|
U, S, Vh = torch.linalg.svd(small, full_matrices=False)
|
|||
|
|
dirs = Vh @ Wq # (hd, hidden)
|
|||
|
|
dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
|||
|
|
eig_dirs[L, h] = dirs[:topk_eff].cpu()
|
|||
|
|
fro_sq[L, h] = float((S * S).sum())
|
|||
|
|
|
|||
|
|
# Signed eigenvalues of (g + g^T) via 2d_h × 2d_h matrix
|
|||
|
|
# J (X X^T) where X = [Wq; Wk] (stacked)
|
|||
|
|
XXT = torch.zeros(2 * head_dim, 2 * head_dim,
|
|||
|
|
device=Wq.device, dtype=Wq.dtype)
|
|||
|
|
XXT[:head_dim, :head_dim] = Wq @ Wq.T
|
|||
|
|
XXT[:head_dim, head_dim:] = Wq @ Wk.T
|
|||
|
|
XXT[head_dim:, :head_dim] = Wk @ Wq.T
|
|||
|
|
XXT[head_dim:, head_dim:] = Wk @ Wk.T
|
|||
|
|
# J matrix is off-diagonal block identity
|
|||
|
|
J = torch.zeros(2 * head_dim, 2 * head_dim,
|
|||
|
|
device=Wq.device, dtype=Wq.dtype)
|
|||
|
|
J[:head_dim, head_dim:] = torch.eye(head_dim,
|
|||
|
|
device=Wq.device, dtype=Wq.dtype)
|
|||
|
|
J[head_dim:, :head_dim] = torch.eye(head_dim,
|
|||
|
|
device=Wq.device, dtype=Wq.dtype)
|
|||
|
|
M = J @ XXT
|
|||
|
|
# M is not symmetric, but its non-zero eigenvalues are those of
|
|||
|
|
# (g + g^T)/2 times 2 → real (since (g + g^T) is symmetric).
|
|||
|
|
# Use general eigvals; imag parts should be near zero up to
|
|||
|
|
# numerical noise.
|
|||
|
|
ev = torch.linalg.eigvals(M)
|
|||
|
|
ev_real = ev.real.cpu().double()
|
|||
|
|
# sort by magnitude descending so top eigenvalues come first
|
|||
|
|
order = torch.argsort(ev_real.abs(), descending=True)
|
|||
|
|
sym_eigs[L, h] = ev_real[order]
|
|||
|
|
if L % 8 == 0:
|
|||
|
|
print(f" eigdecomp L={L}", flush=True)
|
|||
|
|
|
|||
|
|
# Gram matrix: gram[L, L', h] = <g_L^h, g_L'^h>.
|
|||
|
|
# Using A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T, <g, g'> = tr(A B^T) = sum(A * B).
|
|||
|
|
gram = torch.zeros(num_layers, num_layers, num_heads, dtype=torch.float64)
|
|||
|
|
for L in range(num_layers):
|
|||
|
|
for Lp in range(L, num_layers):
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
Wq_L = Wq_list[L][h]
|
|||
|
|
Wk_L = Wk_list[L][h]
|
|||
|
|
Wq_Lp = Wq_list[Lp][h]
|
|||
|
|
Wk_Lp = Wk_list[Lp][h]
|
|||
|
|
A = Wk_L @ Wk_Lp.T # (hd, hd)
|
|||
|
|
B = Wq_L @ Wq_Lp.T # (hd, hd)
|
|||
|
|
v = float((A * B).sum())
|
|||
|
|
gram[L, Lp, h] = v
|
|||
|
|
gram[Lp, L, h] = v
|
|||
|
|
if L % 4 == 0:
|
|||
|
|
print(f" gram row L={L}", flush=True)
|
|||
|
|
|
|||
|
|
# Save
|
|||
|
|
out = {
|
|||
|
|
"model": model_name,
|
|||
|
|
"num_layers": num_layers,
|
|||
|
|
"num_heads": num_heads,
|
|||
|
|
"head_dim": head_dim,
|
|||
|
|
"hidden_size": hidden,
|
|||
|
|
"topk": topk_eff,
|
|||
|
|
"gram": gram.tolist(),
|
|||
|
|
"fro_sq": fro_sq.tolist(),
|
|||
|
|
}
|
|||
|
|
with open(out_path, "w") as f:
|
|||
|
|
json.dump(out, f)
|
|||
|
|
torch.save({"eig_dirs": eig_dirs, "sym_eigs": sym_eigs},
|
|||
|
|
out_path.replace(".json", "-eigdirs.pt"))
|
|||
|
|
print(f"Wrote {out_path} and {out_path.replace('.json', '-eigdirs.pt')}",
|
|||
|
|
flush=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
ap = argparse.ArgumentParser()
|
|||
|
|
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
|||
|
|
ap.add_argument("--out", default="/tmp/sa-grams.json")
|
|||
|
|
ap.add_argument("--topk", type=int, default=8)
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
measure(args.model, args.out, topk=args.topk)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|