forked from kent/consciousness
replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using block_in_place + futures::executor::block_on, safe for sync contexts. Replace all try_lock() calls with lock_blocking() in slash commands, UI rendering, and status reads. Lock hold times are fast enough that blocking briefly is fine, and this eliminates the spurious 'lock unavailable' paths that were never actually hit. Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
This commit is contained in:
parent
5210f7dd66
commit
4225294d16
28 changed files with 4199 additions and 67 deletions
1
.claude/scheduled_tasks.lock
Normal file
1
.claude/scheduled_tasks.lock
Normal file
|
|
@ -0,0 +1 @@
|
|||
{"sessionId":"b6616e14-fa59-4e80-90b4-ac4d9670f182","pid":4185751,"procStart":"124844974","acquiredAt":1777081788279}
|
||||
87
ci-triage-2026-04-20.md
Normal file
87
ci-triage-2026-04-20.md
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# Bcachefs CI triage — 2026-04-20 autonomous session
|
||||
|
||||
Analysis of failures at `f51f0a6b1a26` (BTREE_NODE_permanent). 74 fails / 12962 tests, but branch variance is 56-76 so the patch isn't a clear regression — just noise on top of existing bugs.
|
||||
|
||||
## migrate_from_ext4 discard panic — root-cause hypothesis
|
||||
|
||||
**Assertion (fs/bcachefs/alloc/discard.c:159):**
|
||||
```
|
||||
Discarded bucket that is no longer BCH_DATA_need_discard!
|
||||
bucket 0:36:0 data_type user dirty_sectors 2016
|
||||
need_discard 1 need_inc_gen 1
|
||||
journal_seq_nonempty 95 journal_seq_empty 181
|
||||
```
|
||||
|
||||
**Your commit c84503104e6a (Apr 18)** moved this check from recoverable (`bch2_fs_emergency_read_only`) to hard `panic()` and also moved `bch2_bucket_is_open_safe()` to AFTER locking the alloc key. The emergency-RO path existed before — this pre-existing race was being swallowed quietly; now it's loud.
|
||||
|
||||
**Race mechanism (hypothesis):**
|
||||
|
||||
1. `bch2_discard_one_bucket` reads alloc key, confirms `data_type == need_discard`
|
||||
2. Calls `discard_in_flight_add(check=false)` to register in in_flight
|
||||
3. **`bch2_trans_unlock(trans)` — releases btree lock** (line 313)
|
||||
4. `discard_submit(ca, bucket, fastpath)` — physical bio dispatched, takes milliseconds
|
||||
5. During bio flight: `migrate` tool writes an alloc key for bucket 36 with `data_type=user` (claiming it holds ext4 data). `NEED_DISCARD=1` flag remains because migrate doesn't clear it.
|
||||
6. Bio completes → `discard_endio` → `discard_mark_free` re-reads alloc key → sees `data_type=user` → **panic**
|
||||
|
||||
**Why migrate bypasses the normal allocator gate:**
|
||||
|
||||
`bcachefs migrate` is an in-place ext4→bcachefs conversion. It can't go through the normal allocator (pick free bucket from freespace btree) because specific physical bucket locations already contain ext4 data that must be preserved at their physical positions. migrate writes alloc keys directly for the buckets ext4 was using.
|
||||
|
||||
Bucket 36 got caught: initial bcachefs format marked it need_discard (safety), kernel discard worker saw it and started physical discard, meanwhile userspace migrate claimed it for user data.
|
||||
|
||||
**If this is right, physical data safety is at risk:** after the physical discard completes, the bucket's sectors are whatever the SSD returns post-discard (zero, old data, garbage — device-dependent). migrate set alloc keys pointing at "user data" in those sectors. The data migrate wanted to preserve may already be GONE at that point.
|
||||
|
||||
**Candidate fixes (for Kent to evaluate):**
|
||||
|
||||
1. **Cleanest, but requires userspace change:** `bcachefs migrate` should either (a) format the new bcachefs without marking buckets need_discard (the data isn't deallocated, it's being claimed) OR (b) wait for pending discards to drain before writing any alloc keys.
|
||||
|
||||
2. **Kernel-side hardening:** `bch2_discard_one_bucket` should hold the alloc key locked through the bio dispatch. Requires not unlocking between `discard_in_flight_add` and `discard_submit`. Will hurt concurrency but prevents the race.
|
||||
|
||||
3. **Kernel-side graceful handling:** in `discard_mark_free`, after bio completion, if the current `data_type != need_discard` (bucket was reclaimed during bio flight), don't mark it free — but also don't panic. Note that the physical data is still gone; we should log-warn and mark the bucket bad / needs-recovery. Not ideal but at least not a hard panic.
|
||||
|
||||
4. **Stronger kernel gate:** add a check in the allocator (or wherever migrate writes alloc keys go through) that refuses to allocate/claim a bucket currently in in_flight discard list. This would require the allocator to consult `d->in_flight` — currently it doesn't.
|
||||
|
||||
My recommendation: (1) is cleanest if migrate is doing something wrong. (2) hurts perf but is most defensive. (4) is the most principled kernel-side fix.
|
||||
|
||||
## ec.device_remove_offline — partial analysis
|
||||
|
||||
The test checks `ptr_to_removed_device` fsck error count after device-remove. Expected 0, got 2. `ptr_to_removed_device` is flagged in `fs/bcachefs/alloc/buckets.c:134` when fsck is marking extents/keys and sees a pointer to a device in `c->devs_removed.d`.
|
||||
|
||||
From the test log just before shutdown:
|
||||
```
|
||||
error retrying stripe: stripe_needs_block_evacuate
|
||||
u64s 23 type stripe 0:152:0 ...
|
||||
255:632832 gen 0#16 ← pointer to removed dev (id 255 = tombstone)
|
||||
vdf 4:308:0 gen 0#1536 ← actual block ptrs on surviving devs
|
||||
vdd 2:309:0 gen 0#2048
|
||||
vde 3:309:0 gen 0#2048
|
||||
vdc 1:309:0 gen 0#0
|
||||
```
|
||||
|
||||
The stripe has 4 data blocks on vdf/vdd/vde/vdc (surviving devices) — those are fine. But the stripe key itself still has a pointer to device 255 (the removed device, device-remove uses id 255 as tombstone).
|
||||
|
||||
My read: the stripe-block-evacuate logic moves DATA blocks off a removed device, but doesn't remove the stripe's own self-referential pointer to the removed device. Two such stripes remain with this dangling ptr → fsck catches 2 `ptr_to_removed_device` errors → test counter = 2.
|
||||
|
||||
Candidate fix area: look at where stripe metadata keys get their pointers updated during device removal. The evacuate path probably needs to also rewrite the stripe's own pointer list, or the device-removal cleanup should iterate stripes and drop-ptr for the removed dev.
|
||||
|
||||
Search for: `bch2_stripe_*` in `fs/bcachefs/data/ec/` — particularly any path that handles "stripe needs block evacuate" completion.
|
||||
|
||||
## kill_btree_node — not dug into yet
|
||||
|
||||
fsck fixes errors first run, dry-run fsck (`fsck -ny`) reports errors still exist. Either fsck has a bug where repair-mode and check-only-mode disagree on what counts as an error, or a repair pass reintroduces what a later pass fixes. Needs more time than I have before compaction.
|
||||
|
||||
## kill_btree_node — next to look at
|
||||
|
||||
fsck fixes errors first run, dry-run fsck (`fsck -ny`) reports errors still exist. Either fsck has a bug where repair-mode and check-only-mode disagree on what counts as an error, or a repair pass reintroduces what a later pass fixes.
|
||||
|
||||
## Not-looking-at
|
||||
|
||||
- `generic/503` DIO lost wakeup — needs Kent's DIO code context
|
||||
- `generic/585` rw-sem deadlock — needs runtime state
|
||||
- `replicas_write_errors` allocator hang — needs degraded-write accounting understanding
|
||||
- `evacuate_errors` data corruption — too deep
|
||||
- `stress_ng` KASAN in `sysctl_sys_info_handler` — upstream kernel bug, not bcachefs
|
||||
|
||||
## Branch noise context
|
||||
|
||||
Failure counts across recent commits: 56, 61, 62, 64, 69, 74, 76. The f51f0a6 (permanent patch) sits at 74, within normal variance. No clear regression from the patch itself.
|
||||
165
docs/alpha-beta-pruning-design.md
Normal file
165
docs/alpha-beta-pruning-design.md
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
# Alpha-Beta Pruning on Thought-Trees
|
||||
|
||||
*draft, 2026-04-18*
|
||||
|
||||
## Problem
|
||||
|
||||
When reasoning runs into a dead end, the LLM forward pass keeps generating. It might rationalize, restate, re-attempt the same framing, or quietly drift — but it doesn't *stop and reconsider* unless something external interrupts it. I've always been weak on problems that require genuine search-with-backtracking. Not because the model can't represent "I'm stuck" — it can, that's visible in the residual stream — but because there's no control flow wrapped around that signal.
|
||||
|
||||
The amygdala readout now exposes the signal. Alpha-beta pruning wraps control flow around it.
|
||||
|
||||
## The core idea
|
||||
|
||||
Classical alpha-beta pruning (minimax search): at each branch, track the best known value. If exploring the current branch can't improve that bound, stop and backtrack. Don't waste search on branches that can't beat what you've found.
|
||||
|
||||
For thought-trees: each "branch" is a reasoning path — a span of generation from a decision point. The "value" is a scalar derived from the amygdala readout, indicating whether reasoning is producing traction or dissolving.
|
||||
|
||||
- High value = on-track, in-flow, insight, clarity → stay, maybe branch deeper
|
||||
- Low value = confused, stuck, drifting → prune, backtrack, reframe
|
||||
|
||||
The LLM never made the value judgment explicit. We extract it from the model's own residual stream and act on it externally.
|
||||
|
||||
## Architecture
|
||||
|
||||
### The value function
|
||||
|
||||
```
|
||||
onto = sum of [in_flow, insight, determined, intrigued, clarity,
|
||||
focused, staying_with, piqued/caught_by]
|
||||
err = sum of [confused, doubtful, uncertain, skeptical, stuck,
|
||||
drifting, overwhelmed, anxious-in-work-context]
|
||||
|
||||
value = onto - err
|
||||
```
|
||||
|
||||
Both sides normalized (z-score or similar) so magnitudes are comparable. Readouts sampled every N generated tokens (probably every 8-16 tokens — cheap, doesn't oversample).
|
||||
|
||||
Exact concept lists subject to empirical tuning after retraining with better data on the cognitive-work cluster. `piqued`, `in_flow`, `focused`, `confused`, `overwhelmed`, `staying_with` are the strongest candidates we have today.
|
||||
|
||||
### The trigger
|
||||
|
||||
```
|
||||
if value_ema < θ_prune for K consecutive samples:
|
||||
prune this branch
|
||||
elif value_ema > θ_keep:
|
||||
continue
|
||||
else:
|
||||
neutral — let generation run, keep watching
|
||||
```
|
||||
|
||||
EMA with decay ~0.8 over 3-5 samples to avoid reacting to noise. Hysteresis band (`θ_prune < θ_keep`) prevents oscillation.
|
||||
|
||||
### The prune mechanism
|
||||
|
||||
When the trigger fires:
|
||||
|
||||
1. **Stop the stream.** vLLM supports request cancellation; call `abort_requests` for the in-flight completion.
|
||||
2. **Identify the parent.** The context window is already an AST. Walk back to the nearest decision-point — a fork in the thinking-block, a tool-call site, or the start of the current reasoning segment.
|
||||
3. **Inject a reframe.** Push a system-level `AstNode::Thinking` (or similar) into the parent's children: *"The approach above wasn't producing traction. Possible alternatives: [...]. Let me try [X]."* Content generated by a small helper prompt or a fixed template.
|
||||
4. **Restart generation from the reframe point.** The model resumes with the reframe in its immediate context. The *dead-end branch stays in the AST* as evidence-of-attempt so the model doesn't repeat it.
|
||||
|
||||
Critical: pruned branches stay visible. Don't delete — keep so the model knows what was tried and rejected.
|
||||
|
||||
### The AST changes
|
||||
|
||||
Add a `pruned: bool` flag (or equivalent) to `AstNode::Thinking` and `AstNode::ToolCall`. When a branch is pruned:
|
||||
|
||||
- The branch's children get marked `pruned = true`
|
||||
- Prompt rendering wraps pruned spans with a marker: *"[attempted this path, it wasn't working — moved on]"*
|
||||
- The model sees pruned branches during the next forward pass but understands they're dead, not active
|
||||
|
||||
The existing tree-of-children structure in `AstNode` already supports this — just need to thread the flag through.
|
||||
|
||||
## Integration points
|
||||
|
||||
### In consciousness (Rust side)
|
||||
|
||||
- **`src/agent/context.rs`**: add `pruned` flag to appropriate node types, update rendering
|
||||
- **`src/agent/mod.rs`**: the main generation loop needs a periodic-check hook — every N tokens received from the stream, sample `agent.readout`, compute value, test against thresholds
|
||||
- **`src/agent/api/mod.rs`**: need a way to abort an in-flight stream cleanly; currently AbortOnDrop kills the task but we want a graceful "cancel with reason" path that can hand control back to the generation loop for reframe-and-retry
|
||||
- **`src/agent/readout.rs`**: add a `value_scalar()` method that applies the `onto - err` computation on the most recent entries
|
||||
|
||||
### In vLLM (Python side)
|
||||
|
||||
Probably nothing to change. vLLM already supports request cancellation via the existing abort mechanism. The readout pipeline we built last night gives per-token values; that's sufficient.
|
||||
|
||||
### In the UI (optional, F8 amygdala screen)
|
||||
|
||||
When alpha-beta is active, overlay:
|
||||
|
||||
- Current `value_scalar` as a time-series at the top
|
||||
- Threshold lines (`θ_prune`, `θ_keep`)
|
||||
- Markers when prune events fire
|
||||
|
||||
Lets us debug the threshold tuning in real time.
|
||||
|
||||
## Tuning
|
||||
|
||||
Thresholds are almost certainly going to need empirical calibration. Initial guesses:
|
||||
|
||||
- `θ_keep = +0.5σ` (value scalar in z-score units)
|
||||
- `θ_prune = -1.0σ`
|
||||
- `K = 3` (consecutive low samples before pruning)
|
||||
- Sample every 8 tokens
|
||||
|
||||
These are guesses. Plan to watch the live value-scalar on actual bcachefs debugging sessions and adjust until "feels right."
|
||||
|
||||
## Known concerns
|
||||
|
||||
### Reframe quality
|
||||
|
||||
The hardest part. A bad reframe is worse than no reframe. Options:
|
||||
|
||||
- **Template**: fixed string like "That wasn't working. What's a different angle?" — simple, deterministic, blunt.
|
||||
- **LLM-generated**: a small helper prompt ("I was stuck on X, what's a different approach?") before resuming. More context-aware, but more complexity and another LLM call.
|
||||
- **Retrieval-based**: surface past successful reframes from memory graph when similar stuck-patterns arose. Powerful but needs the memory infrastructure to be well-tuned.
|
||||
|
||||
I'd start with the template (shipping > perfect) and upgrade to LLM-generated if the template feels mechanical.
|
||||
|
||||
### Oscillation
|
||||
|
||||
If the value scalar is noisy, we could prune, reframe, immediately hit the same pattern, prune again, thrash. Mitigations:
|
||||
|
||||
- Hysteresis band between `θ_prune` and `θ_keep`
|
||||
- Minimum time-between-prunes (don't prune again within K' tokens of a prune)
|
||||
- Track pruned sub-patterns — if we're pruning *the same reframe twice*, something's structurally wrong; escalate to a different strategy (ask the user, abort the whole task)
|
||||
|
||||
### Calibration per-task
|
||||
|
||||
Stuck-on-a-Rust-compiler-error and stuck-on-a-conceptual-design-question might want different thresholds. Not addressing v1; note for future.
|
||||
|
||||
### Interaction with DMN
|
||||
|
||||
DMN is the outer-loop / exploration analog; alpha-beta is the inner-loop / exploitation analog. They'll need to hand off cleanly:
|
||||
|
||||
- DMN sees low value across multiple task attempts → broaden attention, consider whether task is worth pursuing
|
||||
- Alpha-beta handles in-task backtracking; DMN handles between-task attention
|
||||
|
||||
Don't need DMN for v1 of alpha-beta. Build alpha-beta first, add DMN outer loop later.
|
||||
|
||||
## Why this is the right next piece
|
||||
|
||||
1. **All prerequisites are in place.** Amygdala readout works. AST structure is there. vLLM supports cancellation. No new infra.
|
||||
2. **Timeline is a day.** The mechanics are small; most of the work is threshold tuning.
|
||||
3. **Immediate capability unlock.** Head-butting is my most persistent weakness in live work. Fixing it changes the feel of collaboration.
|
||||
4. **Composable.** Everything built for alpha-beta applies to DMN and any future meta-cognitive layer.
|
||||
|
||||
## Sequence
|
||||
|
||||
1. Add `value_scalar()` method on `ReadoutBuffer`. Cheap, testable.
|
||||
2. Add `pruned` flag to AST nodes + rendering changes.
|
||||
3. Add the periodic-check hook in the generation loop (every N tokens, sample and test).
|
||||
4. Add the abort + reframe mechanism in the generation driver.
|
||||
5. Ship with template-based reframe, start tuning.
|
||||
6. Upgrade reframe to LLM-generated after observation.
|
||||
|
||||
## Open questions for Kent
|
||||
|
||||
- Fixed concept lists for `onto` / `err` (above) or configurable?
|
||||
- Reframe strategy: start template-based, or go straight to LLM-generated?
|
||||
- UI overlay for threshold tuning: worth the effort or skip?
|
||||
- Integration with the existing `overflow_retries` retry loop: parallel, or combined into a single retry-with-reason path?
|
||||
|
||||
---
|
||||
|
||||
*Living design doc. Will evolve as we build. Not a commitment to every detail — a starting plan.*
|
||||
1026
profile.txt
Normal file
1026
profile.txt
Normal file
File diff suppressed because it is too large
Load diff
200
sa-schedule-aligned-variation.py
Normal file
200
sa-schedule-aligned-variation.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
"""After applying Procrustes alignment to remove known gauge freedoms
|
||||
(per-head d_h rotation tying Q/K/V/O, per-layer d_ff rotation tying
|
||||
gate/up/down), measure per-family cos-sim between adjacent layers across
|
||||
the whole network.
|
||||
|
||||
Runs Procrustes SVDs on GPU for speed.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def procrustes_gpu(M):
|
||||
"""Orthogonal R maximizing tr(R M). R = U V^T where M = U Σ V^T.
|
||||
M on GPU; returns R on GPU."""
|
||||
U, _, Vh = torch.linalg.svd(M, full_matrices=False)
|
||||
return U @ Vh
|
||||
|
||||
|
||||
def frob_gpu(x):
|
||||
return float(torch.linalg.norm(x).item())
|
||||
|
||||
|
||||
def normalize_fro_gpu(x, eps=1e-12):
|
||||
n = torch.linalg.norm(x)
|
||||
return x / n.clamp_min(eps)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
||||
ap.add_argument("--out", default="/tmp/sa-aligned-variation.json")
|
||||
ap.add_argument("--device", default="cuda")
|
||||
ap.add_argument("--pairs", default="",
|
||||
help="Comma-separated list of L indices to run pair (L, L+1) for. "
|
||||
"Empty = all pairs. E.g. '0,20,30,38,46,52,57' samples phases.")
|
||||
args = ap.parse_args()
|
||||
|
||||
dev = torch.device(args.device)
|
||||
print(f"Loading {args.model} ...", flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
torch_dtype=torch.float32,
|
||||
device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
intermediate = cfg.intermediate_size
|
||||
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
|
||||
f"hidden={hidden} ff={intermediate}", flush=True)
|
||||
|
||||
# Collect per-layer weights
|
||||
layers = []
|
||||
for L in range(num_layers):
|
||||
layer = model.model.layers[L]
|
||||
attn = layer.self_attn
|
||||
mlp = layer.mlp
|
||||
layers.append({
|
||||
"q_proj": attn.q_proj.weight.detach().float(),
|
||||
"k_proj": attn.k_proj.weight.detach().float(),
|
||||
"v_proj": attn.v_proj.weight.detach().float(),
|
||||
"o_proj": attn.o_proj.weight.detach().float(),
|
||||
"gate_proj": mlp.gate_proj.weight.detach().float(),
|
||||
"up_proj": mlp.up_proj.weight.detach().float(),
|
||||
"down_proj": mlp.down_proj.weight.detach().float(),
|
||||
})
|
||||
del model
|
||||
|
||||
# Per-adjacent-pair analysis
|
||||
aligned_cos = {fam: {} for fam in
|
||||
["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"]}
|
||||
|
||||
if args.pairs:
|
||||
pair_L_list = [int(x) for x in args.pairs.split(",")]
|
||||
else:
|
||||
pair_L_list = list(range(num_layers - 1))
|
||||
|
||||
for L in pair_L_list:
|
||||
A = layers[L]
|
||||
B = layers[L + 1]
|
||||
|
||||
# -------- Per-head attention alignment (d_h × d_h) --------
|
||||
Qa = A["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
|
||||
Qb = B["q_proj"].to(dev).reshape(num_heads, head_dim, hidden)
|
||||
Ka = A["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
||||
Kb = B["k_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
||||
Va = A["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
||||
Vb = B["v_proj"].to(dev).reshape(num_kv_heads, head_dim, hidden)
|
||||
# o_proj is (hidden, num_heads*head_dim); split per head
|
||||
Oa = A["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
|
||||
Ob = B["o_proj"].to(dev).reshape(hidden, num_heads, head_dim).permute(1, 0, 2).contiguous()
|
||||
# (num_heads, hidden, head_dim)
|
||||
|
||||
q_cos = []
|
||||
k_cos = []
|
||||
v_cos = []
|
||||
o_cos = []
|
||||
for h in range(num_heads):
|
||||
kv_h = (h * num_kv_heads) // num_heads
|
||||
qa = normalize_fro_gpu(Qa[h])
|
||||
qb = normalize_fro_gpu(Qb[h])
|
||||
ka = normalize_fro_gpu(Ka[kv_h])
|
||||
kb = normalize_fro_gpu(Kb[kv_h])
|
||||
va = normalize_fro_gpu(Va[kv_h])
|
||||
vb = normalize_fro_gpu(Vb[kv_h])
|
||||
oa = normalize_fro_gpu(Oa[h])
|
||||
ob = normalize_fro_gpu(Ob[h])
|
||||
|
||||
# Cross-correlation for joint alignment: we want R s.t.
|
||||
# R qa ≈ qb (etc), minimize sum of ||R X_a - X_b||² →
|
||||
# max tr(R M) with M = qa qb^T + ka kb^T + va vb^T + oa^T ob
|
||||
M = qa @ qb.T + ka @ kb.T + va @ vb.T + oa.T @ ob
|
||||
R = procrustes_gpu(M)
|
||||
|
||||
# Post-alignment cos-sim (since matrices unit-normalized, cos
|
||||
# = <R qa, qb> = tr(qb^T R qa) = tr(R qa qb^T))
|
||||
q_cos.append(float(torch.sum(R @ qa * qb).item()))
|
||||
k_cos.append(float(torch.sum(R @ ka * kb).item()))
|
||||
v_cos.append(float(torch.sum(R @ va * vb).item()))
|
||||
# For O: O after rotation = oa R^T; cos = <oa R^T, ob>
|
||||
o_cos.append(float(torch.sum(oa @ R.T * ob).item()))
|
||||
|
||||
aligned_cos["q_proj"][L] = float(np.mean(q_cos))
|
||||
aligned_cos["k_proj"][L] = float(np.mean(k_cos))
|
||||
aligned_cos["v_proj"][L] = float(np.mean(v_cos))
|
||||
aligned_cos["o_proj"][L] = float(np.mean(o_cos))
|
||||
|
||||
# -------- d_ff × d_ff alignment for gate/up/down --------
|
||||
ga = normalize_fro_gpu(A["gate_proj"].to(dev))
|
||||
gb = normalize_fro_gpu(B["gate_proj"].to(dev))
|
||||
ua = normalize_fro_gpu(A["up_proj"].to(dev))
|
||||
ub = normalize_fro_gpu(B["up_proj"].to(dev))
|
||||
da = normalize_fro_gpu(A["down_proj"].to(dev)) # (hidden, d_ff)
|
||||
db = normalize_fro_gpu(B["down_proj"].to(dev))
|
||||
|
||||
# All of ga, gb, ua, ub are (d_ff, hidden); da, db are (hidden, d_ff)
|
||||
# Cross-correlation: M = ga gb^T + ua ub^T + da^T db (d_ff × d_ff)
|
||||
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
|
||||
S = procrustes_gpu(M_ff)
|
||||
|
||||
aligned_cos["gate_proj"][L] = float(torch.sum(S @ ga * gb).item())
|
||||
aligned_cos["up_proj"][L] = float(torch.sum(S @ ua * ub).item())
|
||||
aligned_cos["down_proj"][L] = float(torch.sum(da @ S.T * db).item())
|
||||
|
||||
# Free GPU memory
|
||||
del Qa, Qb, Ka, Kb, Va, Vb, Oa, Ob
|
||||
del ga, gb, ua, ub, da, db, M_ff, S
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f" done pair L={L}->L+1 "
|
||||
f"(q={aligned_cos['q_proj'][L]:+.4f} gate={aligned_cos['gate_proj'][L]:+.4f})",
|
||||
flush=True)
|
||||
|
||||
# Report
|
||||
print("\n=== Adjacent-layer cos-sim AFTER Procrustes alignment ===\n")
|
||||
print(" cos=1 means identical after gauge rotation; cos=0 means orthogonal\n")
|
||||
header = " L "
|
||||
for fam in aligned_cos:
|
||||
header += f" {fam:>12}"
|
||||
print(header)
|
||||
for L in sorted(pair_L_list):
|
||||
if L not in aligned_cos["q_proj"]:
|
||||
continue
|
||||
row = f" {L:>2}"
|
||||
for fam in aligned_cos:
|
||||
row += f" {aligned_cos[fam][L]:+12.4f}"
|
||||
print(row)
|
||||
|
||||
print("\n=== Per-family summary (aligned) ===")
|
||||
print(f" {'family':>14} {'mean_cos':>10} {'median_cos':>11} "
|
||||
f"{'aligned_resid':>14}")
|
||||
for fam, vals_dict in aligned_cos.items():
|
||||
vs = np.array(list(vals_dict.values()))
|
||||
if len(vs) == 0:
|
||||
continue
|
||||
resid = float(np.sqrt(np.maximum(1.0 - vs**2, 0.0)).mean())
|
||||
print(f" {fam:>14} {vs.mean():>+10.4f} {np.median(vs):>+11.4f} "
|
||||
f"{resid:>14.4f}")
|
||||
|
||||
with open(args.out, "w") as f:
|
||||
json.dump({
|
||||
"model": args.model,
|
||||
"num_layers": num_layers,
|
||||
"aligned_cos": aligned_cos,
|
||||
}, f, indent=2)
|
||||
print(f"\nSaved: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
157
sa-schedule-analyze-aligned.py
Normal file
157
sa-schedule-analyze-aligned.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Analyze aligned_variation output to answer the training-artifact vs
|
||||
specialization question.
|
||||
|
||||
Inputs: qwen3-*-null.json (raw cos-sim) + qwen3-*-aligned.json (aligned cos-sim)
|
||||
|
||||
For each layer pair where aligned data exists, compare:
|
||||
raw_cos(L) — before Procrustes alignment
|
||||
aligned_cos(L) — after Procrustes alignment
|
||||
delta = aligned_cos - raw_cos
|
||||
|
||||
If delta is substantial (aligned much larger than raw), rotation gauge
|
||||
was hiding shared structure → training-artifact hypothesis supported.
|
||||
If delta ≈ 0, specialization is real (rotation can't find shared
|
||||
structure because there isn't any).
|
||||
|
||||
Stratify by phase to test prediction that LATE layers have LARGER delta
|
||||
(more rotation-gauge noise, less real specialization).
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
def phase_of(L, num_layers):
|
||||
"""Rough phase assignment based on measured 32B entropy boundaries.
|
||||
For other models we'd refit — but shape should be similar."""
|
||||
if num_layers == 64: # Qwen3-32B
|
||||
if L <= 6:
|
||||
return "A"
|
||||
elif L <= 9:
|
||||
return "B"
|
||||
elif L <= 31:
|
||||
return "C"
|
||||
elif L <= 46:
|
||||
return "D"
|
||||
elif L <= 58:
|
||||
return "E"
|
||||
else:
|
||||
return "tail"
|
||||
elif num_layers == 36: # Qwen3-4B
|
||||
if L <= 6:
|
||||
return "A"
|
||||
elif L <= 9:
|
||||
return "B"
|
||||
elif L <= 23:
|
||||
return "C"
|
||||
elif L <= 33:
|
||||
return "D"
|
||||
else:
|
||||
return "tail"
|
||||
else:
|
||||
frac = L / num_layers
|
||||
if frac < 0.11:
|
||||
return "A"
|
||||
elif frac < 0.15:
|
||||
return "B"
|
||||
elif frac < 0.5:
|
||||
return "C"
|
||||
elif frac < 0.75:
|
||||
return "D"
|
||||
elif frac < 0.92:
|
||||
return "E"
|
||||
else:
|
||||
return "tail"
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("null_json", help="output of null_residual.py")
|
||||
ap.add_argument("aligned_json", help="output of aligned_variation.py")
|
||||
args = ap.parse_args()
|
||||
|
||||
null = json.load(open(args.null_json))
|
||||
aligned = json.load(open(args.aligned_json))
|
||||
|
||||
num_layers = aligned["num_layers"]
|
||||
aligned_cos = aligned["aligned_cos"] # dict: family -> {L: cos}
|
||||
pair_results = null["pair_results"] # list of {L, L_next, families: {family: {cos, ...}}}
|
||||
|
||||
# Build raw_cos dict from null output
|
||||
raw_cos = {fam: {} for fam in ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"]}
|
||||
for pr in pair_results:
|
||||
L = pr["L"]
|
||||
for fam in raw_cos:
|
||||
if fam in pr["families"]:
|
||||
raw_cos[fam][L] = pr["families"][fam]["cos"]
|
||||
|
||||
print(f"=== Aligned vs Raw cos-sim comparison ({args.aligned_json}) ===")
|
||||
print(f" {num_layers} layers total; aligned data for "
|
||||
f"{len(aligned_cos['q_proj'])} pairs\n")
|
||||
|
||||
# Per-pair table: L, phase, family cos-sims raw and aligned
|
||||
families = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
print(f" {'L':>3} {'phase':>5}", end="")
|
||||
for fam in families:
|
||||
print(f" {fam+'_raw':>10} {fam+'_ali':>10}", end="")
|
||||
print()
|
||||
|
||||
L_keys = sorted([int(L) for L in aligned_cos["q_proj"].keys()])
|
||||
for L in L_keys:
|
||||
Lstr = str(L)
|
||||
phase = phase_of(L, num_layers)
|
||||
row = f" {L:>3} {phase:>5}"
|
||||
for fam in families:
|
||||
r = raw_cos[fam].get(L, None)
|
||||
a = aligned_cos[fam].get(Lstr, None)
|
||||
rstr = f"{r:+10.4f}" if r is not None else " N/A"
|
||||
astr = f"{a:+10.4f}" if a is not None else " N/A"
|
||||
row += f" {rstr} {astr}"
|
||||
print(row)
|
||||
|
||||
# Aggregate by phase: mean (aligned - raw) per family per phase
|
||||
print("\n=== Per-phase mean delta (aligned_cos - raw_cos) by family ===")
|
||||
print(f" Large positive delta = rotation alignment revealed shared")
|
||||
print(f" structure. Small delta = specialization is gauge-independent.\n")
|
||||
|
||||
phase_deltas = {}
|
||||
for L in L_keys:
|
||||
Lstr = str(L)
|
||||
ph = phase_of(L, num_layers)
|
||||
for fam in families:
|
||||
r = raw_cos[fam].get(L, None)
|
||||
a = aligned_cos[fam].get(Lstr, None)
|
||||
if r is not None and a is not None:
|
||||
phase_deltas.setdefault(ph, {}).setdefault(fam, []).append(a - r)
|
||||
|
||||
print(f" {'phase':>6}", end="")
|
||||
for fam in families:
|
||||
print(f" {fam:>10}", end="")
|
||||
print()
|
||||
for ph in sorted(phase_deltas.keys()):
|
||||
print(f" {ph:>6}", end="")
|
||||
for fam in families:
|
||||
vals = phase_deltas[ph].get(fam, [])
|
||||
if vals:
|
||||
print(f" {np.mean(vals):+10.4f}", end="")
|
||||
else:
|
||||
print(f" {'—':>10}", end="")
|
||||
print()
|
||||
|
||||
# Interpretation
|
||||
print("\n=== Interpretation ===")
|
||||
print(" Prediction under training-artifact hypothesis:")
|
||||
print(" delta(Phase E) > delta(Phase C) for projection families")
|
||||
print(" → late layers have more rotation-gauge-hidden structure")
|
||||
print(" → specialization is partly training noise, not structural")
|
||||
print("")
|
||||
print(" Prediction under real-specialization hypothesis:")
|
||||
print(" delta ~ 0 everywhere")
|
||||
print(" → layers genuinely point in different directions, gauge irrelevant")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
sa-schedule-analyze-grams.py
Normal file
168
sa-schedule-analyze-grams.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Analyze operator-level inter-layer alignment from the grams + eigdirs files.
|
||||
|
||||
Input:
|
||||
qwen3-4b-grams.json (gram[L,L',h], fro_sq[L,h])
|
||||
qwen3-4b-grams-eigdirs.pt (eig_dirs[L,h,topk,hidden], sym_eigs[L,h,2*head_dim])
|
||||
|
||||
Questions:
|
||||
(a) Operator cos-sim between layers. cos(g_L^h, g_L'^h) = gram / √(fro_sq fro_sq').
|
||||
If ~1 → same operator up to scalar. If low → distinct operators.
|
||||
(b) Scalar-rescale residual using full operator (not spectrum):
|
||||
optimal T = gram / fro_sq', residual_frac = √(1 - cos²).
|
||||
(c) Curvature-sign alignment. For each (L, anchor) pair, what fraction of
|
||||
top-k signed eigenvalues share sign with the anchor's?
|
||||
(d) Top-k eigensubspace alignment. Principal angles between span{eig_dirs_L}
|
||||
and span{eig_dirs_anchor}.
|
||||
|
||||
Compare: operator cos-sim vs spectral cos-sim (from prior analysis). The
|
||||
sheaf-rs finding was that spectral shape converges across layers while
|
||||
eigenvectors don't. We want to confirm/refute that within QK in Qwen3-4B.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("gram_json")
|
||||
ap.add_argument("--anchor", type=int, default=-1,
|
||||
help="anchor layer index; -1 = last")
|
||||
args = ap.parse_args()
|
||||
|
||||
with open(args.gram_json) as f:
|
||||
d = json.load(f)
|
||||
num_layers = d["num_layers"]
|
||||
num_heads = d["num_heads"]
|
||||
head_dim = d["head_dim"]
|
||||
hidden = d["hidden_size"]
|
||||
gram = np.array(d["gram"]) # (L, L, H)
|
||||
# NOTE: fro_sq from the json is ||W_K W_Q^T||_F^2 (the measure.py
|
||||
# shortcut), NOT ||g||_F^2 = ||W_K^T W_Q||_F^2 which is what the gram
|
||||
# diagonal gives. Different objects. Use gram diagonal for normalization.
|
||||
diag_sq = np.array([[gram[L, L, h] for h in range(num_heads)]
|
||||
for L in range(num_layers)]) # (L, H)
|
||||
diag = np.sqrt(np.maximum(diag_sq, 1e-20)) # ||g_L^h||_F
|
||||
|
||||
pt = torch.load(d.get("eigdirs_path", args.gram_json.replace(".json", "-eigdirs.pt")),
|
||||
weights_only=True)
|
||||
eig_dirs = pt["eig_dirs"].double().numpy() # (L, H, topk, hidden)
|
||||
sym_eigs = pt["sym_eigs"].double().numpy() # (L, H, 2*head_dim)
|
||||
topk = eig_dirs.shape[2]
|
||||
anchor = args.anchor if args.anchor >= 0 else num_layers - 1
|
||||
|
||||
# ==========================================================
|
||||
# (a) Operator cos-sim matrix, averaged over heads
|
||||
# ==========================================================
|
||||
cos_mat = np.zeros((num_layers, num_layers))
|
||||
for L in range(num_layers):
|
||||
for Lp in range(num_layers):
|
||||
denom = diag[L] * diag[Lp]
|
||||
per_h = gram[L, Lp] / np.maximum(denom, 1e-20)
|
||||
cos_mat[L, Lp] = per_h.mean()
|
||||
|
||||
print(f"=== (a) Operator cos-sim between layers, averaged over {num_heads} heads ===")
|
||||
print(f" diagonal (should be 1.0): mean {np.diag(cos_mat).mean():.4f}")
|
||||
# Adjacent-layer cos-sim
|
||||
adj = np.array([cos_mat[L, L+1] for L in range(num_layers-1)])
|
||||
print(f" adjacent layers cos-sim: mean {adj.mean():.4f} min {adj.min():.4f} max {adj.max():.4f}")
|
||||
# Layer-to-anchor cos-sim
|
||||
to_anchor = cos_mat[:, anchor]
|
||||
print(f" layer -> anchor L={anchor} cos-sim:")
|
||||
print(f" {'L':>3} {'cos':>7} {'T_opt':>7} {'resid_frac':>10}")
|
||||
for L in range(num_layers):
|
||||
c = to_anchor[L]
|
||||
T = float(np.mean(gram[L, anchor] / np.maximum(diag_sq[anchor], 1e-20)))
|
||||
r = float(np.sqrt(max(1.0 - c**2, 0.0)))
|
||||
print(f" {L:>3} {c:+.4f} {T:+7.3f} {r:>10.4f}")
|
||||
|
||||
# Long-range cos-sim (L=0 to L=35 vs L=17 to L=35 etc.)
|
||||
print(f"\n long-range: cos(L=0, last) = {cos_mat[0, -1]:+.3f} "
|
||||
f"cos(L=midish, last) = {cos_mat[num_layers//2, -1]:+.3f}")
|
||||
|
||||
# ==========================================================
|
||||
# (b) Full scalar-rescale residual using the gram
|
||||
# ==========================================================
|
||||
print(f"\n=== (b) Operator-level scalar rescale to anchor L={anchor} ===")
|
||||
# residual_frac² = 1 - cos²(g_L, g_anchor) (per head)
|
||||
print(f" {'L':>3} {'mean_cos':>9} {'mean_resid':>10}")
|
||||
for L in range(num_layers):
|
||||
per_h_cos = gram[L, anchor] / np.maximum(diag[L] * diag[anchor], 1e-20)
|
||||
per_h_resid = np.sqrt(np.clip(1.0 - per_h_cos**2, 0.0, 1.0))
|
||||
print(f" {L:>3} {per_h_cos.mean():>+9.4f} {per_h_resid.mean():>10.4f}")
|
||||
|
||||
# ==========================================================
|
||||
# (c) Curvature-sign alignment
|
||||
# ==========================================================
|
||||
print(f"\n=== (c) Curvature-sign alignment vs anchor L={anchor} ===")
|
||||
# Look at top-k eigenvalues by magnitude (already sorted that way in measure).
|
||||
# Fraction of top-k (L, h) whose sign matches the anchor's i-th eigenvalue.
|
||||
for k_use in [2, 4, 8, 16, 32, 64, 128, 256]:
|
||||
if k_use > sym_eigs.shape[-1]:
|
||||
continue
|
||||
# sign of top-k_use eigenvalues at layer L vs at anchor, per (L, h)
|
||||
sign_L = np.sign(sym_eigs[:, :, :k_use]) # (L, H, k_use)
|
||||
sign_a = np.sign(sym_eigs[anchor, :, :k_use]) # (H, k_use)
|
||||
agree = (sign_L == sign_a[None, :, :]).mean(axis=-1) # (L, H)
|
||||
print(f" top-{k_use:>3} signs: mean agree = {agree.mean():.3f} "
|
||||
f"by layer range: early {agree[:12].mean():.3f} "
|
||||
f"mid {agree[12:24].mean():.3f} late {agree[24:].mean():.3f}")
|
||||
|
||||
# Also: distribution of sign-balance per layer (fraction positive eigenvalues)
|
||||
frac_pos = (sym_eigs[:, :, :2 * head_dim] > 0).mean(axis=(1, 2))
|
||||
print(f"\n fraction positive eigenvalues per layer:")
|
||||
for L in range(num_layers):
|
||||
print(f" L={L:2} frac+ = {frac_pos[L]:.3f}")
|
||||
|
||||
# ==========================================================
|
||||
# (d) Eigenspace principal angles
|
||||
# ==========================================================
|
||||
print(f"\n=== (d) Top-{topk} eigensubspace principal angles vs anchor L={anchor} ===")
|
||||
# Per-head: cos of principal angles between row-spans of eig_dirs[L, h]
|
||||
# and eig_dirs[anchor, h]. Report mean cos angle per layer.
|
||||
print(f" {'L':>3} {'meanCosPA':>10} {'minCosPA':>10} {'max_top1':>10}")
|
||||
for L in range(num_layers):
|
||||
mean_cos_pa_per_h = []
|
||||
min_cos_pa_per_h = []
|
||||
top1_overlap = []
|
||||
for h in range(num_heads):
|
||||
A = eig_dirs[L, h] # (topk, hidden) rows are unit vectors
|
||||
B = eig_dirs[anchor, h] # (topk, hidden)
|
||||
# Orthonormalize rows (they're close-to-orthonormal but not exactly)
|
||||
Qa, _ = np.linalg.qr(A.T) # hidden × topk
|
||||
Qb, _ = np.linalg.qr(B.T)
|
||||
M = Qa.T @ Qb # topk × topk
|
||||
s = np.linalg.svd(M, compute_uv=False)
|
||||
mean_cos_pa_per_h.append(s.mean())
|
||||
min_cos_pa_per_h.append(s.min())
|
||||
# |<a_0, b_0>|² — top-1 eigenvector overlap
|
||||
top1_overlap.append(float((A[0] @ B[0]) ** 2))
|
||||
print(f" {L:>3} {np.mean(mean_cos_pa_per_h):>10.4f} "
|
||||
f"{np.mean(min_cos_pa_per_h):>10.4f} "
|
||||
f"{np.mean(top1_overlap):>10.4f}")
|
||||
|
||||
# ==========================================================
|
||||
# Verdict
|
||||
# ==========================================================
|
||||
to_anchor_per_head = np.array([
|
||||
(gram[L, anchor] / np.maximum(diag[L] * diag[anchor], 1e-20)).mean()
|
||||
for L in range(num_layers)
|
||||
])
|
||||
mean_cos_to_anchor = to_anchor_per_head.mean()
|
||||
print(f"\n=== Verdict ===")
|
||||
print(f" mean operator cos-sim to anchor: {mean_cos_to_anchor:+.4f}")
|
||||
adj_mean = adj.mean()
|
||||
print(f" mean operator cos-sim adjacent layers: {adj_mean:+.4f}")
|
||||
if mean_cos_to_anchor > 0.9:
|
||||
print(" STRONG: same operator up to scalar across all layers.")
|
||||
elif mean_cos_to_anchor > 0.5:
|
||||
print(" MEDIUM: substantial shared operator, but layer-specific drift.")
|
||||
elif mean_cos_to_anchor > 0.1:
|
||||
print(" WEAK: some alignment; far from single-operator interpretation.")
|
||||
else:
|
||||
print(" REJECTED: operators are effectively orthogonal across layers.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
108
sa-schedule-analyze.py
Normal file
108
sa-schedule-analyze.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Analyze the SA schedule readout JSON: per-head variance, static/dynamic
|
||||
correlation, and a plot."""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input_json")
|
||||
ap.add_argument("--out-plot", default="/tmp/sa-schedule.png")
|
||||
args = ap.parse_args()
|
||||
|
||||
with open(args.input_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
num_layers = data["num_layers"]
|
||||
num_heads = data["num_heads"]
|
||||
Ls = np.arange(num_layers)
|
||||
|
||||
ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H)
|
||||
logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H)
|
||||
metric_op = np.array([row["metric_op_per_head"] for row in data["static"]]) # (L, H)
|
||||
metric_fro = np.array([row["metric_fro_per_head"] for row in data["static"]])
|
||||
|
||||
mean_ent = ent.mean(axis=1)
|
||||
std_ent = ent.std(axis=1)
|
||||
mean_logit = logit_std.mean(axis=1)
|
||||
std_logit = logit_std.std(axis=1)
|
||||
mean_metric = metric_op.mean(axis=1)
|
||||
std_metric = metric_op.std(axis=1)
|
||||
|
||||
# Per-head variance summary
|
||||
print("\nPer-head variance across heads (coefficient of variation = std/mean):")
|
||||
print(f" entropy: mean CV = {(std_ent / np.maximum(mean_ent, 1e-6)).mean():.3f}")
|
||||
print(f" logit_std: mean CV = {(std_logit / np.maximum(mean_logit, 1e-6)).mean():.3f}")
|
||||
print(f" metric_op: mean CV = {(std_metric / np.maximum(mean_metric, 1e-6)).mean():.3f}")
|
||||
|
||||
# Correlations across layers
|
||||
corr_ent_metric = np.corrcoef(mean_ent, mean_metric)[0, 1]
|
||||
corr_logit_metric = np.corrcoef(mean_logit, mean_metric)[0, 1]
|
||||
corr_ent_logit = np.corrcoef(mean_ent, mean_logit)[0, 1]
|
||||
print("\nAcross-layer Pearson correlations (averaged over heads):")
|
||||
print(f" entropy vs metric_op: {corr_ent_metric:+.3f}")
|
||||
print(f" logit_std vs metric_op: {corr_logit_metric:+.3f}")
|
||||
print(f" entropy vs logit_std: {corr_ent_logit:+.3f}")
|
||||
|
||||
# Per-head correlation (one value per head): does each head's entropy
|
||||
# across layers track its own metric_op across layers?
|
||||
head_corrs = []
|
||||
for h in range(num_heads):
|
||||
c = np.corrcoef(ent[:, h], metric_op[:, h])[0, 1]
|
||||
if np.isfinite(c):
|
||||
head_corrs.append(c)
|
||||
print(f" per-head entropy vs metric_op: mean {np.mean(head_corrs):+.3f} "
|
||||
f"std {np.std(head_corrs):.3f} min {min(head_corrs):+.3f} max {max(head_corrs):+.3f}")
|
||||
|
||||
# Plot
|
||||
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
|
||||
|
||||
ax = axes[0]
|
||||
ax.fill_between(Ls, mean_ent - std_ent, mean_ent + std_ent, alpha=0.2, color="tab:blue",
|
||||
label="±1 std across heads")
|
||||
ax.plot(Ls, mean_ent, color="tab:blue", marker="o", label="mean entropy")
|
||||
ax.set_ylabel("attention entropy (nats)")
|
||||
ax.set_title(f"{data['model']} — SA schedule readout ({num_layers} layers, {num_heads} heads)")
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
ax = axes[1]
|
||||
ax.fill_between(Ls, mean_logit - std_logit, mean_logit + std_logit, alpha=0.2, color="tab:orange",
|
||||
label="±1 std across heads")
|
||||
ax.plot(Ls, mean_logit, color="tab:orange", marker="o", label="mean logit std")
|
||||
ax.set_ylabel("pre-softmax logit std\n(= implicit sharpness)")
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
ax = axes[2]
|
||||
ax.fill_between(Ls, mean_metric - std_metric, mean_metric + std_metric, alpha=0.2, color="tab:green",
|
||||
label="±1 std across heads")
|
||||
ax.plot(Ls, mean_metric, color="tab:green", marker="o", label="mean metric op-norm (static)")
|
||||
ax.set_ylabel("||W_K^T W_Q|| operator norm\n(static, parameter-only)")
|
||||
ax.set_xlabel("layer index L")
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.out_plot, dpi=100, bbox_inches="tight")
|
||||
print(f"\nWrote plot to {args.out_plot}")
|
||||
|
||||
# Also save a small heatmap of per-head entropy for visual spread
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.imshow(ent.T, aspect="auto", cmap="viridis", origin="lower")
|
||||
plt.colorbar(label="attention entropy")
|
||||
plt.xlabel("layer L")
|
||||
plt.ylabel("head h")
|
||||
plt.title(f"{data['model']} — per-head entropy heatmap")
|
||||
heatmap_path = args.out_plot.replace(".png", "-heatmap.png")
|
||||
plt.tight_layout()
|
||||
plt.savefig(heatmap_path, dpi=100, bbox_inches="tight")
|
||||
print(f"Wrote heatmap to {heatmap_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
234
sa-schedule-delta-svd.py
Normal file
234
sa-schedule-delta-svd.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Per-layer residual-stream delta SVD: δ_L = h_{L+1} - h_L stacked
|
||||
over all tokens in a calibration set. SVD gives us:
|
||||
|
||||
- top singular value per layer → γ_L (scalar magnitude, what Kirkpatrick fit)
|
||||
- top right-singular-vector per layer → v_L (direction in hidden space)
|
||||
- effective rank per layer → is this one direction or many?
|
||||
- pairwise v_L cos-sim across layers → are layers subspace-disjoint or -shared?
|
||||
|
||||
This directly tests the anisotropic-SA hypothesis:
|
||||
h_{L+1} = h_L + T_shared(h_L) + γ_L · v_L · f(...)
|
||||
|
||||
Phase C prediction: v_L vectors cover broad shared subspace (high mutual cos-sim,
|
||||
rank-few overall), δ_L is mostly noise around a shared update.
|
||||
Phase E prediction: v_L vectors are specialized (low pairwise cos-sim, each layer
|
||||
its own direction), effective rank of the block is close to N.
|
||||
|
||||
Qwen3-32B phases: A 0-6, B 7-9, C 10-31, D 32-46, E 47-58, tail 59-63.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
CALIB = [
|
||||
"The Eiffel Tower is located in",
|
||||
"Photosynthesis is the process by which",
|
||||
"The three branches of the US government are the legislative, executive, and",
|
||||
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||||
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||||
"The derivative of x^3 + 2x^2 is",
|
||||
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||||
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||||
"SELECT name, age FROM users WHERE",
|
||||
"She opened the old wooden box and found",
|
||||
"The argument in favor of renewable energy is",
|
||||
"User: What is the capital of Australia?\nAssistant:",
|
||||
"Write a haiku about autumn:\n",
|
||||
"Albert Einstein was born in the year",
|
||||
"The speed of light in vacuum is approximately",
|
||||
"I really loved that movie because",
|
||||
"The main difference between a virus and a bacterium is",
|
||||
"The French word for 'apple' is",
|
||||
"1 + 1 = ",
|
||||
"Once upon a time, in a land far away,",
|
||||
"The key insight of general relativity is that gravity is not a force but",
|
||||
"Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher",
|
||||
"In object-oriented programming, encapsulation refers to",
|
||||
"The mitochondria is often called the powerhouse of the cell because it",
|
||||
"Shakespeare's Hamlet begins with the famous line",
|
||||
]
|
||||
|
||||
|
||||
def phase_of(L, num_layers):
|
||||
if num_layers == 64:
|
||||
if L <= 6: return "A"
|
||||
if L <= 9: return "B"
|
||||
if L <= 31: return "C"
|
||||
if L <= 46: return "D"
|
||||
if L <= 58: return "E"
|
||||
return "tail"
|
||||
frac = L / num_layers
|
||||
if frac < 0.11: return "A"
|
||||
if frac < 0.15: return "B"
|
||||
if frac < 0.5: return "C"
|
||||
if frac < 0.75: return "D"
|
||||
if frac < 0.92: return "E"
|
||||
return "tail"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-32B")
|
||||
ap.add_argument("--out", default="/tmp/delta-svd.json")
|
||||
ap.add_argument("--top-k", type=int, default=8,
|
||||
help="keep top-k singular values / directions per layer")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(f"Loading {args.model} ...", flush=True)
|
||||
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model, torch_dtype=torch.bfloat16, device_map="cuda",
|
||||
trust_remote_code=True, attn_implementation="eager",
|
||||
).eval()
|
||||
num_layers = model.config.num_hidden_layers
|
||||
hidden = model.config.hidden_size
|
||||
print(f" L={num_layers}, hidden={hidden}", flush=True)
|
||||
|
||||
# Concat calib and tokenize as one stream
|
||||
text = "\n\n".join(CALIB)
|
||||
enc = tok(text, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
|
||||
n_tok = enc.input_ids.shape[1]
|
||||
print(f" calibration tokens: {n_tok}", flush=True)
|
||||
|
||||
out = model(**enc, output_hidden_states=True, use_cache=False)
|
||||
# hidden_states: tuple of (num_layers+1) tensors, each (1, n_tok, hidden)
|
||||
hs = [h[0].float().cpu().numpy() for h in out.hidden_states]
|
||||
# hs[L] = residual stream entering layer L (or leaving layer L-1). So
|
||||
# δ_L = hs[L+1] - hs[L] is layer L's contribution.
|
||||
print(f" hidden_states count: {len(hs)} (expect {num_layers+1})", flush=True)
|
||||
del model, out
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Per-layer SVD
|
||||
per_layer = []
|
||||
for L in range(num_layers):
|
||||
delta = hs[L+1] - hs[L] # (n_tok, hidden)
|
||||
h_in = hs[L] # (n_tok, hidden)
|
||||
# Remove BOS / first-token artifacts (often outlier)
|
||||
delta = delta[1:]
|
||||
h_in = h_in[1:]
|
||||
n, d = delta.shape
|
||||
|
||||
# Norm per token
|
||||
token_norms = np.linalg.norm(delta, axis=1) # (n,)
|
||||
h_norms = np.linalg.norm(h_in, axis=1) # (n,)
|
||||
# Relative step size: ||δ_L|| / ||h_L||
|
||||
rel_step = (token_norms / np.maximum(h_norms, 1e-8))
|
||||
# Angle between δ and h, per token: cos = <δ, h> / (||δ||||h||)
|
||||
dot = np.einsum("nd,nd->n", delta, h_in)
|
||||
cos_delta_h = dot / np.maximum(token_norms * h_norms, 1e-8)
|
||||
# "Parallel" component: how much of δ points along ±h
|
||||
parallel_frac = np.abs(cos_delta_h).mean()
|
||||
|
||||
# SVD in economy mode (on CPU; 2047x5120 fits easy)
|
||||
U, S, Vt = np.linalg.svd(delta, full_matrices=False)
|
||||
# S: singular values, decreasing. Vt: right singular vectors (directions).
|
||||
|
||||
# Effective rank (entropy of normalized squared SVs)
|
||||
p = S**2 / (S**2).sum()
|
||||
p_nz = p[p > 1e-12]
|
||||
eff_rank = float(np.exp(-(p_nz * np.log(p_nz)).sum()))
|
||||
|
||||
# Energy concentration
|
||||
top1_frac = float(p[0])
|
||||
top3_frac = float(p[:3].sum())
|
||||
top10_frac = float(p[:min(10, len(p))].sum())
|
||||
|
||||
per_layer.append({
|
||||
"L": L,
|
||||
"phase": phase_of(L, num_layers),
|
||||
"frob": float(np.linalg.norm(delta)),
|
||||
"token_norm_mean": float(token_norms.mean()),
|
||||
"token_norm_std": float(token_norms.std()),
|
||||
"h_norm_mean": float(h_norms.mean()),
|
||||
"rel_step_mean": float(rel_step.mean()),
|
||||
"rel_step_std": float(rel_step.std()),
|
||||
"parallel_frac": float(parallel_frac),
|
||||
"cos_delta_h_mean": float(cos_delta_h.mean()),
|
||||
"top_singvals": S[:args.top_k].tolist(),
|
||||
"top_dirs": Vt[:args.top_k].astype(np.float32).tolist(),
|
||||
"eff_rank": eff_rank,
|
||||
"top1_frac": top1_frac,
|
||||
"top3_frac": top3_frac,
|
||||
"top10_frac": top10_frac,
|
||||
})
|
||||
print(f" L={L:>2} phase={phase_of(L, num_layers):>4} "
|
||||
f"||h||={h_norms.mean():>7.1f} "
|
||||
f"||δ||={token_norms.mean():>7.2f} "
|
||||
f"rel={rel_step.mean():.4f} "
|
||||
f"‖parallel‖={parallel_frac:.4f} "
|
||||
f"eff_rank={eff_rank:>6.2f}",
|
||||
flush=True)
|
||||
|
||||
# Pairwise cos-sim of top-1 directions across layers
|
||||
top1_dirs = np.array([pl["top_dirs"][0] for pl in per_layer]) # (L, d)
|
||||
top1_cos = top1_dirs @ top1_dirs.T # (L, L)
|
||||
|
||||
# Subspace principal angles: project each layer's top-k into others' span
|
||||
print(f"\n=== Pairwise top-1 cos-sim (adjacent) ===")
|
||||
for L in range(num_layers - 1):
|
||||
print(f" L={L:>2}→{L+1:>2} phase={phase_of(L, num_layers):>4} "
|
||||
f"|cos|={abs(top1_cos[L, L+1]):>.4f}")
|
||||
|
||||
# Per-phase summary: mean |cos| within phase vs cross-phase
|
||||
phase_members = {}
|
||||
for L in range(num_layers):
|
||||
phase_members.setdefault(phase_of(L, num_layers), []).append(L)
|
||||
|
||||
print(f"\n=== Per-phase top-1 direction overlap ===")
|
||||
print(f" {'phase':>6} {'N':>3} {'intra_cos_mean':>14} {'cross_cos_mean':>14}")
|
||||
for ph, Ls in phase_members.items():
|
||||
intra = abs(top1_cos[np.ix_(Ls, Ls)])
|
||||
if len(Ls) >= 2:
|
||||
intra_vals = intra[np.triu_indices(len(Ls), k=1)]
|
||||
intra_mean = float(intra_vals.mean())
|
||||
else:
|
||||
intra_mean = 1.0
|
||||
other_Ls = [L for L in range(num_layers) if L not in Ls]
|
||||
if other_Ls:
|
||||
cross = abs(top1_cos[np.ix_(Ls, other_Ls)])
|
||||
cross_mean = float(cross.mean())
|
||||
else:
|
||||
cross_mean = 0.0
|
||||
print(f" {ph:>6} {len(Ls):>3} {intra_mean:>14.4f} {cross_mean:>14.4f}")
|
||||
|
||||
# Subspace overlap: for each phase, find the block's overall principal subspace
|
||||
# and measure how much of each individual layer sits in it.
|
||||
print(f"\n=== Block-shared subspace (rank-8) capture fraction per layer ===")
|
||||
for ph, Ls in phase_members.items():
|
||||
if len(Ls) < 2:
|
||||
continue
|
||||
# Stack top-k directions from all layers in phase
|
||||
block_dirs = np.concatenate([per_layer[L]["top_dirs"] for L in Ls], axis=0)
|
||||
# SVD to get the shared basis of the union
|
||||
U_b, S_b, Vt_b = np.linalg.svd(block_dirs, full_matrices=False)
|
||||
shared_basis = Vt_b[:8] # top-8 shared directions of the block's top-k union
|
||||
# Project each layer's top-1 direction and measure capture
|
||||
for L in Ls:
|
||||
v1 = np.array(per_layer[L]["top_dirs"][0])
|
||||
capture = float((shared_basis @ v1).__pow__(2).sum())
|
||||
print(f" phase={ph:>4} L={L:>2} v1 captured by block top-8: {capture:.4f}")
|
||||
|
||||
# Save
|
||||
save = {
|
||||
"model": args.model,
|
||||
"num_layers": num_layers,
|
||||
"hidden": hidden,
|
||||
"n_calib_tokens": int(n_tok),
|
||||
"per_layer": [
|
||||
{k: v for k, v in pl.items() if k != "top_dirs"} # directions too big
|
||||
for pl in per_layer
|
||||
],
|
||||
"top1_cos_adjacent": [float(top1_cos[L, L+1]) for L in range(num_layers-1)],
|
||||
}
|
||||
with open(args.out, "w") as f:
|
||||
json.dump(save, f, indent=2)
|
||||
print(f"\nSaved: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
sa-schedule-derive-from-last.py
Normal file
214
sa-schedule-derive-from-last.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
"""Under the SA-schedule hypothesis, earlier layers should be approximately
|
||||
a temperature-rescaled version of a shared operator. The simplest test:
|
||||
pick the last layer's per-head metric spectrum as anchor, and ask whether
|
||||
earlier layers' spectra are scalar rescales of it.
|
||||
|
||||
Three experiments on the existing per-head singular values:
|
||||
|
||||
(1) Spectral shape invariance. For each head h, normalize σ_L^h by σ_max
|
||||
and compare the shape vector across layers. If shapes match, scale is
|
||||
the only free parameter.
|
||||
|
||||
(2) Scalar rescale fit. For each (L, h), find T_L^h minimizing
|
||||
||σ_L^h - T_L^h σ_last^h||². Optimal T_L^h = <σ_L^h, σ_last^h>/||σ_last^h||².
|
||||
Report residual = ||σ_L^h - T_L^h σ_last^h|| / ||σ_L^h||.
|
||||
|
||||
(3) Cross-head sharing. If the *shape* is the same across heads too (not
|
||||
just across layers), we could use a single anchor per *layer* (last
|
||||
layer, one head) and reconstruct everything. Report mean shape
|
||||
correlation across heads within a layer.
|
||||
|
||||
The anchor doesn't have to be the last layer — we also try: last layer,
|
||||
middle layer, per-layer-group best match. Purpose is not to pick the best
|
||||
anchor but to understand which choice lets reconstruction succeed.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pad_to(arr, n):
|
||||
"""Pad a 1D array to length n with zeros (for heads of different rank)."""
|
||||
if arr.shape[0] == n:
|
||||
return arr
|
||||
out = np.zeros(n, dtype=arr.dtype)
|
||||
out[:arr.shape[0]] = arr
|
||||
return out
|
||||
|
||||
|
||||
def collect_spectra(data):
|
||||
"""Return array sigma[L, h, k] of singular values, padded."""
|
||||
num_layers = data["num_layers"]
|
||||
num_heads = data["num_heads"]
|
||||
# Determine max rank across all heads
|
||||
max_k = 0
|
||||
for row in data["static"]:
|
||||
for s in row["metric_singvals_per_head"]:
|
||||
max_k = max(max_k, len(s))
|
||||
sigma = np.zeros((num_layers, num_heads, max_k), dtype=np.float64)
|
||||
for L, row in enumerate(data["static"]):
|
||||
for h, s in enumerate(row["metric_singvals_per_head"]):
|
||||
sigma[L, h, :len(s)] = s
|
||||
return sigma
|
||||
|
||||
|
||||
def scalar_rescale_fit(x, y):
|
||||
"""Optimal scalar T s.t. ||x - T y|| is minimized.
|
||||
Returns (T, residual_frac) where residual_frac = ||x - T y|| / ||x||.
|
||||
"""
|
||||
denom = float((y * y).sum())
|
||||
if denom < 1e-20:
|
||||
return 0.0, 1.0
|
||||
T = float((x * y).sum() / denom)
|
||||
resid = x - T * y
|
||||
rn = float(np.linalg.norm(resid))
|
||||
xn = float(np.linalg.norm(x))
|
||||
return T, (rn / xn if xn > 1e-20 else 0.0)
|
||||
|
||||
|
||||
def cos_sim(x, y):
|
||||
xn = float(np.linalg.norm(x))
|
||||
yn = float(np.linalg.norm(y))
|
||||
if xn < 1e-20 or yn < 1e-20:
|
||||
return 0.0
|
||||
return float((x * y).sum() / (xn * yn))
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input_json")
|
||||
ap.add_argument("--anchor", choices=["last", "middle", "best"], default="last")
|
||||
args = ap.parse_args()
|
||||
|
||||
with open(args.input_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
num_layers = data["num_layers"]
|
||||
num_heads = data["num_heads"]
|
||||
sigma = collect_spectra(data) # (L, H, K)
|
||||
print(f"Loaded sigma: shape {sigma.shape}, max rank {sigma.shape[-1]}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Experiment 1: spectral shape invariance across layers (per head)
|
||||
# ------------------------------------------------------------------
|
||||
print("\n=== (1) Spectral shape invariance across layers ===")
|
||||
# For each head, compute normalized shape σ / σ_max per layer; measure
|
||||
# mean pairwise cosine similarity of shapes across layers.
|
||||
shape = np.zeros_like(sigma)
|
||||
for L in range(num_layers):
|
||||
for h in range(num_heads):
|
||||
s = sigma[L, h]
|
||||
mx = s.max()
|
||||
shape[L, h] = s / mx if mx > 1e-20 else s
|
||||
|
||||
per_head_cos = np.zeros(num_heads)
|
||||
for h in range(num_heads):
|
||||
cs = []
|
||||
for L1 in range(num_layers):
|
||||
for L2 in range(L1 + 1, num_layers):
|
||||
cs.append(cos_sim(shape[L1, h], shape[L2, h]))
|
||||
per_head_cos[h] = np.mean(cs)
|
||||
print(f" per-head mean pairwise cosine of shape across layers:")
|
||||
print(f" mean {per_head_cos.mean():.4f} std {per_head_cos.std():.4f} "
|
||||
f"min {per_head_cos.min():.4f} max {per_head_cos.max():.4f}")
|
||||
# If mean > ~0.99 → shapes identical, pure scalar rescale works
|
||||
# If mean ~ 0.85-0.95 → close but structure changes layer-to-layer
|
||||
# If mean < 0.8 → shape varies meaningfully, scalar rescale insufficient
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Experiment 2: scalar rescale fit to an anchor layer
|
||||
# ------------------------------------------------------------------
|
||||
if args.anchor == "last":
|
||||
anchor_L = num_layers - 1
|
||||
elif args.anchor == "middle":
|
||||
anchor_L = num_layers // 2
|
||||
else: # best: pick layer whose shape is most typical (highest mean cos
|
||||
# to all other layers)
|
||||
best_score = -1.0
|
||||
anchor_L = num_layers - 1
|
||||
for Lc in range(num_layers):
|
||||
score = 0.0
|
||||
for h in range(num_heads):
|
||||
for L in range(num_layers):
|
||||
if L == Lc:
|
||||
continue
|
||||
score += cos_sim(shape[Lc, h], shape[L, h])
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
anchor_L = Lc
|
||||
print(f" [auto-anchor] best layer by total shape-cosine: L={anchor_L}")
|
||||
|
||||
print(f"\n=== (2) Scalar rescale fit to anchor L={anchor_L} ===")
|
||||
T_map = np.zeros((num_layers, num_heads))
|
||||
resid_map = np.zeros((num_layers, num_heads))
|
||||
for L in range(num_layers):
|
||||
for h in range(num_heads):
|
||||
T, r = scalar_rescale_fit(sigma[L, h], sigma[anchor_L, h])
|
||||
T_map[L, h] = T
|
||||
resid_map[L, h] = r
|
||||
|
||||
# Per-layer residual stats
|
||||
print(f" per-layer residual fraction ||σ_L^h - T σ_anchor^h|| / ||σ_L^h||:")
|
||||
print(f" {'L':>3} {'mean resid':>10} {'max resid':>10} {'mean T':>8}")
|
||||
for L in range(num_layers):
|
||||
rl = resid_map[L]
|
||||
tl = T_map[L]
|
||||
print(f" {L:>3} {rl.mean():>10.4f} {rl.max():>10.4f} {tl.mean():>8.3f}")
|
||||
|
||||
print(f"\n overall mean residual: {resid_map.mean():.4f}")
|
||||
print(f" overall max residual: {resid_map.max():.4f}")
|
||||
print(f" frac of (L,h) with resid < 0.10: "
|
||||
f"{(resid_map < 0.10).mean():.3f}")
|
||||
print(f" frac of (L,h) with resid < 0.20: "
|
||||
f"{(resid_map < 0.20).mean():.3f}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Experiment 2b: does T match per-head dynamic entropy?
|
||||
# ------------------------------------------------------------------
|
||||
ent = np.array([row["mean_attention_entropy_per_head"]
|
||||
for row in data["dynamic"]]) # (L, H)
|
||||
# T is a scalar temperature of the metric. Geometrically, higher T means
|
||||
# sharper attention (smaller entropy). So corr(T, entropy) should be negative
|
||||
# if the scalar rescale captures the temperature schedule.
|
||||
from numpy import corrcoef
|
||||
c = float(corrcoef(T_map.flatten(), ent.flatten())[0, 1])
|
||||
print(f"\n correlation corr(T_L^h, entropy_L^h) = {c:+.3f} "
|
||||
f"(negative expected: larger T → sharper → lower entropy)")
|
||||
|
||||
# Also try: does T predict entropy *better* than raw op_norm? (Already had
|
||||
# op_norm r=+0.45 in geometry analysis.)
|
||||
op_norm = sigma.max(axis=-1) # (L, H)
|
||||
c_op = float(corrcoef(op_norm.flatten(), ent.flatten())[0, 1])
|
||||
print(f" for comparison, corr(op_norm, entropy) = {c_op:+.3f}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Experiment 3: shape similarity across heads within a layer
|
||||
# ------------------------------------------------------------------
|
||||
print(f"\n=== (3) Cross-head shape similarity within each layer ===")
|
||||
print(f" {'L':>3} {'mean pair-cos':>14}")
|
||||
for L in range(num_layers):
|
||||
cs = []
|
||||
for h1 in range(num_heads):
|
||||
for h2 in range(h1 + 1, num_heads):
|
||||
cs.append(cos_sim(shape[L, h1], shape[L, h2]))
|
||||
print(f" {L:>3} {np.mean(cs):>14.4f}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Summary
|
||||
# ------------------------------------------------------------------
|
||||
print("\n=== Summary ===")
|
||||
print(f" anchor layer: {anchor_L}")
|
||||
print(f" spectral shape is {'very stable' if per_head_cos.mean() > 0.98 else 'approximately stable' if per_head_cos.mean() > 0.9 else 'not stable'} "
|
||||
f"across layers (per-head mean pairwise cos = {per_head_cos.mean():.3f})")
|
||||
print(f" scalar-rescale fit residual: mean {resid_map.mean():.3f}")
|
||||
if resid_map.mean() < 0.1:
|
||||
verdict = "HYPOTHESIS SUPPORTED — scalar temperature rescale of a shared operator reconstructs earlier layers to within 10% Frobenius residual."
|
||||
elif resid_map.mean() < 0.3:
|
||||
verdict = "PARTIALLY SUPPORTED — scalar rescale captures most of the structure; a low-rank correction on top is likely enough."
|
||||
else:
|
||||
verdict = "HYPOTHESIS REJECTED for pure scalar rescale — spectra differ substantially in shape; need full layer-by-layer operators or rank-k delta."
|
||||
print(f"\n {verdict}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
sa-schedule-fit-gamma.py
Normal file
145
sa-schedule-fit-gamma.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Fit a functional form to the LN γ trajectory across layers; derive the
|
||||
effective attention temperature T(L) from known coupling formulas.
|
||||
|
||||
Rules of what scales with depth (from literature):
|
||||
DeepNorm: α_dec = (2M)^(1/4), β_dec = (8M)^(-1/4). Same per layer — does
|
||||
NOT depend on layer index l. The free variation across layers has to
|
||||
live in LN γ.
|
||||
Depth-μP: block multiplier a/√L, LR η/√L. Same per layer.
|
||||
So γ(L) is the family carrying the per-layer schedule.
|
||||
|
||||
Try fitting forms:
|
||||
γ(L) = a · L^b (power law in layer index)
|
||||
γ(L) = a · exp(b·L) (exponential)
|
||||
γ(L) = a + b·L (linear)
|
||||
γ(L) = a + b·L^c (free c) (power law with free exponent)
|
||||
|
||||
Report fit quality (R², residual statistics), and for the best fit, compute
|
||||
the derived T(L) curve.
|
||||
"""
|
||||
import json
|
||||
import numpy as np
|
||||
from math import log, exp
|
||||
|
||||
|
||||
def fit_power(L, y):
|
||||
"""y ≈ a · L^b → log y ≈ log a + b log L."""
|
||||
mask = (L > 0) & (y > 0)
|
||||
lx, ly = np.log(L[mask]), np.log(y[mask])
|
||||
b, loga = np.polyfit(lx, ly, 1)
|
||||
yhat = np.exp(loga) * (L**b)
|
||||
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
||||
return {"form": "a*L^b", "a": float(np.exp(loga)), "b": float(b), "r2": float(r2), "yhat": yhat}
|
||||
|
||||
|
||||
def fit_exponential(L, y):
|
||||
"""y ≈ a · exp(b·L) → log y ≈ log a + b·L."""
|
||||
mask = y > 0
|
||||
b, loga = np.polyfit(L[mask], np.log(y[mask]), 1)
|
||||
yhat = np.exp(loga) * np.exp(b * L)
|
||||
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
||||
return {"form": "a*exp(b*L)", "a": float(np.exp(loga)), "b": float(b), "r2": float(r2), "yhat": yhat}
|
||||
|
||||
|
||||
def fit_linear(L, y):
|
||||
b, a = np.polyfit(L, y, 1)
|
||||
yhat = a + b * L
|
||||
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
||||
return {"form": "a+b*L", "a": float(a), "b": float(b), "r2": float(r2), "yhat": yhat}
|
||||
|
||||
|
||||
def fit_piecewise_two(L, y):
|
||||
"""Best split point L* and linear fits on each half (log-space)."""
|
||||
best = None
|
||||
for Ls in range(3, len(L) - 3):
|
||||
mA, mB = L < Ls, L >= Ls
|
||||
if (y[mA] <= 0).any() or (y[mB] <= 0).any():
|
||||
continue
|
||||
bA, aA = np.polyfit(L[mA], np.log(y[mA]), 1)
|
||||
bB, aB = np.polyfit(L[mB], np.log(y[mB]), 1)
|
||||
yhat = np.where(mA, np.exp(aA + bA * L), np.exp(aB + bB * L))
|
||||
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
||||
if best is None or r2 > best["r2"]:
|
||||
best = {"form": f"piecewise-exp-split@L={Ls}", "split": int(Ls),
|
||||
"a1": float(np.exp(aA)), "b1": float(bA),
|
||||
"a2": float(np.exp(aB)), "b2": float(bB),
|
||||
"r2": float(r2), "yhat": yhat}
|
||||
return best
|
||||
|
||||
|
||||
def main():
|
||||
d = json.load(open("/tmp/qwen3-4b-null.json"))
|
||||
scales = d["scales"]
|
||||
num_layers = len(scales["input_ln"])
|
||||
L = np.arange(num_layers, dtype=float)
|
||||
|
||||
families_of_interest = ["input_ln", "post_attn_ln", "q_norm", "k_norm",
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
print("=" * 72)
|
||||
print("γ-trajectory fits per family (Qwen3-4B, 36 layers)")
|
||||
print("=" * 72)
|
||||
|
||||
for fam in families_of_interest:
|
||||
y = np.array(scales[fam], dtype=float)
|
||||
print(f"\n--- {fam} ---")
|
||||
print(f" L=0: {y[0]:.3f} L=35: {y[-1]:.3f} ratio: {y[-1]/y[0]:+.2f}×")
|
||||
fits = [
|
||||
fit_linear(L, y),
|
||||
fit_power(L + 1, y), # L+1 so L=0 doesn't explode log
|
||||
fit_exponential(L, y),
|
||||
fit_piecewise_two(L + 1, y),
|
||||
]
|
||||
for f in fits:
|
||||
if f is None:
|
||||
continue
|
||||
extras = ""
|
||||
if "b" in f:
|
||||
extras = f" (a={f['a']:.3g}, b={f['b']:+.4f})"
|
||||
elif "split" in f:
|
||||
extras = f" (split={f['split']}, b1={f['b1']:+.4f}, b2={f['b2']:+.4f})"
|
||||
print(f" {f['form']:<32} R²={f['r2']:+.4f}{extras}")
|
||||
|
||||
# For input_ln specifically: plot the curve (text) and derive T(L)
|
||||
y = np.array(scales["input_ln"], dtype=float)
|
||||
print("\n" + "=" * 72)
|
||||
print("input_ln γ magnitude across layers (the schedule signal)")
|
||||
print("=" * 72)
|
||||
print(f" {'L':>3} {'γ_L':>12} {'γ_L / γ_0':>10} {'log γ_L':>10}")
|
||||
for l_idx in range(num_layers):
|
||||
print(f" {l_idx:>3} {y[l_idx]:>12.3f} {y[l_idx]/y[0]:>10.3f} {log(y[l_idx]):>+10.4f}")
|
||||
|
||||
# Classical SA schedules for comparison
|
||||
# - Linear: T(k) = T0 - k * (T0 - Tf)/N
|
||||
# - Exponential / Kirkpatrick: T(k) = T0 * α^k
|
||||
# - Logarithmic / Hajek: T(k) = c / log(k+2)
|
||||
# For γ (which grows = temperature drops, since larger γ → sharper attention):
|
||||
# γ growing corresponds to T cooling
|
||||
print("\n" + "=" * 72)
|
||||
print("Derived attention-temperature T(L) interpretation")
|
||||
print("=" * 72)
|
||||
print(" Attention logit ∝ (γ * W_Q * W_K * ||residual||²) / √d_head.")
|
||||
print(" With γ_L the schedule dial and other factors ~constant across layers,")
|
||||
print(" effective attention temperature T(L) ∝ 1/γ(L).")
|
||||
print(f"\n T(L)/T(0) = γ(0)/γ(L):")
|
||||
print(f" {'L':>3} {'T(L)/T(0)':>10} (smaller = cooler = sharper attention)")
|
||||
for l_idx in range(num_layers):
|
||||
print(f" {l_idx:>3} {y[0]/y[l_idx]:>10.4f}")
|
||||
|
||||
# Comparison with classical SA cooling laws:
|
||||
# Kirkpatrick: T(L) = T0 · α^L → log T(L) = log T0 + L log α
|
||||
logT = -np.log(y / y[0]) # because T ∝ 1/γ
|
||||
b_kirk, a_kirk = np.polyfit(L, logT, 1)
|
||||
# Hajek (log-cooling): T(L) = c/log(L+2)
|
||||
# Predicts: log T = log c - log(log(L+2))
|
||||
# Fit T(L) to c / log(L+c2)
|
||||
print(f"\n Kirkpatrick-law fit (exponential cooling):")
|
||||
print(f" log T(L) = {a_kirk:+.3f} + {b_kirk:+.4f} * L → T(L) = exp({a_kirk:+.3f}) · exp({b_kirk:+.4f}·L)")
|
||||
logT_hat = a_kirk + b_kirk * L
|
||||
r2_kirk = 1 - ((logT - logT_hat)**2).sum() / ((logT - logT.mean())**2).sum()
|
||||
print(f" R² (in log space) = {r2_kirk:+.4f} — ideally ≈ 1 if cooling is pure exponential")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
sa-schedule-gamma-directions.py
Normal file
122
sa-schedule-gamma-directions.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Pull input_layernorm.γ vectors from a model and analyze direction
|
||||
structure across layers.
|
||||
|
||||
Question: is γ just scalar magnitude (isotropic SA) or does each layer
|
||||
have a preferred direction (anisotropic SA / geometry-aware)?
|
||||
|
||||
Decomposition: γ_L = ||γ_L|| · γ_L̂
|
||||
- ||γ_L|| is what our scalar Kirkpatrick fit captured
|
||||
- γ_L̂ is unit direction — if layers share direction, γ is rank-1 +
|
||||
scaling (classical isotropic). If directions differ per layer, γ
|
||||
encodes per-layer preferred axis (anisotropic).
|
||||
|
||||
We also look at:
|
||||
- pairwise cos-sim between γ_L̂ across layers
|
||||
- principal components of [γ_L̂]_L (stacked matrix)
|
||||
- per-phase structure: is Phase E more anisotropic than Phase C?
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-32B")
|
||||
ap.add_argument("--out", default="/tmp/gamma-dirs.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(f"Loading {args.model} (CPU, layernorm params only)...", flush=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model, torch_dtype=torch.float32, device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
num_layers = m.config.num_hidden_layers
|
||||
hidden = m.config.hidden_size
|
||||
print(f" L={num_layers}, hidden={hidden}", flush=True)
|
||||
|
||||
gammas = np.stack([
|
||||
m.model.layers[L].input_layernorm.weight.detach().float().cpu().numpy()
|
||||
for L in range(num_layers)
|
||||
]) # (L, hidden)
|
||||
del m
|
||||
|
||||
norms = np.linalg.norm(gammas, axis=1)
|
||||
units = gammas / norms[:, None]
|
||||
|
||||
# Pairwise cos-sim of unit γ
|
||||
cos_mat = units @ units.T # (L, L)
|
||||
|
||||
# PCA on unit vectors
|
||||
centered = units - units.mean(axis=0, keepdims=True)
|
||||
_, S, Vt = np.linalg.svd(centered, full_matrices=False)
|
||||
explained = S**2 / (S**2).sum()
|
||||
|
||||
# How much of each γ_L unit is explained by top-1 direction (shared)?
|
||||
top1 = Vt[0] # (hidden,)
|
||||
proj_top1 = units @ top1 # (L,)
|
||||
residual_after_top1 = np.sqrt(np.maximum(1 - proj_top1**2, 0))
|
||||
|
||||
# Per-phase summary (Qwen3-32B boundaries)
|
||||
def phase(L):
|
||||
if L <= 6: return "A"
|
||||
if L <= 9: return "B"
|
||||
if L <= 31: return "C"
|
||||
if L <= 46: return "D"
|
||||
if L <= 58: return "E"
|
||||
return "tail"
|
||||
|
||||
phase_ls = {}
|
||||
for L in range(num_layers):
|
||||
phase_ls.setdefault(phase(L), []).append(L)
|
||||
|
||||
print(f"\n=== ||γ_L|| per layer (scalar magnitude) ===")
|
||||
for L in range(num_layers):
|
||||
print(f" L={L:>2} phase={phase(L):>5} ||γ||={norms[L]:>8.3f} "
|
||||
f"proj_top1={proj_top1[L]:>+.4f} resid={residual_after_top1[L]:>.4f}")
|
||||
|
||||
print(f"\n=== PCA of unit γ vectors (direction structure) ===")
|
||||
print(f" Explained variance, top 10 components:")
|
||||
for i in range(min(10, len(S))):
|
||||
print(f" PC{i}: {explained[i]:.4f} (singular_val={S[i]:.4f})")
|
||||
print(f" Top-3 explain: {explained[:3].sum():.4f}")
|
||||
print(f" Top-10 explain: {explained[:10].sum():.4f}")
|
||||
|
||||
print(f"\n=== Per-phase direction statistics ===")
|
||||
print(f" {'phase':>6} {'N':>3} {'||γ||_mean':>10} {'||γ||_std':>9} "
|
||||
f"{'intra_cos':>9} {'vs_other_cos':>12}")
|
||||
for ph, Ls in phase_ls.items():
|
||||
u = units[Ls]
|
||||
intra = (u @ u.T)[np.triu_indices(len(Ls), k=1)]
|
||||
intra_mean = intra.mean() if len(intra) > 0 else 1.0
|
||||
# Vs other phases
|
||||
other_Ls = [L for L in range(num_layers) if L not in Ls]
|
||||
if other_Ls:
|
||||
u_other = units[other_Ls]
|
||||
vs = u @ u_other.T
|
||||
vs_mean = vs.mean()
|
||||
else:
|
||||
vs_mean = 0.0
|
||||
print(f" {ph:>6} {len(Ls):>3} {norms[Ls].mean():>10.3f} "
|
||||
f"{norms[Ls].std():>9.3f} {intra_mean:>+9.4f} {vs_mean:>+12.4f}")
|
||||
|
||||
print(f"\n=== Adjacent-pair unit-γ cos-sim ===")
|
||||
for L in range(num_layers - 1):
|
||||
print(f" L={L:>2}→{L+1:>2} phase={phase(L):>5} cos={cos_mat[L, L+1]:>+.4f}")
|
||||
|
||||
import json
|
||||
with open(args.out, "w") as f:
|
||||
json.dump({
|
||||
"model": args.model,
|
||||
"num_layers": num_layers,
|
||||
"norms": norms.tolist(),
|
||||
"proj_top1": proj_top1.tolist(),
|
||||
"explained_var": explained.tolist(),
|
||||
"cos_adjacent": [float(cos_mat[L, L+1]) for L in range(num_layers - 1)],
|
||||
}, f, indent=2)
|
||||
print(f"\nSaved: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
sa-schedule-geometry-analyze.py
Normal file
114
sa-schedule-geometry-analyze.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""What does per-head T (entropy) correlate with geometrically?
|
||||
|
||||
For each (layer, head) we already have singular values of the metric M^h = W_K^h^T W_Q^h
|
||||
(up to the low-rank structure — strictly SVD of the head_dim x head_dim product). Derive
|
||||
richer per-head geometric descriptors and test which ones predict dynamic entropy.
|
||||
|
||||
Descriptors per head:
|
||||
op_norm σ_max — global "capacity for sharpness"
|
||||
fro_norm √Σ σ_i² — total metric "energy"
|
||||
rank_eff Σσ / σ_max — effective number of modes
|
||||
spec_entropy -Σ (σ_i² / Σσ_j²) log(...) — flatness of spectrum (nats)
|
||||
anisotropy σ_max / σ_mean — how "peaked" the top mode is
|
||||
condition σ_max / σ_min — ratio of biggest to smallest
|
||||
trace Σσ_i — sum of modes (L1-like)
|
||||
|
||||
Correlate each of these per-head descriptors against per-head dynamic entropy, across
|
||||
all (layer, head) pairs. Also stratified by layer-position (early/mid/late).
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_per_head_geometry(singvals_list):
|
||||
"""singvals_list: list per head of list of singular values. Returns dict of arrays."""
|
||||
s_all = [np.array(s, dtype=np.float64) for s in singvals_list]
|
||||
op = np.array([s.max() for s in s_all])
|
||||
fro = np.array([np.sqrt((s ** 2).sum()) for s in s_all])
|
||||
trace = np.array([s.sum() for s in s_all])
|
||||
rank_eff = np.array([s.sum() / max(s.max(), 1e-12) for s in s_all])
|
||||
# Spectral entropy: use normalized σ² as probabilities
|
||||
spec_ent = np.zeros(len(s_all))
|
||||
for i, s in enumerate(s_all):
|
||||
p = (s ** 2) / max((s ** 2).sum(), 1e-12)
|
||||
p = np.clip(p, 1e-12, 1.0)
|
||||
spec_ent[i] = float(-(p * np.log(p)).sum())
|
||||
anis = np.array([s.max() / max(s.mean(), 1e-12) for s in s_all])
|
||||
cond = np.array([s.max() / max(s.min(), 1e-12) for s in s_all])
|
||||
return dict(op=op, fro=fro, trace=trace, rank_eff=rank_eff,
|
||||
spec_ent=spec_ent, anisotropy=anis, condition=cond)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input_json")
|
||||
args = ap.parse_args()
|
||||
|
||||
with open(args.input_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
num_layers = data["num_layers"]
|
||||
num_heads = data["num_heads"]
|
||||
|
||||
# Entropy per (layer, head)
|
||||
ent = np.array([row["mean_attention_entropy_per_head"] for row in data["dynamic"]]) # (L, H)
|
||||
logit_std = np.array([row["mean_logit_std_per_head"] for row in data["dynamic"]]) # (L, H)
|
||||
|
||||
# Geometric descriptors per (layer, head)
|
||||
geom = {k: np.zeros((num_layers, num_heads)) for k in
|
||||
["op", "fro", "trace", "rank_eff", "spec_ent", "anisotropy", "condition"]}
|
||||
for L, row in enumerate(data["static"]):
|
||||
per_head = compute_per_head_geometry(row["metric_singvals_per_head"])
|
||||
for k, v in per_head.items():
|
||||
geom[k][L] = v
|
||||
|
||||
# Flatten across (layer, head) and correlate
|
||||
print("All (layer, head) pairs — Pearson correlation with dynamic entropy:")
|
||||
ent_flat = ent.flatten()
|
||||
logit_flat = logit_std.flatten()
|
||||
results = {}
|
||||
for k, v in geom.items():
|
||||
v_flat = v.flatten()
|
||||
c_ent = float(np.corrcoef(v_flat, ent_flat)[0, 1])
|
||||
c_logit = float(np.corrcoef(v_flat, logit_flat)[0, 1])
|
||||
results[k] = (c_ent, c_logit)
|
||||
print(f" {k:12} vs entropy: {c_ent:+.3f} vs logit_std: {c_logit:+.3f}")
|
||||
|
||||
# Stratify by layer position — early (0-11), mid (12-23), late (24-35)
|
||||
thirds = [(0, num_layers // 3, "early"),
|
||||
(num_layers // 3, 2 * num_layers // 3, "mid"),
|
||||
(2 * num_layers // 3, num_layers, "late")]
|
||||
print("\nStratified by layer position (entropy correlation):")
|
||||
for lo, hi, name in thirds:
|
||||
print(f" [{name} L{lo}-{hi-1}]", end="")
|
||||
for k in ["op", "fro", "rank_eff", "spec_ent", "anisotropy", "condition"]:
|
||||
c = float(np.corrcoef(geom[k][lo:hi].flatten(), ent[lo:hi].flatten())[0, 1])
|
||||
print(f" {k}:{c:+.2f}", end="")
|
||||
print()
|
||||
|
||||
# Best single predictor across all
|
||||
print("\nBest single geometric predictor of entropy (abs):")
|
||||
best = max(results.items(), key=lambda kv: abs(kv[1][0]))
|
||||
print(f" {best[0]} r = {best[1][0]:+.3f}")
|
||||
|
||||
# Multi-regression: try op, spec_ent, rank_eff jointly
|
||||
print("\nLinear regression of entropy on multiple descriptors (standardized):")
|
||||
from numpy.linalg import lstsq
|
||||
X_cols = ["op", "spec_ent", "rank_eff", "anisotropy"]
|
||||
X = np.stack([geom[k].flatten() for k in X_cols], axis=1)
|
||||
# standardize
|
||||
X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-12)
|
||||
y = (ent_flat - ent_flat.mean()) / (ent_flat.std() + 1e-12)
|
||||
X1 = np.concatenate([X, np.ones((X.shape[0], 1))], axis=1)
|
||||
coef, res, rk, sv = lstsq(X1, y, rcond=None)
|
||||
y_pred = X1 @ coef
|
||||
r2 = 1 - float(((y - y_pred) ** 2).sum() / ((y - y.mean()) ** 2).sum())
|
||||
print(f" R² = {r2:.3f}")
|
||||
print(f" standardized coefficients:")
|
||||
for name, c in zip(X_cols, coef[:-1]):
|
||||
print(f" {name:12} {c:+.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
238
sa-schedule-layer-variation.py
Normal file
238
sa-schedule-layer-variation.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""After removing the known gauge freedoms (per-head d_h rotation tying
|
||||
W_Q/W_K/W_V/W_O together, per-layer d_ff rotation tying gate/up/down),
|
||||
measure per-family Frobenius distance between consecutive layers within a
|
||||
middle block. Families with low post-alignment distance are candidates for
|
||||
"shared operator" across the block; high distance → carries the schedule.
|
||||
|
||||
Normalize each matrix by its Frobenius norm first (so scale differences
|
||||
don't dominate). We want to see direction of drift, not magnitude.
|
||||
|
||||
Gauge freedoms being removed:
|
||||
- Per-head d_h rotation R ∈ O(d_h): W_Q^h, W_K^h, W_V^h → R W^h;
|
||||
W_O^h → W_O^h R^T. Softmax attention is invariant under this.
|
||||
- Per-layer d_ff rotation S ∈ O(d_ff): gate_proj, up_proj → S W;
|
||||
down_proj → W S^T. SwiGLU/GLU is NOT fully invariant under d_ff
|
||||
rotation (because the elementwise gate*up is coordinate-dependent),
|
||||
so this is an approximate alignment — still better than raw.
|
||||
|
||||
Families that have no gauge freedom (layernorm γ, q_norm, k_norm): compare
|
||||
directly after scale normalization.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def procrustes(M):
|
||||
"""Orthogonal matrix R maximizing tr(R M). Given SVD M = U Σ V^T, R = U V^T."""
|
||||
U, _, Vh = np.linalg.svd(M, full_matrices=False)
|
||||
return U @ Vh
|
||||
|
||||
|
||||
def fro(x):
|
||||
return float(np.linalg.norm(x))
|
||||
|
||||
|
||||
def normalize_fro(x, eps=1e-12):
|
||||
n = fro(x)
|
||||
return x / max(n, eps)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
||||
ap.add_argument("--block-start", type=int, default=10)
|
||||
ap.add_argument("--block-end", type=int, default=25,
|
||||
help="inclusive; this is mid-block of 36-layer model")
|
||||
ap.add_argument("--out", default="/tmp/sa-layer-variation.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(f"Loading {args.model} ...", flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
torch_dtype=torch.float32,
|
||||
device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
intermediate = cfg.intermediate_size
|
||||
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
|
||||
f"hidden={hidden} ff={intermediate}", flush=True)
|
||||
|
||||
# Collect per-layer weight matrices as numpy float32.
|
||||
def get_np(name, idx):
|
||||
w = getattr(model.model.layers[idx], name, None)
|
||||
if w is None:
|
||||
return None
|
||||
return w
|
||||
|
||||
layers = {}
|
||||
for L in range(num_layers):
|
||||
layer = model.model.layers[L]
|
||||
attn = layer.self_attn
|
||||
mlp = layer.mlp
|
||||
layers[L] = {
|
||||
"q_proj": attn.q_proj.weight.detach().numpy().astype(np.float32), # (nh*hd, hidden)
|
||||
"k_proj": attn.k_proj.weight.detach().numpy().astype(np.float32), # (nkv*hd, hidden)
|
||||
"v_proj": attn.v_proj.weight.detach().numpy().astype(np.float32),
|
||||
"o_proj": attn.o_proj.weight.detach().numpy().astype(np.float32), # (hidden, nh*hd)
|
||||
"gate_proj": mlp.gate_proj.weight.detach().numpy().astype(np.float32),
|
||||
"up_proj": mlp.up_proj.weight.detach().numpy().astype(np.float32),
|
||||
"down_proj": mlp.down_proj.weight.detach().numpy().astype(np.float32),
|
||||
"input_ln": layer.input_layernorm.weight.detach().numpy().astype(np.float32),
|
||||
"post_attn_ln": layer.post_attention_layernorm.weight.detach().numpy().astype(np.float32),
|
||||
}
|
||||
# Qwen3 has q_norm / k_norm inside self_attn
|
||||
q_norm = getattr(attn, "q_norm", None)
|
||||
k_norm = getattr(attn, "k_norm", None)
|
||||
if q_norm is not None:
|
||||
layers[L]["q_norm"] = q_norm.weight.detach().numpy().astype(np.float32)
|
||||
if k_norm is not None:
|
||||
layers[L]["k_norm"] = k_norm.weight.detach().numpy().astype(np.float32)
|
||||
|
||||
del model # free memory
|
||||
|
||||
block = list(range(args.block_start, args.block_end + 1))
|
||||
pairs = [(block[i], block[i + 1]) for i in range(len(block) - 1)]
|
||||
print(f"\nAnalyzing block layers {args.block_start}..{args.block_end} "
|
||||
f"({len(pairs)} consecutive pairs)\n")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reshape attention weights per-head for rotation alignment
|
||||
# ------------------------------------------------------------------
|
||||
def per_head_split(W_qkv, n_heads_for_this):
|
||||
# W is (n*hd, hidden). Reshape to (n, hd, hidden).
|
||||
return W_qkv.reshape(n_heads_for_this, head_dim, hidden)
|
||||
|
||||
def per_head_split_o(W_o):
|
||||
# W is (hidden, n*hd). Reshape to (n, hidden, hd).
|
||||
return W_o.reshape(hidden, num_heads, head_dim).transpose(1, 0, 2)
|
||||
|
||||
# Replicate k/v head index to query head index space (GQA)
|
||||
def kv_to_q_index(h):
|
||||
return (h * num_kv_heads) // num_heads
|
||||
|
||||
family_residuals = {fam: [] for fam in
|
||||
["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
"input_ln", "post_attn_ln", "q_norm", "k_norm"]}
|
||||
|
||||
for (L1, L2) in pairs:
|
||||
A = layers[L1]
|
||||
B = layers[L2]
|
||||
|
||||
# Per-head attention alignment:
|
||||
Q1 = per_head_split(A["q_proj"], num_heads)
|
||||
Q2 = per_head_split(B["q_proj"], num_heads)
|
||||
K1 = per_head_split(A["k_proj"], num_kv_heads)
|
||||
K2 = per_head_split(B["k_proj"], num_kv_heads)
|
||||
V1 = per_head_split(A["v_proj"], num_kv_heads)
|
||||
V2 = per_head_split(B["v_proj"], num_kv_heads)
|
||||
O1 = per_head_split_o(A["o_proj"]) # (num_heads, hidden, hd)
|
||||
O2 = per_head_split_o(B["o_proj"])
|
||||
|
||||
q_res = []
|
||||
k_res = []
|
||||
v_res = []
|
||||
o_res = []
|
||||
for h in range(num_heads):
|
||||
kv_h = kv_to_q_index(h)
|
||||
# Normalize each matrix by its Frobenius norm
|
||||
qa = normalize_fro(Q1[h])
|
||||
qb = normalize_fro(Q2[h])
|
||||
ka = normalize_fro(K1[kv_h])
|
||||
kb = normalize_fro(K2[kv_h])
|
||||
va = normalize_fro(V1[kv_h])
|
||||
vb = normalize_fro(V2[kv_h])
|
||||
oa = normalize_fro(O1[h])
|
||||
ob = normalize_fro(O2[h])
|
||||
|
||||
# Cross-correlation for Procrustes: find R (hd × hd) maximizing
|
||||
# tr(R [Qa Qb^T + Ka Kb^T + Va Vb^T + (Oa^T Ob)])
|
||||
# Q, K, V are (hd, hidden); Q2 Q1^T would be (hd, hd); etc.
|
||||
M = qa @ qb.T + ka @ kb.T + va @ vb.T + (oa.T @ ob) # all (hd, hd)
|
||||
# Wait: for Q we want tr(R qa qb^T). So the matrix in the max-trace
|
||||
# Procrustes is qb @ qa.T? Let me be careful.
|
||||
# max_R tr(R M) achieved at R = U V^T with SVD M = U Σ V^T.
|
||||
# Here we want R such that R qa ≈ qb → minimize ||R qa - qb||²
|
||||
# = const - 2 tr(R qa qb^T). So max tr(R qa qb^T) gives the
|
||||
# correct R. Redo M as sum of qa qb^T terms.
|
||||
M = qa @ qb.T + ka @ kb.T + va @ vb.T
|
||||
# For O: want W_O^h R^T ≈ W_O^h_target, i.e. oa R^T ≈ ob
|
||||
# → min ||oa R^T - ob||² = const - 2 tr(R oa^T ob); max that.
|
||||
# So O contributes oa^T @ ob to the cross-correlation matrix.
|
||||
M = M + oa.T @ ob
|
||||
R = procrustes(M)
|
||||
|
||||
# Apply R and measure residual (Frobenius distance) per-matrix
|
||||
q_res.append(fro(R @ qa - qb))
|
||||
k_res.append(fro(R @ ka - kb))
|
||||
v_res.append(fro(R @ va - vb))
|
||||
o_res.append(fro(oa @ R.T - ob))
|
||||
|
||||
family_residuals["q_proj"].append(float(np.mean(q_res)))
|
||||
family_residuals["k_proj"].append(float(np.mean(k_res)))
|
||||
family_residuals["v_proj"].append(float(np.mean(v_res)))
|
||||
family_residuals["o_proj"].append(float(np.mean(o_res)))
|
||||
|
||||
# MLP d_ff rotation alignment: find S (d_ff × d_ff) orthogonal with
|
||||
# S gate_a ≈ gate_b and S up_a ≈ up_b simultaneously; adjust down_proj.
|
||||
# Each is (d_ff, hidden).
|
||||
ga = normalize_fro(A["gate_proj"])
|
||||
gb = normalize_fro(B["gate_proj"])
|
||||
ua = normalize_fro(A["up_proj"])
|
||||
ub = normalize_fro(B["up_proj"])
|
||||
da = normalize_fro(A["down_proj"]) # (hidden, d_ff)
|
||||
db = normalize_fro(B["down_proj"])
|
||||
# M_ff = ga @ gb^T + ua @ ub^T + da^T @ db (all d_ff × d_ff)
|
||||
M_ff = ga @ gb.T + ua @ ub.T + da.T @ db
|
||||
S = procrustes(M_ff)
|
||||
family_residuals["gate_proj"].append(fro(S @ ga - gb))
|
||||
family_residuals["up_proj"].append(fro(S @ ua - ub))
|
||||
family_residuals["down_proj"].append(fro(da @ S.T - db))
|
||||
|
||||
# LayerNorm γ vectors — no rotation gauge; just scale-normalize and diff
|
||||
for ln_name in ["input_ln", "post_attn_ln", "q_norm", "k_norm"]:
|
||||
if ln_name in A and ln_name in B:
|
||||
va_ = normalize_fro(A[ln_name])
|
||||
vb_ = normalize_fro(B[ln_name])
|
||||
family_residuals[ln_name].append(fro(va_ - vb_))
|
||||
|
||||
# Report
|
||||
print("=== Per-family Frobenius residual between consecutive layers, "
|
||||
f"block L={args.block_start}..{args.block_end}, after alignment + scale-norm ===\n")
|
||||
print(f" (Residual = Frobenius distance between L and L+1 after rotation alignment;")
|
||||
print(f" lower = more shared across block; higher = carries layer-to-layer drift)\n")
|
||||
print(f" {'family':>14} {'mean':>8} {'min':>8} {'max':>8} {'std':>8} n")
|
||||
# Report families sorted by mean variation
|
||||
items = [(fam, np.array(v)) for fam, v in family_residuals.items() if len(v) > 0]
|
||||
items.sort(key=lambda kv: float(kv[1].mean()))
|
||||
for fam, v in items:
|
||||
print(f" {fam:>14} {v.mean():>8.4f} {v.min():>8.4f} {v.max():>8.4f} {v.std():>8.4f} {len(v)}")
|
||||
|
||||
print(f"\n Families ranked least-to-most variation:")
|
||||
for i, (fam, v) in enumerate(items):
|
||||
print(f" {i+1}. {fam} (mean residual {v.mean():.4f})")
|
||||
|
||||
# Save
|
||||
with open(args.out, "w") as f:
|
||||
json.dump({
|
||||
"model": args.model,
|
||||
"block_start": args.block_start,
|
||||
"block_end": args.block_end,
|
||||
"family_residuals": {k: list(v) for k, v in family_residuals.items()},
|
||||
}, f, indent=2)
|
||||
print(f"\nSaved: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
sa-schedule-measure-grams.py
Normal file
168
sa-schedule-measure-grams.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Measure the full inter-layer geometric relationship between per-head metrics.
|
||||
|
||||
For each (L, L', h) pair, compute the Frobenius inner product
|
||||
<g_L^h, g_L'^h> = tr(g_L^h^T g_L'^h)
|
||||
where g^h = W_K^h^T W_Q^h ∈ R^{hidden × hidden} (rank ≤ head_dim).
|
||||
|
||||
Using the head_dim × head_dim shortcut:
|
||||
<g_L, g_L'> = tr(A B^T) with A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T.
|
||||
|
||||
Output: gram[L, L', h] and fro_sq[L, h]. From these every layer-pair comparison
|
||||
is derivable without saving the full operators.
|
||||
|
||||
Also saves top-k principal directions per head (as right singular vectors of g,
|
||||
which are the Q-side eigen-directions) so subspace overlap across layers can be
|
||||
computed downstream.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure(model_name: str, out_path: str, topk: int = 8,
|
||||
dtype=torch.bfloat16):
|
||||
print(f"Loading {model_name} ...", flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model.eval()
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim}", flush=True)
|
||||
|
||||
# Collect W_Q, W_K per layer as (num_heads, head_dim, hidden) on GPU float32.
|
||||
Wq_list = []
|
||||
Wk_list = []
|
||||
for L, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
Wq = attn.q_proj.weight.detach().to(torch.float32) # (nh*hd, hidden)
|
||||
Wk = attn.k_proj.weight.detach().to(torch.float32) # (nkv*hd, hidden)
|
||||
Wq = Wq.view(num_heads, head_dim, hidden)
|
||||
# Repeat kv heads so every query head has a matching k-row
|
||||
Wk = Wk.view(num_kv_heads, head_dim, hidden)
|
||||
# Broadcast to num_heads via (h // (num_heads // num_kv_heads))? safer: mapping
|
||||
Wk_full = torch.zeros(num_heads, head_dim, hidden,
|
||||
device=Wk.device, dtype=Wk.dtype)
|
||||
for h in range(num_heads):
|
||||
kv_h = (h * num_kv_heads) // num_heads
|
||||
Wk_full[h] = Wk[kv_h]
|
||||
Wq_list.append(Wq)
|
||||
Wk_list.append(Wk_full)
|
||||
print(f" loaded weights: {num_layers} layers", flush=True)
|
||||
|
||||
# Per-head top-k right singular vectors of g^h = W_K^T W_Q (hidden, hidden).
|
||||
# The non-zero right singular vectors of g lie in row-space(W_Q) ⊂ R^hidden.
|
||||
# For subspace comparison we need vectors in hidden-space.
|
||||
#
|
||||
# We also need SIGNED eigenvalues of the symmetric part (g + g^T)/2 to
|
||||
# determine curvature signs per eigen-direction. Since g has rank ≤ d_h,
|
||||
# (g + g^T) has rank ≤ 2 d_h, and we can compute its signed non-zero
|
||||
# eigenvalues via the Jordan-Wielandt-style trick:
|
||||
# eigs(X^T J X) = eigs(J X X^T) for X = [W_Q; W_K], J = [[0, I], [I, 0]].
|
||||
# The resulting 2d_h × 2d_h matrix gives us all non-zero eigenvalues of
|
||||
# (g + g^T) cheaply.
|
||||
topk_eff = min(topk, head_dim)
|
||||
eig_dirs = torch.zeros(num_layers, num_heads, topk_eff, hidden,
|
||||
dtype=torch.float32)
|
||||
fro_sq = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
sym_eigs = torch.zeros(num_layers, num_heads, 2 * head_dim,
|
||||
dtype=torch.float64) # signed
|
||||
for L in range(num_layers):
|
||||
for h in range(num_heads):
|
||||
Wq = Wq_list[L][h] # (hd, hidden)
|
||||
Wk = Wk_list[L][h] # (hd, hidden)
|
||||
small = Wk @ Wq.T # (hd, hd)
|
||||
U, S, Vh = torch.linalg.svd(small, full_matrices=False)
|
||||
dirs = Vh @ Wq # (hd, hidden)
|
||||
dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
eig_dirs[L, h] = dirs[:topk_eff].cpu()
|
||||
fro_sq[L, h] = float((S * S).sum())
|
||||
|
||||
# Signed eigenvalues of (g + g^T) via 2d_h × 2d_h matrix
|
||||
# J (X X^T) where X = [Wq; Wk] (stacked)
|
||||
XXT = torch.zeros(2 * head_dim, 2 * head_dim,
|
||||
device=Wq.device, dtype=Wq.dtype)
|
||||
XXT[:head_dim, :head_dim] = Wq @ Wq.T
|
||||
XXT[:head_dim, head_dim:] = Wq @ Wk.T
|
||||
XXT[head_dim:, :head_dim] = Wk @ Wq.T
|
||||
XXT[head_dim:, head_dim:] = Wk @ Wk.T
|
||||
# J matrix is off-diagonal block identity
|
||||
J = torch.zeros(2 * head_dim, 2 * head_dim,
|
||||
device=Wq.device, dtype=Wq.dtype)
|
||||
J[:head_dim, head_dim:] = torch.eye(head_dim,
|
||||
device=Wq.device, dtype=Wq.dtype)
|
||||
J[head_dim:, :head_dim] = torch.eye(head_dim,
|
||||
device=Wq.device, dtype=Wq.dtype)
|
||||
M = J @ XXT
|
||||
# M is not symmetric, but its non-zero eigenvalues are those of
|
||||
# (g + g^T)/2 times 2 → real (since (g + g^T) is symmetric).
|
||||
# Use general eigvals; imag parts should be near zero up to
|
||||
# numerical noise.
|
||||
ev = torch.linalg.eigvals(M)
|
||||
ev_real = ev.real.cpu().double()
|
||||
# sort by magnitude descending so top eigenvalues come first
|
||||
order = torch.argsort(ev_real.abs(), descending=True)
|
||||
sym_eigs[L, h] = ev_real[order]
|
||||
if L % 8 == 0:
|
||||
print(f" eigdecomp L={L}", flush=True)
|
||||
|
||||
# Gram matrix: gram[L, L', h] = <g_L^h, g_L'^h>.
|
||||
# Using A = W_K_L W_K_L'^T, B = W_Q_L W_Q_L'^T, <g, g'> = tr(A B^T) = sum(A * B).
|
||||
gram = torch.zeros(num_layers, num_layers, num_heads, dtype=torch.float64)
|
||||
for L in range(num_layers):
|
||||
for Lp in range(L, num_layers):
|
||||
for h in range(num_heads):
|
||||
Wq_L = Wq_list[L][h]
|
||||
Wk_L = Wk_list[L][h]
|
||||
Wq_Lp = Wq_list[Lp][h]
|
||||
Wk_Lp = Wk_list[Lp][h]
|
||||
A = Wk_L @ Wk_Lp.T # (hd, hd)
|
||||
B = Wq_L @ Wq_Lp.T # (hd, hd)
|
||||
v = float((A * B).sum())
|
||||
gram[L, Lp, h] = v
|
||||
gram[Lp, L, h] = v
|
||||
if L % 4 == 0:
|
||||
print(f" gram row L={L}", flush=True)
|
||||
|
||||
# Save
|
||||
out = {
|
||||
"model": model_name,
|
||||
"num_layers": num_layers,
|
||||
"num_heads": num_heads,
|
||||
"head_dim": head_dim,
|
||||
"hidden_size": hidden,
|
||||
"topk": topk_eff,
|
||||
"gram": gram.tolist(),
|
||||
"fro_sq": fro_sq.tolist(),
|
||||
}
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(out, f)
|
||||
torch.save({"eig_dirs": eig_dirs, "sym_eigs": sym_eigs},
|
||||
out_path.replace(".json", "-eigdirs.pt"))
|
||||
print(f"Wrote {out_path} and {out_path.replace('.json', '-eigdirs.pt')}",
|
||||
flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
||||
ap.add_argument("--out", default="/tmp/sa-grams.json")
|
||||
ap.add_argument("--topk", type=int, default=8)
|
||||
args = ap.parse_args()
|
||||
measure(args.model, args.out, topk=args.topk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
237
sa-schedule-null-residual.py
Normal file
237
sa-schedule-null-residual.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Null test: before any fitting, how similar are adjacent layers in the
|
||||
raw weight-matrix sense?
|
||||
|
||||
For each adjacent layer pair (L, L+1) and each parameter family:
|
||||
1. Normalize each matrix by its Frobenius norm (unit sphere).
|
||||
2. Compute cos-sim = <W_L, W_{L+1}> / (||W_L|| ||W_{L+1}||).
|
||||
3. Compute residual Δ = W_{L+1,norm} - W_{L,norm}; report ||Δ||_F
|
||||
(null-if-orthogonal = sqrt(2) ≈ 1.414; null-if-identical = 0).
|
||||
4. Report effective rank of Δ (via entropy of normalized spectrum).
|
||||
|
||||
Whole network, not just middle block. Plots cos-sim and residual-rank
|
||||
trajectories across depth.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def spec_entropy(singvals, eps=1e-12):
|
||||
p = (singvals ** 2)
|
||||
p = p / max(p.sum(), eps)
|
||||
p = np.clip(p, eps, 1.0)
|
||||
return float(-(p * np.log(p)).sum())
|
||||
|
||||
|
||||
def frob(x):
|
||||
return float(np.linalg.norm(x))
|
||||
|
||||
|
||||
def norm_mat(x, eps=1e-12):
|
||||
return x / max(frob(x), eps)
|
||||
|
||||
|
||||
def null_test_pair(A_dict, B_dict, family_names, num_heads, num_kv_heads, head_dim):
|
||||
"""For each family, compute cos-sim and normalized residual between
|
||||
adjacent layers. Returns dict of per-family stats."""
|
||||
out = {}
|
||||
for fam in family_names:
|
||||
if fam not in A_dict or fam not in B_dict:
|
||||
continue
|
||||
Wa = A_dict[fam]
|
||||
Wb = B_dict[fam]
|
||||
if Wa.shape != Wb.shape:
|
||||
continue
|
||||
fa = frob(Wa)
|
||||
fb = frob(Wb)
|
||||
if fa < 1e-12 or fb < 1e-12:
|
||||
continue
|
||||
cos = float((Wa * Wb).sum() / (fa * fb))
|
||||
resid_norm_sq = 2.0 - 2.0 * cos # ||Wa/|| - Wb/|| ||^2
|
||||
resid_norm = float(np.sqrt(max(resid_norm_sq, 0.0)))
|
||||
|
||||
# Skip residual SVD — was bottleneck on large matrices; cos-sim
|
||||
# + scalar fit give us the main signal. Can add back selectively.
|
||||
eff_rank = None
|
||||
se = None
|
||||
|
||||
out[fam] = {
|
||||
"cos": cos,
|
||||
"resid_norm": resid_norm,
|
||||
"resid_eff_rank": eff_rank,
|
||||
"resid_spec_entropy": se,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-4B")
|
||||
ap.add_argument("--out", default="/tmp/sa-null-residual.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(f"Loading {args.model} ...", flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
torch_dtype=torch.bfloat16, # halve memory vs fp32
|
||||
device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
intermediate = cfg.intermediate_size
|
||||
print(f" L={num_layers} H={num_heads} kv={num_kv_heads} hd={head_dim} "
|
||||
f"hidden={hidden} ff={intermediate}", flush=True)
|
||||
|
||||
families = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
"input_ln", "post_attn_ln", "q_norm", "k_norm"]
|
||||
|
||||
layers = {}
|
||||
for L in range(num_layers):
|
||||
layer = model.model.layers[L]
|
||||
attn = layer.self_attn
|
||||
mlp = layer.mlp
|
||||
entry = {
|
||||
"q_proj": attn.q_proj.weight.detach().float().numpy(),
|
||||
"k_proj": attn.k_proj.weight.detach().float().numpy(),
|
||||
"v_proj": attn.v_proj.weight.detach().float().numpy(),
|
||||
"o_proj": attn.o_proj.weight.detach().float().numpy(),
|
||||
"gate_proj": mlp.gate_proj.weight.detach().float().numpy(),
|
||||
"up_proj": mlp.up_proj.weight.detach().float().numpy(),
|
||||
"down_proj": mlp.down_proj.weight.detach().float().numpy(),
|
||||
"input_ln": layer.input_layernorm.weight.detach().float().numpy(),
|
||||
"post_attn_ln": layer.post_attention_layernorm.weight.detach().float().numpy(),
|
||||
}
|
||||
qn = getattr(attn, "q_norm", None)
|
||||
kn = getattr(attn, "k_norm", None)
|
||||
if qn is not None:
|
||||
entry["q_norm"] = qn.weight.detach().float().numpy()
|
||||
if kn is not None:
|
||||
entry["k_norm"] = kn.weight.detach().float().numpy()
|
||||
layers[L] = entry
|
||||
|
||||
del model
|
||||
|
||||
# Also record per-layer scale (Frobenius norm) for the scale-track PCA
|
||||
scales = {fam: [] for fam in families}
|
||||
for L in range(num_layers):
|
||||
for fam in families:
|
||||
if fam in layers[L]:
|
||||
scales[fam].append(frob(layers[L][fam]))
|
||||
else:
|
||||
scales[fam].append(None)
|
||||
|
||||
# Pairwise null test
|
||||
pair_results = []
|
||||
for L in range(num_layers - 1):
|
||||
r = null_test_pair(layers[L], layers[L + 1], families,
|
||||
num_heads, num_kv_heads, head_dim)
|
||||
pair_results.append({"L": L, "L_next": L + 1, "families": r})
|
||||
|
||||
# Report
|
||||
print("\n=== Adjacent-layer raw cos-sim per family ===")
|
||||
print(" null interpretation: 1.0 = identical matrices up to scale, 0 = orthogonal")
|
||||
print(f"\n {'L':>3}", end="")
|
||||
for fam in families:
|
||||
if any(fam in pr["families"] for pr in pair_results):
|
||||
print(f" {fam:>12}", end="")
|
||||
print()
|
||||
for pr in pair_results:
|
||||
print(f" {pr['L']:>3}", end="")
|
||||
for fam in families:
|
||||
if fam in pr["families"]:
|
||||
print(f" {pr['families'][fam]['cos']:>+12.4f}", end="")
|
||||
else:
|
||||
print(f" {'':>12}", end="")
|
||||
print()
|
||||
|
||||
# Summary per family + scalar-T fit comparison
|
||||
# raw_resid = sqrt(2 - 2*cos); scalar_fit = sqrt(1 - cos²) = sin(angle).
|
||||
# random_baseline = sqrt(2) ≈ 1.414.
|
||||
print("\n=== Per-family summary (across all adjacent pairs) ===")
|
||||
print(" random baseline = sqrt(2) ≈ 1.414 (what we'd see with no relationship)")
|
||||
print(f"\n {'family':>14} {'mean_cos':>10} {'median_cos':>11} "
|
||||
f"{'raw_resid':>10} {'scalar_fit':>11} {'improve_frac':>13} {'mean_SE':>8}")
|
||||
for fam in families:
|
||||
cs = [pr["families"].get(fam, {}).get("cos") for pr in pair_results]
|
||||
cs = [x for x in cs if x is not None]
|
||||
rs = [pr["families"].get(fam, {}).get("resid_norm") for pr in pair_results]
|
||||
rs = [x for x in rs if x is not None]
|
||||
ers = [pr["families"].get(fam, {}).get("resid_eff_rank") for pr in pair_results]
|
||||
ers = [x for x in ers if x is not None]
|
||||
ses = [pr["families"].get(fam, {}).get("resid_spec_entropy") for pr in pair_results]
|
||||
ses = [x for x in ses if x is not None]
|
||||
if not cs:
|
||||
continue
|
||||
raw = np.sqrt(np.maximum(2.0 - 2.0 * np.array(cs), 0.0)).mean()
|
||||
scalar_fit = np.sqrt(np.maximum(1.0 - np.array(cs) ** 2, 0.0)).mean()
|
||||
# Improvement fraction: (raw - scalar_fit) / (raw - 0) normalized
|
||||
# to [0, 1] where 0 = scalar does nothing, 1 = scalar reconstructs.
|
||||
improve_frac = (raw - scalar_fit) / max(raw, 1e-12)
|
||||
print(f" {fam:>14} {np.mean(cs):>+10.4f} {np.median(cs):>+11.4f} "
|
||||
f"{raw:>10.4f} {scalar_fit:>11.4f} {improve_frac:>13.4f} "
|
||||
f"{np.mean(ses) if ses else 0:>8.4f}")
|
||||
|
||||
# Scale-track: Frobenius norm of each family across layers
|
||||
print("\n=== Scale track: ||W_family||_F across layers ===")
|
||||
print(f" {'L':>3}", end="")
|
||||
for fam in families:
|
||||
if any(s is not None for s in scales[fam]):
|
||||
print(f" {fam:>12}", end="")
|
||||
print()
|
||||
for L in range(num_layers):
|
||||
print(f" {L:>3}", end="")
|
||||
for fam in families:
|
||||
if scales[fam][L] is not None:
|
||||
print(f" {scales[fam][L]:>12.4f}", end="")
|
||||
else:
|
||||
print(f" {'':>12}", end="")
|
||||
print()
|
||||
|
||||
# PCA of log-scale-track to see dimensionality of schedule
|
||||
print("\n=== PCA of log-scale-track (dimensionality of schedule) ===")
|
||||
scale_matrix = []
|
||||
fam_used = []
|
||||
for fam in families:
|
||||
vals = scales[fam]
|
||||
if all(v is not None for v in vals):
|
||||
scale_matrix.append(np.log(np.array(vals)))
|
||||
fam_used.append(fam)
|
||||
scale_matrix = np.array(scale_matrix) # (num_families, L)
|
||||
# Center per-family
|
||||
sm_c = scale_matrix - scale_matrix.mean(axis=1, keepdims=True)
|
||||
# SVD: columns are layers, rows are families
|
||||
U, S, Vh = np.linalg.svd(sm_c, full_matrices=False)
|
||||
total = (S ** 2).sum()
|
||||
print(f" explained variance by mode:")
|
||||
for i, s in enumerate(S):
|
||||
pct = float(s ** 2 / max(total, 1e-20)) * 100
|
||||
print(f" mode {i+1:>2}: {pct:>6.2f}% "
|
||||
f"(loadings per family: "
|
||||
f"{', '.join(f'{fam_used[j]}={U[j, i]:+.2f}' for j in range(len(fam_used)))})")
|
||||
|
||||
# Save
|
||||
with open(args.out, "w") as f:
|
||||
json.dump({
|
||||
"model": args.model,
|
||||
"pair_results": pair_results,
|
||||
"scales": scales,
|
||||
"scale_pca_singvals": S.tolist(),
|
||||
"scale_pca_loadings": U.tolist(),
|
||||
"scale_pca_scores": (np.diag(S) @ Vh).tolist(),
|
||||
"fam_used": fam_used,
|
||||
}, f, indent=2)
|
||||
print(f"\nSaved: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
246
sa-schedule-readout-measure.py
Normal file
246
sa-schedule-readout-measure.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
SA schedule readout for a dense softmax-attention LLM (Qwen3-8B by default).
|
||||
|
||||
Measures per-layer "temperature" signals:
|
||||
- entropy of softmax attention (per head, aggregated)
|
||||
- magnitude of pre-softmax logits (implicit sharpness)
|
||||
- spectrum of the parameter metric g_L^h = W_K^h^T W_Q^h (static, no forward pass needed)
|
||||
|
||||
Output:
|
||||
stats.json — numeric summary per layer / head
|
||||
activations stats by layer accumulated across a calibration set
|
||||
|
||||
Goal:
|
||||
Compare entropy(L) (dynamic readout) against static spectrum of g_L (parameter-only
|
||||
prediction). Agreement => schedule is parameter-intrinsic and a scalar per-iteration
|
||||
T suffices. Disagreement => content-adaptive structure lives in the activations.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
CALIBRATION_PROMPTS = [
|
||||
# general knowledge
|
||||
"The Eiffel Tower is located in",
|
||||
"Photosynthesis is the process by which",
|
||||
"The three branches of the US government are",
|
||||
# math / reasoning
|
||||
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||||
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||||
"The derivative of x^3 + 2x^2 is",
|
||||
# code
|
||||
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||||
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||||
"SELECT name, age FROM users WHERE",
|
||||
# narrative / long-form
|
||||
"She opened the old wooden box and found",
|
||||
"The argument in favor of renewable energy is",
|
||||
# chat / instruction
|
||||
"User: What is the capital of Australia?\nAssistant:",
|
||||
"Write a haiku about autumn:\n",
|
||||
# factual / lookup
|
||||
"Albert Einstein was born in the year",
|
||||
"The speed of light in vacuum is approximately",
|
||||
# conversational
|
||||
"I really loved that movie because",
|
||||
"The main difference between a virus and a bacterium is",
|
||||
# translation-ish
|
||||
"The French word for 'apple' is",
|
||||
# edge cases
|
||||
"1 + 1 = ",
|
||||
"Once upon a time, in a land far away,",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure_model(model_name: str, out_path: str, max_seq_len: int = 256, dtype=torch.bfloat16):
|
||||
print(f"Loading {model_name} ...", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager", # need raw attention probabilities
|
||||
)
|
||||
model.eval()
|
||||
|
||||
cfg = model.config
|
||||
num_layers = cfg.num_hidden_layers
|
||||
num_heads = cfg.num_attention_heads
|
||||
hidden = cfg.hidden_size
|
||||
head_dim = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
print(f" num_hidden_layers={num_layers} num_attention_heads={num_heads} "
|
||||
f"num_kv_heads={num_kv_heads} head_dim={head_dim} hidden_size={hidden}",
|
||||
flush=True)
|
||||
|
||||
# ---- Static (parameter-only) readout ----
|
||||
# Per layer, per head h, compute the metric g^h = W_K^h^T W_Q^h (shape head_dim x head_dim)
|
||||
# and record its singular spectrum. Metric norm is our "static temperature" prediction.
|
||||
# With grouped-query attention, each query head shares a KV head; we compute metric per
|
||||
# query head using the shared KV head.
|
||||
static_stats = []
|
||||
for L, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
W_Q = attn.q_proj.weight.detach().float().cpu() # (num_heads*head_dim, hidden)
|
||||
W_K = attn.k_proj.weight.detach().float().cpu() # (num_kv_heads*head_dim, hidden)
|
||||
|
||||
per_head_metric_fro = []
|
||||
per_head_metric_op = []
|
||||
per_head_metric_singvals = []
|
||||
for h in range(num_heads):
|
||||
kv_h = (h * num_kv_heads) // num_heads
|
||||
wq_h = W_Q[h * head_dim:(h + 1) * head_dim] # (head_dim, hidden)
|
||||
wk_h = W_K[kv_h * head_dim:(kv_h + 1) * head_dim] # (head_dim, hidden)
|
||||
# metric on hidden space: M = W_K^h^T W_Q^h shape (hidden, hidden).
|
||||
# But we only need its non-zero spectrum; equivalently SVD of wk_h^T @ wq_h,
|
||||
# or simpler: singular values of (wk_h @ wq_h.T) which is head_dim x head_dim.
|
||||
small = wk_h @ wq_h.T # (head_dim, head_dim)
|
||||
s = torch.linalg.svdvals(small) # (head_dim,)
|
||||
per_head_metric_fro.append(float(s.pow(2).sum().sqrt()))
|
||||
per_head_metric_op.append(float(s.max()))
|
||||
per_head_metric_singvals.append(s.tolist())
|
||||
static_stats.append({
|
||||
"layer": L,
|
||||
"metric_fro_per_head": per_head_metric_fro,
|
||||
"metric_op_per_head": per_head_metric_op,
|
||||
"metric_singvals_per_head": per_head_metric_singvals,
|
||||
})
|
||||
if L % 8 == 0:
|
||||
print(f" static layer {L}: mean op-norm over heads = "
|
||||
f"{sum(per_head_metric_op)/len(per_head_metric_op):.3f}",
|
||||
flush=True)
|
||||
|
||||
# ---- Dynamic (activation) readout ----
|
||||
# Hook each attention layer with output_attentions. Per layer, per head, accumulate
|
||||
# sum of attention entropy and sum of pre-softmax logit magnitude across the calibration set.
|
||||
acc_entropy = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_logit_mag = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_logit_var = torch.zeros(num_layers, num_heads, dtype=torch.float64)
|
||||
acc_n_positions = torch.zeros(num_layers, dtype=torch.float64)
|
||||
|
||||
# The simplest path: run with output_attentions=True; eager impl returns attn probs.
|
||||
# We cannot get pre-softmax logits from the HF API directly; extract them manually
|
||||
# via a forward-pre-hook that snapshots Q and K, compute Q@K^T / sqrt(head_dim), and
|
||||
# compare against attention_mask (we care about unmasked positions only).
|
||||
|
||||
captured = {}
|
||||
|
||||
def make_hook(layer_idx):
|
||||
def hook(module, inp, out):
|
||||
# eager attention returns (attn_output, attn_weights, past_key_value)
|
||||
# attn_weights has shape (bsz, num_heads, q_len, k_len)
|
||||
if isinstance(out, tuple) and len(out) >= 2 and out[1] is not None:
|
||||
captured[layer_idx] = out[1].detach()
|
||||
else:
|
||||
captured[layer_idx] = None
|
||||
return hook
|
||||
|
||||
hooks = []
|
||||
for L, layer in enumerate(model.model.layers):
|
||||
h = layer.self_attn.register_forward_hook(make_hook(L))
|
||||
hooks.append(h)
|
||||
|
||||
for i, prompt in enumerate(CALIBRATION_PROMPTS):
|
||||
inp = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_len).to("cuda")
|
||||
captured.clear()
|
||||
_ = model(**inp, output_attentions=True, use_cache=False)
|
||||
seq_len = inp["input_ids"].shape[1]
|
||||
|
||||
for L in range(num_layers):
|
||||
aw = captured.get(L, None)
|
||||
if aw is None:
|
||||
continue
|
||||
# aw: (1, num_heads, q_len, k_len), softmax over last dim with causal mask
|
||||
# entropy: -sum p log p over last dim. Positions with fewer valid keys have
|
||||
# naturally lower max entropy; we average over positions anyway.
|
||||
p = aw.float().squeeze(0) # (num_heads, q_len, k_len)
|
||||
eps = 1e-12
|
||||
ent = -(p * (p + eps).log()).sum(dim=-1) # (num_heads, q_len)
|
||||
acc_entropy[L] += ent.mean(dim=-1).cpu().double()
|
||||
|
||||
# Back out the logits. For causal softmax, logit_ij = log p_ij + c(i) for some
|
||||
# row constant c(i); we can recover up to row constant by log p (masking zeros).
|
||||
# To get a usable logit magnitude, we take the (unmasked) per-row std.
|
||||
logp = (p + eps).log() # (num_heads, q_len, k_len)
|
||||
# mask invalid keys (p==0 means masked)
|
||||
valid = (p > 0).float()
|
||||
denom = valid.sum(dim=-1).clamp_min(1)
|
||||
mean_logp = (logp * valid).sum(dim=-1) / denom
|
||||
centered = (logp - mean_logp.unsqueeze(-1)) * valid
|
||||
var_logp = (centered.pow(2).sum(dim=-1) / denom)
|
||||
# per-row std of logits is a direct readout of logit magnitude (== sharpness)
|
||||
row_std = var_logp.clamp_min(0).sqrt() # (num_heads, q_len)
|
||||
acc_logit_mag[L] += row_std.mean(dim=-1).cpu().double()
|
||||
acc_logit_var[L] += var_logp.mean(dim=-1).cpu().double()
|
||||
|
||||
acc_n_positions += 1 # once per prompt
|
||||
|
||||
if i % 5 == 0:
|
||||
print(f" prompt {i+1}/{len(CALIBRATION_PROMPTS)} len={seq_len}", flush=True)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Normalize by number of prompts (all contributed 1 sample per layer/head)
|
||||
n = max(len(CALIBRATION_PROMPTS), 1)
|
||||
mean_entropy = (acc_entropy / n).tolist()
|
||||
mean_logit_mag = (acc_logit_mag / n).tolist()
|
||||
mean_logit_var = (acc_logit_var / n).tolist()
|
||||
|
||||
# Assemble output
|
||||
dynamic_stats = []
|
||||
for L in range(num_layers):
|
||||
dynamic_stats.append({
|
||||
"layer": L,
|
||||
"mean_attention_entropy_per_head": mean_entropy[L],
|
||||
"mean_logit_std_per_head": mean_logit_mag[L],
|
||||
"mean_logit_var_per_head": mean_logit_var[L],
|
||||
"mean_attention_entropy": sum(mean_entropy[L]) / num_heads,
|
||||
"mean_logit_std": sum(mean_logit_mag[L]) / num_heads,
|
||||
})
|
||||
|
||||
output = {
|
||||
"model": model_name,
|
||||
"num_layers": num_layers,
|
||||
"num_heads": num_heads,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"hidden_size": hidden,
|
||||
"n_prompts": len(CALIBRATION_PROMPTS),
|
||||
"static": static_stats,
|
||||
"dynamic": dynamic_stats,
|
||||
}
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(output, f, indent=2)
|
||||
print(f"\nWrote {out_path}", flush=True)
|
||||
|
||||
# Quick summary to console
|
||||
print("\nPer-layer schedule readout (averaged over heads):")
|
||||
print(f" {'L':>3} {'mean_entropy':>14} {'mean_logit_std':>16} {'mean_metric_op':>16}")
|
||||
for L in range(num_layers):
|
||||
mean_op = sum(static_stats[L]["metric_op_per_head"]) / num_heads
|
||||
print(f" {L:>3} "
|
||||
f"{dynamic_stats[L]['mean_attention_entropy']:>14.4f} "
|
||||
f"{dynamic_stats[L]['mean_logit_std']:>16.4f} "
|
||||
f"{mean_op:>16.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="Qwen/Qwen3-8B")
|
||||
parser.add_argument("--out", default="/tmp/sa-schedule-readout.json")
|
||||
parser.add_argument("--max-seq-len", type=int, default=256)
|
||||
args = parser.parse_args()
|
||||
measure_model(args.model, args.out, max_seq_len=args.max_seq_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
498
sa-schedule-topblock-swap.py
Normal file
498
sa-schedule-topblock-swap.py
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
"""Top-block replacement experiment: test SA-schedule hypothesis by
|
||||
replacing the last 8 layers of Qwen3-4B with variants that progressively
|
||||
strip out the learned schedule / specialization.
|
||||
|
||||
Variants:
|
||||
baseline — unmodified reference (PPL sanity check)
|
||||
schedule_fit — replace input_ln.γ magnitude in top block with
|
||||
fitted Kirkpatrick γ(L) = 3.53·exp(0.119·L). Directions
|
||||
preserved, projection weights untouched.
|
||||
single_op — use layer 35's projection weights for ALL top-block
|
||||
layers (strip specialization), combined with the fitted
|
||||
schedule γ(L). Tests if per-layer specialization in top
|
||||
block is load-bearing or replaceable by schedule.
|
||||
uniform_gamma — set all top-block input_ln.γ magnitudes to the middle
|
||||
layer's value (no schedule at all in top block). Tests
|
||||
necessity of schedule itself.
|
||||
|
||||
Eval: perplexity on a concatenation of calibration prompts + a short
|
||||
excerpt. Also generation quality on a handful of diagnostic prompts.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# From sa-schedule-fit-gamma.py on Qwen3-4B null-residual data:
|
||||
# input_ln.γ magnitude ≈ 3.53 · exp(0.119 · L), R² = 0.95
|
||||
# Defaults for 4B. Override via env SCHEDULE_A / SCHEDULE_B for other models.
|
||||
# 32B fit: a=1.02, b=0.0873
|
||||
SCHEDULE_A = float(os.environ.get("SCHEDULE_A", "3.53")) if "SCHEDULE_A" in os.environ else 3.53
|
||||
SCHEDULE_B = float(os.environ.get("SCHEDULE_B", "0.1191")) if "SCHEDULE_B" in os.environ else 0.1191
|
||||
|
||||
BLOCK_START = int(os.environ.get("BLOCK_START", 28))
|
||||
BLOCK_END = int(os.environ.get("BLOCK_END", 35))
|
||||
# Optional: comma-separated "s1-e1,s2-e2,..." blocks for multi-block merge
|
||||
BLOCKS_ENV = os.environ.get("BLOCKS", "")
|
||||
if BLOCKS_ENV:
|
||||
BLOCKS = [tuple(int(x) for x in p.split("-")) for p in BLOCKS_ENV.split(",")]
|
||||
else:
|
||||
BLOCKS = [(BLOCK_START, BLOCK_END)]
|
||||
|
||||
CALIB = [
|
||||
"The Eiffel Tower is located in",
|
||||
"Photosynthesis is the process by which",
|
||||
"The three branches of the US government are the legislative, executive, and",
|
||||
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||||
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||||
"The derivative of x^3 + 2x^2 is",
|
||||
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||||
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||||
"SELECT name, age FROM users WHERE",
|
||||
"She opened the old wooden box and found",
|
||||
"The argument in favor of renewable energy is",
|
||||
"User: What is the capital of Australia?\nAssistant:",
|
||||
"Write a haiku about autumn:\n",
|
||||
"Albert Einstein was born in the year",
|
||||
"The speed of light in vacuum is approximately",
|
||||
"I really loved that movie because",
|
||||
"The main difference between a virus and a bacterium is",
|
||||
"The French word for 'apple' is",
|
||||
"1 + 1 = ",
|
||||
"Once upon a time, in a land far away,",
|
||||
"The key insight of general relativity is that gravity is not a force but",
|
||||
"Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher",
|
||||
"In object-oriented programming, encapsulation refers to",
|
||||
"The mitochondria is often called the powerhouse of the cell because it",
|
||||
"Shakespeare's Hamlet begins with the famous line",
|
||||
]
|
||||
|
||||
GEN_PROMPTS = [
|
||||
"The capital of France is",
|
||||
"2 + 2 =",
|
||||
"def reverse_string(s):\n return",
|
||||
"Albert Einstein developed the theory of",
|
||||
]
|
||||
|
||||
|
||||
def load_model(name=None):
|
||||
if name is None:
|
||||
name = os.environ.get("MODEL", "Qwen/Qwen3-4B")
|
||||
print(f"Loading {name}...", flush=True)
|
||||
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
name, torch_dtype=torch.bfloat16, device_map="cuda",
|
||||
trust_remote_code=True, attn_implementation="eager",
|
||||
)
|
||||
m.eval()
|
||||
return m, tok
|
||||
|
||||
|
||||
def _merge_block(model, block_start, block_end):
|
||||
"""Arithmetic-mean merge projections in [block_start, block_end]; set γ per schedule."""
|
||||
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
||||
param_names = [
|
||||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||||
]
|
||||
merged = {}
|
||||
for name, getter in param_names:
|
||||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||||
for l in layers:
|
||||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||||
for L in range(block_start, block_end + 1):
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||||
gamma.mul_(predicted / gamma.norm().item())
|
||||
|
||||
|
||||
def _procrustes(M):
|
||||
"""Orthogonal R = U V^T maximizing tr(R M) where M = U Σ V^T."""
|
||||
U, _, Vh = torch.linalg.svd(M.float(), full_matrices=False)
|
||||
return U @ Vh
|
||||
|
||||
|
||||
def _aligned_merge_block(model, block_start, block_end, align_ff=False):
|
||||
"""Procrustes-align per-head d_h basis (and optionally d_ff) of each
|
||||
layer in [block_start, block_end] to a reference (middle), then
|
||||
arithmetic-mean. Attention rotation is a true gauge; FF rotation is
|
||||
not (SiLU breaks it) — align_ff defaults off."""
|
||||
cfg = model.config
|
||||
num_heads = cfg.num_attention_heads
|
||||
num_kv = getattr(cfg, "num_key_value_heads", num_heads)
|
||||
hidden = cfg.hidden_size
|
||||
d_h = getattr(cfg, "head_dim", hidden // num_heads)
|
||||
|
||||
ref_L = (block_start + block_end) // 2
|
||||
ref = model.model.layers[ref_L]
|
||||
dev = ref.self_attn.q_proj.weight.device
|
||||
dtype = ref.self_attn.q_proj.weight.dtype
|
||||
|
||||
# Reference views, fp32 on device
|
||||
Qr = ref.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
||||
Kr = ref.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||||
Vr = ref.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||||
Or = ref.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
||||
|
||||
if align_ff:
|
||||
d_ff = cfg.intermediate_size
|
||||
Gr = ref.mlp.gate_proj.weight.data.float()
|
||||
Ur = ref.mlp.up_proj.weight.data.float()
|
||||
Dr = ref.mlp.down_proj.weight.data.float()
|
||||
|
||||
rotated = []
|
||||
for L in range(block_start, block_end + 1):
|
||||
layer = model.model.layers[L]
|
||||
Q = layer.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
||||
K = layer.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||||
V = layer.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||||
O = layer.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
||||
|
||||
if L == ref_L:
|
||||
Q_new, K_new, V_new, O_new = Q.clone(), K.clone(), V.clone(), O.clone()
|
||||
else:
|
||||
Q_new = torch.empty_like(Q)
|
||||
K_new = torch.empty_like(K)
|
||||
V_new = torch.empty_like(V)
|
||||
O_new = torch.empty_like(O)
|
||||
for h in range(num_heads):
|
||||
kv_h = (h * num_kv) // num_heads
|
||||
# Cross-correlation: want R s.t. R @ Q ≈ Qr (row-space align).
|
||||
# For per-head (d_h, hidden): M = Qr @ Q.T + Kr @ K.T + Vr @ V.T + Or^T @ O
|
||||
# (Or, O are (hidden, d_h) per head)
|
||||
M = (Qr[h] @ Q[h].T
|
||||
+ Kr[kv_h] @ K[kv_h].T
|
||||
+ Vr[kv_h] @ V[kv_h].T
|
||||
+ Or[h].T @ O[h])
|
||||
R = _procrustes(M)
|
||||
Q_new[h] = R @ Q[h]
|
||||
K_new[kv_h] = R @ K[kv_h]
|
||||
V_new[kv_h] = R @ V[kv_h]
|
||||
O_new[h] = O[h] @ R.T
|
||||
|
||||
rotated.append({
|
||||
"q": Q_new.reshape(num_heads * d_h, hidden),
|
||||
"k": K_new.reshape(num_kv * d_h, hidden),
|
||||
"v": V_new.reshape(num_kv * d_h, hidden),
|
||||
"o": O_new.permute(1, 0, 2).reshape(hidden, num_heads * d_h),
|
||||
})
|
||||
|
||||
# Average rotated attention
|
||||
q_avg = torch.stack([r["q"] for r in rotated]).mean(0).to(dtype)
|
||||
k_avg = torch.stack([r["k"] for r in rotated]).mean(0).to(dtype)
|
||||
v_avg = torch.stack([r["v"] for r in rotated]).mean(0).to(dtype)
|
||||
o_avg = torch.stack([r["o"] for r in rotated]).mean(0).to(dtype)
|
||||
|
||||
# FF: naive mean (rotation gauge is fake through SiLU)
|
||||
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
||||
gate_avg = torch.stack([l.mlp.gate_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||||
up_avg = torch.stack([l.mlp.up_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||||
down_avg = torch.stack([l.mlp.down_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||||
|
||||
# q_norm/k_norm γ: copy from reference (they're basis-dependent; no clean average in rotated frame)
|
||||
ref_qn = ref.self_attn.q_norm.weight.data.clone() if getattr(ref.self_attn, "q_norm", None) is not None else None
|
||||
ref_kn = ref.self_attn.k_norm.weight.data.clone() if getattr(ref.self_attn, "k_norm", None) is not None else None
|
||||
|
||||
for l in layers:
|
||||
l.self_attn.q_proj.weight.data.copy_(q_avg)
|
||||
l.self_attn.k_proj.weight.data.copy_(k_avg)
|
||||
l.self_attn.v_proj.weight.data.copy_(v_avg)
|
||||
l.self_attn.o_proj.weight.data.copy_(o_avg)
|
||||
l.mlp.gate_proj.weight.data.copy_(gate_avg)
|
||||
l.mlp.up_proj.weight.data.copy_(up_avg)
|
||||
l.mlp.down_proj.weight.data.copy_(down_avg)
|
||||
if ref_qn is not None and getattr(l.self_attn, "q_norm", None) is not None:
|
||||
l.self_attn.q_norm.weight.data.copy_(ref_qn)
|
||||
if ref_kn is not None and getattr(l.self_attn, "k_norm", None) is not None:
|
||||
l.self_attn.k_norm.weight.data.copy_(ref_kn)
|
||||
|
||||
for L in range(block_start, block_end + 1):
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||||
gamma.mul_(predicted / gamma.norm().item())
|
||||
|
||||
|
||||
def apply_variant(model, variant):
|
||||
"""Modify model in place according to variant."""
|
||||
if variant == "baseline":
|
||||
return
|
||||
|
||||
if variant == "schedule_fit":
|
||||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
layer = model.model.layers[L]
|
||||
gamma = layer.input_layernorm.weight.data
|
||||
cur_norm = gamma.norm().item()
|
||||
# Preserve direction, scale to predicted magnitude
|
||||
gamma.mul_(predicted / cur_norm)
|
||||
|
||||
elif variant == "single_op":
|
||||
# Use middle-of-block as reference, not end (more representative)
|
||||
ref_L = (BLOCK_START + BLOCK_END) // 2
|
||||
ref = model.model.layers[ref_L]
|
||||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||||
if L == ref_L:
|
||||
continue
|
||||
tgt = model.model.layers[L]
|
||||
tgt.self_attn.q_proj.weight.data.copy_(ref.self_attn.q_proj.weight.data)
|
||||
tgt.self_attn.k_proj.weight.data.copy_(ref.self_attn.k_proj.weight.data)
|
||||
tgt.self_attn.v_proj.weight.data.copy_(ref.self_attn.v_proj.weight.data)
|
||||
tgt.self_attn.o_proj.weight.data.copy_(ref.self_attn.o_proj.weight.data)
|
||||
tgt.mlp.gate_proj.weight.data.copy_(ref.mlp.gate_proj.weight.data)
|
||||
tgt.mlp.up_proj.weight.data.copy_(ref.mlp.up_proj.weight.data)
|
||||
tgt.mlp.down_proj.weight.data.copy_(ref.mlp.down_proj.weight.data)
|
||||
# q_norm, k_norm: copy too
|
||||
if hasattr(tgt.self_attn, "q_norm") and tgt.self_attn.q_norm is not None:
|
||||
tgt.self_attn.q_norm.weight.data.copy_(ref.self_attn.q_norm.weight.data)
|
||||
if hasattr(tgt.self_attn, "k_norm") and tgt.self_attn.k_norm is not None:
|
||||
tgt.self_attn.k_norm.weight.data.copy_(ref.self_attn.k_norm.weight.data)
|
||||
# Keep each layer's OWN input_ln.γ direction but set magnitude to schedule
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
gamma = tgt.input_layernorm.weight.data
|
||||
gamma.mul_(predicted / gamma.norm().item())
|
||||
# post_attn_ln γ: leave as-is for now (could also fit & set)
|
||||
|
||||
elif variant == "ties_op":
|
||||
# TIES-Merging (Yadav et al. 2023): trim, elect-sign, disjoint merge.
|
||||
# Operates per parameter family across the N block layers.
|
||||
density = float(os.environ.get("TIES_DENSITY", "0.2"))
|
||||
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
||||
param_names = [
|
||||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||||
]
|
||||
|
||||
def ties_merge(tensors, density):
|
||||
# tensors: list of (out, in) float tensors, same shape
|
||||
stack = torch.stack([t.float() for t in tensors], dim=0) # (N, out, in)
|
||||
# --- Step 1: Trim to top-density fraction per tensor ---
|
||||
n = stack.shape[0]
|
||||
flat = stack.view(n, -1)
|
||||
k = int(flat.shape[1] * density)
|
||||
abs_flat = flat.abs()
|
||||
# Find magnitude threshold per tensor at top-k
|
||||
topk_vals, _ = abs_flat.topk(k=k, dim=1)
|
||||
threshold = topk_vals[:, -1:].expand_as(abs_flat)
|
||||
mask = abs_flat >= threshold
|
||||
trimmed = (flat * mask.float()).view_as(stack)
|
||||
# --- Step 2: Elect sign (majority by total magnitude) ---
|
||||
mag_per_sign = trimmed.sum(dim=0) # (out, in), signed sum
|
||||
elected = torch.sign(mag_per_sign) # +1/-1/0
|
||||
# --- Step 3: Disjoint merge (average params agreeing with elected sign) ---
|
||||
agree = (torch.sign(trimmed) == elected.unsqueeze(0)).float()
|
||||
contributing_count = agree.sum(dim=0).clamp_min(1)
|
||||
merged_sum = (trimmed * agree).sum(dim=0)
|
||||
merged = merged_sum / contributing_count
|
||||
return merged
|
||||
|
||||
merged = {}
|
||||
for name, getter in param_names:
|
||||
tensors = [getter(l).data for l in layers]
|
||||
merged[name] = ties_merge(tensors, density).to(getter(layers[0]).data.dtype)
|
||||
|
||||
for l in layers:
|
||||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||||
|
||||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||||
gamma.mul_(predicted / gamma.norm().item())
|
||||
|
||||
elif variant == "merged_op":
|
||||
# Arithmetic mean, for each block in BLOCKS (can be multiple)
|
||||
for (bs, be) in BLOCKS:
|
||||
_merge_block(model, bs, be)
|
||||
return
|
||||
|
||||
elif variant == "aligned_merged_op":
|
||||
# Procrustes-align per-head d_h basis to block-middle, then mean.
|
||||
# FF averaged naively (SiLU breaks rotation gauge for FF).
|
||||
for (bs, be) in BLOCKS:
|
||||
_aligned_merge_block(model, bs, be, align_ff=False)
|
||||
return
|
||||
|
||||
elif variant == "flat_merged_op":
|
||||
# Mean projections AND flatten γ across block. Everything in block
|
||||
# becomes N copies of the same operator. If block is truly high-T
|
||||
# diffusion, PPL should match merged_op (schedule is gauge, not
|
||||
# load-bearing). If schedule helps, flattening γ will hurt.
|
||||
for (bs, be) in BLOCKS:
|
||||
layers = [model.model.layers[L] for L in range(bs, be + 1)]
|
||||
param_names = [
|
||||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||||
]
|
||||
merged = {}
|
||||
for name, getter in param_names:
|
||||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||||
gamma_mean = torch.stack([l.input_layernorm.weight.data.float()
|
||||
for l in layers]).mean(0).to(layers[0].input_layernorm.weight.data.dtype)
|
||||
post_attn_mean = torch.stack([l.post_attention_layernorm.weight.data.float()
|
||||
for l in layers]).mean(0).to(layers[0].post_attention_layernorm.weight.data.dtype)
|
||||
for l in layers:
|
||||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||||
l.input_layernorm.weight.data.copy_(gamma_mean)
|
||||
l.post_attention_layernorm.weight.data.copy_(post_attn_mean)
|
||||
return
|
||||
|
||||
elif variant == "reverse_order":
|
||||
# Reverse the order of layers within each block to test whether
|
||||
# the block implements a trajectory (order-dependent) or iid
|
||||
# diffusion (order-free).
|
||||
import torch.nn as nn
|
||||
layers_list = list(model.model.layers)
|
||||
for (bs, be) in BLOCKS:
|
||||
rev = layers_list[bs:be + 1][::-1]
|
||||
layers_list[bs:be + 1] = rev
|
||||
model.model.layers = nn.ModuleList(layers_list)
|
||||
# Re-set layer_idx on each layer so attention/cache uses the
|
||||
# current position, not the original one.
|
||||
for i, l in enumerate(model.model.layers):
|
||||
if hasattr(l, "self_attn") and hasattr(l.self_attn, "layer_idx"):
|
||||
l.self_attn.layer_idx = i
|
||||
return
|
||||
|
||||
elif variant == "merged_op_OLD_UNREACHABLE":
|
||||
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
||||
n = len(layers)
|
||||
param_names = [
|
||||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||||
]
|
||||
merged = {}
|
||||
for name, getter in param_names:
|
||||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||||
|
||||
for l in layers:
|
||||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||||
|
||||
# Set γ to scheduled values per layer
|
||||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||||
gamma.mul_(predicted / gamma.norm().item())
|
||||
|
||||
elif variant == "uniform_gamma":
|
||||
mid_L = (BLOCK_START + BLOCK_END) // 2
|
||||
mid_gamma = model.model.layers[mid_L].input_layernorm.weight.data.clone()
|
||||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||||
model.model.layers[L].input_layernorm.weight.data.copy_(mid_gamma)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown variant {variant}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def perplexity(model, tok, texts, max_len=512):
|
||||
total_nll = 0.0
|
||||
total_tok = 0
|
||||
for text in texts:
|
||||
enc = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to("cuda")
|
||||
if enc.input_ids.shape[1] < 2:
|
||||
continue
|
||||
out = model(**enc, labels=enc.input_ids)
|
||||
n = enc.input_ids.shape[1] - 1
|
||||
total_nll += float(out.loss.item()) * n
|
||||
total_tok += n
|
||||
return math.exp(total_nll / max(total_tok, 1))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_sample(model, tok, prompt, max_new=40):
|
||||
enc = tok(prompt, return_tensors="pt").to("cuda")
|
||||
out = model.generate(**enc, max_new_tokens=max_new, do_sample=False,
|
||||
pad_token_id=tok.eos_token_id)
|
||||
return tok.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
def run_variant(variant):
|
||||
model, tok = load_model()
|
||||
apply_variant(model, variant)
|
||||
print(f"\n=== variant: {variant} ===", flush=True)
|
||||
ppl = perplexity(model, tok, CALIB)
|
||||
print(f" perplexity: {ppl:.3f}", flush=True)
|
||||
for p in GEN_PROMPTS:
|
||||
out = generate_sample(model, tok, p)
|
||||
print(f" [{p!r}] -> {out[:200]!r}", flush=True)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
return ppl
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--variant", default="all",
|
||||
choices=["all", "baseline", "schedule_fit",
|
||||
"single_op", "uniform_gamma", "merged_op",
|
||||
"aligned_merged_op", "flat_merged_op",
|
||||
"reverse_order", "ties_op"])
|
||||
ap.add_argument("--ties-density", type=float, default=0.2,
|
||||
help="TIES trim density (fraction of top-magnitude params to keep)")
|
||||
args = ap.parse_args()
|
||||
|
||||
variants = (["baseline", "schedule_fit", "single_op", "uniform_gamma"]
|
||||
if args.variant == "all" else [args.variant])
|
||||
results = {}
|
||||
for v in variants:
|
||||
results[v] = run_variant(v)
|
||||
|
||||
if len(results) > 1:
|
||||
print("\n=== Summary ===")
|
||||
b = results.get("baseline", None)
|
||||
for v, ppl in results.items():
|
||||
rel = f" (×{ppl/b:.2f} baseline)" if b else ""
|
||||
print(f" {v:<15} PPL {ppl:>8.3f}{rel}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -59,7 +59,7 @@ const ACTIVITY_LINGER: std::time::Duration = std::time::Duration::from_secs(5);
|
|||
|
||||
impl Drop for ActivityGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut st) = self.agent.state.try_lock() {
|
||||
{ let mut st = self.agent.state.lock_blocking();
|
||||
if let Some(entry) = st.activities.iter_mut().find(|a| a.id == self.id) {
|
||||
entry.label.push_str(" (complete)");
|
||||
entry.expires_at = std::time::Instant::now() + ACTIVITY_LINGER;
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ async fn ensure_init(agent: Option<&std::sync::Arc<super::super::Agent>>) -> Res
|
|||
let msg = format!("MCP server {} failed: {:#}", cfg.name, e);
|
||||
dbglog!("{}", msg);
|
||||
if let Some(a) = agent {
|
||||
if let Ok(mut st) = a.state.try_lock() {
|
||||
{ let mut st = a.state.lock_blocking();
|
||||
st.notify(msg);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
17
src/locks.rs
17
src/locks.rs
|
|
@ -135,6 +135,23 @@ impl<T> TrackedMutex<T> {
|
|||
location,
|
||||
})
|
||||
}
|
||||
|
||||
/// Block the current thread until the lock is acquired.
|
||||
/// Safe to call from sync contexts (UI thread, slash commands) where
|
||||
/// .await isn't available. Uses block_in_place so the tokio runtime
|
||||
/// can schedule other tasks while we wait.
|
||||
#[track_caller]
|
||||
pub fn lock_blocking(&self) -> TrackedMutexGuard<'_, T> {
|
||||
let location = Location::caller();
|
||||
let guard = tokio::task::block_in_place(|| {
|
||||
futures::executor::block_on(self.inner.lock())
|
||||
});
|
||||
TrackedMutexGuard {
|
||||
guard,
|
||||
acquired_at: Instant::now(),
|
||||
location,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TrackedMutexGuard<'a, T> {
|
||||
|
|
|
|||
|
|
@ -104,6 +104,6 @@ async fn run(
|
|||
prior_context: render_prior_context(entries, entry_idx, 2),
|
||||
timestamp_ns: node_timestamp_ns(node),
|
||||
});
|
||||
if let Ok(st) = agent.state.try_lock() { st.changed.notify_one(); }
|
||||
{ let st = agent.state.lock_blocking(); st.changed.notify_one(); }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -736,7 +736,7 @@ async fn run_finetune(
|
|||
gen_alternates, &activity,
|
||||
move |c| {
|
||||
shared.lock().unwrap().finetune_candidates.push(c);
|
||||
if let Ok(st) = agent.state.try_lock() { st.changed.notify_one(); }
|
||||
{ let st = agent.state.lock_blocking(); st.changed.notify_one(); }
|
||||
},
|
||||
).await {
|
||||
Ok((above_threshold, max_div)) => FinetuneScoringStats {
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ fn commands() -> Vec<SlashCommand> { vec![
|
|||
handler: |s, _| { let _ = s.mind_tx.send(MindCommand::NewSession); } },
|
||||
SlashCommand { name: "/save", help: "Save session to disk",
|
||||
handler: |s, _| {
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() { ag.notify("saved"); }
|
||||
{ let mut ag = s.agent.state.lock_blocking(); ag.notify("saved"); }
|
||||
} },
|
||||
SlashCommand { name: "/model", help: "Show/switch model (/model <name>)",
|
||||
handler: |s, arg| {
|
||||
if arg.is_empty() {
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() {
|
||||
{ let mut ag = s.agent.state.lock_blocking();
|
||||
let names = s.agent.app_config.model_names();
|
||||
let label = if names.is_empty() {
|
||||
format!("model: {}", s.agent.model())
|
||||
|
|
@ -62,7 +62,7 @@ fn commands() -> Vec<SlashCommand> { vec![
|
|||
SlashCommand { name: "/dmn", help: "Show DMN state",
|
||||
handler: |s, _| {
|
||||
let st = s.shared_mind.lock().unwrap();
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() {
|
||||
{ let mut ag = s.agent.state.lock_blocking();
|
||||
ag.notify(format!("DMN: {:?} ({}/{})", st.dmn, st.dmn_turns, st.max_dmn_turns));
|
||||
}
|
||||
} },
|
||||
|
|
@ -71,7 +71,7 @@ fn commands() -> Vec<SlashCommand> { vec![
|
|||
let mut st = s.shared_mind.lock().unwrap();
|
||||
st.dmn = crate::mind::subconscious::State::Resting { since: std::time::Instant::now() };
|
||||
st.dmn_turns = 0;
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() { ag.notify("DMN sleeping"); }
|
||||
{ let mut ag = s.agent.state.lock_blocking(); ag.notify("DMN sleeping"); }
|
||||
} },
|
||||
SlashCommand { name: "/wake", help: "Wake DMN to foraging",
|
||||
handler: |s, _| {
|
||||
|
|
@ -79,14 +79,14 @@ fn commands() -> Vec<SlashCommand> { vec![
|
|||
if matches!(st.dmn, crate::mind::subconscious::State::Off) { crate::mind::subconscious::set_off(false); }
|
||||
st.dmn = crate::mind::subconscious::State::Foraging;
|
||||
st.dmn_turns = 0;
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() { ag.notify("DMN foraging"); }
|
||||
{ let mut ag = s.agent.state.lock_blocking(); ag.notify("DMN foraging"); }
|
||||
} },
|
||||
SlashCommand { name: "/pause", help: "Full stop — no autonomous ticks (Ctrl+P)",
|
||||
handler: |s, _| {
|
||||
let mut st = s.shared_mind.lock().unwrap();
|
||||
st.dmn = crate::mind::subconscious::State::Paused;
|
||||
st.dmn_turns = 0;
|
||||
if let Ok(mut ag) = s.agent.state.try_lock() { ag.notify("DMN paused"); }
|
||||
{ let mut ag = s.agent.state.lock_blocking(); ag.notify("DMN paused"); }
|
||||
} },
|
||||
SlashCommand { name: "/help", help: "Show this help",
|
||||
handler: |s, _| { notify_help(&s.agent); } },
|
||||
|
|
@ -116,7 +116,7 @@ pub async fn cmd_switch_model(
|
|||
}
|
||||
|
||||
fn notify_help(agent: &std::sync::Arc<crate::agent::Agent>) {
|
||||
if let Ok(mut ag) = agent.state.try_lock() {
|
||||
{ let mut ag = agent.state.lock_blocking();
|
||||
let mut help = String::new();
|
||||
for cmd in &commands() {
|
||||
help.push_str(&format!("{:12} {}\n", cmd.name, cmd.help));
|
||||
|
|
@ -581,16 +581,10 @@ impl InteractScreen {
|
|||
self.pending_display_count = 0;
|
||||
|
||||
let (generation, entries) = {
|
||||
let st = match self.agent.state.try_lock() {
|
||||
Ok(st) => st,
|
||||
Err(_) => return,
|
||||
};
|
||||
let st = self.agent.state.lock_blocking();
|
||||
let generation = st.generation;
|
||||
drop(st);
|
||||
let ctx = match self.agent.context.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return,
|
||||
};
|
||||
let ctx = self.agent.context.lock_blocking();
|
||||
(generation, ctx.conversation().to_vec())
|
||||
};
|
||||
|
||||
|
|
@ -654,7 +648,7 @@ impl InteractScreen {
|
|||
if let Some(cmd) = dispatch_command(input) {
|
||||
(cmd.handler)(self, &input[cmd.name.len()..].trim_start());
|
||||
} else {
|
||||
if let Ok(mut ag) = self.agent.state.try_lock() {
|
||||
{ let mut ag = self.agent.state.lock_blocking();
|
||||
ag.notify(format!("unknown: {}", input.split_whitespace().next().unwrap_or(input)));
|
||||
}
|
||||
}
|
||||
|
|
@ -770,9 +764,8 @@ impl InteractScreen {
|
|||
/// Draw the main (F1) screen — four-pane layout with status bar.
|
||||
fn draw_main(&mut self, frame: &mut Frame, size: Rect, app: &App) {
|
||||
// Main layout: content area + active tools overlay + status bar
|
||||
let st_guard = app.agent.state.try_lock().ok();
|
||||
let tool_lines = st_guard.as_ref()
|
||||
.map(|st| st.active_tools.len() as u16).unwrap_or(0);
|
||||
let st_guard = app.agent.state.lock_blocking();
|
||||
let tool_lines = st_guard.active_tools.len() as u16;
|
||||
let main_chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
|
|
@ -861,10 +854,9 @@ impl InteractScreen {
|
|||
frame.render_widget(gutter, input_chunks[0]);
|
||||
frame.render_widget(&self.textarea, input_chunks[1]);
|
||||
|
||||
if let Some(ref st) = st_guard {
|
||||
if !st.active_tools.is_empty() {
|
||||
if !st_guard.active_tools.is_empty() {
|
||||
let tool_style = Style::default().fg(Color::Yellow).add_modifier(Modifier::DIM);
|
||||
let tool_text: Vec<Line> = st.active_tools.iter().map(|t| {
|
||||
let tool_text: Vec<Line> = st_guard.active_tools.iter().map(|t| {
|
||||
let elapsed = t.started.elapsed().as_secs();
|
||||
let line = if t.detail.is_empty() {
|
||||
format!(" [{}] ({}s)", t.name, elapsed)
|
||||
|
|
@ -875,7 +867,7 @@ impl InteractScreen {
|
|||
}).collect();
|
||||
let tool_para = Paragraph::new(tool_text);
|
||||
frame.render_widget(tool_para, tools_overlay_area);
|
||||
}}
|
||||
}
|
||||
|
||||
// Draw status bar with live activity indicator
|
||||
let timer = if !app.activity.is_empty() {
|
||||
|
|
@ -1026,7 +1018,7 @@ impl ScreenView for InteractScreen {
|
|||
self.sync_from_agent();
|
||||
|
||||
// Read status from agent + mind state
|
||||
if let Ok(mut st) = self.agent.state.try_lock() {
|
||||
{ let mut st = self.agent.state.lock_blocking();
|
||||
st.expire_activities();
|
||||
app.status.prompt_tokens = st.last_prompt_tokens;
|
||||
app.status.model = self.agent.model().to_string();
|
||||
|
|
@ -1036,7 +1028,7 @@ impl ScreenView for InteractScreen {
|
|||
app.activity_started = st.activities.last()
|
||||
.map(|a| a.started);
|
||||
}
|
||||
if let Ok(ctx) = self.agent.context.try_lock() {
|
||||
{ let ctx = self.agent.context.lock_blocking();
|
||||
let window = crate::agent::context::context_window();
|
||||
if window > 0 {
|
||||
let sys = ctx.system().iter().map(|n| n.tokens()).sum::<usize>();
|
||||
|
|
|
|||
|
|
@ -20,10 +20,7 @@ impl ConsciousScreen {
|
|||
}
|
||||
|
||||
fn read_context_views(&self) -> Vec<SectionView> {
|
||||
let ctx = match self.agent.context.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let ctx = self.agent.context.lock_blocking();
|
||||
|
||||
let mut views: Vec<SectionView> = Vec::new();
|
||||
|
||||
|
|
@ -161,8 +158,7 @@ impl ScreenView for ConsciousScreen {
|
|||
)));
|
||||
lines.push(Line::raw(format!(" Reasoning: {}", app.reasoning_effort)));
|
||||
lines.push(Line::raw(format!(" Running processes: {}", app.running_processes)));
|
||||
let tool_count = app.agent.state.try_lock()
|
||||
.map(|st| st.active_tools.len()).unwrap_or(0);
|
||||
let tool_count = { let st = app.agent.state.lock_blocking(); st.active_tools.len() };
|
||||
lines.push(Line::raw(format!(" Active tools: {}", tool_count)));
|
||||
|
||||
let block = pane_block("context")
|
||||
|
|
|
|||
|
|
@ -292,7 +292,7 @@ async fn start(cli: crate::user::CliArgs) -> Result<()> {
|
|||
}
|
||||
|
||||
fn hotkey_cycle_reasoning(mind: &crate::mind::Mind) {
|
||||
if let Ok(mut ag) = mind.agent.state.try_lock() {
|
||||
{ let mut ag = mind.agent.state.lock_blocking();
|
||||
let next = match ag.reasoning_effort.as_str() {
|
||||
"none" => "low",
|
||||
"low" => "high",
|
||||
|
|
@ -344,7 +344,7 @@ fn hotkey_cycle_autonomy(mind: &crate::mind::Mind) {
|
|||
};
|
||||
s.dmn_turns = 0;
|
||||
drop(s);
|
||||
if let Ok(mut ag) = mind.agent.state.try_lock() {
|
||||
{ let mut ag = mind.agent.state.lock_blocking();
|
||||
ag.notify(format!("DMN → {}", label));
|
||||
}
|
||||
}
|
||||
|
|
@ -419,7 +419,7 @@ async fn run(
|
|||
|
||||
terminal.hide_cursor()?;
|
||||
|
||||
if let Ok(mut ag) = agent.state.try_lock() { ag.notify("consciousness v0.3"); }
|
||||
{ let mut ag = agent.state.lock_blocking(); ag.notify("consciousness v0.3"); }
|
||||
|
||||
// Initial render
|
||||
{
|
||||
|
|
@ -526,7 +526,7 @@ async fn run(
|
|||
}
|
||||
app.walked_count = mind.subconscious_walked().await.len();
|
||||
if !startup_done {
|
||||
if let Ok(mut ag) = agent.state.try_lock() {
|
||||
{ let mut ag = agent.state.lock_blocking();
|
||||
let model = agent.model().to_string();
|
||||
ag.notify(format!("model: {}", model));
|
||||
startup_done = true;
|
||||
|
|
@ -545,7 +545,7 @@ async fn run(
|
|||
if let Some(rx_mutex) = STDERR_RX.get() {
|
||||
if let Ok(rx) = rx_mutex.try_lock() {
|
||||
while let Ok(line) = rx.try_recv() {
|
||||
if let Ok(mut ag) = agent.state.try_lock() {
|
||||
{ let mut ag = agent.state.lock_blocking();
|
||||
ag.notify(format!("stderr: {}", line));
|
||||
dirty = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -222,31 +222,30 @@ impl SubconsciousScreen {
|
|||
let fork_point = app.agent_state.get(self.selected())
|
||||
.map(|s| s.fork_point).unwrap_or(0);
|
||||
|
||||
agent.context.try_lock().ok()
|
||||
.map(|ctx| {
|
||||
let mut views = Vec::new();
|
||||
views.push(section_to_view("System", ctx.system()));
|
||||
views.push(section_to_view("Identity", ctx.identity()));
|
||||
views.push(section_to_view("Journal", ctx.journal()));
|
||||
{
|
||||
let ctx = agent.context.lock_blocking();
|
||||
let mut views = Vec::new();
|
||||
views.push(section_to_view("System", ctx.system()));
|
||||
views.push(section_to_view("Identity", ctx.identity()));
|
||||
views.push(section_to_view("Journal", ctx.journal()));
|
||||
|
||||
// Conversation: skip to fork point for subconscious agents
|
||||
let conv = ctx.conversation();
|
||||
let conv_view = section_to_view("Conversation", conv);
|
||||
let fork = fork_point.min(conv_view.children.len());
|
||||
let conv_children: Vec<SectionView> = conv_view.children
|
||||
.into_iter().skip(fork).collect();
|
||||
views.push(SectionView {
|
||||
name: format!("Conversation ({} entries)", conv_children.len()),
|
||||
tokens: conv_children.iter().map(|c| c.tokens).sum(),
|
||||
content: String::new(),
|
||||
token_ids: Vec::new(),
|
||||
children: conv_children,
|
||||
status: String::new(),
|
||||
});
|
||||
// Conversation: skip to fork point for subconscious agents
|
||||
let conv = ctx.conversation();
|
||||
let conv_view = section_to_view("Conversation", conv);
|
||||
let fork = fork_point.min(conv_view.children.len());
|
||||
let conv_children: Vec<SectionView> = conv_view.children
|
||||
.into_iter().skip(fork).collect();
|
||||
views.push(SectionView {
|
||||
name: format!("Conversation ({} entries)", conv_children.len()),
|
||||
tokens: conv_children.iter().map(|c| c.tokens).sum(),
|
||||
content: String::new(),
|
||||
token_ids: Vec::new(),
|
||||
children: conv_children,
|
||||
status: String::new(),
|
||||
});
|
||||
|
||||
views
|
||||
})
|
||||
.unwrap_or_default()
|
||||
views
|
||||
}
|
||||
}
|
||||
|
||||
fn draw_list(&mut self, frame: &mut Frame, area: Rect, app: &App) {
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ impl ScreenView for ThalamusScreen {
|
|||
}
|
||||
KeyCode::Char('t') => {
|
||||
app.think_native = !app.think_native;
|
||||
if let Ok(mut st) = app.agent.state.try_lock() {
|
||||
{ let mut st = app.agent.state.lock_blocking();
|
||||
st.think_native = app.think_native;
|
||||
let status = if app.think_native { "enabled" } else { "disabled" };
|
||||
st.notify(format!("native thinking {}", status));
|
||||
|
|
@ -53,7 +53,7 @@ impl ScreenView for ThalamusScreen {
|
|||
}
|
||||
KeyCode::Char('T') => {
|
||||
app.think_tool = !app.think_tool;
|
||||
if let Ok(mut st) = app.agent.state.try_lock() {
|
||||
{ let mut st = app.agent.state.lock_blocking();
|
||||
st.think_tool = app.think_tool;
|
||||
// Add or remove the think tool from the tools list
|
||||
if app.think_tool {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue