consciousness/sa-schedule-layer-variation.py

238 lines
10 KiB
Python
Raw Normal View History

"""After removing the known gauge freedoms (per-head d_h rotation tying
W_Q/W_K/W_V/W_O together, per-layer d_ff rotation tying gate/up/down),
measure per-family Frobenius distance between consecutive layers within a
middle block. Families with low post-alignment distance are candidates for
"shared operator" across the block; high distance carries the schedule.
Normalize each matrix by its Frobenius norm first (so scale differences
don't dominate). We want to see direction of drift, not magnitude.
Gauge freedoms being removed:
- Per-head d_h rotation R O(d_h): W_Q^h, W_K^h, W_V^h R W^h;
W_O^h W_O^h R^T. Softmax attention is invariant under this.
- Per-layer d_ff rotation S O(d_ff): gate_proj, up_proj S W;
down_proj W S^T. SwiGLU/GLU is NOT fully invariant under d_ff
rotation (because the elementwise gate*up is coordinate-dependent),
so this is an approximate alignment still better than raw.
Families that have no gauge freedom (layernorm γ, q_norm, k_norm): compare
directly after scale normalization.
"""
import argparse
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM
def procrustes(M):
"""Orthogonal matrix R maximizing tr(R M). Given SVD M = U Σ V^T, R = U V^T."""
U, _, Vh = np.linalg.svd(M, full_matrices=False)
return U @ Vh
def fro(x):
return float(np.linalg.norm(x))
def normalize_fro(x, eps=1e-12):
n = fro(x)
return x / max(n, eps)
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-4B")
ap.add_argument("--block-start", type=int, default=10)
ap.add_argument("--block-end", type=int, default=25,
help="inclusive; this is mid-block of 36-layer model")
ap.add_argument("--out", default="/tmp/sa-layer-variation.json")
args = ap.parse_args()
print(f"Loading {args.model} ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
attn_implementation="eager",
)
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
intermediate = cfg.intermediate_size
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
f"hidden={hidden} ff={intermediate}", flush=True)
# Collect per-layer weight matrices as numpy float32.
def get_np(name, idx):
w = getattr(model.model.layers[idx], name, None)
if w is None:
return None
return w
layers = {}
for L in range(num_layers):
layer = model.model.layers[L]
attn = layer.self_attn
mlp = layer.mlp
layers[L] = {
"q_proj": attn.q_proj.weight.detach().numpy().astype(np.float32), # (nh*hd, hidden)
"k_proj": attn.k_proj.weight.detach().numpy().astype(np.float32), # (nkv*hd, hidden)
"v_proj": attn.v_proj.weight.detach().numpy().astype(np.float32),
"o_proj": attn.o_proj.weight.detach().numpy().astype(np.float32), # (hidden, nh*hd)
"gate_proj": mlp.gate_proj.weight.detach().numpy().astype(np.float32),
"up_proj": mlp.up_proj.weight.detach().numpy().astype(np.float32),
"down_proj": mlp.down_proj.weight.detach().numpy().astype(np.float32),
"input_ln": layer.input_layernorm.weight.detach().numpy().astype(np.float32),
"post_attn_ln": layer.post_attention_layernorm.weight.detach().numpy().astype(np.float32),
}
# Qwen3 has q_norm / k_norm inside self_attn
q_norm = getattr(attn, "q_norm", None)
k_norm = getattr(attn, "k_norm", None)
if q_norm is not None:
layers[L]["q_norm"] = q_norm.weight.detach().numpy().astype(np.float32)
if k_norm is not None:
layers[L]["k_norm"] = k_norm.weight.detach().numpy().astype(np.float32)
del model # free memory
block = list(range(args.block_start, args.block_end + 1))
pairs = [(block[i], block[i + 1]) for i in range(len(block) - 1)]
print(f"\nAnalyzing block layers {args.block_start}..{args.block_end} "
f"({len(pairs)} consecutive pairs)\n")
# ------------------------------------------------------------------
# Reshape attention weights per-head for rotation alignment
# ------------------------------------------------------------------
def per_head_split(W_qkv, n_heads_for_this):
# W is (n*hd, hidden). Reshape to (n, hd, hidden).
return W_qkv.reshape(n_heads_for_this, head_dim, hidden)
def per_head_split_o(W_o):
# W is (hidden, n*hd). Reshape to (n, hidden, hd).
return W_o.reshape(hidden, num_heads, head_dim).transpose(1, 0, 2)
# Replicate k/v head index to query head index space (GQA)
def kv_to_q_index(h):
return (h * num_kv_heads) // num_heads
family_residuals = {fam: [] for fam in
["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"input_ln", "post_attn_ln", "q_norm", "k_norm"]}
for (L1, L2) in pairs:
A = layers[L1]
B = layers[L2]
# Per-head attention alignment:
Q1 = per_head_split(A["q_proj"], num_heads)
Q2 = per_head_split(B["q_proj"], num_heads)
K1 = per_head_split(A["k_proj"], num_kv_heads)
K2 = per_head_split(B["k_proj"], num_kv_heads)
V1 = per_head_split(A["v_proj"], num_kv_heads)
V2 = per_head_split(B["v_proj"], num_kv_heads)
O1 = per_head_split_o(A["o_proj"]) # (num_heads, hidden, hd)
O2 = per_head_split_o(B["o_proj"])
q_res = []
k_res = []
v_res = []
o_res = []
for h in range(num_heads):
kv_h = kv_to_q_index(h)
# Normalize each matrix by its Frobenius norm
qa = normalize_fro(Q1[h])
qb = normalize_fro(Q2[h])
ka = normalize_fro(K1[kv_h])
kb = normalize_fro(K2[kv_h])
va = normalize_fro(V1[kv_h])
vb = normalize_fro(V2[kv_h])
oa = normalize_fro(O1[h])
ob = normalize_fro(O2[h])
# Cross-correlation for Procrustes: find R (hd × hd) maximizing
# tr(R [Qa Qb^T + Ka Kb^T + Va Vb^T + (Oa^T Ob)])
# Q, K, V are (hd, hidden); Q2 Q1^T would be (hd, hd); etc.
M = qa @ qb.T + ka @ kb.T + va @ vb.T + (oa.T @ ob) # all (hd, hd)
# Wait: for Q we want tr(R qa qb^T). So the matrix in the max-trace
# Procrustes is qb @ qa.T? Let me be careful.
# max_R tr(R M) achieved at R = U V^T with SVD M = U Σ V^T.
# Here we want R such that R qa ≈ qb → minimize ||R qa - qb||²
# = const - 2 tr(R qa qb^T). So max tr(R qa qb^T) gives the
# correct R. Redo M as sum of qa qb^T terms.
M = qa @ qb.T + ka @ kb.T + va @ vb.T
# For O: want W_O^h R^T ≈ W_O^h_target, i.e. oa R^T ≈ ob
# → min ||oa R^T - ob||² = const - 2 tr(R oa^T ob); max that.
# So O contributes oa^T @ ob to the cross-correlation matrix.
M = M + oa.T @ ob
R = procrustes(M)
# Apply R and measure residual (Frobenius distance) per-matrix
q_res.append(fro(R @ qa - qb))
k_res.append(fro(R @ ka - kb))
v_res.append(fro(R @ va - vb))
o_res.append(fro(oa @ R.T - ob))
family_residuals["q_proj"].append(float(np.mean(q_res)))
family_residuals["k_proj"].append(float(np.mean(k_res)))
family_residuals["v_proj"].append(float(np.mean(v_res)))
family_residuals["o_proj"].append(float(np.mean(o_res)))
# MLP d_ff rotation alignment: find S (d_ff × d_ff) orthogonal with
# S gate_a ≈ gate_b and S up_a ≈ up_b simultaneously; adjust down_proj.
# Each is (d_ff, hidden).
ga = normalize_fro(A["gate_proj"])
gb = normalize_fro(B["gate_proj"])
ua = normalize_fro(A["up_proj"])
ub = normalize_fro(B["up_proj"])
da = normalize_fro(A["down_proj"]) # (hidden, d_ff)
db = normalize_fro(B["down_proj"])
# M_ff = ga @ gb^T + ua @ ub^T + da^T @ db (all d_ff × d_ff)
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
S = procrustes(M_ff)
family_residuals["gate_proj"].append(fro(S @ ga - gb))
family_residuals["up_proj"].append(fro(S @ ua - ub))
family_residuals["down_proj"].append(fro(da @ S.T - db))
# LayerNorm γ vectors — no rotation gauge; just scale-normalize and diff
for ln_name in ["input_ln", "post_attn_ln", "q_norm", "k_norm"]:
if ln_name in A and ln_name in B:
va_ = normalize_fro(A[ln_name])
vb_ = normalize_fro(B[ln_name])
family_residuals[ln_name].append(fro(va_ - vb_))
# Report
print("=== Per-family Frobenius residual between consecutive layers, "
f"block L={args.block_start}..{args.block_end}, after alignment + scale-norm ===\n")
print(f" (Residual = Frobenius distance between L and L+1 after rotation alignment;")
print(f" lower = more shared across block; higher = carries layer-to-layer drift)\n")
print(f" {'family':>14} {'mean':>8} {'min':>8} {'max':>8} {'std':>8} n")
# Report families sorted by mean variation
items = [(fam, np.array(v)) for fam, v in family_residuals.items() if len(v) > 0]
items.sort(key=lambda kv: float(kv[1].mean()))
for fam, v in items:
print(f" {fam:>14} {v.mean():>8.4f} {v.min():>8.4f} {v.max():>8.4f} {v.std():>8.4f} {len(v)}")
print(f"\n Families ranked least-to-most variation:")
for i, (fam, v) in enumerate(items):
print(f" {i+1}. {fam} (mean residual {v.mean():.4f})")
# Save
with open(args.out, "w") as f:
json.dump({
"model": args.model,
"block_start": args.block_start,
"block_end": args.block_end,
"family_residuals": {k: list(v) for k, v in family_residuals.items()},
}, f, indent=2)
print(f"\nSaved: {args.out}")
if __name__ == "__main__":
main()