consciousness/sa-schedule-measure-grams.py

168 lines
7.1 KiB
Python
Raw Normal View History

"""Measure the full inter-layer geometric relationship between per-head metrics.
For each (L, L', h) pair, compute the Frobenius inner product
<g_L^h, g_L'^h> = tr(g_L^h^T g_L'^h)
where g^h = W_K^h^T W_Q^h R^{hidden × hidden} (rank head_dim).
Using the head_dim × head_dim shortcut:
<g_L, g_L'> = tr(A B^T) with A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T.
Output: gram[L, L', h] and fro_sq[L, h]. From these every layer-pair comparison
is derivable without saving the full operators.
Also saves top-k principal directions per head (as right singular vectors of g,
which are the Q-side eigen-directions) so subspace overlap across layers can be
computed downstream.
"""
import argparse
import json
import os
import numpy as np
import torch
from transformers import AutoModelForCausalLM
@torch.no_grad()
def measure(model_name: str, out_path: str, topk: int = 8,
dtype=torch.bfloat16):
print(f"Loading {model_name} ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
attn_implementation="eager",
)
model.eval()
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim}", flush=True)
# Collect W_Q, W_K per layer as (num_heads, head_dim, hidden) on GPU float32.
Wq_list = []
Wk_list = []
for L, layer in enumerate(model.model.layers):
attn = layer.self_attn
Wq = attn.q_proj.weight.detach().to(torch.float32) # (nh*hd, hidden)
Wk = attn.k_proj.weight.detach().to(torch.float32) # (nkv*hd, hidden)
Wq = Wq.view(num_heads, head_dim, hidden)
# Repeat kv heads so every query head has a matching k-row
Wk = Wk.view(num_kv_heads, head_dim, hidden)
# Broadcast to num_heads via (h // (num_heads // num_kv_heads))? safer: mapping
Wk_full = torch.zeros(num_heads, head_dim, hidden,
device=Wk.device, dtype=Wk.dtype)
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
Wk_full[h] = Wk[kv_h]
Wq_list.append(Wq)
Wk_list.append(Wk_full)
print(f" loaded weights: {num_layers} layers", flush=True)
# Per-head top-k right singular vectors of g^h = W_K^T W_Q (hidden, hidden).
# The non-zero right singular vectors of g lie in row-space(W_Q) ⊂ R^hidden.
# For subspace comparison we need vectors in hidden-space.
#
# We also need SIGNED eigenvalues of the symmetric part (g + g^T)/2 to
# determine curvature signs per eigen-direction. Since g has rank ≤ d_h,
# (g + g^T) has rank ≤ 2 d_h, and we can compute its signed non-zero
# eigenvalues via the Jordan-Wielandt-style trick:
# eigs(X^T J X) = eigs(J X X^T) for X = [W_Q; W_K], J = [[0, I], [I, 0]].
# The resulting 2d_h × 2d_h matrix gives us all non-zero eigenvalues of
# (g + g^T) cheaply.
topk_eff = min(topk, head_dim)
eig_dirs = torch.zeros(num_layers, num_heads, topk_eff, hidden,
dtype=torch.float32)
fro_sq = torch.zeros(num_layers, num_heads, dtype=torch.float64)
sym_eigs = torch.zeros(num_layers, num_heads, 2 * head_dim,
dtype=torch.float64) # signed
for L in range(num_layers):
for h in range(num_heads):
Wq = Wq_list[L][h] # (hd, hidden)
Wk = Wk_list[L][h] # (hd, hidden)
small = Wk @ Wq.T # (hd, hd)
U, S, Vh = torch.linalg.svd(small, full_matrices=False)
dirs = Vh @ Wq # (hd, hidden)
dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-12)
eig_dirs[L, h] = dirs[:topk_eff].cpu()
fro_sq[L, h] = float((S * S).sum())
# Signed eigenvalues of (g + g^T) via 2d_h × 2d_h matrix
# J (X X^T) where X = [Wq; Wk] (stacked)
XXT = torch.zeros(2 * head_dim, 2 * head_dim,
device=Wq.device, dtype=Wq.dtype)
XXT[:head_dim, :head_dim] = Wq @ Wq.T
XXT[:head_dim, head_dim:] = Wq @ Wk.T
XXT[head_dim:, :head_dim] = Wk @ Wq.T
XXT[head_dim:, head_dim:] = Wk @ Wk.T
# J matrix is off-diagonal block identity
J = torch.zeros(2 * head_dim, 2 * head_dim,
device=Wq.device, dtype=Wq.dtype)
J[:head_dim, head_dim:] = torch.eye(head_dim,
device=Wq.device, dtype=Wq.dtype)
J[head_dim:, :head_dim] = torch.eye(head_dim,
device=Wq.device, dtype=Wq.dtype)
M = J @ XXT
# M is not symmetric, but its non-zero eigenvalues are those of
# (g + g^T)/2 times 2 → real (since (g + g^T) is symmetric).
# Use general eigvals; imag parts should be near zero up to
# numerical noise.
ev = torch.linalg.eigvals(M)
ev_real = ev.real.cpu().double()
# sort by magnitude descending so top eigenvalues come first
order = torch.argsort(ev_real.abs(), descending=True)
sym_eigs[L, h] = ev_real[order]
if L % 8 == 0:
print(f" eigdecomp L={L}", flush=True)
# Gram matrix: gram[L, L', h] = <g_L^h, g_L'^h>.
# Using A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T, <g, g'> = tr(A B^T) = sum(A * B).
gram = torch.zeros(num_layers, num_layers, num_heads, dtype=torch.float64)
for L in range(num_layers):
for Lp in range(L, num_layers):
for h in range(num_heads):
Wq_L = Wq_list[L][h]
Wk_L = Wk_list[L][h]
Wq_Lp = Wq_list[Lp][h]
Wk_Lp = Wk_list[Lp][h]
A = Wk_L @ Wk_Lp.T # (hd, hd)
B = Wq_L @ Wq_Lp.T # (hd, hd)
v = float((A * B).sum())
gram[L, Lp, h] = v
gram[Lp, L, h] = v
if L % 4 == 0:
print(f" gram row L={L}", flush=True)
# Save
out = {
"model": model_name,
"num_layers": num_layers,
"num_heads": num_heads,
"head_dim": head_dim,
"hidden_size": hidden,
"topk": topk_eff,
"gram": gram.tolist(),
"fro_sq": fro_sq.tolist(),
}
with open(out_path, "w") as f:
json.dump(out, f)
torch.save({"eig_dirs": eig_dirs, "sym_eigs": sym_eigs},
out_path.replace(".json", "-eigdirs.pt"))
print(f"Wrote {out_path} and {out_path.replace('.json', '-eigdirs.pt')}",
flush=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-4B")
ap.add_argument("--out", default="/tmp/sa-grams.json")
ap.add_argument("--topk", type=int, default=8)
args = ap.parse_args()
measure(args.model, args.out, topk=args.topk)
if __name__ == "__main__":
main()