consciousness/sa-schedule-aligned-variation.py

200 lines
7.6 KiB
Python
Raw Normal View History

"""After applying Procrustes alignment to remove known gauge freedoms
(per-head d_h rotation tying Q/K/V/O, per-layer d_ff rotation tying
gate/up/down), measure per-family cos-sim between adjacent layers across
the whole network.
Runs Procrustes SVDs on GPU for speed.
"""
import argparse
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM
def procrustes_gpu(M):
"""Orthogonal R maximizing tr(R M). R = U V^T where M = U Σ V^T.
M on GPU; returns R on GPU."""
U, _, Vh = torch.linalg.svd(M, full_matrices=False)
return U @ Vh
def frob_gpu(x):
return float(torch.linalg.norm(x).item())
def normalize_fro_gpu(x, eps=1e-12):
n = torch.linalg.norm(x)
return x / n.clamp_min(eps)
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-4B")
ap.add_argument("--out", default="/tmp/sa-aligned-variation.json")
ap.add_argument("--device", default="cuda")
ap.add_argument("--pairs", default="",
help="Comma-separated list of L indices to run pair (L, L+1) for. "
"Empty = all pairs. E.g. '0,20,30,38,46,52,57' samples phases.")
args = ap.parse_args()
dev = torch.device(args.device)
print(f"Loading {args.model} ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
attn_implementation="eager",
)
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
intermediate = cfg.intermediate_size
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
f"hidden={hidden} ff={intermediate}", flush=True)
# Collect per-layer weights
layers = []
for L in range(num_layers):
layer = model.model.layers[L]
attn = layer.self_attn
mlp = layer.mlp
layers.append({
"q_proj": attn.q_proj.weight.detach().float(),
"k_proj": attn.k_proj.weight.detach().float(),
"v_proj": attn.v_proj.weight.detach().float(),
"o_proj": attn.o_proj.weight.detach().float(),
"gate_proj": mlp.gate_proj.weight.detach().float(),
"up_proj": mlp.up_proj.weight.detach().float(),
"down_proj": mlp.down_proj.weight.detach().float(),
})
del model
# Per-adjacent-pair analysis
aligned_cos = {fam: {} for fam in
["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]}
if args.pairs:
pair_L_list = [int(x) for x in args.pairs.split(",")]
else:
pair_L_list = list(range(num_layers - 1))
for L in pair_L_list:
A = layers[L]
B = layers[L + 1]
# -------- Per-head attention alignment (d_h × d_h) --------
Qa = A["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
Qb = B["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
Ka = A["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Kb = B["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Va = A["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Vb = B["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
# o_proj is (hidden, num_heads*head_dim); split per head
Oa = A["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
Ob = B["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
# (num_heads, hidden, head_dim)
q_cos = []
k_cos = []
v_cos = []
o_cos = []
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
qa = normalize_fro_gpu(Qa[h])
qb = normalize_fro_gpu(Qb[h])
ka = normalize_fro_gpu(Ka[kv_h])
kb = normalize_fro_gpu(Kb[kv_h])
va = normalize_fro_gpu(Va[kv_h])
vb = normalize_fro_gpu(Vb[kv_h])
oa = normalize_fro_gpu(Oa[h])
ob = normalize_fro_gpu(Ob[h])
# Cross-correlation for joint alignment: we want R s.t.
# R qa ≈ qb (etc), minimize sum of ||R X_a - X_b||² →
# max tr(R M) with M = qa qb^T + ka kb^T + va vb^T + oa^T ob
M = qa @ qb.T + ka @ kb.T + va @ vb.T + oa.T @ ob
R = procrustes_gpu(M)
# Post-alignment cos-sim (since matrices unit-normalized, cos
# = <R qa, qb> = tr(qb^T R qa) = tr(R qa qb^T))
q_cos.append(float(torch.sum(R @ qa * qb).item()))
k_cos.append(float(torch.sum(R @ ka * kb).item()))
v_cos.append(float(torch.sum(R @ va * vb).item()))
# For O: O after rotation = oa R^T; cos = <oa R^T, ob>
o_cos.append(float(torch.sum(oa @ R.T * ob).item()))
aligned_cos["q_proj"][L] = float(np.mean(q_cos))
aligned_cos["k_proj"][L] = float(np.mean(k_cos))
aligned_cos["v_proj"][L] = float(np.mean(v_cos))
aligned_cos["o_proj"][L] = float(np.mean(o_cos))
# -------- d_ff × d_ff alignment for gate/up/down --------
ga = normalize_fro_gpu(A["gate_proj"].to(dev))
gb = normalize_fro_gpu(B["gate_proj"].to(dev))
ua = normalize_fro_gpu(A["up_proj"].to(dev))
ub = normalize_fro_gpu(B["up_proj"].to(dev))
da = normalize_fro_gpu(A["down_proj"].to(dev)) # (hidden, d_ff)
db = normalize_fro_gpu(B["down_proj"].to(dev))
# All of ga, gb, ua, ub are (d_ff, hidden); da, db are (hidden, d_ff)
# Cross-correlation: M = ga gb^T + ua ub^T + da^T db (d_ff × d_ff)
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
S = procrustes_gpu(M_ff)
aligned_cos["gate_proj"][L] = float(torch.sum(S @ ga * gb).item())
aligned_cos["up_proj"][L] = float(torch.sum(S @ ua * ub).item())
aligned_cos["down_proj"][L] = float(torch.sum(da @ S.T * db).item())
# Free GPU memory
del Qa, Qb, Ka, Kb, Va, Vb, Oa, Ob
del ga, gb, ua, ub, da, db, M_ff, S
torch.cuda.empty_cache()
print(f" done pair L={L}->L+1 "
f"(q={aligned_cos['q_proj'][L]:+.4f} gate={aligned_cos['gate_proj'][L]:+.4f})",
flush=True)
# Report
print("\n=== Adjacent-layer cos-sim AFTER Procrustes alignment ===\n")
print(" cos=1 means identical after gauge rotation; cos=0 means orthogonal\n")
header = " L "
for fam in aligned_cos:
header += f" {fam:>12}"
print(header)
for L in sorted(pair_L_list):
if L not in aligned_cos["q_proj"]:
continue
row = f" {L:>2}"
for fam in aligned_cos:
row += f" {aligned_cos[fam][L]:+12.4f}"
print(row)
print("\n=== Per-family summary (aligned) ===")
print(f" {'family':>14} {'mean_cos':>10} {'median_cos':>11} "
f"{'aligned_resid':>14}")
for fam, vals_dict in aligned_cos.items():
vs = np.array(list(vals_dict.values()))
if len(vs) == 0:
continue
resid = float(np.sqrt(np.maximum(1.0 - vs**2, 0.0)).mean())
print(f" {fam:>14} {vs.mean():>+10.4f} {np.median(vs):>+11.4f} "
f"{resid:>14.4f}")
with open(args.out, "w") as f:
json.dump({
"model": args.model,
"num_layers": num_layers,
"aligned_cos": aligned_cos,
}, f, indent=2)
print(f"\nSaved: {args.out}")
if __name__ == "__main__":
main()