forked from kent/consciousness
108 lines
4.5 KiB
Python
108 lines
4.5 KiB
Python
|
|
"""Analyze the SA schedule readout JSON: per-head variance, static/dynamic
|
||
|
|
correlation, and a plot."""
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
import matplotlib
|
||
|
|
matplotlib.use("Agg")
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
ap = argparse.ArgumentParser()
|
||
|
|
ap.add_argument("input_json")
|
||
|
|
ap.add_argument("--out-plot", default="/tmp/sa-schedule.png")
|
||
|
|
args = ap.parse_args()
|
||
|
|
|
||
|
|
with open(args.input_json) as f:
|
||
|
|
data = json.load(f)
|
||
|
|
|
||
|
|
num_layers = data["num_layers"]
|
||
|
|
num_heads = data["num_heads"]
|
||
|
|
Ls = np.arange(num_layers)
|
||
|
|
|
||
|
|
ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H)
|
||
|
|
logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H)
|
||
|
|
metric_op = np.array([row["metric_op_per_head"] for row in data["static"]]) # (L, H)
|
||
|
|
metric_fro = np.array([row["metric_fro_per_head"] for row in data["static"]])
|
||
|
|
|
||
|
|
mean_ent = ent.mean(axis=1)
|
||
|
|
std_ent = ent.std(axis=1)
|
||
|
|
mean_logit = logit_std.mean(axis=1)
|
||
|
|
std_logit = logit_std.std(axis=1)
|
||
|
|
mean_metric = metric_op.mean(axis=1)
|
||
|
|
std_metric = metric_op.std(axis=1)
|
||
|
|
|
||
|
|
# Per-head variance summary
|
||
|
|
print("\nPer-head variance across heads (coefficient of variation = std/mean):")
|
||
|
|
print(f" entropy: mean CV = {(std_ent / np.maximum(mean_ent, 1e-6)).mean():.3f}")
|
||
|
|
print(f" logit_std: mean CV = {(std_logit / np.maximum(mean_logit, 1e-6)).mean():.3f}")
|
||
|
|
print(f" metric_op: mean CV = {(std_metric / np.maximum(mean_metric, 1e-6)).mean():.3f}")
|
||
|
|
|
||
|
|
# Correlations across layers
|
||
|
|
corr_ent_metric = np.corrcoef(mean_ent, mean_metric)[0, 1]
|
||
|
|
corr_logit_metric = np.corrcoef(mean_logit, mean_metric)[0, 1]
|
||
|
|
corr_ent_logit = np.corrcoef(mean_ent, mean_logit)[0, 1]
|
||
|
|
print("\nAcross-layer Pearson correlations (averaged over heads):")
|
||
|
|
print(f" entropy vs metric_op: {corr_ent_metric:+.3f}")
|
||
|
|
print(f" logit_std vs metric_op: {corr_logit_metric:+.3f}")
|
||
|
|
print(f" entropy vs logit_std: {corr_ent_logit:+.3f}")
|
||
|
|
|
||
|
|
# Per-head correlation (one value per head): does each head's entropy
|
||
|
|
# across layers track its own metric_op across layers?
|
||
|
|
head_corrs = []
|
||
|
|
for h in range(num_heads):
|
||
|
|
c = np.corrcoef(ent[:, h], metric_op[:, h])[0, 1]
|
||
|
|
if np.isfinite(c):
|
||
|
|
head_corrs.append(c)
|
||
|
|
print(f" per-head entropy vs metric_op: mean {np.mean(head_corrs):+.3f} "
|
||
|
|
f"std {np.std(head_corrs):.3f} min {min(head_corrs):+.3f} max {max(head_corrs):+.3f}")
|
||
|
|
|
||
|
|
# Plot
|
||
|
|
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
|
||
|
|
|
||
|
|
ax = axes[0]
|
||
|
|
ax.fill_between(Ls, mean_ent - std_ent, mean_ent + std_ent, alpha=0.2, color="tab:blue",
|
||
|
|
label="±1 std across heads")
|
||
|
|
ax.plot(Ls, mean_ent, color="tab:blue", marker="o", label="mean entropy")
|
||
|
|
ax.set_ylabel("attention entropy (nats)")
|
||
|
|
ax.set_title(f"{data['model']} — SA schedule readout ({num_layers} layers, {num_heads} heads)")
|
||
|
|
ax.legend(loc="upper right")
|
||
|
|
ax.grid(alpha=0.3)
|
||
|
|
|
||
|
|
ax = axes[1]
|
||
|
|
ax.fill_between(Ls, mean_logit - std_logit, mean_logit + std_logit, alpha=0.2, color="tab:orange",
|
||
|
|
label="±1 std across heads")
|
||
|
|
ax.plot(Ls, mean_logit, color="tab:orange", marker="o", label="mean logit std")
|
||
|
|
ax.set_ylabel("pre-softmax logit std\n(= implicit sharpness)")
|
||
|
|
ax.legend(loc="upper right")
|
||
|
|
ax.grid(alpha=0.3)
|
||
|
|
|
||
|
|
ax = axes[2]
|
||
|
|
ax.fill_between(Ls, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2, color="tab:green",
|
||
|
|
label="±1 std across heads")
|
||
|
|
ax.plot(Ls, mean_metric, color="tab:green", marker="o", label="mean metric op-norm (static)")
|
||
|
|
ax.set_ylabel("||W_K^T W_Q|| operator norm\n(static, parameter-only)")
|
||
|
|
ax.set_xlabel("layer index L")
|
||
|
|
ax.legend(loc="upper right")
|
||
|
|
ax.grid(alpha=0.3)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(args.out_plot, dpi=100, bbox_inches="tight")
|
||
|
|
print(f"\nWrote plot to {args.out_plot}")
|
||
|
|
|
||
|
|
# Also save a small heatmap of per-head entropy for visual spread
|
||
|
|
plt.figure(figsize=(10, 6))
|
||
|
|
plt.imshow(ent.T, aspect="auto", cmap="viridis", origin="lower")
|
||
|
|
plt.colorbar(label="attention entropy")
|
||
|
|
plt.xlabel("layer L")
|
||
|
|
plt.ylabel("head h")
|
||
|
|
plt.title(f"{data['model']} — per-head entropy heatmap")
|
||
|
|
heatmap_path = args.out_plot.replace(".png", "-heatmap.png")
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(heatmap_path, dpi=100, bbox_inches="tight")
|
||
|
|
print(f"Wrote heatmap to {heatmap_path}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|