consciousness/sa-schedule-analyze.py
Kent Overstreet 4225294d16 replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
2026-04-25 15:35:14 -04:00

108 lines
4.5 KiB
Python

"""Analyze the SA schedule readout JSON: per-head variance, static/dynamic
correlation, and a plot."""
import argparse
import json
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input_json")
ap.add_argument("--out-plot", default="/tmp/sa-schedule.png")
args = ap.parse_args()
with open(args.input_json) as f:
data = json.load(f)
num_layers = data["num_layers"]
num_heads = data["num_heads"]
Ls = np.arange(num_layers)
ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H)
logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H)
metric_op = np.array([row["metric_op_per_head"] for row in data["static"]]) # (L, H)
metric_fro = np.array([row["metric_fro_per_head"] for row in data["static"]])
mean_ent = ent.mean(axis=1)
std_ent = ent.std(axis=1)
mean_logit = logit_std.mean(axis=1)
std_logit = logit_std.std(axis=1)
mean_metric = metric_op.mean(axis=1)
std_metric = metric_op.std(axis=1)
# Per-head variance summary
print("\nPer-head variance across heads (coefficient of variation = std/mean):")
print(f" entropy: mean CV = {(std_ent / np.maximum(mean_ent, 1e-6)).mean():.3f}")
print(f" logit_std: mean CV = {(std_logit / np.maximum(mean_logit, 1e-6)).mean():.3f}")
print(f" metric_op: mean CV = {(std_metric / np.maximum(mean_metric, 1e-6)).mean():.3f}")
# Correlations across layers
corr_ent_metric = np.corrcoef(mean_ent, mean_metric)[0, 1]
corr_logit_metric = np.corrcoef(mean_logit, mean_metric)[0, 1]
corr_ent_logit = np.corrcoef(mean_ent, mean_logit)[0, 1]
print("\nAcross-layer Pearson correlations (averaged over heads):")
print(f" entropy vs metric_op: {corr_ent_metric:+.3f}")
print(f" logit_std vs metric_op: {corr_logit_metric:+.3f}")
print(f" entropy vs logit_std: {corr_ent_logit:+.3f}")
# Per-head correlation (one value per head): does each head's entropy
# across layers track its own metric_op across layers?
head_corrs = []
for h in range(num_heads):
c = np.corrcoef(ent[:, h], metric_op[:, h])[0, 1]
if np.isfinite(c):
head_corrs.append(c)
print(f" per-head entropy vs metric_op: mean {np.mean(head_corrs):+.3f} "
f"std {np.std(head_corrs):.3f} min {min(head_corrs):+.3f} max {max(head_corrs):+.3f}")
# Plot
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
ax = axes[0]
ax.fill_between(Ls, mean_ent - std_ent, mean_ent + std_ent, alpha=0.2, color="tab:blue",
label="±1 std across heads")
ax.plot(Ls, mean_ent, color="tab:blue", marker="o", label="mean entropy")
ax.set_ylabel("attention entropy (nats)")
ax.set_title(f"{data['model']} — SA schedule readout ({num_layers} layers, {num_heads} heads)")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
ax = axes[1]
ax.fill_between(Ls, mean_logit - std_logit, mean_logit + std_logit, alpha=0.2, color="tab:orange",
label="±1 std across heads")
ax.plot(Ls, mean_logit, color="tab:orange", marker="o", label="mean logit std")
ax.set_ylabel("pre-softmax logit std\n(= implicit sharpness)")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
ax = axes[2]
ax.fill_between(Ls, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2, color="tab:green",
label="±1 std across heads")
ax.plot(Ls, mean_metric, color="tab:green", marker="o", label="mean metric op-norm (static)")
ax.set_ylabel("||W_K^T W_Q|| operator norm\n(static, parameter-only)")
ax.set_xlabel("layer index L")
ax.legend(loc="upper right")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(args.out_plot, dpi=100, bbox_inches="tight")
print(f"\nWrote plot to {args.out_plot}")
# Also save a small heatmap of per-head entropy for visual spread
plt.figure(figsize=(10, 6))
plt.imshow(ent.T, aspect="auto", cmap="viridis", origin="lower")
plt.colorbar(label="attention entropy")
plt.xlabel("layer L")
plt.ylabel("head h")
plt.title(f"{data['model']} — per-head entropy heatmap")
heatmap_path = args.out_plot.replace(".png", "-heatmap.png")
plt.tight_layout()
plt.savefig(heatmap_path, dpi=100, bbox_inches="tight")
print(f"Wrote heatmap to {heatmap_path}")
if __name__ == "__main__":
main()