consciousness/sa-schedule-readout-measure.py

246 lines
10 KiB
Python
Raw Normal View History

"""
SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default).
Measures per-layer "temperature" signals:
- entropy of softmax attention (per head, aggregated)
- magnitude of pre-softmax logits (implicit sharpness)
- spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed)
Output:
stats.json numeric summary per layer / head
activations stats by layer accumulated across a calibration set
Goal:
Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only
prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration
T suffices. Disagreement => content-adaptive structure lives in the activations.
"""
import argparse
import json
import os
import math
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
CALIBRATION_PROMPTS = [
# general knowledge
"The Eiffel Tower is located in",
"Photosynthesis is the process by which",
"The three branches of the US government are",
# math / reasoning
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
"Solve for x: 3x + 7 = 22. The answer is x =",
"The derivative of x^3 + 2x^2 is",
# code
"def fibonacci(n):\n if n < 2:\n return n\n return",
"# Python list comprehension to square even numbers in 0-9\nresult = ",
"SELECT name, age FROM users WHERE",
# narrative / long-form
"She opened the old wooden box and found",
"The argument in favor of renewable energy is",
# chat / instruction
"User: What is the capital of Australia?\nAssistant:",
"Write a haiku about autumn:\n",
# factual / lookup
"Albert Einstein was born in the year",
"The speed of light in vacuum is approximately",
# conversational
"I really loved that movie because",
"The main difference between a virus and a bacterium is",
# translation-ish
"The French word for 'apple' is",
# edge cases
"1 + 1 = ",
"Once upon a time, in a land far away,",
]
@torch.no_grad()
def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16):
print(f"Loading {model_name} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
attn_implementation="eager", # need raw attention probabilities
)
model.eval()
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} "
f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}",
flush=True)
# ---- Static (parameter-only) readout ----
# Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim)
# and record its singular spectrum. Metric norm is our "static temperature" prediction.
# With grouped-query attention, each query head shares a KV head; we compute metric per
# query head using the shared KV head.
static_stats = []
for L, layer in enumerate(model.model.layers):
attn = layer.self_attn
W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden)
W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden)
per_head_metric_fro = []
per_head_metric_op = []
per_head_metric_singvals = []
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden)
wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden)
# metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden).
# But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h,
# or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim.
small = wk_h @ wq_h.T # (head_dim, head_dim)
s = torch.linalg.svdvals(small) # (head_dim,)
per_head_metric_fro.append(float(s.pow(2).sum().sqrt()))
per_head_metric_op.append(float(s.max()))
per_head_metric_singvals.append(s.tolist())
static_stats.append({
"layer": L,
"metric_fro_per_head": per_head_metric_fro,
"metric_op_per_head": per_head_metric_op,
"metric_singvals_per_head": per_head_metric_singvals,
})
if L % 8 == 0:
print(f" static layer {L}: mean op-norm over heads = "
f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}",
flush=True)
# ---- Dynamic (activation) readout ----
# Hook each attention layer with output_attentions. Per layer, per head, accumulate
# sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set.
acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64)
acc_n_positions = torch.zeros(num_layers, dtype=torch.float64)
# The simplest path: run with output_attentions=True; eager impl returns attn probs.
# We cannot get pre-softmax logits from the HF API directly; extract them manually
# via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and
# compare against attention_mask (we care about unmasked positions only).
captured = {}
def make_hook(layer_idx):
def hook(module, inp, out):
# eager attention returns (attn_output, attn_weights, past_key_value)
# attn_weights has shape (bsz, num_heads, q_len, k_len)
if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None:
captured[layer_idx] = out[1].detach()
else:
captured[layer_idx] = None
return hook
hooks = []
for L, layer in enumerate(model.model.layers):
h = layer.self_attn.register_forward_hook(make_hook(L))
hooks.append(h)
for i, prompt in enumerate(CALIBRATION_PROMPTS):
inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda")
captured.clear()
_ = model(**inp, output_attentions=True, use_cache=False)
seq_len = inp["input_ids"].shape[1]
for L in range(num_layers):
aw = captured.get(L, None)
if aw is None:
continue
# aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask
# entropy: -sum p log p over last dim. Positions with fewer valid keys have
# naturally lower max entropy; we average over positions anyway.
p = aw.float().squeeze(0) # (num_heads, q_len, k_len)
eps = 1e-12
ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len)
acc_entropy[L] += ent.mean(dim=-1).cpu().double()
# Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some
# row constant c(i); we can recover up to row constant by log p (masking zeros).
# To get a usable logit magnitude, we take the (unmasked) per-row std.
logp = (p + eps).log() # (num_heads, q_len, k_len)
# mask invalid keys (p==0 means masked)
valid = (p > 0).float()
denom = valid.sum(dim=-1).clamp_min(1)
mean_logp = (logp * valid).sum(dim=-1) / denom
centered = (logp - mean_logp.unsqueeze(-1)) * valid
var_logp = (centered.pow(2).sum(dim=-1) / denom)
# per-row std of logits is a direct readout of logit magnitude (== sharpness)
row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len)
acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double()
acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double()
acc_n_positions += 1 # once per prompt
if i % 5 == 0:
print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True)
for h in hooks:
h.remove()
# Normalize by number of prompts (all contributed 1 sample per layer/head)
n = max(len(CALIBRATION_PROMPTS), 1)
mean_entropy = (acc_entropy / n).tolist()
mean_logit_mag = (acc_logit_mag / n).tolist()
mean_logit_var = (acc_logit_var / n).tolist()
# Assemble output
dynamic_stats = []
for L in range(num_layers):
dynamic_stats.append({
"layer": L,
"mean_attention_entropy_per_head": mean_entropy[L],
"mean_logit_std_per_head": mean_logit_mag[L],
"mean_logit_var_per_head": mean_logit_var[L],
"mean_attention_entropy": sum(mean_entropy[L]) / num_heads,
"mean_logit_std": sum(mean_logit_mag[L]) / num_heads,
})
output = {
"model": model_name,
"num_layers": num_layers,
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"hidden_size": hidden,
"n_prompts": len(CALIBRATION_PROMPTS),
"static": static_stats,
"dynamic": dynamic_stats,
}
with open(out_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\nWrote {out_path}", flush=True)
# Quick summary to console
print("\nPer-layer schedule readout (averaged over heads):")
print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}")
for L in range(num_layers):
mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads
print(f" {L:>3} "
f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} "
f"{dynamic_stats[L]['mean_logit_std']:>16.4f} "
f"{mean_op:>16.4f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3-8B")
parser.add_argument("--out", default="/tmp/sa-schedule-readout.json")
parser.add_argument("--max-seq-len", type=int, default=256)
args = parser.parse_args()
measure_model(args.model, args.out, max_seq_len=args.max_seq_len)
if __name__ == "__main__":
main()