consciousness/sa-schedule-analyze.py

108 lines
4.5 KiB
Python
Raw Normal View History

"""Analyze the SA schedule readout JSON: per-head variance, static/dynamic
correlation, and a plot."""
import argparse
import json
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input_json")
ap.add_argument("--out-plot", default="/tmp/sa-schedule.png")
args = ap.parse_args()
with open(args.input_json) as f:
data = json.load(f)
num_layers = data["num_layers"]
num_heads = data["num_heads"]
Ls = np.arange(num_layers)
ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H)
logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H)
metric_op = np.array([row["metric_op_per_head"] for row in data["static"]]) # (L, H)
metric_fro = np.array([row["metric_fro_per_head"] for row in data["static"]])
mean_ent = ent.mean(axis=1)
std_ent = ent.std(axis=1)
mean_logit = logit_std.mean(axis=1)
std_logit = logit_std.std(axis=1)
mean_metric = metric_op.mean(axis=1)
std_metric = metric_op.std(axis=1)
# Per-head variance summary
print("\nPer-head variance across heads (coefficient of variation = std/mean):")
print(f" entropy: mean CV = {(std_ent / np.maximum(mean_ent, 1e-6)).mean():.3f}")
print(f" logit_std: mean CV = {(std_logit / np.maximum(mean_logit, 1e-6)).mean():.3f}")
print(f" metric_op: mean CV = {(std_metric / np.maximum(mean_metric, 1e-6)).mean():.3f}")
# Correlations across layers
corr_ent_metric = np.corrcoef(mean_ent, mean_metric)[0, 1]
corr_logit_metric = np.corrcoef(mean_logit, mean_metric)[0, 1]
corr_ent_logit = np.corrcoef(mean_ent, mean_logit)[0, 1]
print("\nAcross-layer Pearson correlations (averaged over heads):")
print(f" entropy vs metric_op: {corr_ent_metric:+.3f}")
print(f" logit_std vs metric_op: {corr_logit_metric:+.3f}")
print(f" entropy vs logit_std: {corr_ent_logit:+.3f}")
# Per-head correlation (one value per head): does each head's entropy
# across layers track its own metric_op across layers?
head_corrs = []
for h in range(num_heads):
c = np.corrcoef(ent[:, h], metric_op[:, h])[0, 1]
if np.isfinite(c):
head_corrs.append(c)
print(f" per-head entropy vs metric_op: mean {np.mean(head_corrs):+.3f} "
f"std {np.std(head_corrs):.3f} min {min(head_corrs):+.3f} max {max(head_corrs):+.3f}")
# Plot
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
ax = axes[0]
ax.fill_between(Ls, mean_ent - std_ent, mean_ent + std_ent, alpha=0.2, color="tab:blue",
label="±1 std across heads")
ax.plot(Ls, mean_ent, color="tab:blue", marker="o", label="mean entropy")
ax.set_ylabel("attention entropy (nats)")
ax.set_title(f"{data['model']} — SA schedule readout ({num_layers} layers, {num_heads} heads)")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
ax = axes[1]
ax.fill_between(Ls, mean_logit - std_logit, mean_logit + std_logit, alpha=0.2, color="tab:orange",
label="±1 std across heads")
ax.plot(Ls, mean_logit, color="tab:orange", marker="o", label="mean logit std")
ax.set_ylabel("pre-softmax logit std\n(= implicit sharpness)")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
ax = axes[2]
ax.fill_between(Ls, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2, color="tab:green",
label="±1 std across heads")
ax.plot(Ls, mean_metric, color="tab:green", marker="o", label="mean metric op-norm (static)")
ax.set_ylabel("||W_K^T W_Q|| operator norm\n(static, parameter-only)")
ax.set_xlabel("layer index L")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(args.out_plot, dpi=100, bbox_inches="tight")
print(f"\nWrote plot to {args.out_plot}")
# Also save a small heatmap of per-head entropy for visual spread
plt.figure(figsize=(10, 6))
plt.imshow(ent.T, aspect="auto", cmap="viridis", origin="lower")
plt.colorbar(label="attention entropy")
plt.xlabel("layer L")
plt.ylabel("head h")
plt.title(f"{data['model']} — per-head entropy heatmap")
heatmap_path = args.out_plot.replace(".png", "-heatmap.png")
plt.tight_layout()
plt.savefig(heatmap_path, dpi=100, bbox_inches="tight")
print(f"Wrote heatmap to {heatmap_path}")
if __name__ == "__main__":
main()