forked from kent/consciousness
238 lines
10 KiB
Python
238 lines
10 KiB
Python
|
|
"""After removing the known gauge freedoms (per-head d_h rotation tying
|
|||
|
|
W_Q/W_K/W_V/W_O together, per-layer d_ff rotation tying gate/up/down),
|
|||
|
|
measure per-family Frobenius distance between consecutive layers within a
|
|||
|
|
middle block. Families with low post-alignment distance are candidates for
|
|||
|
|
"shared operator" across the block; high distance → carries the schedule.
|
|||
|
|
|
|||
|
|
Normalize each matrix by its Frobenius norm first (so scale differences
|
|||
|
|
don't dominate). We want to see direction of drift, not magnitude.
|
|||
|
|
|
|||
|
|
Gauge freedoms being removed:
|
|||
|
|
- Per-head d_h rotation R ∈ O(d_h): W_Q^h, W_K^h, W_V^h → R W^h;
|
|||
|
|
W_O^h → W_O^h R^T. Softmax attention is invariant under this.
|
|||
|
|
- Per-layer d_ff rotation S ∈ O(d_ff): gate_proj, up_proj → S W;
|
|||
|
|
down_proj → W S^T. SwiGLU/GLU is NOT fully invariant under d_ff
|
|||
|
|
rotation (because the elementwise gate*up is coordinate-dependent),
|
|||
|
|
so this is an approximate alignment — still better than raw.
|
|||
|
|
|
|||
|
|
Families that have no gauge freedom (layernorm γ, q_norm, k_norm): compare
|
|||
|
|
directly after scale normalization.
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from transformers import AutoModelForCausalLM
|
|||
|
|
|
|||
|
|
|
|||
|
|
def procrustes(M):
|
|||
|
|
"""Orthogonal matrix R maximizing tr(R M). Given SVD M = U Σ V^T, R = U V^T."""
|
|||
|
|
U, _, Vh = np.linalg.svd(M, full_matrices=False)
|
|||
|
|
return U @ Vh
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fro(x):
|
|||
|
|
return float(np.linalg.norm(x))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_fro(x, eps=1e-12):
|
|||
|
|
n = fro(x)
|
|||
|
|
return x / max(n, eps)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def main():
|
|||
|
|
ap = argparse.ArgumentParser()
|
|||
|
|
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
|||
|
|
ap.add_argument("--block-start", type=int, default=10)
|
|||
|
|
ap.add_argument("--block-end", type=int, default=25,
|
|||
|
|
help="inclusive; this is mid-block of 36-layer model")
|
|||
|
|
ap.add_argument("--out", default="/tmp/sa-layer-variation.json")
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
|
|||
|
|
print(f"Loading {args.model} ...", flush=True)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
args.model,
|
|||
|
|
torch_dtype=torch.float32,
|
|||
|
|
device_map="cpu",
|
|||
|
|
trust_remote_code=True,
|
|||
|
|
attn_implementation="eager",
|
|||
|
|
)
|
|||
|
|
cfg = model.config
|
|||
|
|
num_layers = cfg.num_hidden_layers
|
|||
|
|
num_heads = cfg.num_attention_heads
|
|||
|
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
|||
|
|
hidden = cfg.hidden_size
|
|||
|
|
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
|||
|
|
intermediate = cfg.intermediate_size
|
|||
|
|
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
|
|||
|
|
f"hidden={hidden} ff={intermediate}", flush=True)
|
|||
|
|
|
|||
|
|
# Collect per-layer weight matrices as numpy float32.
|
|||
|
|
def get_np(name, idx):
|
|||
|
|
w = getattr(model.model.layers[idx], name, None)
|
|||
|
|
if w is None:
|
|||
|
|
return None
|
|||
|
|
return w
|
|||
|
|
|
|||
|
|
layers = {}
|
|||
|
|
for L in range(num_layers):
|
|||
|
|
layer = model.model.layers[L]
|
|||
|
|
attn = layer.self_attn
|
|||
|
|
mlp = layer.mlp
|
|||
|
|
layers[L] = {
|
|||
|
|
"q_proj": attn.q_proj.weight.detach().numpy().astype(np.float32), # (nh*hd, hidden)
|
|||
|
|
"k_proj": attn.k_proj.weight.detach().numpy().astype(np.float32), # (nkv*hd, hidden)
|
|||
|
|
"v_proj": attn.v_proj.weight.detach().numpy().astype(np.float32),
|
|||
|
|
"o_proj": attn.o_proj.weight.detach().numpy().astype(np.float32), # (hidden, nh*hd)
|
|||
|
|
"gate_proj": mlp.gate_proj.weight.detach().numpy().astype(np.float32),
|
|||
|
|
"up_proj": mlp.up_proj.weight.detach().numpy().astype(np.float32),
|
|||
|
|
"down_proj": mlp.down_proj.weight.detach().numpy().astype(np.float32),
|
|||
|
|
"input_ln": layer.input_layernorm.weight.detach().numpy().astype(np.float32),
|
|||
|
|
"post_attn_ln": layer.post_attention_layernorm.weight.detach().numpy().astype(np.float32),
|
|||
|
|
}
|
|||
|
|
# Qwen3 has q_norm / k_norm inside self_attn
|
|||
|
|
q_norm = getattr(attn, "q_norm", None)
|
|||
|
|
k_norm = getattr(attn, "k_norm", None)
|
|||
|
|
if q_norm is not None:
|
|||
|
|
layers[L]["q_norm"] = q_norm.weight.detach().numpy().astype(np.float32)
|
|||
|
|
if k_norm is not None:
|
|||
|
|
layers[L]["k_norm"] = k_norm.weight.detach().numpy().astype(np.float32)
|
|||
|
|
|
|||
|
|
del model # free memory
|
|||
|
|
|
|||
|
|
block = list(range(args.block_start, args.block_end + 1))
|
|||
|
|
pairs = [(block[i], block[i + 1]) for i in range(len(block) - 1)]
|
|||
|
|
print(f"\nAnalyzing block layers {args.block_start}..{args.block_end} "
|
|||
|
|
f"({len(pairs)} consecutive pairs)\n")
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Reshape attention weights per-head for rotation alignment
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def per_head_split(W_qkv, n_heads_for_this):
|
|||
|
|
# W is (n*hd, hidden). Reshape to (n, hd, hidden).
|
|||
|
|
return W_qkv.reshape(n_heads_for_this, head_dim, hidden)
|
|||
|
|
|
|||
|
|
def per_head_split_o(W_o):
|
|||
|
|
# W is (hidden, n*hd). Reshape to (n, hidden, hd).
|
|||
|
|
return W_o.reshape(hidden, num_heads, head_dim).transpose(1, 0, 2)
|
|||
|
|
|
|||
|
|
# Replicate k/v head index to query head index space (GQA)
|
|||
|
|
def kv_to_q_index(h):
|
|||
|
|
return (h * num_kv_heads) // num_heads
|
|||
|
|
|
|||
|
|
family_residuals = {fam: [] for fam in
|
|||
|
|
["q_proj", "k_proj", "v_proj", "o_proj",
|
|||
|
|
"gate_proj", "up_proj", "down_proj",
|
|||
|
|
"input_ln", "post_attn_ln", "q_norm", "k_norm"]}
|
|||
|
|
|
|||
|
|
for (L1, L2) in pairs:
|
|||
|
|
A = layers[L1]
|
|||
|
|
B = layers[L2]
|
|||
|
|
|
|||
|
|
# Per-head attention alignment:
|
|||
|
|
Q1 = per_head_split(A["q_proj"], num_heads)
|
|||
|
|
Q2 = per_head_split(B["q_proj"], num_heads)
|
|||
|
|
K1 = per_head_split(A["k_proj"], num_kv_heads)
|
|||
|
|
K2 = per_head_split(B["k_proj"], num_kv_heads)
|
|||
|
|
V1 = per_head_split(A["v_proj"], num_kv_heads)
|
|||
|
|
V2 = per_head_split(B["v_proj"], num_kv_heads)
|
|||
|
|
O1 = per_head_split_o(A["o_proj"]) # (num_heads, hidden, hd)
|
|||
|
|
O2 = per_head_split_o(B["o_proj"])
|
|||
|
|
|
|||
|
|
q_res = []
|
|||
|
|
k_res = []
|
|||
|
|
v_res = []
|
|||
|
|
o_res = []
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
kv_h = kv_to_q_index(h)
|
|||
|
|
# Normalize each matrix by its Frobenius norm
|
|||
|
|
qa = normalize_fro(Q1[h])
|
|||
|
|
qb = normalize_fro(Q2[h])
|
|||
|
|
ka = normalize_fro(K1[kv_h])
|
|||
|
|
kb = normalize_fro(K2[kv_h])
|
|||
|
|
va = normalize_fro(V1[kv_h])
|
|||
|
|
vb = normalize_fro(V2[kv_h])
|
|||
|
|
oa = normalize_fro(O1[h])
|
|||
|
|
ob = normalize_fro(O2[h])
|
|||
|
|
|
|||
|
|
# Cross-correlation for Procrustes: find R (hd × hd) maximizing
|
|||
|
|
# tr(R [Qa Qb^T + Ka Kb^T + Va Vb^T + (Oa^T Ob)])
|
|||
|
|
# Q, K, V are (hd, hidden); Q2 Q1^T would be (hd, hd); etc.
|
|||
|
|
M = qa @ qb.T + ka @ kb.T + va @ vb.T + (oa.T @ ob) # all (hd, hd)
|
|||
|
|
# Wait: for Q we want tr(R qa qb^T). So the matrix in the max-trace
|
|||
|
|
# Procrustes is qb @ qa.T? Let me be careful.
|
|||
|
|
# max_R tr(R M) achieved at R = U V^T with SVD M = U Σ V^T.
|
|||
|
|
# Here we want R such that R qa ≈ qb → minimize ||R qa - qb||²
|
|||
|
|
# = const - 2 tr(R qa qb^T). So max tr(R qa qb^T) gives the
|
|||
|
|
# correct R. Redo M as sum of qa qb^T terms.
|
|||
|
|
M = qa @ qb.T + ka @ kb.T + va @ vb.T
|
|||
|
|
# For O: want W_O^h R^T ≈ W_O^h_target, i.e. oa R^T ≈ ob
|
|||
|
|
# → min ||oa R^T - ob||² = const - 2 tr(R oa^T ob); max that.
|
|||
|
|
# So O contributes oa^T @ ob to the cross-correlation matrix.
|
|||
|
|
M = M + oa.T @ ob
|
|||
|
|
R = procrustes(M)
|
|||
|
|
|
|||
|
|
# Apply R and measure residual (Frobenius distance) per-matrix
|
|||
|
|
q_res.append(fro(R @ qa - qb))
|
|||
|
|
k_res.append(fro(R @ ka - kb))
|
|||
|
|
v_res.append(fro(R @ va - vb))
|
|||
|
|
o_res.append(fro(oa @ R.T - ob))
|
|||
|
|
|
|||
|
|
family_residuals["q_proj"].append(float(np.mean(q_res)))
|
|||
|
|
family_residuals["k_proj"].append(float(np.mean(k_res)))
|
|||
|
|
family_residuals["v_proj"].append(float(np.mean(v_res)))
|
|||
|
|
family_residuals["o_proj"].append(float(np.mean(o_res)))
|
|||
|
|
|
|||
|
|
# MLP d_ff rotation alignment: find S (d_ff × d_ff) orthogonal with
|
|||
|
|
# S gate_a ≈ gate_b and S up_a ≈ up_b simultaneously; adjust down_proj.
|
|||
|
|
# Each is (d_ff, hidden).
|
|||
|
|
ga = normalize_fro(A["gate_proj"])
|
|||
|
|
gb = normalize_fro(B["gate_proj"])
|
|||
|
|
ua = normalize_fro(A["up_proj"])
|
|||
|
|
ub = normalize_fro(B["up_proj"])
|
|||
|
|
da = normalize_fro(A["down_proj"]) # (hidden, d_ff)
|
|||
|
|
db = normalize_fro(B["down_proj"])
|
|||
|
|
# M_ff = ga @ gb^T + ua @ ub^T + da^T @ db (all d_ff × d_ff)
|
|||
|
|
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
|
|||
|
|
S = procrustes(M_ff)
|
|||
|
|
family_residuals["gate_proj"].append(fro(S @ ga - gb))
|
|||
|
|
family_residuals["up_proj"].append(fro(S @ ua - ub))
|
|||
|
|
family_residuals["down_proj"].append(fro(da @ S.T - db))
|
|||
|
|
|
|||
|
|
# LayerNorm γ vectors — no rotation gauge; just scale-normalize and diff
|
|||
|
|
for ln_name in ["input_ln", "post_attn_ln", "q_norm", "k_norm"]:
|
|||
|
|
if ln_name in A and ln_name in B:
|
|||
|
|
va_ = normalize_fro(A[ln_name])
|
|||
|
|
vb_ = normalize_fro(B[ln_name])
|
|||
|
|
family_residuals[ln_name].append(fro(va_ - vb_))
|
|||
|
|
|
|||
|
|
# Report
|
|||
|
|
print("=== Per-family Frobenius residual between consecutive layers, "
|
|||
|
|
f"block L={args.block_start}..{args.block_end}, after alignment + scale-norm ===\n")
|
|||
|
|
print(f" (Residual = Frobenius distance between L and L+1 after rotation alignment;")
|
|||
|
|
print(f" lower = more shared across block; higher = carries layer-to-layer drift)\n")
|
|||
|
|
print(f" {'family':>14} {'mean':>8} {'min':>8} {'max':>8} {'std':>8} n")
|
|||
|
|
# Report families sorted by mean variation
|
|||
|
|
items = [(fam, np.array(v)) for fam, v in family_residuals.items() if len(v) > 0]
|
|||
|
|
items.sort(key=lambda kv: float(kv[1].mean()))
|
|||
|
|
for fam, v in items:
|
|||
|
|
print(f" {fam:>14} {v.mean():>8.4f} {v.min():>8.4f} {v.max():>8.4f} {v.std():>8.4f} {len(v)}")
|
|||
|
|
|
|||
|
|
print(f"\n Families ranked least-to-most variation:")
|
|||
|
|
for i, (fam, v) in enumerate(items):
|
|||
|
|
print(f" {i+1}. {fam} (mean residual {v.mean():.4f})")
|
|||
|
|
|
|||
|
|
# Save
|
|||
|
|
with open(args.out, "w") as f:
|
|||
|
|
json.dump({
|
|||
|
|
"model": args.model,
|
|||
|
|
"block_start": args.block_start,
|
|||
|
|
"block_end": args.block_end,
|
|||
|
|
"family_residuals": {k: list(v) for k, v in family_residuals.items()},
|
|||
|
|
}, f, indent=2)
|
|||
|
|
print(f"\nSaved: {args.out}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|