"""After applying Procrustes alignment to remove known gauge freedoms (per-head d_h rotation tying Q/K/V/O, per-layer d_ff rotation tying gate/up/down), measure per-family cos-sim between adjacent layers across the whole network. Runs Procrustes SVDs on GPU for speed. """ import argparse import json import numpy as np import torch from transformers import AutoModelForCausalLM def procrustes_gpu(M): """Orthogonal R maximizing tr(R M). R = U V^T where M = U Σ V^T. M on GPU; returns R on GPU.""" U, _, Vh = torch.linalg.svd(M, full_matrices=False) return U @ Vh def frob_gpu(x): return float(torch.linalg.norm(x).item()) def normalize_fro_gpu(x, eps=1e-12): n = torch.linalg.norm(x) return x / n.clamp_min(eps) @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-4B") ap.add_argument("--out", default="/tmp/sa-aligned-variation.json") ap.add_argument("--device", default="cuda") ap.add_argument("--pairs", default="", help="Comma-separated list of L indices to run pair (L, L+1) for. " "Empty = all pairs. E.g. '0,20,30,38,46,52,57' samples phases.") args = ap.parse_args() dev = torch.device(args.device) print(f"Loading {args.model} ...", flush=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, attn_implementation="eager", ) cfg = model.config num_layers = cfg.num_hidden_layers num_heads = cfg.num_attention_heads num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads) hidden = cfg.hidden_size head_dim = getattr(cfg, "head_dim", hidden // num_heads) intermediate = cfg.intermediate_size print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} " f"hidden={hidden} ff={intermediate}", flush=True) # Collect per-layer weights layers = [] for L in range(num_layers): layer = model.model.layers[L] attn = layer.self_attn mlp = layer.mlp layers.append({ "q_proj": attn.q_proj.weight.detach().float(), "k_proj": attn.k_proj.weight.detach().float(), "v_proj": attn.v_proj.weight.detach().float(), "o_proj": attn.o_proj.weight.detach().float(), "gate_proj": mlp.gate_proj.weight.detach().float(), "up_proj": mlp.up_proj.weight.detach().float(), "down_proj": mlp.down_proj.weight.detach().float(), }) del model # Per-adjacent-pair analysis aligned_cos = {fam: {} for fam in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]} if args.pairs: pair_L_list = [int(x) for x in args.pairs.split(",")] else: pair_L_list = list(range(num_layers - 1)) for L in pair_L_list: A = layers[L] B = layers[L + 1] # -------- Per-head attention alignment (d_h × d_h) -------- Qa = A["q_proj"].to(dev).reshape(num_heads, head_dim, hidden) Qb = B["q_proj"].to(dev).reshape(num_heads, head_dim, hidden) Ka = A["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden) Kb = B["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden) Va = A["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden) Vb = B["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden) # o_proj is (hidden, num_heads*head_dim); split per head Oa = A["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous() Ob = B["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous() # (num_heads, hidden, head_dim) q_cos = [] k_cos = [] v_cos = [] o_cos = [] for h in range(num_heads): kv_h = (h * num_kv_heads) // num_heads qa = normalize_fro_gpu(Qa[h]) qb = normalize_fro_gpu(Qb[h]) ka = normalize_fro_gpu(Ka[kv_h]) kb = normalize_fro_gpu(Kb[kv_h]) va = normalize_fro_gpu(Va[kv_h]) vb = normalize_fro_gpu(Vb[kv_h]) oa = normalize_fro_gpu(Oa[h]) ob = normalize_fro_gpu(Ob[h]) # Cross-correlation for joint alignment: we want R s.t. # R qa ≈ qb (etc), minimize sum of ||R X_a - X_b||² → # max tr(R M) with M = qa qb^T + ka kb^T + va vb^T + oa^T ob M = qa @ qb.T + ka @ kb.T + va @ vb.T + oa.T @ ob R = procrustes_gpu(M) # Post-alignment cos-sim (since matrices unit-normalized, cos # = = tr(qb^T R qa) = tr(R qa qb^T)) q_cos.append(float(torch.sum(R @ qa * qb).item())) k_cos.append(float(torch.sum(R @ ka * kb).item())) v_cos.append(float(torch.sum(R @ va * vb).item())) # For O: O after rotation = oa R^T; cos = o_cos.append(float(torch.sum(oa @ R.T * ob).item())) aligned_cos["q_proj"][L] = float(np.mean(q_cos)) aligned_cos["k_proj"][L] = float(np.mean(k_cos)) aligned_cos["v_proj"][L] = float(np.mean(v_cos)) aligned_cos["o_proj"][L] = float(np.mean(o_cos)) # -------- d_ff × d_ff alignment for gate/up/down -------- ga = normalize_fro_gpu(A["gate_proj"].to(dev)) gb = normalize_fro_gpu(B["gate_proj"].to(dev)) ua = normalize_fro_gpu(A["up_proj"].to(dev)) ub = normalize_fro_gpu(B["up_proj"].to(dev)) da = normalize_fro_gpu(A["down_proj"].to(dev)) # (hidden, d_ff) db = normalize_fro_gpu(B["down_proj"].to(dev)) # All of ga, gb, ua, ub are (d_ff, hidden); da, db are (hidden, d_ff) # Cross-correlation: M = ga gb^T + ua ub^T + da^T db (d_ff × d_ff) M_ff = ga @ gb.T + ua @ ub.T + da.T @ db S = procrustes_gpu(M_ff) aligned_cos["gate_proj"][L] = float(torch.sum(S @ ga * gb).item()) aligned_cos["up_proj"][L] = float(torch.sum(S @ ua * ub).item()) aligned_cos["down_proj"][L] = float(torch.sum(da @ S.T * db).item()) # Free GPU memory del Qa, Qb, Ka, Kb, Va, Vb, Oa, Ob del ga, gb, ua, ub, da, db, M_ff, S torch.cuda.empty_cache() print(f" done pair L={L}->L+1 " f"(q={aligned_cos['q_proj'][L]:+.4f} gate={aligned_cos['gate_proj'][L]:+.4f})", flush=True) # Report print("\n=== Adjacent-layer cos-sim AFTER Procrustes alignment ===\n") print(" cos=1 means identical after gauge rotation; cos=0 means orthogonal\n") header = " L " for fam in aligned_cos: header += f" {fam:>12}" print(header) for L in sorted(pair_L_list): if L not in aligned_cos["q_proj"]: continue row = f" {L:>2}" for fam in aligned_cos: row += f" {aligned_cos[fam][L]:+12.4f}" print(row) print("\n=== Per-family summary (aligned) ===") print(f" {'family':>14} {'mean_cos':>10} {'median_cos':>11} " f"{'aligned_resid':>14}") for fam, vals_dict in aligned_cos.items(): vs = np.array(list(vals_dict.values())) if len(vs) == 0: continue resid = float(np.sqrt(np.maximum(1.0 - vs**2, 0.0)).mean()) print(f" {fam:>14} {vs.mean():>+10.4f} {np.median(vs):>+11.4f} " f"{resid:>14.4f}") with open(args.out, "w") as f: json.dump({ "model": args.model, "num_layers": num_layers, "aligned_cos": aligned_cos, }, f, indent=2) print(f"\nSaved: {args.out}") if __name__ == "__main__": main()