forked from kent/consciousness
200 lines
7.6 KiB
Python
200 lines
7.6 KiB
Python
|
|
"""After applying Procrustes alignment to remove known gauge freedoms
|
|||
|
|
(per-head d_h rotation tying Q/K/V/O, per-layer d_ff rotation tying
|
|||
|
|
gate/up/down), measure per-family cos-sim between adjacent layers across
|
|||
|
|
the whole network.
|
|||
|
|
|
|||
|
|
Runs Procrustes SVDs on GPU for speed.
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from transformers import AutoModelForCausalLM
|
|||
|
|
|
|||
|
|
|
|||
|
|
def procrustes_gpu(M):
|
|||
|
|
"""Orthogonal R maximizing tr(R M). R = U V^T where M = U Σ V^T.
|
|||
|
|
M on GPU; returns R on GPU."""
|
|||
|
|
U, _, Vh = torch.linalg.svd(M, full_matrices=False)
|
|||
|
|
return U @ Vh
|
|||
|
|
|
|||
|
|
|
|||
|
|
def frob_gpu(x):
|
|||
|
|
return float(torch.linalg.norm(x).item())
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_fro_gpu(x, eps=1e-12):
|
|||
|
|
n = torch.linalg.norm(x)
|
|||
|
|
return x / n.clamp_min(eps)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def main():
|
|||
|
|
ap = argparse.ArgumentParser()
|
|||
|
|
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
|||
|
|
ap.add_argument("--out", default="/tmp/sa-aligned-variation.json")
|
|||
|
|
ap.add_argument("--device", default="cuda")
|
|||
|
|
ap.add_argument("--pairs", default="",
|
|||
|
|
help="Comma-separated list of L indices to run pair (L, L+1) for. "
|
|||
|
|
"Empty = all pairs. E.g. '0,20,30,38,46,52,57' samples phases.")
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
|
|||
|
|
dev = torch.device(args.device)
|
|||
|
|
print(f"Loading {args.model} ...", flush=True)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
args.model,
|
|||
|
|
torch_dtype=torch.float32,
|
|||
|
|
device_map="cpu",
|
|||
|
|
trust_remote_code=True,
|
|||
|
|
attn_implementation="eager",
|
|||
|
|
)
|
|||
|
|
cfg = model.config
|
|||
|
|
num_layers = cfg.num_hidden_layers
|
|||
|
|
num_heads = cfg.num_attention_heads
|
|||
|
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
|||
|
|
hidden = cfg.hidden_size
|
|||
|
|
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
|||
|
|
intermediate = cfg.intermediate_size
|
|||
|
|
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
|
|||
|
|
f"hidden={hidden} ff={intermediate}", flush=True)
|
|||
|
|
|
|||
|
|
# Collect per-layer weights
|
|||
|
|
layers = []
|
|||
|
|
for L in range(num_layers):
|
|||
|
|
layer = model.model.layers[L]
|
|||
|
|
attn = layer.self_attn
|
|||
|
|
mlp = layer.mlp
|
|||
|
|
layers.append({
|
|||
|
|
"q_proj": attn.q_proj.weight.detach().float(),
|
|||
|
|
"k_proj": attn.k_proj.weight.detach().float(),
|
|||
|
|
"v_proj": attn.v_proj.weight.detach().float(),
|
|||
|
|
"o_proj": attn.o_proj.weight.detach().float(),
|
|||
|
|
"gate_proj": mlp.gate_proj.weight.detach().float(),
|
|||
|
|
"up_proj": mlp.up_proj.weight.detach().float(),
|
|||
|
|
"down_proj": mlp.down_proj.weight.detach().float(),
|
|||
|
|
})
|
|||
|
|
del model
|
|||
|
|
|
|||
|
|
# Per-adjacent-pair analysis
|
|||
|
|
aligned_cos = {fam: {} for fam in
|
|||
|
|
["q_proj", "k_proj", "v_proj", "o_proj",
|
|||
|
|
"gate_proj", "up_proj", "down_proj"]}
|
|||
|
|
|
|||
|
|
if args.pairs:
|
|||
|
|
pair_L_list = [int(x) for x in args.pairs.split(",")]
|
|||
|
|
else:
|
|||
|
|
pair_L_list = list(range(num_layers - 1))
|
|||
|
|
|
|||
|
|
for L in pair_L_list:
|
|||
|
|
A = layers[L]
|
|||
|
|
B = layers[L + 1]
|
|||
|
|
|
|||
|
|
# -------- Per-head attention alignment (d_h × d_h) --------
|
|||
|
|
Qa = A["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
|
|||
|
|
Qb = B["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
|
|||
|
|
Ka = A["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
|||
|
|
Kb = B["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
|||
|
|
Va = A["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
|||
|
|
Vb = B["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
|||
|
|
# o_proj is (hidden, num_heads*head_dim); split per head
|
|||
|
|
Oa = A["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
|
|||
|
|
Ob = B["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
|
|||
|
|
# (num_heads, hidden, head_dim)
|
|||
|
|
|
|||
|
|
q_cos = []
|
|||
|
|
k_cos = []
|
|||
|
|
v_cos = []
|
|||
|
|
o_cos = []
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
kv_h = (h * num_kv_heads) // num_heads
|
|||
|
|
qa = normalize_fro_gpu(Qa[h])
|
|||
|
|
qb = normalize_fro_gpu(Qb[h])
|
|||
|
|
ka = normalize_fro_gpu(Ka[kv_h])
|
|||
|
|
kb = normalize_fro_gpu(Kb[kv_h])
|
|||
|
|
va = normalize_fro_gpu(Va[kv_h])
|
|||
|
|
vb = normalize_fro_gpu(Vb[kv_h])
|
|||
|
|
oa = normalize_fro_gpu(Oa[h])
|
|||
|
|
ob = normalize_fro_gpu(Ob[h])
|
|||
|
|
|
|||
|
|
# Cross-correlation for joint alignment: we want R s.t.
|
|||
|
|
# R qa ≈ qb (etc), minimize sum of ||R X_a - X_b||² →
|
|||
|
|
# max tr(R M) with M = qa qb^T + ka kb^T + va vb^T + oa^T ob
|
|||
|
|
M = qa @ qb.T + ka @ kb.T + va @ vb.T + oa.T @ ob
|
|||
|
|
R = procrustes_gpu(M)
|
|||
|
|
|
|||
|
|
# Post-alignment cos-sim (since matrices unit-normalized, cos
|
|||
|
|
# = <R qa, qb> = tr(qb^T R qa) = tr(R qa qb^T))
|
|||
|
|
q_cos.append(float(torch.sum(R @ qa * qb).item()))
|
|||
|
|
k_cos.append(float(torch.sum(R @ ka * kb).item()))
|
|||
|
|
v_cos.append(float(torch.sum(R @ va * vb).item()))
|
|||
|
|
# For O: O after rotation = oa R^T; cos = <oa R^T, ob>
|
|||
|
|
o_cos.append(float(torch.sum(oa @ R.T * ob).item()))
|
|||
|
|
|
|||
|
|
aligned_cos["q_proj"][L] = float(np.mean(q_cos))
|
|||
|
|
aligned_cos["k_proj"][L] = float(np.mean(k_cos))
|
|||
|
|
aligned_cos["v_proj"][L] = float(np.mean(v_cos))
|
|||
|
|
aligned_cos["o_proj"][L] = float(np.mean(o_cos))
|
|||
|
|
|
|||
|
|
# -------- d_ff × d_ff alignment for gate/up/down --------
|
|||
|
|
ga = normalize_fro_gpu(A["gate_proj"].to(dev))
|
|||
|
|
gb = normalize_fro_gpu(B["gate_proj"].to(dev))
|
|||
|
|
ua = normalize_fro_gpu(A["up_proj"].to(dev))
|
|||
|
|
ub = normalize_fro_gpu(B["up_proj"].to(dev))
|
|||
|
|
da = normalize_fro_gpu(A["down_proj"].to(dev)) # (hidden, d_ff)
|
|||
|
|
db = normalize_fro_gpu(B["down_proj"].to(dev))
|
|||
|
|
|
|||
|
|
# All of ga, gb, ua, ub are (d_ff, hidden); da, db are (hidden, d_ff)
|
|||
|
|
# Cross-correlation: M = ga gb^T + ua ub^T + da^T db (d_ff × d_ff)
|
|||
|
|
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
|
|||
|
|
S = procrustes_gpu(M_ff)
|
|||
|
|
|
|||
|
|
aligned_cos["gate_proj"][L] = float(torch.sum(S @ ga * gb).item())
|
|||
|
|
aligned_cos["up_proj"][L] = float(torch.sum(S @ ua * ub).item())
|
|||
|
|
aligned_cos["down_proj"][L] = float(torch.sum(da @ S.T * db).item())
|
|||
|
|
|
|||
|
|
# Free GPU memory
|
|||
|
|
del Qa, Qb, Ka, Kb, Va, Vb, Oa, Ob
|
|||
|
|
del ga, gb, ua, ub, da, db, M_ff, S
|
|||
|
|
torch.cuda.empty_cache()
|
|||
|
|
|
|||
|
|
print(f" done pair L={L}->L+1 "
|
|||
|
|
f"(q={aligned_cos['q_proj'][L]:+.4f} gate={aligned_cos['gate_proj'][L]:+.4f})",
|
|||
|
|
flush=True)
|
|||
|
|
|
|||
|
|
# Report
|
|||
|
|
print("\n=== Adjacent-layer cos-sim AFTER Procrustes alignment ===\n")
|
|||
|
|
print(" cos=1 means identical after gauge rotation; cos=0 means orthogonal\n")
|
|||
|
|
header = " L "
|
|||
|
|
for fam in aligned_cos:
|
|||
|
|
header += f" {fam:>12}"
|
|||
|
|
print(header)
|
|||
|
|
for L in sorted(pair_L_list):
|
|||
|
|
if L not in aligned_cos["q_proj"]:
|
|||
|
|
continue
|
|||
|
|
row = f" {L:>2}"
|
|||
|
|
for fam in aligned_cos:
|
|||
|
|
row += f" {aligned_cos[fam][L]:+12.4f}"
|
|||
|
|
print(row)
|
|||
|
|
|
|||
|
|
print("\n=== Per-family summary (aligned) ===")
|
|||
|
|
print(f" {'family':>14} {'mean_cos':>10} {'median_cos':>11} "
|
|||
|
|
f"{'aligned_resid':>14}")
|
|||
|
|
for fam, vals_dict in aligned_cos.items():
|
|||
|
|
vs = np.array(list(vals_dict.values()))
|
|||
|
|
if len(vs) == 0:
|
|||
|
|
continue
|
|||
|
|
resid = float(np.sqrt(np.maximum(1.0 - vs**2, 0.0)).mean())
|
|||
|
|
print(f" {fam:>14} {vs.mean():>+10.4f} {np.median(vs):>+11.4f} "
|
|||
|
|
f"{resid:>14.4f}")
|
|||
|
|
|
|||
|
|
with open(args.out, "w") as f:
|
|||
|
|
json.dump({
|
|||
|
|
"model": args.model,
|
|||
|
|
"num_layers": num_layers,
|
|||
|
|
"aligned_cos": aligned_cos,
|
|||
|
|
}, f, indent=2)
|
|||
|
|
print(f"\nSaved: {args.out}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|