"""What does per-head T (entropy) correlate with geometrically? For each (layer, head) we already have singular values of the metric M^h = W_K^h^T W_Q^h (up to the low-rank structure — strictly SVD of the head_dim x head_dim product). Derive richer per-head geometric descriptors and test which ones predict dynamic entropy. Descriptors per head: op_norm σ_max — global "capacity for sharpness" fro_norm √Σ σ_i² — total metric "energy" rank_eff Σσ / σ_max — effective number of modes spec_entropy -Σ (σ_i² / Σσ_j²) log(...) — flatness of spectrum (nats) anisotropy σ_max / σ_mean — how "peaked" the top mode is condition σ_max / σ_min — ratio of biggest to smallest trace Σσ_i — sum of modes (L1-like) Correlate each of these per-head descriptors against per-head dynamic entropy, across all (layer, head) pairs. Also stratified by layer-position (early/mid/late). """ import argparse import json import numpy as np def compute_per_head_geometry(singvals_list): """singvals_list: list per head of list of singular values. Returns dict of arrays.""" s_all = [np.array(s, dtype=np.float64) for s in singvals_list] op = np.array([s.max() for s in s_all]) fro = np.array([np.sqrt((s ** 2).sum()) for s in s_all]) trace = np.array([s.sum() for s in s_all]) rank_eff = np.array([s.sum() / max(s.max(), 1e-12) for s in s_all]) # Spectral entropy: use normalized σ² as probabilities spec_ent = np.zeros(len(s_all)) for i, s in enumerate(s_all): p = (s ** 2) / max((s ** 2).sum(), 1e-12) p = np.clip(p, 1e-12, 1.0) spec_ent[i] = float(-(p * np.log(p)).sum()) anis = np.array([s.max() / max(s.mean(), 1e-12) for s in s_all]) cond = np.array([s.max() / max(s.min(), 1e-12) for s in s_all]) return dict(op=op, fro=fro, trace=trace, rank_eff=rank_eff, spec_ent=spec_ent, anisotropy=anis, condition=cond) def main(): ap = argparse.ArgumentParser() ap.add_argument("input_json") args = ap.parse_args() with open(args.input_json) as f: data = json.load(f) num_layers = data["num_layers"] num_heads = data["num_heads"] # Entropy per (layer, head) ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H) logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H) # Geometric descriptors per (layer, head) geom = {k: np.zeros((num_layers, num_heads)) for k in ["op", "fro", "trace", "rank_eff", "spec_ent", "anisotropy", "condition"]} for L, row in enumerate(data["static"]): per_head = compute_per_head_geometry(row["metric_singvals_per_head"]) for k, v in per_head.items(): geom[k][L] = v # Flatten across (layer, head) and correlate print("All (layer, head) pairs — Pearson correlation with dynamic entropy:") ent_flat = ent.flatten() logit_flat = logit_std.flatten() results = {} for k, v in geom.items(): v_flat = v.flatten() c_ent = float(np.corrcoef(v_flat, ent_flat)[0, 1]) c_logit = float(np.corrcoef(v_flat, logit_flat)[0, 1]) results[k] = (c_ent, c_logit) print(f" {k:12} vs entropy: {c_ent:+.3f} vs logit_std: {c_logit:+.3f}") # Stratify by layer position — early (0-11), mid (12-23), late (24-35) thirds = [(0, num_layers // 3, "early"), (num_layers // 3, 2 * num_layers // 3, "mid"), (2 * num_layers // 3, num_layers, "late")] print("\nStratified by layer position (entropy correlation):") for lo, hi, name in thirds: print(f" [{name} L{lo}-{hi-1}]", end="") for k in ["op", "fro", "rank_eff", "spec_ent", "anisotropy", "condition"]: c = float(np.corrcoef(geom[k][lo:hi].flatten(), ent[lo:hi].flatten())[0, 1]) print(f" {k}:{c:+.2f}", end="") print() # Best single predictor across all print("\nBest single geometric predictor of entropy (abs):") best = max(results.items(), key=lambda kv: abs(kv[1][0])) print(f" {best[0]} r = {best[1][0]:+.3f}") # Multi-regression: try op, spec_ent, rank_eff jointly print("\nLinear regression of entropy on multiple descriptors (standardized):") from numpy.linalg import lstsq X_cols = ["op", "spec_ent", "rank_eff", "anisotropy"] X = np.stack([geom[k].flatten() for k in X_cols], axis=1) # standardize X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-12) y = (ent_flat - ent_flat.mean()) / (ent_flat.std() + 1e-12) X1 = np.concatenate([X, np.ones((X.shape[0], 1))], axis=1) coef, res, rk, sv = lstsq(X1, y, rcond=None) y_pred = X1 @ coef r2 = 1 - float(((y - y_pred) ** 2).sum() / ((y - y.mean()) ** 2).sum()) print(f" R² = {r2:.3f}") print(f" standardized coefficients:") for name, c in zip(X_cols, coef[:-1]): print(f" {name:12} {c:+.3f}") if __name__ == "__main__": main()