consciousness/sa-schedule-aligned-variation.py
Kent Overstreet 4225294d16 replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
2026-04-25 15:35:14 -04:00

200 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""After applying Procrustes alignment to remove known gauge freedoms
(per-head d_h rotation tying Q/K/V/O, per-layer d_ff rotation tying
gate/up/down), measure per-family cos-sim between adjacent layers across
the whole network.
Runs Procrustes SVDs on GPU for speed.
"""
import argparse
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM
def procrustes_gpu(M):
"""Orthogonal R maximizing tr(R M). R = U V^T where M = U Σ V^T.
M on GPU; returns R on GPU."""
U, _, Vh = torch.linalg.svd(M, full_matrices=False)
return U @ Vh
def frob_gpu(x):
return float(torch.linalg.norm(x).item())
def normalize_fro_gpu(x, eps=1e-12):
n = torch.linalg.norm(x)
return x / n.clamp_min(eps)
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-4B")
ap.add_argument("--out", default="/tmp/sa-aligned-variation.json")
ap.add_argument("--device", default="cuda")
ap.add_argument("--pairs", default="",
help="Comma-separated list of L indices to run pair (L, L+1) for. "
"Empty = all pairs. E.g. '0,20,30,38,46,52,57' samples phases.")
args = ap.parse_args()
dev = torch.device(args.device)
print(f"Loading {args.model} ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
attn_implementation="eager",
)
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
intermediate = cfg.intermediate_size
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
f"hidden={hidden} ff={intermediate}", flush=True)
# Collect per-layer weights
layers = []
for L in range(num_layers):
layer = model.model.layers[L]
attn = layer.self_attn
mlp = layer.mlp
layers.append({
"q_proj": attn.q_proj.weight.detach().float(),
"k_proj": attn.k_proj.weight.detach().float(),
"v_proj": attn.v_proj.weight.detach().float(),
"o_proj": attn.o_proj.weight.detach().float(),
"gate_proj": mlp.gate_proj.weight.detach().float(),
"up_proj": mlp.up_proj.weight.detach().float(),
"down_proj": mlp.down_proj.weight.detach().float(),
})
del model
# Per-adjacent-pair analysis
aligned_cos = {fam: {} for fam in
["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]}
if args.pairs:
pair_L_list = [int(x) for x in args.pairs.split(",")]
else:
pair_L_list = list(range(num_layers - 1))
for L in pair_L_list:
A = layers[L]
B = layers[L + 1]
# -------- Per-head attention alignment (d_h × d_h) --------
Qa = A["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
Qb = B["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
Ka = A["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Kb = B["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Va = A["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
Vb = B["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
# o_proj is (hidden, num_heads*head_dim); split per head
Oa = A["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
Ob = B["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
# (num_heads, hidden, head_dim)
q_cos = []
k_cos = []
v_cos = []
o_cos = []
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
qa = normalize_fro_gpu(Qa[h])
qb = normalize_fro_gpu(Qb[h])
ka = normalize_fro_gpu(Ka[kv_h])
kb = normalize_fro_gpu(Kb[kv_h])
va = normalize_fro_gpu(Va[kv_h])
vb = normalize_fro_gpu(Vb[kv_h])
oa = normalize_fro_gpu(Oa[h])
ob = normalize_fro_gpu(Ob[h])
# Cross-correlation for joint alignment: we want R s.t.
# R qa ≈ qb (etc), minimize sum of ||R X_a - X_b||² →
# max tr(R M) with M = qa qb^T + ka kb^T + va vb^T + oa^T ob
M = qa @ qb.T + ka @ kb.T + va @ vb.T + oa.T @ ob
R = procrustes_gpu(M)
# Post-alignment cos-sim (since matrices unit-normalized, cos
# = <R qa, qb> = tr(qb^T R qa) = tr(R qa qb^T))
q_cos.append(float(torch.sum(R @ qa * qb).item()))
k_cos.append(float(torch.sum(R @ ka * kb).item()))
v_cos.append(float(torch.sum(R @ va * vb).item()))
# For O: O after rotation = oa R^T; cos = <oa R^T, ob>
o_cos.append(float(torch.sum(oa @ R.T * ob).item()))
aligned_cos["q_proj"][L] = float(np.mean(q_cos))
aligned_cos["k_proj"][L] = float(np.mean(k_cos))
aligned_cos["v_proj"][L] = float(np.mean(v_cos))
aligned_cos["o_proj"][L] = float(np.mean(o_cos))
# -------- d_ff × d_ff alignment for gate/up/down --------
ga = normalize_fro_gpu(A["gate_proj"].to(dev))
gb = normalize_fro_gpu(B["gate_proj"].to(dev))
ua = normalize_fro_gpu(A["up_proj"].to(dev))
ub = normalize_fro_gpu(B["up_proj"].to(dev))
da = normalize_fro_gpu(A["down_proj"].to(dev)) # (hidden, d_ff)
db = normalize_fro_gpu(B["down_proj"].to(dev))
# All of ga, gb, ua, ub are (d_ff, hidden); da, db are (hidden, d_ff)
# Cross-correlation: M = ga gb^T + ua ub^T + da^T db (d_ff × d_ff)
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
S = procrustes_gpu(M_ff)
aligned_cos["gate_proj"][L] = float(torch.sum(S @ ga * gb).item())
aligned_cos["up_proj"][L] = float(torch.sum(S @ ua * ub).item())
aligned_cos["down_proj"][L] = float(torch.sum(da @ S.T * db).item())
# Free GPU memory
del Qa, Qb, Ka, Kb, Va, Vb, Oa, Ob
del ga, gb, ua, ub, da, db, M_ff, S
torch.cuda.empty_cache()
print(f" done pair L={L}->L+1 "
f"(q={aligned_cos['q_proj'][L]:+.4f} gate={aligned_cos['gate_proj'][L]:+.4f})",
flush=True)
# Report
print("\n=== Adjacent-layer cos-sim AFTER Procrustes alignment ===\n")
print(" cos=1 means identical after gauge rotation; cos=0 means orthogonal\n")
header = " L "
for fam in aligned_cos:
header += f" {fam:>12}"
print(header)
for L in sorted(pair_L_list):
if L not in aligned_cos["q_proj"]:
continue
row = f" {L:>2}"
for fam in aligned_cos:
row += f" {aligned_cos[fam][L]:+12.4f}"
print(row)
print("\n=== Per-family summary (aligned) ===")
print(f" {'family':>14} {'mean_cos':>10} {'median_cos':>11} "
f"{'aligned_resid':>14}")
for fam, vals_dict in aligned_cos.items():
vs = np.array(list(vals_dict.values()))
if len(vs) == 0:
continue
resid = float(np.sqrt(np.maximum(1.0 - vs**2, 0.0)).mean())
print(f" {fam:>14} {vs.mean():>+10.4f} {np.median(vs):>+11.4f} "
f"{resid:>14.4f}")
with open(args.out, "w") as f:
json.dump({
"model": args.model,
"num_layers": num_layers,
"aligned_cos": aligned_cos,
}, f, indent=2)
print(f"\nSaved: {args.out}")
if __name__ == "__main__":
main()