""" SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default). Measures per-layer "temperature" signals: - entropy of softmax attention (per head, aggregated) - magnitude of pre-softmax logits (implicit sharpness) - spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed) Output: stats.json — numeric summary per layer / head activations stats by layer accumulated across a calibration set Goal: Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration T suffices. Disagreement => content-adaptive structure lives in the activations. """ import argparse import json import os import math import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer CALIBRATION_PROMPTS = [ # general knowledge "The Eiffel Tower is located in", "Photosynthesis is the process by which", "The three branches of the US government are", # math / reasoning "If a train travels 60 miles per hour for 2.5 hours, the total distance covered is", "Solve for x: 3x + 7 = 22. The answer is x =", "The derivative of x^3 + 2x^2 is", # code "def fibonacci(n):\n if n < 2:\n return n\n return", "# Python list comprehension to square even numbers in 0-9\nresult = ", "SELECT name, age FROM users WHERE", # narrative / long-form "She opened the old wooden box and found", "The argument in favor of renewable energy is", # chat / instruction "User: What is the capital of Australia?\nAssistant:", "Write a haiku about autumn:\n", # factual / lookup "Albert Einstein was born in the year", "The speed of light in vacuum is approximately", # conversational "I really loved that movie because", "The main difference between a virus and a bacterium is", # translation-ish "The French word for 'apple' is", # edge cases "1 + 1 = ", "Once upon a time, in a land far away,", ] @torch.no_grad() def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16): print(f"Loading {model_name} ...", flush=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="cuda", trust_remote_code=True, attn_implementation="eager", # need raw attention probabilities ) model.eval() cfg = model.config num_layers = cfg.num_hidden_layers num_heads = cfg.num_attention_heads hidden = cfg.hidden_size head_dim = getattr(cfg, "head_dim", hidden // num_heads) num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads) print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} " f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}", flush=True) # ---- Static (parameter-only) readout ---- # Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim) # and record its singular spectrum. Metric norm is our "static temperature" prediction. # With grouped-query attention, each query head shares a KV head; we compute metric per # query head using the shared KV head. static_stats = [] for L, layer in enumerate(model.model.layers): attn = layer.self_attn W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden) W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden) per_head_metric_fro = [] per_head_metric_op = [] per_head_metric_singvals = [] for h in range(num_heads): kv_h = (h * num_kv_heads) // num_heads wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden) wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden) # metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden). # But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h, # or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim. small = wk_h @ wq_h.T # (head_dim, head_dim) s = torch.linalg.svdvals(small) # (head_dim,) per_head_metric_fro.append(float(s.pow(2).sum().sqrt())) per_head_metric_op.append(float(s.max())) per_head_metric_singvals.append(s.tolist()) static_stats.append({ "layer": L, "metric_fro_per_head": per_head_metric_fro, "metric_op_per_head": per_head_metric_op, "metric_singvals_per_head": per_head_metric_singvals, }) if L % 8 == 0: print(f" static layer {L}: mean op-norm over heads = " f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}", flush=True) # ---- Dynamic (activation) readout ---- # Hook each attention layer with output_attentions. Per layer, per head, accumulate # sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set. acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64) acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64) acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64) acc_n_positions = torch.zeros(num_layers, dtype=torch.float64) # The simplest path: run with output_attentions=True; eager impl returns attn probs. # We cannot get pre-softmax logits from the HF API directly; extract them manually # via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and # compare against attention_mask (we care about unmasked positions only). captured = {} def make_hook(layer_idx): def hook(module, inp, out): # eager attention returns (attn_output, attn_weights, past_key_value) # attn_weights has shape (bsz, num_heads, q_len, k_len) if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None: captured[layer_idx] = out[1].detach() else: captured[layer_idx] = None return hook hooks = [] for L, layer in enumerate(model.model.layers): h = layer.self_attn.register_forward_hook(make_hook(L)) hooks.append(h) for i, prompt in enumerate(CALIBRATION_PROMPTS): inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda") captured.clear() _ = model(**inp, output_attentions=True, use_cache=False) seq_len = inp["input_ids"].shape[1] for L in range(num_layers): aw = captured.get(L, None) if aw is None: continue # aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask # entropy: -sum p log p over last dim. Positions with fewer valid keys have # naturally lower max entropy; we average over positions anyway. p = aw.float().squeeze(0) # (num_heads, q_len, k_len) eps = 1e-12 ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len) acc_entropy[L] += ent.mean(dim=-1).cpu().double() # Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some # row constant c(i); we can recover up to row constant by log p (masking zeros). # To get a usable logit magnitude, we take the (unmasked) per-row std. logp = (p + eps).log() # (num_heads, q_len, k_len) # mask invalid keys (p==0 means masked) valid = (p > 0).float() denom = valid.sum(dim=-1).clamp_min(1) mean_logp = (logp * valid).sum(dim=-1) / denom centered = (logp - mean_logp.unsqueeze(-1)) * valid var_logp = (centered.pow(2).sum(dim=-1) / denom) # per-row std of logits is a direct readout of logit magnitude (== sharpness) row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len) acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double() acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double() acc_n_positions += 1 # once per prompt if i % 5 == 0: print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True) for h in hooks: h.remove() # Normalize by number of prompts (all contributed 1 sample per layer/head) n = max(len(CALIBRATION_PROMPTS), 1) mean_entropy = (acc_entropy / n).tolist() mean_logit_mag = (acc_logit_mag / n).tolist() mean_logit_var = (acc_logit_var / n).tolist() # Assemble output dynamic_stats = [] for L in range(num_layers): dynamic_stats.append({ "layer": L, "mean_attention_entropy_per_head": mean_entropy[L], "mean_logit_std_per_head": mean_logit_mag[L], "mean_logit_var_per_head": mean_logit_var[L], "mean_attention_entropy": sum(mean_entropy[L]) / num_heads, "mean_logit_std": sum(mean_logit_mag[L]) / num_heads, }) output = { "model": model_name, "num_layers": num_layers, "num_heads": num_heads, "num_kv_heads": num_kv_heads, "head_dim": head_dim, "hidden_size": hidden, "n_prompts": len(CALIBRATION_PROMPTS), "static": static_stats, "dynamic": dynamic_stats, } with open(out_path, "w") as f: json.dump(output, f, indent=2) print(f"\nWrote {out_path}", flush=True) # Quick summary to console print("\nPer-layer schedule readout (averaged over heads):") print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}") for L in range(num_layers): mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads print(f" {L:>3} " f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} " f"{dynamic_stats[L]['mean_logit_std']:>16.4f} " f"{mean_op:>16.4f}") def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="Qwen/Qwen3-8B") parser.add_argument("--out", default="/tmp/sa-schedule-readout.json") parser.add_argument("--max-seq-len", type=int, default=256) args = parser.parse_args() measure_model(args.model, args.out, max_seq_len=args.max_seq_len) if __name__ == "__main__": main()