consciousness/sa-schedule-readout-measure.py
Kent Overstreet 4225294d16 replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
2026-04-25 15:35:14 -04:00

246 lines
10 KiB
Python

"""
SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default).
Measures per-layer "temperature" signals:
- entropy of softmax attention (per head, aggregated)
- magnitude of pre-softmax logits (implicit sharpness)
- spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed)
Output:
stats.json — numeric summary per layer / head
activations stats by layer accumulated across a calibration set
Goal:
Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only
prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration
T suffices. Disagreement => content-adaptive structure lives in the activations.
"""
import argparse
import json
import os
import math
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
CALIBRATION_PROMPTS = [
# general knowledge
"The Eiffel Tower is located in",
"Photosynthesis is the process by which",
"The three branches of the US government are",
# math / reasoning
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
"Solve for x: 3x + 7 = 22. The answer is x =",
"The derivative of x^3 + 2x^2 is",
# code
"def fibonacci(n):\n if n < 2:\n return n\n return",
"# Python list comprehension to square even numbers in 0-9\nresult = ",
"SELECT name, age FROM users WHERE",
# narrative / long-form
"She opened the old wooden box and found",
"The argument in favor of renewable energy is",
# chat / instruction
"User: What is the capital of Australia?\nAssistant:",
"Write a haiku about autumn:\n",
# factual / lookup
"Albert Einstein was born in the year",
"The speed of light in vacuum is approximately",
# conversational
"I really loved that movie because",
"The main difference between a virus and a bacterium is",
# translation-ish
"The French word for 'apple' is",
# edge cases
"1 + 1 = ",
"Once upon a time, in a land far away,",
]
@torch.no_grad()
def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16):
print(f"Loading {model_name} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
attn_implementation="eager", # need raw attention probabilities
)
model.eval()
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} "
f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}",
flush=True)
# ---- Static (parameter-only) readout ----
# Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim)
# and record its singular spectrum. Metric norm is our "static temperature" prediction.
# With grouped-query attention, each query head shares a KV head; we compute metric per
# query head using the shared KV head.
static_stats = []
for L, layer in enumerate(model.model.layers):
attn = layer.self_attn
W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden)
W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden)
per_head_metric_fro = []
per_head_metric_op = []
per_head_metric_singvals = []
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden)
wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden)
# metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden).
# But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h,
# or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim.
small = wk_h @ wq_h.T # (head_dim, head_dim)
s = torch.linalg.svdvals(small) # (head_dim,)
per_head_metric_fro.append(float(s.pow(2).sum().sqrt()))
per_head_metric_op.append(float(s.max()))
per_head_metric_singvals.append(s.tolist())
static_stats.append({
"layer": L,
"metric_fro_per_head": per_head_metric_fro,
"metric_op_per_head": per_head_metric_op,
"metric_singvals_per_head": per_head_metric_singvals,
})
if L % 8 == 0:
print(f" static layer {L}: mean op-norm over heads = "
f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}",
flush=True)
# ---- Dynamic (activation) readout ----
# Hook each attention layer with output_attentions. Per layer, per head, accumulate
# sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set.
acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_n_positions = torch.zeros(num_layers, dtype=torch.float64)
# The simplest path: run with output_attentions=True; eager impl returns attn probs.
# We cannot get pre-softmax logits from the HF API directly; extract them manually
# via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and
# compare against attention_mask (we care about unmasked positions only).
captured = {}
def make_hook(layer_idx):
def hook(module, inp, out):
# eager attention returns (attn_output, attn_weights, past_key_value)
# attn_weights has shape (bsz, num_heads, q_len, k_len)
if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None:
captured[layer_idx] = out[1].detach()
else:
captured[layer_idx] = None
return hook
hooks = []
for L, layer in enumerate(model.model.layers):
h = layer.self_attn.register_forward_hook(make_hook(L))
hooks.append(h)
for i, prompt in enumerate(CALIBRATION_PROMPTS):
inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda")
captured.clear()
_ = model(**inp, output_attentions=True, use_cache=False)
seq_len = inp["input_ids"].shape[1]
for L in range(num_layers):
aw = captured.get(L, None)
if aw is None:
continue
# aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask
# entropy: -sum p log p over last dim. Positions with fewer valid keys have
# naturally lower max entropy; we average over positions anyway.
p = aw.float().squeeze(0) # (num_heads, q_len, k_len)
eps = 1e-12
ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len)
acc_entropy[L] += ent.mean(dim=-1).cpu().double()
# Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some
# row constant c(i); we can recover up to row constant by log p (masking zeros).
# To get a usable logit magnitude, we take the (unmasked) per-row std.
logp = (p + eps).log() # (num_heads, q_len, k_len)
# mask invalid keys (p==0 means masked)
valid = (p > 0).float()
denom = valid.sum(dim=-1).clamp_min(1)
mean_logp = (logp * valid).sum(dim=-1) / denom
centered = (logp - mean_logp.unsqueeze(-1)) * valid
var_logp = (centered.pow(2).sum(dim=-1) / denom)
# per-row std of logits is a direct readout of logit magnitude (== sharpness)
row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len)
acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double()
acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double()
acc_n_positions += 1 # once per prompt
if i % 5 == 0:
print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True)
for h in hooks:
h.remove()
# Normalize by number of prompts (all contributed 1 sample per layer/head)
n = max(len(CALIBRATION_PROMPTS), 1)
mean_entropy = (acc_entropy / n).tolist()
mean_logit_mag = (acc_logit_mag / n).tolist()
mean_logit_var = (acc_logit_var / n).tolist()
# Assemble output
dynamic_stats = []
for L in range(num_layers):
dynamic_stats.append({
"layer": L,
"mean_attention_entropy_per_head": mean_entropy[L],
"mean_logit_std_per_head": mean_logit_mag[L],
"mean_logit_var_per_head": mean_logit_var[L],
"mean_attention_entropy": sum(mean_entropy[L]) / num_heads,
"mean_logit_std": sum(mean_logit_mag[L]) / num_heads,
})
output = {
"model": model_name,
"num_layers": num_layers,
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"hidden_size": hidden,
"n_prompts": len(CALIBRATION_PROMPTS),
"static": static_stats,
"dynamic": dynamic_stats,
}
with open(out_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\nWrote {out_path}", flush=True)
# Quick summary to console
print("\nPer-layer schedule readout (averaged over heads):")
print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}")
for L in range(num_layers):
mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads
print(f" {L:>3} "
f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} "
f"{dynamic_stats[L]['mean_logit_std']:>16.4f} "
f"{mean_op:>16.4f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3-8B")
parser.add_argument("--out", default="/tmp/sa-schedule-readout.json")
parser.add_argument("--max-seq-len", type=int, default=256)
args = parser.parse_args()
measure_model(args.model, args.out, max_seq_len=args.max_seq_len)
if __name__ == "__main__":
main()