replace try_lock() with lock_blocking() across UI thread

Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
This commit is contained in:
Kent Overstreet 2026-04-25 15:35:14 -04:00
commit 4225294d16
28 changed files with 4199 additions and 67 deletions

View file

@ -0,0 +1,168 @@
"""Measure the full inter-layer geometric relationship between per-head metrics.
For each (L, L', h) pair, compute the Frobenius inner product
<g_L^h, g_L'^h> = tr(g_L^h^T g_L'^h)
where g^h = W_K^h^T W_Q^h R^{hidden × hidden} (rank head_dim).
Using the head_dim × head_dim shortcut:
<g_L, g_L'> = tr(A B^T) with A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T.
Output: gram[L, L', h] and fro_sq[L, h]. From these every layer-pair comparison
is derivable without saving the full operators.
Also saves top-k principal directions per head (as right singular vectors of g,
which are the Q-side eigen-directions) so subspace overlap across layers can be
computed downstream.
"""
import argparse
import json
import os
import numpy as np
import torch
from transformers import AutoModelForCausalLM
@torch.no_grad()
def measure(model_name: str, out_path: str, topk: int = 8,
dtype=torch.bfloat16):
print(f"Loading {model_name} ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
attn_implementation="eager",
)
model.eval()
cfg = model.config
num_layers = cfg.num_hidden_layers
num_heads = cfg.num_attention_heads
hidden = cfg.hidden_size
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim}", flush=True)
# Collect W_Q, W_K per layer as (num_heads, head_dim, hidden) on GPU float32.
Wq_list = []
Wk_list = []
for L, layer in enumerate(model.model.layers):
attn = layer.self_attn
Wq = attn.q_proj.weight.detach().to(torch.float32) # (nh*hd, hidden)
Wk = attn.k_proj.weight.detach().to(torch.float32) # (nkv*hd, hidden)
Wq = Wq.view(num_heads, head_dim, hidden)
# Repeat kv heads so every query head has a matching k-row
Wk = Wk.view(num_kv_heads, head_dim, hidden)
# Broadcast to num_heads via (h // (num_heads // num_kv_heads))? safer: mapping
Wk_full = torch.zeros(num_heads, head_dim, hidden,
device=Wk.device, dtype=Wk.dtype)
for h in range(num_heads):
kv_h = (h * num_kv_heads) // num_heads
Wk_full[h] = Wk[kv_h]
Wq_list.append(Wq)
Wk_list.append(Wk_full)
print(f" loaded weights: {num_layers} layers", flush=True)
# Per-head top-k right singular vectors of g^h = W_K^T W_Q (hidden, hidden).
# The non-zero right singular vectors of g lie in row-space(W_Q) ⊂ R^hidden.
# For subspace comparison we need vectors in hidden-space.
#
# We also need SIGNED eigenvalues of the symmetric part (g + g^T)/2 to
# determine curvature signs per eigen-direction. Since g has rank ≤ d_h,
# (g + g^T) has rank ≤ 2 d_h, and we can compute its signed non-zero
# eigenvalues via the Jordan-Wielandt-style trick:
# eigs(X^T J X) = eigs(J X X^T) for X = [W_Q; W_K], J = [[0, I], [I, 0]].
# The resulting 2d_h × 2d_h matrix gives us all non-zero eigenvalues of
# (g + g^T) cheaply.
topk_eff = min(topk, head_dim)
eig_dirs = torch.zeros(num_layers, num_heads, topk_eff, hidden,
dtype=torch.float32)
fro_sq = torch.zeros(num_layers, num_heads, dtype=torch.float64)
sym_eigs = torch.zeros(num_layers, num_heads, 2 * head_dim,
dtype=torch.float64) # signed
for L in range(num_layers):
for h in range(num_heads):
Wq = Wq_list[L][h] # (hd, hidden)
Wk = Wk_list[L][h] # (hd, hidden)
small = Wk @ Wq.T # (hd, hd)
U, S, Vh = torch.linalg.svd(small, full_matrices=False)
dirs = Vh @ Wq # (hd, hidden)
dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-12)
eig_dirs[L, h] = dirs[:topk_eff].cpu()
fro_sq[L, h] = float((S * S).sum())
# Signed eigenvalues of (g + g^T) via 2d_h × 2d_h matrix
# J (X X^T) where X = [Wq; Wk] (stacked)
XXT = torch.zeros(2 * head_dim, 2 * head_dim,
device=Wq.device, dtype=Wq.dtype)
XXT[:head_dim, :head_dim] = Wq @ Wq.T
XXT[:head_dim, head_dim:] = Wq @ Wk.T
XXT[head_dim:, :head_dim] = Wk @ Wq.T
XXT[head_dim:, head_dim:] = Wk @ Wk.T
# J matrix is off-diagonal block identity
J = torch.zeros(2 * head_dim, 2 * head_dim,
device=Wq.device, dtype=Wq.dtype)
J[:head_dim, head_dim:] = torch.eye(head_dim,
device=Wq.device, dtype=Wq.dtype)
J[head_dim:, :head_dim] = torch.eye(head_dim,
device=Wq.device, dtype=Wq.dtype)
M = J @ XXT
# M is not symmetric, but its non-zero eigenvalues are those of
# (g + g^T)/2 times 2 → real (since (g + g^T) is symmetric).
# Use general eigvals; imag parts should be near zero up to
# numerical noise.
ev = torch.linalg.eigvals(M)
ev_real = ev.real.cpu().double()
# sort by magnitude descending so top eigenvalues come first
order = torch.argsort(ev_real.abs(), descending=True)
sym_eigs[L, h] = ev_real[order]
if L % 8 == 0:
print(f" eigdecomp L={L}", flush=True)
# Gram matrix: gram[L, L', h] = <g_L^h, g_L'^h>.
# Using A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T, <g, g'> = tr(A B^T) = sum(A * B).
gram = torch.zeros(num_layers, num_layers, num_heads, dtype=torch.float64)
for L in range(num_layers):
for Lp in range(L, num_layers):
for h in range(num_heads):
Wq_L = Wq_list[L][h]
Wk_L = Wk_list[L][h]
Wq_Lp = Wq_list[Lp][h]
Wk_Lp = Wk_list[Lp][h]
A = Wk_L @ Wk_Lp.T # (hd, hd)
B = Wq_L @ Wq_Lp.T # (hd, hd)
v = float((A * B).sum())
gram[L, Lp, h] = v
gram[Lp, L, h] = v
if L % 4 == 0:
print(f" gram row L={L}", flush=True)
# Save
out = {
"model": model_name,
"num_layers": num_layers,
"num_heads": num_heads,
"head_dim": head_dim,
"hidden_size": hidden,
"topk": topk_eff,
"gram": gram.tolist(),
"fro_sq": fro_sq.tolist(),
}
with open(out_path, "w") as f:
json.dump(out, f)
torch.save({"eig_dirs": eig_dirs, "sym_eigs": sym_eigs},
out_path.replace(".json", "-eigdirs.pt"))
print(f"Wrote {out_path} and {out_path.replace('.json', '-eigdirs.pt')}",
flush=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-4B")
ap.add_argument("--out", default="/tmp/sa-grams.json")
ap.add_argument("--topk", type=int, default=8)
args = ap.parse_args()
measure(args.model, args.out, topk=args.topk)
if __name__ == "__main__":
main()