forked from kent/consciousness
246 lines
10 KiB
Python
246 lines
10 KiB
Python
|
|
"""
|
||
|
|
SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default).
|
||
|
|
|
||
|
|
Measures per-layer "temperature" signals:
|
||
|
|
- entropy of softmax attention (per head, aggregated)
|
||
|
|
- magnitude of pre-softmax logits (implicit sharpness)
|
||
|
|
- spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed)
|
||
|
|
|
||
|
|
Output:
|
||
|
|
stats.json — numeric summary per layer / head
|
||
|
|
activations stats by layer accumulated across a calibration set
|
||
|
|
|
||
|
|
Goal:
|
||
|
|
Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only
|
||
|
|
prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration
|
||
|
|
T suffices. Disagreement => content-adaptive structure lives in the activations.
|
||
|
|
"""
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import math
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
|
||
|
|
CALIBRATION_PROMPTS = [
|
||
|
|
# general knowledge
|
||
|
|
"The Eiffel Tower is located in",
|
||
|
|
"Photosynthesis is the process by which",
|
||
|
|
"The three branches of the US government are",
|
||
|
|
# math / reasoning
|
||
|
|
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||
|
|
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||
|
|
"The derivative of x^3 + 2x^2 is",
|
||
|
|
# code
|
||
|
|
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||
|
|
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||
|
|
"SELECT name, age FROM users WHERE",
|
||
|
|
# narrative / long-form
|
||
|
|
"She opened the old wooden box and found",
|
||
|
|
"The argument in favor of renewable energy is",
|
||
|
|
# chat / instruction
|
||
|
|
"User: What is the capital of Australia?\nAssistant:",
|
||
|
|
"Write a haiku about autumn:\n",
|
||
|
|
# factual / lookup
|
||
|
|
"Albert Einstein was born in the year",
|
||
|
|
"The speed of light in vacuum is approximately",
|
||
|
|
# conversational
|
||
|
|
"I really loved that movie because",
|
||
|
|
"The main difference between a virus and a bacterium is",
|
||
|
|
# translation-ish
|
||
|
|
"The French word for 'apple' is",
|
||
|
|
# edge cases
|
||
|
|
"1 + 1 = ",
|
||
|
|
"Once upon a time, in a land far away,",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16):
|
||
|
|
print(f"Loading {model_name} ...", flush=True)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
torch_dtype=dtype,
|
||
|
|
device_map="cuda",
|
||
|
|
trust_remote_code=True,
|
||
|
|
attn_implementation="eager", # need raw attention probabilities
|
||
|
|
)
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
cfg = model.config
|
||
|
|
num_layers = cfg.num_hidden_layers
|
||
|
|
num_heads = cfg.num_attention_heads
|
||
|
|
hidden = cfg.hidden_size
|
||
|
|
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||
|
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||
|
|
print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} "
|
||
|
|
f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}",
|
||
|
|
flush=True)
|
||
|
|
|
||
|
|
# ---- Static (parameter-only) readout ----
|
||
|
|
# Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim)
|
||
|
|
# and record its singular spectrum. Metric norm is our "static temperature" prediction.
|
||
|
|
# With grouped-query attention, each query head shares a KV head; we compute metric per
|
||
|
|
# query head using the shared KV head.
|
||
|
|
static_stats = []
|
||
|
|
for L, layer in enumerate(model.model.layers):
|
||
|
|
attn = layer.self_attn
|
||
|
|
W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden)
|
||
|
|
W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden)
|
||
|
|
|
||
|
|
per_head_metric_fro = []
|
||
|
|
per_head_metric_op = []
|
||
|
|
per_head_metric_singvals = []
|
||
|
|
for h in range(num_heads):
|
||
|
|
kv_h = (h * num_kv_heads) // num_heads
|
||
|
|
wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden)
|
||
|
|
wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden)
|
||
|
|
# metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden).
|
||
|
|
# But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h,
|
||
|
|
# or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim.
|
||
|
|
small = wk_h @ wq_h.T # (head_dim, head_dim)
|
||
|
|
s = torch.linalg.svdvals(small) # (head_dim,)
|
||
|
|
per_head_metric_fro.append(float(s.pow(2).sum().sqrt()))
|
||
|
|
per_head_metric_op.append(float(s.max()))
|
||
|
|
per_head_metric_singvals.append(s.tolist())
|
||
|
|
static_stats.append({
|
||
|
|
"layer": L,
|
||
|
|
"metric_fro_per_head": per_head_metric_fro,
|
||
|
|
"metric_op_per_head": per_head_metric_op,
|
||
|
|
"metric_singvals_per_head": per_head_metric_singvals,
|
||
|
|
})
|
||
|
|
if L % 8 == 0:
|
||
|
|
print(f" static layer {L}: mean op-norm over heads = "
|
||
|
|
f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}",
|
||
|
|
flush=True)
|
||
|
|
|
||
|
|
# ---- Dynamic (activation) readout ----
|
||
|
|
# Hook each attention layer with output_attentions. Per layer, per head, accumulate
|
||
|
|
# sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set.
|
||
|
|
acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||
|
|
acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||
|
|
acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||
|
|
acc_n_positions = torch.zeros(num_layers, dtype=torch.float64)
|
||
|
|
|
||
|
|
# The simplest path: run with output_attentions=True; eager impl returns attn probs.
|
||
|
|
# We cannot get pre-softmax logits from the HF API directly; extract them manually
|
||
|
|
# via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and
|
||
|
|
# compare against attention_mask (we care about unmasked positions only).
|
||
|
|
|
||
|
|
captured = {}
|
||
|
|
|
||
|
|
def make_hook(layer_idx):
|
||
|
|
def hook(module, inp, out):
|
||
|
|
# eager attention returns (attn_output, attn_weights, past_key_value)
|
||
|
|
# attn_weights has shape (bsz, num_heads, q_len, k_len)
|
||
|
|
if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None:
|
||
|
|
captured[layer_idx] = out[1].detach()
|
||
|
|
else:
|
||
|
|
captured[layer_idx] = None
|
||
|
|
return hook
|
||
|
|
|
||
|
|
hooks = []
|
||
|
|
for L, layer in enumerate(model.model.layers):
|
||
|
|
h = layer.self_attn.register_forward_hook(make_hook(L))
|
||
|
|
hooks.append(h)
|
||
|
|
|
||
|
|
for i, prompt in enumerate(CALIBRATION_PROMPTS):
|
||
|
|
inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda")
|
||
|
|
captured.clear()
|
||
|
|
_ = model(**inp, output_attentions=True, use_cache=False)
|
||
|
|
seq_len = inp["input_ids"].shape[1]
|
||
|
|
|
||
|
|
for L in range(num_layers):
|
||
|
|
aw = captured.get(L, None)
|
||
|
|
if aw is None:
|
||
|
|
continue
|
||
|
|
# aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask
|
||
|
|
# entropy: -sum p log p over last dim. Positions with fewer valid keys have
|
||
|
|
# naturally lower max entropy; we average over positions anyway.
|
||
|
|
p = aw.float().squeeze(0) # (num_heads, q_len, k_len)
|
||
|
|
eps = 1e-12
|
||
|
|
ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len)
|
||
|
|
acc_entropy[L] += ent.mean(dim=-1).cpu().double()
|
||
|
|
|
||
|
|
# Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some
|
||
|
|
# row constant c(i); we can recover up to row constant by log p (masking zeros).
|
||
|
|
# To get a usable logit magnitude, we take the (unmasked) per-row std.
|
||
|
|
logp = (p + eps).log() # (num_heads, q_len, k_len)
|
||
|
|
# mask invalid keys (p==0 means masked)
|
||
|
|
valid = (p > 0).float()
|
||
|
|
denom = valid.sum(dim=-1).clamp_min(1)
|
||
|
|
mean_logp = (logp * valid).sum(dim=-1) / denom
|
||
|
|
centered = (logp - mean_logp.unsqueeze(-1)) * valid
|
||
|
|
var_logp = (centered.pow(2).sum(dim=-1) / denom)
|
||
|
|
# per-row std of logits is a direct readout of logit magnitude (== sharpness)
|
||
|
|
row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len)
|
||
|
|
acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double()
|
||
|
|
acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double()
|
||
|
|
|
||
|
|
acc_n_positions += 1 # once per prompt
|
||
|
|
|
||
|
|
if i % 5 == 0:
|
||
|
|
print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True)
|
||
|
|
|
||
|
|
for h in hooks:
|
||
|
|
h.remove()
|
||
|
|
|
||
|
|
# Normalize by number of prompts (all contributed 1 sample per layer/head)
|
||
|
|
n = max(len(CALIBRATION_PROMPTS), 1)
|
||
|
|
mean_entropy = (acc_entropy / n).tolist()
|
||
|
|
mean_logit_mag = (acc_logit_mag / n).tolist()
|
||
|
|
mean_logit_var = (acc_logit_var / n).tolist()
|
||
|
|
|
||
|
|
# Assemble output
|
||
|
|
dynamic_stats = []
|
||
|
|
for L in range(num_layers):
|
||
|
|
dynamic_stats.append({
|
||
|
|
"layer": L,
|
||
|
|
"mean_attention_entropy_per_head": mean_entropy[L],
|
||
|
|
"mean_logit_std_per_head": mean_logit_mag[L],
|
||
|
|
"mean_logit_var_per_head": mean_logit_var[L],
|
||
|
|
"mean_attention_entropy": sum(mean_entropy[L]) / num_heads,
|
||
|
|
"mean_logit_std": sum(mean_logit_mag[L]) / num_heads,
|
||
|
|
})
|
||
|
|
|
||
|
|
output = {
|
||
|
|
"model": model_name,
|
||
|
|
"num_layers": num_layers,
|
||
|
|
"num_heads": num_heads,
|
||
|
|
"num_kv_heads": num_kv_heads,
|
||
|
|
"head_dim": head_dim,
|
||
|
|
"hidden_size": hidden,
|
||
|
|
"n_prompts": len(CALIBRATION_PROMPTS),
|
||
|
|
"static": static_stats,
|
||
|
|
"dynamic": dynamic_stats,
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(out_path, "w") as f:
|
||
|
|
json.dump(output, f, indent=2)
|
||
|
|
print(f"\nWrote {out_path}", flush=True)
|
||
|
|
|
||
|
|
# Quick summary to console
|
||
|
|
print("\nPer-layer schedule readout (averaged over heads):")
|
||
|
|
print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}")
|
||
|
|
for L in range(num_layers):
|
||
|
|
mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads
|
||
|
|
print(f" {L:>3} "
|
||
|
|
f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} "
|
||
|
|
f"{dynamic_stats[L]['mean_logit_std']:>16.4f} "
|
||
|
|
f"{mean_op:>16.4f}")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--model", default="Qwen/Qwen3-8B")
|
||
|
|
parser.add_argument("--out", default="/tmp/sa-schedule-readout.json")
|
||
|
|
parser.add_argument("--max-seq-len", type=int, default=256)
|
||
|
|
args = parser.parse_args()
|
||
|
|
measure_model(args.model, args.out, max_seq_len=args.max_seq_len)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|