"""Analyze the SA schedule readout JSON: per-head variance, static/dynamic correlation, and a plot.""" import argparse import json import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def main(): ap = argparse.ArgumentParser() ap.add_argument("input_json") ap.add_argument("--out-plot", default="/tmp/sa-schedule.png") args = ap.parse_args() with open(args.input_json) as f: data = json.load(f) num_layers = data["num_layers"] num_heads = data["num_heads"] Ls = np.arange(num_layers) ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H) logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H) metric_op = np.array([row["metric_op_per_head"] for row in data["static"]]) # (L, H) metric_fro = np.array([row["metric_fro_per_head"] for row in data["static"]]) mean_ent = ent.mean(axis=1) std_ent = ent.std(axis=1) mean_logit = logit_std.mean(axis=1) std_logit = logit_std.std(axis=1) mean_metric = metric_op.mean(axis=1) std_metric = metric_op.std(axis=1) # Per-head variance summary print("\nPer-head variance across heads (coefficient of variation = std/mean):") print(f" entropy: mean CV = {(std_ent / np.maximum(mean_ent, 1e-6)).mean():.3f}") print(f" logit_std: mean CV = {(std_logit / np.maximum(mean_logit, 1e-6)).mean():.3f}") print(f" metric_op: mean CV = {(std_metric / np.maximum(mean_metric, 1e-6)).mean():.3f}") # Correlations across layers corr_ent_metric = np.corrcoef(mean_ent, mean_metric)[0, 1] corr_logit_metric = np.corrcoef(mean_logit, mean_metric)[0, 1] corr_ent_logit = np.corrcoef(mean_ent, mean_logit)[0, 1] print("\nAcross-layer Pearson correlations (averaged over heads):") print(f" entropy vs metric_op: {corr_ent_metric:+.3f}") print(f" logit_std vs metric_op: {corr_logit_metric:+.3f}") print(f" entropy vs logit_std: {corr_ent_logit:+.3f}") # Per-head correlation (one value per head): does each head's entropy # across layers track its own metric_op across layers? head_corrs = [] for h in range(num_heads): c = np.corrcoef(ent[:, h], metric_op[:, h])[0, 1] if np.isfinite(c): head_corrs.append(c) print(f" per-head entropy vs metric_op: mean {np.mean(head_corrs):+.3f} " f"std {np.std(head_corrs):.3f} min {min(head_corrs):+.3f} max {max(head_corrs):+.3f}") # Plot fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True) ax = axes[0] ax.fill_between(Ls, mean_ent - std_ent, mean_ent + std_ent, alpha=0.2, color="tab:blue", label="±1 std across heads") ax.plot(Ls, mean_ent, color="tab:blue", marker="o", label="mean entropy") ax.set_ylabel("attention entropy (nats)") ax.set_title(f"{data['model']} — SA schedule readout ({num_layers} layers, {num_heads} heads)") ax.legend(loc="upper right") ax.grid(alpha=0.3) ax = axes[1] ax.fill_between(Ls, mean_logit - std_logit, mean_logit + std_logit, alpha=0.2, color="tab:orange", label="±1 std across heads") ax.plot(Ls, mean_logit, color="tab:orange", marker="o", label="mean logit std") ax.set_ylabel("pre-softmax logit std\n(= implicit sharpness)") ax.legend(loc="upper right") ax.grid(alpha=0.3) ax = axes[2] ax.fill_between(Ls, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2, color="tab:green", label="±1 std across heads") ax.plot(Ls, mean_metric, color="tab:green", marker="o", label="mean metric op-norm (static)") ax.set_ylabel("||W_K^T W_Q|| operator norm\n(static, parameter-only)") ax.set_xlabel("layer index L") ax.legend(loc="upper right") ax.grid(alpha=0.3) plt.tight_layout() plt.savefig(args.out_plot, dpi=100, bbox_inches="tight") print(f"\nWrote plot to {args.out_plot}") # Also save a small heatmap of per-head entropy for visual spread plt.figure(figsize=(10, 6)) plt.imshow(ent.T, aspect="auto", cmap="viridis", origin="lower") plt.colorbar(label="attention entropy") plt.xlabel("layer L") plt.ylabel("head h") plt.title(f"{data['model']} — per-head entropy heatmap") heatmap_path = args.out_plot.replace(".png", "-heatmap.png") plt.tight_layout() plt.savefig(heatmap_path, dpi=100, bbox_inches="tight") print(f"Wrote heatmap to {heatmap_path}") if __name__ == "__main__": main()