forked from kent/consciousness
replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using block_in_place + futures::executor::block_on, safe for sync contexts. Replace all try_lock() calls with lock_blocking() in slash commands, UI rendering, and status reads. Lock hold times are fast enough that blocking briefly is fine, and this eliminates the spurious 'lock unavailable' paths that were never actually hit. Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
This commit is contained in:
parent
5210f7dd66
commit
4225294d16
28 changed files with 4199 additions and 67 deletions
246
sa-schedule-readout-measure.py
Normal file
246
sa-schedule-readout-measure.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default).
|
||||
|
||||
Measures per-layer "temperature" signals:
|
||||
- entropy of softmax attention (per head, aggregated)
|
||||
- magnitude of pre-softmax logits (implicit sharpness)
|
||||
- spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed)
|
||||
|
||||
Output:
|
||||
stats.json — numeric summary per layer / head
|
||||
activations stats by layer accumulated across a calibration set
|
||||
|
||||
Goal:
|
||||
Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only
|
||||
prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration
|
||||
T suffices. Disagreement => content-adaptive structure lives in the activations.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
CALIBRATION_PROMPTS = [
|
||||
# general knowledge
|
||||
"The Eiffel Tower is located in",
|
||||
"Photosynthesis is the process by which",
|
||||
"The three branches of the US government are",
|
||||
# math / reasoning
|
||||
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||||
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||||
"The derivative of x^3 + 2x^2 is",
|
||||
# code
|
||||
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||||
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||||
"SELECT name, age FROM users WHERE",
|
||||
# narrative / long-form
|
||||
"She opened the old wooden box and found",
|
||||
"The argument in favor of renewable energy is",
|
||||
# chat / instruction
|
||||
"User: What is the capital of Australia?\nAssistant:",
|
||||
"Write a haiku about autumn:\n",
|
||||
# factual / lookup
|
||||
"Albert Einstein was born in the year",
|
||||
"The speed of light in vacuum is approximately",
|
||||
# conversational
|
||||
"I really loved that movie because",
|
||||
"The main difference between a virus and a bacterium is",
|
||||
# translation-ish
|
||||
"The French word for 'apple' is",
|
||||
# edge cases
|
||||
"1 + 1 = ",
|
||||
"Once upon a time, in a land far away,",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16):
|
||||
print(f"Loading {model_name} ...", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager", # need raw attention probabilities
|
||||
)
|
||||
model.eval()
|
||||
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} "
|
||||
f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}",
|
||||
flush=True)
|
||||
|
||||
# ---- Static (parameter-only) readout ----
|
||||
# Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim)
|
||||
# and record its singular spectrum. Metric norm is our "static temperature" prediction.
|
||||
# With grouped-query attention, each query head shares a KV head; we compute metric per
|
||||
# query head using the shared KV head.
|
||||
static_stats = []
|
||||
for L, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden)
|
||||
W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden)
|
||||
|
||||
per_head_metric_fro = []
|
||||
per_head_metric_op = []
|
||||
per_head_metric_singvals = []
|
||||
for h in range(num_heads):
|
||||
kv_h = (h * num_kv_heads) // num_heads
|
||||
wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden)
|
||||
wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden)
|
||||
# metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden).
|
||||
# But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h,
|
||||
# or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim.
|
||||
small = wk_h @ wq_h.T # (head_dim, head_dim)
|
||||
s = torch.linalg.svdvals(small) # (head_dim,)
|
||||
per_head_metric_fro.append(float(s.pow(2).sum().sqrt()))
|
||||
per_head_metric_op.append(float(s.max()))
|
||||
per_head_metric_singvals.append(s.tolist())
|
||||
static_stats.append({
|
||||
"layer": L,
|
||||
"metric_fro_per_head": per_head_metric_fro,
|
||||
"metric_op_per_head": per_head_metric_op,
|
||||
"metric_singvals_per_head": per_head_metric_singvals,
|
||||
})
|
||||
if L % 8 == 0:
|
||||
print(f" static layer {L}: mean op-norm over heads = "
|
||||
f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}",
|
||||
flush=True)
|
||||
|
||||
# ---- Dynamic (activation) readout ----
|
||||
# Hook each attention layer with output_attentions. Per layer, per head, accumulate
|
||||
# sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set.
|
||||
acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_n_positions = torch.zeros(num_layers, dtype=torch.float64)
|
||||
|
||||
# The simplest path: run with output_attentions=True; eager impl returns attn probs.
|
||||
# We cannot get pre-softmax logits from the HF API directly; extract them manually
|
||||
# via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and
|
||||
# compare against attention_mask (we care about unmasked positions only).
|
||||
|
||||
captured = {}
|
||||
|
||||
def make_hook(layer_idx):
|
||||
def hook(module, inp, out):
|
||||
# eager attention returns (attn_output, attn_weights, past_key_value)
|
||||
# attn_weights has shape (bsz, num_heads, q_len, k_len)
|
||||
if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None:
|
||||
captured[layer_idx] = out[1].detach()
|
||||
else:
|
||||
captured[layer_idx] = None
|
||||
return hook
|
||||
|
||||
hooks = []
|
||||
for L, layer in enumerate(model.model.layers):
|
||||
h = layer.self_attn.register_forward_hook(make_hook(L))
|
||||
hooks.append(h)
|
||||
|
||||
for i, prompt in enumerate(CALIBRATION_PROMPTS):
|
||||
inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda")
|
||||
captured.clear()
|
||||
_ = model(**inp, output_attentions=True, use_cache=False)
|
||||
seq_len = inp["input_ids"].shape[1]
|
||||
|
||||
for L in range(num_layers):
|
||||
aw = captured.get(L, None)
|
||||
if aw is None:
|
||||
continue
|
||||
# aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask
|
||||
# entropy: -sum p log p over last dim. Positions with fewer valid keys have
|
||||
# naturally lower max entropy; we average over positions anyway.
|
||||
p = aw.float().squeeze(0) # (num_heads, q_len, k_len)
|
||||
eps = 1e-12
|
||||
ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len)
|
||||
acc_entropy[L] += ent.mean(dim=-1).cpu().double()
|
||||
|
||||
# Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some
|
||||
# row constant c(i); we can recover up to row constant by log p (masking zeros).
|
||||
# To get a usable logit magnitude, we take the (unmasked) per-row std.
|
||||
logp = (p + eps).log() # (num_heads, q_len, k_len)
|
||||
# mask invalid keys (p==0 means masked)
|
||||
valid = (p > 0).float()
|
||||
denom = valid.sum(dim=-1).clamp_min(1)
|
||||
mean_logp = (logp * valid).sum(dim=-1) / denom
|
||||
centered = (logp - mean_logp.unsqueeze(-1)) * valid
|
||||
var_logp = (centered.pow(2).sum(dim=-1) / denom)
|
||||
# per-row std of logits is a direct readout of logit magnitude (== sharpness)
|
||||
row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len)
|
||||
acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double()
|
||||
acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double()
|
||||
|
||||
acc_n_positions += 1 # once per prompt
|
||||
|
||||
if i % 5 == 0:
|
||||
print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Normalize by number of prompts (all contributed 1 sample per layer/head)
|
||||
n = max(len(CALIBRATION_PROMPTS), 1)
|
||||
mean_entropy = (acc_entropy / n).tolist()
|
||||
mean_logit_mag = (acc_logit_mag / n).tolist()
|
||||
mean_logit_var = (acc_logit_var / n).tolist()
|
||||
|
||||
# Assemble output
|
||||
dynamic_stats = []
|
||||
for L in range(num_layers):
|
||||
dynamic_stats.append({
|
||||
"layer": L,
|
||||
"mean_attention_entropy_per_head": mean_entropy[L],
|
||||
"mean_logit_std_per_head": mean_logit_mag[L],
|
||||
"mean_logit_var_per_head": mean_logit_var[L],
|
||||
"mean_attention_entropy": sum(mean_entropy[L]) / num_heads,
|
||||
"mean_logit_std": sum(mean_logit_mag[L]) / num_heads,
|
||||
})
|
||||
|
||||
output = {
|
||||
"model": model_name,
|
||||
"num_layers": num_layers,
|
||||
"num_heads": num_heads,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"hidden_size": hidden,
|
||||
"n_prompts": len(CALIBRATION_PROMPTS),
|
||||
"static": static_stats,
|
||||
"dynamic": dynamic_stats,
|
||||
}
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(output, f, indent=2)
|
||||
print(f"\nWrote {out_path}", flush=True)
|
||||
|
||||
# Quick summary to console
|
||||
print("\nPer-layer schedule readout (averaged over heads):")
|
||||
print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}")
|
||||
for L in range(num_layers):
|
||||
mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads
|
||||
print(f" {L:>3} "
|
||||
f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} "
|
||||
f"{dynamic_stats[L]['mean_logit_std']:>16.4f} "
|
||||
f"{mean_op:>16.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="Qwen/Qwen3-8B")
|
||||
parser.add_argument("--out", default="/tmp/sa-schedule-readout.json")
|
||||
parser.add_argument("--max-seq-len", type=int, default=256)
|
||||
args = parser.parse_args()
|
||||
measure_model(args.model, args.out, max_seq_len=args.max_seq_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue