consciousness/sa-schedule-derive-from-last.py

214 lines
9.2 KiB
Python
Raw Normal View History

"""Under the SA-schedule hypothesis, earlier layers should be approximately
a temperature-rescaled version of a shared operator. The simplest test:
pick the last layer's per-head metric spectrum as anchor, and ask whether
earlier layers' spectra are scalar rescales of it.
Three experiments on the existing per-head singular values:
(1) Spectral shape invariance. For each head h, normalize σ_L^h by σ_max
and compare the shape vector across layers. If shapes match, scale is
the only free parameter.
(2) Scalar rescale fit. For each (L, h), find T_L^h minimizing
||σ_L^h - T_L^h σ_last^h||². Optimal T_L^h = <σ_L^h, σ_last^h>/||σ_last^h||².
Report residual = ||σ_L^h - T_L^h σ_last^h|| / ||σ_L^h||.
(3) Cross-head sharing. If the *shape* is the same across heads too (not
just across layers), we could use a single anchor per *layer* (last
layer, one head) and reconstruct everything. Report mean shape
correlation across heads within a layer.
The anchor doesn't have to be the last layer — we also try: last layer,
middle layer, per-layer-group best match. Purpose is not to pick the best
anchor but to understand which choice lets reconstruction succeed.
"""
import argparse
import json
import numpy as np
def pad_to(arr, n):
"""Pad a 1D array to length n with zeros (for heads of different rank)."""
if arr.shape[0] == n:
return arr
out = np.zeros(n, dtype=arr.dtype)
out[:arr.shape[0]] = arr
return out
def collect_spectra(data):
"""Return array sigma[L, h, k] of singular values, padded."""
num_layers = data["num_layers"]
num_heads = data["num_heads"]
# Determine max rank across all heads
max_k = 0
for row in data["static"]:
for s in row["metric_singvals_per_head"]:
max_k = max(max_k, len(s))
sigma = np.zeros((num_layers, num_heads, max_k), dtype=np.float64)
for L, row in enumerate(data["static"]):
for h, s in enumerate(row["metric_singvals_per_head"]):
sigma[L, h, :len(s)] = s
return sigma
def scalar_rescale_fit(x, y):
"""Optimal scalar T s.t. ||x - T y|| is minimized.
Returns (T, residual_frac) where residual_frac = ||x - T y|| / ||x||.
"""
denom = float((y * y).sum())
if denom < 1e-20:
return 0.0, 1.0
T = float((x * y).sum() / denom)
resid = x - T * y
rn = float(np.linalg.norm(resid))
xn = float(np.linalg.norm(x))
return T, (rn / xn if xn > 1e-20 else 0.0)
def cos_sim(x, y):
xn = float(np.linalg.norm(x))
yn = float(np.linalg.norm(y))
if xn < 1e-20 or yn < 1e-20:
return 0.0
return float((x * y).sum() / (xn * yn))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input_json")
ap.add_argument("--anchor", choices=["last", "middle", "best"], default="last")
args = ap.parse_args()
with open(args.input_json) as f:
data = json.load(f)
num_layers = data["num_layers"]
num_heads = data["num_heads"]
sigma = collect_spectra(data) # (L, H, K)
print(f"Loaded sigma: shape {sigma.shape}, max rank {sigma.shape[-1]}")
# ------------------------------------------------------------------
# Experiment 1: spectral shape invariance across layers (per head)
# ------------------------------------------------------------------
print("\n=== (1) Spectral shape invariance across layers ===")
# For each head, compute normalized shape σ / σ_max per layer; measure
# mean pairwise cosine similarity of shapes across layers.
shape = np.zeros_like(sigma)
for L in range(num_layers):
for h in range(num_heads):
s = sigma[L, h]
mx = s.max()
shape[L, h] = s / mx if mx > 1e-20 else s
per_head_cos = np.zeros(num_heads)
for h in range(num_heads):
cs = []
for L1 in range(num_layers):
for L2 in range(L1 + 1, num_layers):
cs.append(cos_sim(shape[L1, h], shape[L2, h]))
per_head_cos[h] = np.mean(cs)
print(f" per-head mean pairwise cosine of shape across layers:")
print(f" mean {per_head_cos.mean():.4f} std {per_head_cos.std():.4f} "
f"min {per_head_cos.min():.4f} max {per_head_cos.max():.4f}")
# If mean > ~0.99 → shapes identical, pure scalar rescale works
# If mean ~ 0.85-0.95 → close but structure changes layer-to-layer
# If mean < 0.8 → shape varies meaningfully, scalar rescale insufficient
# ------------------------------------------------------------------
# Experiment 2: scalar rescale fit to an anchor layer
# ------------------------------------------------------------------
if args.anchor == "last":
anchor_L = num_layers - 1
elif args.anchor == "middle":
anchor_L = num_layers // 2
else: # best: pick layer whose shape is most typical (highest mean cos
# to all other layers)
best_score = -1.0
anchor_L = num_layers - 1
for Lc in range(num_layers):
score = 0.0
for h in range(num_heads):
for L in range(num_layers):
if L == Lc:
continue
score += cos_sim(shape[Lc, h], shape[L, h])
if score > best_score:
best_score = score
anchor_L = Lc
print(f" [auto-anchor] best layer by total shape-cosine: L={anchor_L}")
print(f"\n=== (2) Scalar rescale fit to anchor L={anchor_L} ===")
T_map = np.zeros((num_layers, num_heads))
resid_map = np.zeros((num_layers, num_heads))
for L in range(num_layers):
for h in range(num_heads):
T, r = scalar_rescale_fit(sigma[L, h], sigma[anchor_L, h])
T_map[L, h] = T
resid_map[L, h] = r
# Per-layer residual stats
print(f" per-layer residual fraction ||σ_L^h - T σ_anchor^h|| / ||σ_L^h||:")
print(f" {'L':>3} {'mean resid':>10} {'max resid':>10} {'mean T':>8}")
for L in range(num_layers):
rl = resid_map[L]
tl = T_map[L]
print(f" {L:>3} {rl.mean():>10.4f} {rl.max():>10.4f} {tl.mean():>8.3f}")
print(f"\n overall mean residual: {resid_map.mean():.4f}")
print(f" overall max residual: {resid_map.max():.4f}")
print(f" frac of (L,h) with resid < 0.10: "
f"{(resid_map < 0.10).mean():.3f}")
print(f" frac of (L,h) with resid < 0.20: "
f"{(resid_map < 0.20).mean():.3f}")
# ------------------------------------------------------------------
# Experiment 2b: does T match per-head dynamic entropy?
# ------------------------------------------------------------------
ent = np.array([row["mean_attention_entropy_per_head"]
for row in data["dynamic"]]) # (L, H)
# T is a scalar temperature of the metric. Geometrically, higher T means
# sharper attention (smaller entropy). So corr(T, entropy) should be negative
# if the scalar rescale captures the temperature schedule.
from numpy import corrcoef
c = float(corrcoef(T_map.flatten(), ent.flatten())[0, 1])
print(f"\n correlation corr(T_L^h, entropy_L^h) = {c:+.3f} "
f"(negative expected: larger T → sharper → lower entropy)")
# Also try: does T predict entropy *better* than raw op_norm? (Already had
# op_norm r=+0.45 in geometry analysis.)
op_norm = sigma.max(axis=-1) # (L, H)
c_op = float(corrcoef(op_norm.flatten(), ent.flatten())[0, 1])
print(f" for comparison, corr(op_norm, entropy) = {c_op:+.3f}")
# ------------------------------------------------------------------
# Experiment 3: shape similarity across heads within a layer
# ------------------------------------------------------------------
print(f"\n=== (3) Cross-head shape similarity within each layer ===")
print(f" {'L':>3} {'mean pair-cos':>14}")
for L in range(num_layers):
cs = []
for h1 in range(num_heads):
for h2 in range(h1 + 1, num_heads):
cs.append(cos_sim(shape[L, h1], shape[L, h2]))
print(f" {L:>3} {np.mean(cs):>14.4f}")
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
print("\n=== Summary ===")
print(f" anchor layer: {anchor_L}")
print(f" spectral shape is {'very stable' if per_head_cos.mean() > 0.98 else 'approximately stable' if per_head_cos.mean() > 0.9 else 'not stable'} "
f"across layers (per-head mean pairwise cos = {per_head_cos.mean():.3f})")
print(f" scalar-rescale fit residual: mean {resid_map.mean():.3f}")
if resid_map.mean() < 0.1:
verdict = "HYPOTHESIS SUPPORTED — scalar temperature rescale of a shared operator reconstructs earlier layers to within 10% Frobenius residual."
elif resid_map.mean() < 0.3:
verdict = "PARTIALLY SUPPORTED — scalar rescale captures most of the structure; a low-rank correction on top is likely enough."
else:
verdict = "HYPOTHESIS REJECTED for pure scalar rescale — spectra differ substantially in shape; need full layer-by-layer operators or rank-k delta."
print(f"\n {verdict}")
if __name__ == "__main__":
main()