"""After removing the known gauge freedoms (per-head d_h rotation tying W_Q/W_K/W_V/W_O together, per-layer d_ff rotation tying gate/up/down), measure per-family Frobenius distance between consecutive layers within a middle block. Families with low post-alignment distance are candidates for "shared operator" across the block; high distance → carries the schedule. Normalize each matrix by its Frobenius norm first (so scale differences don't dominate). We want to see direction of drift, not magnitude. Gauge freedoms being removed: - Per-head d_h rotation R ∈ O(d_h): W_Q^h, W_K^h, W_V^h → R W^h; W_O^h → W_O^h R^T. Softmax attention is invariant under this. - Per-layer d_ff rotation S ∈ O(d_ff): gate_proj, up_proj → S W; down_proj → W S^T. SwiGLU/GLU is NOT fully invariant under d_ff rotation (because the elementwise gate*up is coordinate-dependent), so this is an approximate alignment — still better than raw. Families that have no gauge freedom (layernorm γ, q_norm, k_norm): compare directly after scale normalization. """ import argparse import json import numpy as np import torch from transformers import AutoModelForCausalLM def procrustes(M): """Orthogonal matrix R maximizing tr(R M). Given SVD M = U Σ V^T, R = U V^T.""" U, _, Vh = np.linalg.svd(M, full_matrices=False) return U @ Vh def fro(x): return float(np.linalg.norm(x)) def normalize_fro(x, eps=1e-12): n = fro(x) return x / max(n, eps) @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-4B") ap.add_argument("--block-start", type=int, default=10) ap.add_argument("--block-end", type=int, default=25, help="inclusive; this is mid-block of 36-layer model") ap.add_argument("--out", default="/tmp/sa-layer-variation.json") args = ap.parse_args() print(f"Loading {args.model} ...", flush=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, attn_implementation="eager", ) cfg = model.config num_layers = cfg.num_hidden_layers num_heads = cfg.num_attention_heads num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads) hidden = cfg.hidden_size head_dim = getattr(cfg, "head_dim", hidden // num_heads) intermediate = cfg.intermediate_size print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} " f"hidden={hidden} ff={intermediate}", flush=True) # Collect per-layer weight matrices as numpy float32. def get_np(name, idx): w = getattr(model.model.layers[idx], name, None) if w is None: return None return w layers = {} for L in range(num_layers): layer = model.model.layers[L] attn = layer.self_attn mlp = layer.mlp layers[L] = { "q_proj": attn.q_proj.weight.detach().numpy().astype(np.float32), # (nh*hd, hidden) "k_proj": attn.k_proj.weight.detach().numpy().astype(np.float32), # (nkv*hd, hidden) "v_proj": attn.v_proj.weight.detach().numpy().astype(np.float32), "o_proj": attn.o_proj.weight.detach().numpy().astype(np.float32), # (hidden, nh*hd) "gate_proj": mlp.gate_proj.weight.detach().numpy().astype(np.float32), "up_proj": mlp.up_proj.weight.detach().numpy().astype(np.float32), "down_proj": mlp.down_proj.weight.detach().numpy().astype(np.float32), "input_ln": layer.input_layernorm.weight.detach().numpy().astype(np.float32), "post_attn_ln": layer.post_attention_layernorm.weight.detach().numpy().astype(np.float32), } # Qwen3 has q_norm / k_norm inside self_attn q_norm = getattr(attn, "q_norm", None) k_norm = getattr(attn, "k_norm", None) if q_norm is not None: layers[L]["q_norm"] = q_norm.weight.detach().numpy().astype(np.float32) if k_norm is not None: layers[L]["k_norm"] = k_norm.weight.detach().numpy().astype(np.float32) del model # free memory block = list(range(args.block_start, args.block_end + 1)) pairs = [(block[i], block[i + 1]) for i in range(len(block) - 1)] print(f"\nAnalyzing block layers {args.block_start}..{args.block_end} " f"({len(pairs)} consecutive pairs)\n") # ------------------------------------------------------------------ # Reshape attention weights per-head for rotation alignment # ------------------------------------------------------------------ def per_head_split(W_qkv, n_heads_for_this): # W is (n*hd, hidden). Reshape to (n, hd, hidden). return W_qkv.reshape(n_heads_for_this, head_dim, hidden) def per_head_split_o(W_o): # W is (hidden, n*hd). Reshape to (n, hidden, hd). return W_o.reshape(hidden, num_heads, head_dim).transpose(1, 0, 2) # Replicate k/v head index to query head index space (GQA) def kv_to_q_index(h): return (h * num_kv_heads) // num_heads family_residuals = {fam: [] for fam in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "input_ln", "post_attn_ln", "q_norm", "k_norm"]} for (L1, L2) in pairs: A = layers[L1] B = layers[L2] # Per-head attention alignment: Q1 = per_head_split(A["q_proj"], num_heads) Q2 = per_head_split(B["q_proj"], num_heads) K1 = per_head_split(A["k_proj"], num_kv_heads) K2 = per_head_split(B["k_proj"], num_kv_heads) V1 = per_head_split(A["v_proj"], num_kv_heads) V2 = per_head_split(B["v_proj"], num_kv_heads) O1 = per_head_split_o(A["o_proj"]) # (num_heads, hidden, hd) O2 = per_head_split_o(B["o_proj"]) q_res = [] k_res = [] v_res = [] o_res = [] for h in range(num_heads): kv_h = kv_to_q_index(h) # Normalize each matrix by its Frobenius norm qa = normalize_fro(Q1[h]) qb = normalize_fro(Q2[h]) ka = normalize_fro(K1[kv_h]) kb = normalize_fro(K2[kv_h]) va = normalize_fro(V1[kv_h]) vb = normalize_fro(V2[kv_h]) oa = normalize_fro(O1[h]) ob = normalize_fro(O2[h]) # Cross-correlation for Procrustes: find R (hd × hd) maximizing # tr(R [Qa Qb^T + Ka Kb^T + Va Vb^T + (Oa^T Ob)]) # Q, K, V are (hd, hidden); Q2 Q1^T would be (hd, hd); etc. M = qa @ qb.T + ka @ kb.T + va @ vb.T + (oa.T @ ob) # all (hd, hd) # Wait: for Q we want tr(R qa qb^T). So the matrix in the max-trace # Procrustes is qb @ qa.T? Let me be careful. # max_R tr(R M) achieved at R = U V^T with SVD M = U Σ V^T. # Here we want R such that R qa ≈ qb → minimize ||R qa - qb||² # = const - 2 tr(R qa qb^T). So max tr(R qa qb^T) gives the # correct R. Redo M as sum of qa qb^T terms. M = qa @ qb.T + ka @ kb.T + va @ vb.T # For O: want W_O^h R^T ≈ W_O^h_target, i.e. oa R^T ≈ ob # → min ||oa R^T - ob||² = const - 2 tr(R oa^T ob); max that. # So O contributes oa^T @ ob to the cross-correlation matrix. M = M + oa.T @ ob R = procrustes(M) # Apply R and measure residual (Frobenius distance) per-matrix q_res.append(fro(R @ qa - qb)) k_res.append(fro(R @ ka - kb)) v_res.append(fro(R @ va - vb)) o_res.append(fro(oa @ R.T - ob)) family_residuals["q_proj"].append(float(np.mean(q_res))) family_residuals["k_proj"].append(float(np.mean(k_res))) family_residuals["v_proj"].append(float(np.mean(v_res))) family_residuals["o_proj"].append(float(np.mean(o_res))) # MLP d_ff rotation alignment: find S (d_ff × d_ff) orthogonal with # S gate_a ≈ gate_b and S up_a ≈ up_b simultaneously; adjust down_proj. # Each is (d_ff, hidden). ga = normalize_fro(A["gate_proj"]) gb = normalize_fro(B["gate_proj"]) ua = normalize_fro(A["up_proj"]) ub = normalize_fro(B["up_proj"]) da = normalize_fro(A["down_proj"]) # (hidden, d_ff) db = normalize_fro(B["down_proj"]) # M_ff = ga @ gb^T + ua @ ub^T + da^T @ db (all d_ff × d_ff) M_ff = ga @ gb.T + ua @ ub.T + da.T @ db S = procrustes(M_ff) family_residuals["gate_proj"].append(fro(S @ ga - gb)) family_residuals["up_proj"].append(fro(S @ ua - ub)) family_residuals["down_proj"].append(fro(da @ S.T - db)) # LayerNorm γ vectors — no rotation gauge; just scale-normalize and diff for ln_name in ["input_ln", "post_attn_ln", "q_norm", "k_norm"]: if ln_name in A and ln_name in B: va_ = normalize_fro(A[ln_name]) vb_ = normalize_fro(B[ln_name]) family_residuals[ln_name].append(fro(va_ - vb_)) # Report print("=== Per-family Frobenius residual between consecutive layers, " f"block L={args.block_start}..{args.block_end}, after alignment + scale-norm ===\n") print(f" (Residual = Frobenius distance between L and L+1 after rotation alignment;") print(f" lower = more shared across block; higher = carries layer-to-layer drift)\n") print(f" {'family':>14} {'mean':>8} {'min':>8} {'max':>8} {'std':>8} n") # Report families sorted by mean variation items = [(fam, np.array(v)) for fam, v in family_residuals.items() if len(v) > 0] items.sort(key=lambda kv: float(kv[1].mean())) for fam, v in items: print(f" {fam:>14} {v.mean():>8.4f} {v.min():>8.4f} {v.max():>8.4f} {v.std():>8.4f} {len(v)}") print(f"\n Families ranked least-to-most variation:") for i, (fam, v) in enumerate(items): print(f" {i+1}. {fam} (mean residual {v.mean():.4f})") # Save with open(args.out, "w") as f: json.dump({ "model": args.model, "block_start": args.block_start, "block_end": args.block_end, "family_residuals": {k: list(v) for k, v in family_residuals.items()}, }, f, indent=2) print(f"\nSaved: {args.out}") if __name__ == "__main__": main()