Previous run was grinding on CPU for 36+ minutes because the per-story V_i tensors were stored on CPU by the collector, and _subspace_concept_direction inherited that device. The per-concept eigh on 5120x5120 is glacial on CPU and fast on GPU (~1s). Add explicit device parameter; pass training device. Transfer result back to CPU for storage. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
1269 lines
48 KiB
Python
1269 lines
48 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
"""Train concept-readout vectors via Contrastive Activation Addition.
|
||
|
||
Reads the hand-written story corpus at
|
||
``amygdala_stories/{stories,paired}/`` and produces the per-layer
|
||
safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
|
||
loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
|
||
|
||
Training data (cross-concept contrast):
|
||
|
||
positive for emotion E:
|
||
stories/E.txt
|
||
paired/<scenario>/E.txt (for each scenario that covers E)
|
||
|
||
negative for emotion E:
|
||
stories/<all other emotions>.txt
|
||
paired/<scenario>/baseline.txt (for each scenario)
|
||
|
||
Within-scenario paired stories are the highest-signal pairs (same
|
||
content, different concept framing); unpaired stories provide bulk
|
||
contrast across the 80 emotions we have written so far.
|
||
|
||
Pooling: last non-pad token. Matches how readout is consumed at decode
|
||
time (residual read at the sampler's query position).
|
||
|
||
Output:
|
||
|
||
readout.safetensors
|
||
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
|
||
readout.json
|
||
{
|
||
"concepts": [...],
|
||
"layers": [...],
|
||
"hidden_size": int,
|
||
"dtype": "float16"
|
||
}
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import gc
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
|
||
import safetensors.torch
|
||
import torch
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
||
|
||
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||
"""Pick the last non-pad token's hidden state per example.
|
||
|
||
hidden: [batch, seq, hidden_dim]
|
||
attention_mask: [batch, seq]
|
||
returns: [batch, hidden_dim]
|
||
"""
|
||
last_idx = attention_mask.sum(dim=1) - 1
|
||
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
||
return hidden[batch_idx, last_idx]
|
||
|
||
|
||
def _find_layers_module(model) -> torch.nn.ModuleList:
|
||
"""Walk a few likely paths to find the transformer-block list."""
|
||
candidates = [
|
||
"model.layers",
|
||
"model.model.layers",
|
||
"model.language_model.layers",
|
||
"model.language_model.model.layers",
|
||
"language_model.model.layers",
|
||
"transformer.h",
|
||
]
|
||
for path in candidates:
|
||
obj = model
|
||
ok = True
|
||
for part in path.split("."):
|
||
if not hasattr(obj, part):
|
||
ok = False
|
||
break
|
||
obj = getattr(obj, part)
|
||
if ok and isinstance(obj, torch.nn.ModuleList):
|
||
return obj
|
||
raise RuntimeError(
|
||
f"Couldn't find transformer layer list. Tried: {candidates}"
|
||
)
|
||
|
||
|
||
def _collect_activations(
|
||
model,
|
||
tokenizer,
|
||
texts: list[str],
|
||
target_layers: list[int],
|
||
device: torch.device,
|
||
batch_size: int,
|
||
max_length: int,
|
||
*,
|
||
label: str = "",
|
||
) -> torch.Tensor:
|
||
"""Run texts through the model, capture residual stream at target
|
||
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
||
"""
|
||
import time
|
||
|
||
assert all(isinstance(t, str) and t for t in texts), (
|
||
f"_collect_activations: empty or non-string text in {label!r}"
|
||
)
|
||
|
||
captures: dict[int, torch.Tensor] = {}
|
||
|
||
def make_hook(idx: int):
|
||
def hook(_mod, _inp, output):
|
||
hs = output[0] if isinstance(output, tuple) else output
|
||
captures[idx] = hs.detach()
|
||
return hook
|
||
|
||
layers_module = _find_layers_module(model)
|
||
handles = [
|
||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||
for idx in target_layers
|
||
]
|
||
|
||
out_rows: list[torch.Tensor] = []
|
||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||
start = time.time()
|
||
try:
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||
batch = texts[i : i + batch_size]
|
||
tok = tokenizer(
|
||
batch,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=max_length,
|
||
).to(device)
|
||
captures.clear()
|
||
model(**tok)
|
||
|
||
per_layer = [
|
||
_pool_last(captures[idx], tok["attention_mask"])
|
||
.to(torch.float32)
|
||
.cpu()
|
||
for idx in target_layers
|
||
]
|
||
out_rows.append(torch.stack(per_layer, dim=1))
|
||
del tok, captures
|
||
if b_idx % 10 == 0:
|
||
torch.cuda.empty_cache()
|
||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||
elapsed = time.time() - start
|
||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||
print(
|
||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||
flush=True,
|
||
)
|
||
captures = {}
|
||
finally:
|
||
for h in handles:
|
||
h.remove()
|
||
|
||
return torch.cat(out_rows, dim=0)
|
||
|
||
|
||
def _collect_per_story_subspaces(
|
||
model,
|
||
tokenizer,
|
||
texts: list[str],
|
||
target_layers: list[int],
|
||
device: torch.device,
|
||
batch_size: int,
|
||
max_length: int,
|
||
*,
|
||
k: int = 20,
|
||
label: str = "",
|
||
) -> list[dict[int, torch.Tensor]]:
|
||
"""Run texts through the model, capture the full per-token residual-stream
|
||
activations at each target layer, do SVD per story, return the top-k right
|
||
singular vectors.
|
||
|
||
Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to
|
||
a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the
|
||
subspace the story's tokens span in activation space at that layer).
|
||
|
||
The per-story subspace captures *all* the directions a story occupies —
|
||
concept, narrator, topic, style. Finding the direction common to stories of
|
||
the same concept (via the sum of V_i V_i^T and its top eigenvector)
|
||
cancels nuisance directions that differ across stories while preserving
|
||
directions they share.
|
||
"""
|
||
import time
|
||
|
||
assert all(isinstance(t, str) and t for t in texts), (
|
||
f"_collect_per_story_subspaces: empty or non-string text in {label!r}"
|
||
)
|
||
|
||
captures: dict[int, torch.Tensor] = {}
|
||
|
||
def make_hook(idx: int):
|
||
def hook(_mod, _inp, output):
|
||
hs = output[0] if isinstance(output, tuple) else output
|
||
captures[idx] = hs.detach()
|
||
return hook
|
||
|
||
layers_module = _find_layers_module(model)
|
||
handles = [
|
||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||
for idx in target_layers
|
||
]
|
||
|
||
# One entry per text: {layer_idx: V[hidden, k]}
|
||
out: list[dict[int, torch.Tensor]] = [
|
||
{} for _ in range(len(texts))
|
||
]
|
||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||
start = time.time()
|
||
try:
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||
batch = texts[i : i + batch_size]
|
||
tok = tokenizer(
|
||
batch,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=max_length,
|
||
).to(device)
|
||
captures.clear()
|
||
model(**tok)
|
||
|
||
# For each item in the batch, for each layer, SVD on the
|
||
# non-pad tokens.
|
||
attn = tok["attention_mask"]
|
||
for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()):
|
||
story_idx = i + t_idx_in_batch
|
||
for l_idx, layer in enumerate(target_layers):
|
||
hs = captures[layer][t_idx_in_batch, :n_tok, :]
|
||
# Center tokens so SVD captures variation within story,
|
||
# not the story's center-of-mass:
|
||
hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0)
|
||
# SVD: hs = U Σ V^T; V has hidden-dim columns.
|
||
# For n_tok < k, the subspace rank is bounded by n_tok.
|
||
try:
|
||
_u, _s, vh = torch.linalg.svd(hs, full_matrices=False)
|
||
except Exception:
|
||
# Degenerate case (all-zero hs, n_tok=1): fall back
|
||
# to the last-token vector itself, unit-normed.
|
||
vec = captures[layer][t_idx_in_batch, n_tok - 1, :]
|
||
vec = vec.to(torch.float32)
|
||
nrm = vec.norm().clamp_min(1e-6)
|
||
vh = (vec / nrm).unsqueeze(0) # [1, hidden]
|
||
# Take top-k rows of V^T (= top-k right singular vecs).
|
||
top = min(k, vh.shape[0])
|
||
V = vh[:top].t().contiguous().cpu() # [hidden, top]
|
||
out[story_idx][layer] = V
|
||
del tok, captures
|
||
if b_idx % 10 == 0:
|
||
torch.cuda.empty_cache()
|
||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||
elapsed = time.time() - start
|
||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||
print(
|
||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||
flush=True,
|
||
)
|
||
captures = {}
|
||
finally:
|
||
for h in handles:
|
||
h.remove()
|
||
|
||
return out
|
||
|
||
|
||
def _subspace_concept_direction(
|
||
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
||
base_V: list[torch.Tensor],
|
||
hidden: int,
|
||
*,
|
||
top_k: int = 5,
|
||
device: torch.device | None = None,
|
||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
"""Subspace-common-direction CAA alternative.
|
||
|
||
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
||
same over baselines. Returns a weighted sum of the top-k eigenvectors of
|
||
(M_pos - M_base), weights = eigenvalues (so stronger common directions
|
||
contribute more), unit-normed. Returns the full eigenvalue spectrum for
|
||
diagnostics.
|
||
|
||
top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
|
||
captures richer structure when the concept lives in a multi-dimensional
|
||
shared subspace — which the flat eigenvalue spectrum observed in
|
||
practice suggests is the common case. Selection happens AFTER the
|
||
eigendecomposition so nothing is lost up to that point.
|
||
"""
|
||
if device is None:
|
||
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||
dtype = torch.float32
|
||
|
||
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
|
||
if not Vs:
|
||
return torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||
M = torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||
for V in Vs:
|
||
V = V.to(dtype=dtype, device=device)
|
||
M.addmm_(V, V.t())
|
||
M /= len(Vs)
|
||
return M
|
||
|
||
M_pos = acc(pos_V)
|
||
M_base = acc(base_V)
|
||
M = M_pos - M_base
|
||
|
||
# Symmetric eigendecomposition.
|
||
eigvals, eigvecs = torch.linalg.eigh(M)
|
||
# eigh returns ascending; top-k are the last k columns.
|
||
k = max(1, min(top_k, eigvecs.shape[1]))
|
||
top_vals = eigvals[-k:] # [k], ascending within top-k
|
||
top_vecs = eigvecs[:, -k:] # [hidden, k]
|
||
# Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
|
||
# negative weights to 0 (wrong-sign directions shouldn't contribute).
|
||
w = top_vals.clamp_min(0.0)
|
||
combined = top_vecs @ w # [hidden]
|
||
combined = combined / combined.norm().clamp_min(1e-6)
|
||
return combined, eigvals
|
||
|
||
|
||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
||
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
||
]:
|
||
"""Return ``(positives_by_emotion, baselines)``.
|
||
|
||
Cross-concept negatives are computed at training time from
|
||
``positives_by_emotion`` — each emotion's negative set is the
|
||
union of all other emotions' positives plus the baseline texts.
|
||
Empty .txt files are skipped with a warning.
|
||
"""
|
||
def _read_nonempty(path: Path) -> str | None:
|
||
text = path.read_text().strip()
|
||
if not text:
|
||
print(
|
||
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
|
||
)
|
||
return None
|
||
return text
|
||
|
||
positives: dict[str, list[str]] = {}
|
||
for story_path in sorted(stories_dir.glob("*.txt")):
|
||
text = _read_nonempty(story_path)
|
||
if text is None:
|
||
continue
|
||
emotion = story_path.stem
|
||
positives.setdefault(emotion, []).append(text)
|
||
|
||
baselines: list[str] = []
|
||
if paired_dir is not None and paired_dir.exists():
|
||
for scenario_dir in sorted(paired_dir.iterdir()):
|
||
if not scenario_dir.is_dir():
|
||
continue
|
||
baseline_path = scenario_dir / "baseline.txt"
|
||
if baseline_path.exists():
|
||
text = _read_nonempty(baseline_path)
|
||
if text is not None:
|
||
baselines.append(text)
|
||
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
||
if framing_path.stem == "baseline":
|
||
continue
|
||
text = _read_nonempty(framing_path)
|
||
if text is None:
|
||
continue
|
||
emotion = framing_path.stem
|
||
positives.setdefault(emotion, []).append(text)
|
||
|
||
return positives, baselines
|
||
|
||
|
||
def _find_o_proj(layer) -> torch.nn.Module | None:
|
||
"""Locate the attention output projection within a transformer layer."""
|
||
for path in (
|
||
"self_attn.o_proj",
|
||
"self_attn.out_proj",
|
||
"attention.o_proj",
|
||
"attn.out_proj",
|
||
):
|
||
obj = layer
|
||
ok = True
|
||
for part in path.split("."):
|
||
if not hasattr(obj, part):
|
||
ok = False
|
||
break
|
||
obj = getattr(obj, part)
|
||
if ok:
|
||
return obj
|
||
return None
|
||
|
||
|
||
def _collect_attention_inputs(
|
||
model,
|
||
tokenizer,
|
||
texts: list[str],
|
||
target_layers: list[int],
|
||
device: torch.device,
|
||
batch_size: int,
|
||
max_length: int,
|
||
*,
|
||
label: str = "",
|
||
) -> tuple[torch.Tensor, list[int]]:
|
||
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
|
||
attention outputs right before the output projection).
|
||
|
||
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
|
||
The active_layers list is the subset of target_layers whose attention
|
||
module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.)
|
||
may be silently skipped.
|
||
"""
|
||
import time
|
||
|
||
layers_module = _find_layers_module(model)
|
||
captures: dict[int, torch.Tensor] = {}
|
||
handles = []
|
||
active_layers: list[int] = []
|
||
|
||
def make_hook(idx: int):
|
||
def hook(_mod, inputs):
|
||
x = inputs[0] if isinstance(inputs, tuple) else inputs
|
||
captures[idx] = x.detach()
|
||
return hook
|
||
|
||
for idx in target_layers:
|
||
o_proj = _find_o_proj(layers_module[idx])
|
||
if o_proj is not None:
|
||
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
|
||
active_layers.append(idx)
|
||
|
||
if not active_layers:
|
||
return torch.zeros(0, 0, 0), []
|
||
|
||
out_rows: list[torch.Tensor] = []
|
||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||
start = time.time()
|
||
try:
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||
batch = texts[i : i + batch_size]
|
||
tok = tokenizer(
|
||
batch,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=max_length,
|
||
).to(device)
|
||
captures.clear()
|
||
model(**tok)
|
||
|
||
per_layer = [
|
||
_pool_last(captures[idx], tok["attention_mask"])
|
||
.to(torch.float32)
|
||
.cpu()
|
||
for idx in active_layers
|
||
]
|
||
out_rows.append(torch.stack(per_layer, dim=1))
|
||
del tok, captures
|
||
if b_idx % 10 == 0:
|
||
torch.cuda.empty_cache()
|
||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||
elapsed = time.time() - start
|
||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||
print(
|
||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||
flush=True,
|
||
)
|
||
captures = {}
|
||
finally:
|
||
for h in handles:
|
||
h.remove()
|
||
|
||
return torch.cat(out_rows, dim=0), active_layers
|
||
|
||
|
||
def _compute_per_head_ranking(
|
||
emotions: list[str],
|
||
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
|
||
baseline_attn_inputs: torch.Tensor,
|
||
positives_by_emotion: dict[str, list[str]],
|
||
text_to_row: dict[str, int],
|
||
active_layers: list[int],
|
||
n_heads_per_layer: dict[int, int],
|
||
text_to_emotion: dict[str, str],
|
||
unique_positive_texts: list[str],
|
||
) -> dict:
|
||
"""For each concept, rank attention heads by contribution magnitude.
|
||
|
||
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
|
||
compute diff-of-means between positives and negatives per head, rank
|
||
heads by the L2 norm of that diff. The top heads are the ones most
|
||
strongly implicated in the concept circuit.
|
||
|
||
Why this matters: meta-relational concepts (trust, recognition,
|
||
"seen") often don't give a strong residual-stream diff-of-means but
|
||
DO give a strong per-head signal — the concept lives in a small
|
||
attention circuit rather than in the residual-stream sum.
|
||
"""
|
||
result: dict[str, dict] = {}
|
||
|
||
for e_idx, emotion in enumerate(emotions):
|
||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||
neg_rows = [
|
||
i
|
||
for i, t in enumerate(unique_positive_texts)
|
||
if text_to_emotion[t] != emotion
|
||
]
|
||
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
|
||
neg = attn_inputs[neg_rows]
|
||
if baseline_attn_inputs.shape[0] > 0:
|
||
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
|
||
|
||
per_layer: dict[str, list] = {}
|
||
for l_idx, target_l in enumerate(active_layers):
|
||
n_heads = n_heads_per_layer.get(target_l)
|
||
if not n_heads:
|
||
continue
|
||
hidden = pos.shape[-1]
|
||
if hidden % n_heads != 0:
|
||
continue
|
||
head_dim = hidden // n_heads
|
||
|
||
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
|
||
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
|
||
|
||
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
|
||
head_norms = diff.norm(dim=-1) # [n_heads]
|
||
# Normalise by neg variance per head so different-scale heads
|
||
# don't dominate purely on activation magnitude.
|
||
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
|
||
head_selectivity = head_norms / neg_std # [n_heads]
|
||
|
||
k = min(10, n_heads)
|
||
top_vals, top_idxs = head_selectivity.topk(k)
|
||
top_heads = [
|
||
[int(i), float(head_norms[i]), float(head_selectivity[i])]
|
||
for i in top_idxs
|
||
]
|
||
per_layer[str(target_l)] = {
|
||
"n_heads": n_heads,
|
||
"head_dim": head_dim,
|
||
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
|
||
"head_concentration": float(
|
||
# fraction of total head-norm captured by top-k
|
||
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
|
||
),
|
||
}
|
||
|
||
result[emotion] = {"per_layer": per_layer}
|
||
|
||
return result
|
||
|
||
|
||
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
|
||
"""Best-effort read of num_attention_heads per layer. Qwen uses the
|
||
top-level config; falls back to config.num_attention_heads.
|
||
"""
|
||
cfg = model.config
|
||
if hasattr(cfg, "get_text_config"):
|
||
cfg = cfg.get_text_config()
|
||
n = getattr(cfg, "num_attention_heads", None)
|
||
if n is None:
|
||
return {}
|
||
return {l: n for l in target_layers}
|
||
|
||
|
||
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
|
||
"""Return the W_down weight for the MLP at the given transformer layer.
|
||
|
||
Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj).
|
||
Returns None if nothing matches — downstream code skips the single-neuron
|
||
alignment check in that case rather than failing.
|
||
"""
|
||
layers = _find_layers_module(model)
|
||
layer = layers[layer_idx]
|
||
for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"):
|
||
obj = layer
|
||
ok = True
|
||
for part in path.split("."):
|
||
if not hasattr(obj, part):
|
||
ok = False
|
||
break
|
||
obj = getattr(obj, part)
|
||
if ok and hasattr(obj, "weight"):
|
||
# Shape convention: [hidden, mlp_inner] — each column is one
|
||
# MLP neuron's contribution direction into the residual stream.
|
||
return obj.weight.detach()
|
||
return None
|
||
|
||
|
||
def _compute_quality_report(
|
||
emotions: list[str],
|
||
positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden]
|
||
baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden]
|
||
positives_by_emotion: dict[str, list[str]],
|
||
text_to_row: dict[str, int],
|
||
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
|
||
target_layers: list[int],
|
||
model,
|
||
positive_texts: list[str],
|
||
text_to_emotion: dict[str, str],
|
||
) -> dict:
|
||
"""Per-concept quality metrics:
|
||
|
||
- first_pc_variance_ratio: SVD on centered positive activations.
|
||
>0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree).
|
||
- story_projection_*: how each positive story projects onto the
|
||
concept direction. Low std = tight agreement.
|
||
- best_neuron_cosine: alignment of the residual-space direction with
|
||
the nearest W_down column (= single MLP neuron). >0.6 = essentially
|
||
single-neuron.
|
||
- nearest_concepts: top-5 concept directions most parallel to this
|
||
one. Cosine >0.8 means the vector is confused with a neighbor.
|
||
"""
|
||
report: dict = {}
|
||
n_layers = per_layer_vectors.shape[0]
|
||
|
||
# Pre-compute per-layer W_down for single-neuron alignment. Keep on
|
||
# CPU to match the per_layer_vectors tensor.
|
||
w_down: dict[int, torch.Tensor] = {}
|
||
for target_l in target_layers:
|
||
w = _find_mlp_down_proj(model, target_l)
|
||
if w is not None:
|
||
# Unit-normalize each column (one per MLP neuron).
|
||
w = w.to(torch.float32).cpu()
|
||
norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6)
|
||
w_down[target_l] = w / norms # [hidden, mlp_inner]
|
||
|
||
# Pre-compute unit-normed concept vectors (for cross-concept cosines).
|
||
vec_norm = per_layer_vectors / per_layer_vectors.norm(
|
||
dim=-1, keepdim=True
|
||
).clamp_min(1e-6)
|
||
|
||
for e_idx, emotion in enumerate(emotions):
|
||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||
pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden]
|
||
|
||
per_layer: dict = {}
|
||
for l_idx, target_l in enumerate(target_layers):
|
||
pos_l = pos[:, l_idx, :] # [n_pos, hidden]
|
||
diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed
|
||
pos_mean_l = pos_l.mean(dim=0)
|
||
|
||
# SVD for rank analysis — if first PC dominates, stories agree.
|
||
centered = pos_l - pos_mean_l
|
||
# svdvals errors on 1-row; handle that.
|
||
if centered.shape[0] >= 2:
|
||
S = torch.linalg.svdvals(centered)
|
||
var = S ** 2
|
||
var_total = var.sum().clamp_min(1e-12)
|
||
var_ratios = (var / var_total).tolist()
|
||
else:
|
||
var_ratios = [1.0]
|
||
|
||
# Per-story projection onto the concept direction.
|
||
projections = pos_l @ diff_l # [n_pos]
|
||
|
||
# Per-story alignment: cosine(story_dir, concept_dir) where
|
||
# story_dir = pos_i - pos_mean (centered, pointing away from center).
|
||
if centered.shape[0] >= 2:
|
||
centered_norm = centered / centered.norm(
|
||
dim=-1, keepdim=True
|
||
).clamp_min(1e-6)
|
||
alignments = centered_norm @ diff_l
|
||
else:
|
||
alignments = torch.zeros(1)
|
||
|
||
# Single-neuron alignment: is the direction close to any
|
||
# W_down column?
|
||
nb_best_idx = None
|
||
nb_best_cos = None
|
||
nb_top5 = None
|
||
if target_l in w_down:
|
||
W = w_down[target_l]
|
||
cos = W.t() @ diff_l # [mlp_inner]
|
||
abs_cos = cos.abs()
|
||
k = min(5, abs_cos.shape[0])
|
||
top_vals, top_idxs = abs_cos.topk(k)
|
||
nb_best_idx = int(top_idxs[0])
|
||
nb_best_cos = float(cos[top_idxs[0]])
|
||
nb_top5 = [[int(i), float(cos[i])] for i in top_idxs]
|
||
|
||
per_layer[str(target_l)] = {
|
||
"top3_variance_ratios": [
|
||
float(v) for v in var_ratios[:3]
|
||
],
|
||
"first_pc_variance_ratio": float(var_ratios[0]),
|
||
"story_projection_mean": float(projections.mean()),
|
||
"story_projection_std": float(projections.std()),
|
||
"story_projection_min": float(projections.min()),
|
||
"story_projection_max": float(projections.max()),
|
||
"story_alignment_mean": float(alignments.mean()),
|
||
"story_alignment_std": float(alignments.std()),
|
||
"best_neuron_idx": nb_best_idx,
|
||
"best_neuron_cosine": nb_best_cos,
|
||
"top5_neurons": nb_top5,
|
||
}
|
||
|
||
# Outlier stories: lowest-aligned on the middle target layer.
|
||
mid = n_layers // 2
|
||
pos_l_mid = pos[:, mid, :]
|
||
mid_mean = pos_l_mid.mean(dim=0)
|
||
mid_diff = per_layer_vectors[mid, e_idx]
|
||
centered_mid = pos_l_mid - mid_mean
|
||
if centered_mid.shape[0] >= 2:
|
||
centered_mid_norm = centered_mid / centered_mid.norm(
|
||
dim=-1, keepdim=True
|
||
).clamp_min(1e-6)
|
||
mid_aligns = centered_mid_norm @ mid_diff # [n_pos]
|
||
# Lowest two alignments = candidate outliers.
|
||
k = min(2, mid_aligns.shape[0])
|
||
low_vals, low_idxs = mid_aligns.topk(k, largest=False)
|
||
outliers = [
|
||
[
|
||
positives_by_emotion[emotion][int(i)],
|
||
float(mid_aligns[i]),
|
||
]
|
||
for i in low_idxs
|
||
]
|
||
else:
|
||
outliers = []
|
||
|
||
# Nearest other concepts at the middle target layer.
|
||
this_norm = vec_norm[mid, e_idx]
|
||
all_cos = vec_norm[mid] @ this_norm # [n_concepts]
|
||
all_cos[e_idx] = -2.0 # mask self
|
||
k = min(5, all_cos.shape[0] - 1)
|
||
top_vals, top_idxs = all_cos.topk(k)
|
||
nearest = [
|
||
[emotions[int(i)], float(v)]
|
||
for i, v in zip(top_idxs, top_vals)
|
||
]
|
||
|
||
report[emotion] = {
|
||
"n_positive_stories": len(pos_rows),
|
||
"per_layer": per_layer,
|
||
"outlier_stories": outliers,
|
||
"nearest_concepts": nearest,
|
||
}
|
||
|
||
return report
|
||
|
||
|
||
def _compute_linear_combinations(
|
||
emotions: list[str],
|
||
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
|
||
target_layers: list[int],
|
||
*,
|
||
ridge_lambda: float = 0.01,
|
||
top_k: int = 5,
|
||
) -> dict:
|
||
"""For each concept, ridge-regress its direction against all other
|
||
concept directions. Report R² (how much of the target direction is
|
||
explained by a linear combination of others) + top contributors.
|
||
|
||
R² > 0.9 = concept is essentially a linear combination of others
|
||
(redundant, or part of a cluster that needs disambiguating)
|
||
R² < 0.5 = concept has a substantial unique component
|
||
ridge_lambda keeps the coefficients stable when concepts are near-collinear.
|
||
"""
|
||
n_layers, n_concepts, hidden = per_layer_vectors.shape
|
||
result: dict[str, dict] = {}
|
||
|
||
# Middle layer for summary — same convention as nearest_concepts.
|
||
mid = n_layers // 2
|
||
|
||
for l_idx, target_l in enumerate(target_layers):
|
||
V = per_layer_vectors[l_idx] # [n_concepts, hidden]
|
||
|
||
for i, name in enumerate(emotions):
|
||
target = V[i] # [hidden]
|
||
mask = torch.arange(n_concepts) != i
|
||
others = V[mask] # [n-1, hidden]
|
||
|
||
# Ridge: solve (O O^T + lam I) alpha = O t
|
||
OOt = others @ others.t() # [n-1, n-1]
|
||
b = others @ target # [n-1]
|
||
A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype)
|
||
alpha = torch.linalg.solve(A, b)
|
||
|
||
recon = others.t() @ alpha # [hidden]
|
||
resid = target - recon
|
||
t_sq = (target * target).sum().clamp_min(1e-12)
|
||
r2 = 1.0 - (resid * resid).sum() / t_sq
|
||
|
||
abs_alpha = alpha.abs()
|
||
k = min(top_k, n_concepts - 1)
|
||
top_vals, top_idxs = abs_alpha.topk(k)
|
||
other_names = [emotions[j] for j in range(n_concepts) if j != i]
|
||
top = [
|
||
[other_names[int(j)], float(alpha[j])]
|
||
for j in top_idxs
|
||
]
|
||
|
||
entry = result.setdefault(name, {})
|
||
entry.setdefault("per_layer", {})[str(target_l)] = {
|
||
"r_squared": float(r2),
|
||
"residual_norm": float(resid.norm()),
|
||
"top_contributors": top,
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(description=__doc__)
|
||
ap.add_argument("--model", required=True, help="HF model id or path")
|
||
ap.add_argument(
|
||
"--stories-dir",
|
||
required=True,
|
||
help="Path to amygdala_stories/stories/",
|
||
)
|
||
ap.add_argument(
|
||
"--paired-dir",
|
||
default=None,
|
||
help="Path to amygdala_stories/paired/ (optional)",
|
||
)
|
||
ap.add_argument(
|
||
"--target-layers",
|
||
required=True,
|
||
help="Comma-separated layer indices, e.g. 40,50,60,70",
|
||
)
|
||
ap.add_argument(
|
||
"--output-dir",
|
||
required=True,
|
||
help="Directory to write readout.safetensors + readout.json",
|
||
)
|
||
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||
ap.add_argument("--batch-size", type=int, default=2)
|
||
ap.add_argument("--max-length", type=int, default=512)
|
||
ap.add_argument("--device", default="cuda:0")
|
||
ap.add_argument(
|
||
"--min-positives",
|
||
type=int,
|
||
default=1,
|
||
help="Skip emotions with fewer positive examples than this",
|
||
)
|
||
ap.add_argument(
|
||
"--method",
|
||
default="pooled",
|
||
choices=["pooled", "subspace"],
|
||
help="Concept-extraction method: 'pooled' (classic CAA, "
|
||
"pos_mean - neg_mean on last-token activations) or 'subspace' "
|
||
"(per-story SVD; top eigenvector of Σ V_i V_i^T for positives "
|
||
"minus same for baselines — captures what's common across "
|
||
"stories' full-trajectory subspaces)",
|
||
)
|
||
ap.add_argument(
|
||
"--subspace-k",
|
||
type=int,
|
||
default=99999,
|
||
help="Max top-k right singular vectors per story for subspace method "
|
||
"(clamped to min(n_tokens, hidden_dim) per story). Default is "
|
||
"effectively 'keep full per-story subspace' — each story's V_i "
|
||
"spans its entire natural row space. On a hidden_dim=5120 "
|
||
"residual and ~500-token stories, that's ~500 vectors per story. "
|
||
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
|
||
)
|
||
ap.add_argument(
|
||
"--subspace-eigen-k",
|
||
type=int,
|
||
default=5,
|
||
help="Number of top eigenvectors of M_pos - M_base to combine into "
|
||
"the concept direction. Weighted sum by eigenvalue (so strongest "
|
||
"common directions contribute most). eigen_k=1 recovers "
|
||
"single-eigenvector behavior. Higher values (5-10) capture "
|
||
"richer structure when the concept's shared-subspace spectrum "
|
||
"is flat (which it tends to be in practice).",
|
||
)
|
||
ap.add_argument(
|
||
"--quality-report",
|
||
action="store_true",
|
||
help="After training, compute a per-concept quality report "
|
||
"(SVD rank, per-story alignment, single-neuron alignment, "
|
||
"nearest-concept contamination) and write quality.json",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||
dtype = {
|
||
"bf16": torch.bfloat16,
|
||
"fp16": torch.float16,
|
||
"fp32": torch.float32,
|
||
}[args.dtype]
|
||
|
||
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
|
||
stories_dir = Path(args.stories_dir)
|
||
if not stories_dir.is_dir():
|
||
raise FileNotFoundError(
|
||
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
|
||
)
|
||
if args.paired_dir is not None:
|
||
pd = Path(args.paired_dir)
|
||
if not pd.is_dir():
|
||
raise FileNotFoundError(
|
||
f"--paired-dir {pd!s} does not exist or is not a dir"
|
||
)
|
||
|
||
# Quick corpus pre-scan so failures show up before we load the model.
|
||
positives_preview, baselines_preview = _load_corpus(
|
||
stories_dir,
|
||
Path(args.paired_dir) if args.paired_dir else None,
|
||
)
|
||
n_emotions_preview = sum(
|
||
1 for ps in positives_preview.values()
|
||
if len(ps) >= args.min_positives
|
||
)
|
||
if n_emotions_preview == 0:
|
||
raise RuntimeError(
|
||
f"corpus has 0 emotions with >= {args.min_positives} positive "
|
||
f"examples. Check {stories_dir} — is it the right directory?"
|
||
)
|
||
print(
|
||
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
|
||
f"{args.min_positives}), {len(baselines_preview)} baselines"
|
||
)
|
||
|
||
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||
if tokenizer.pad_token_id is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.model,
|
||
torch_dtype=dtype,
|
||
device_map=args.device,
|
||
low_cpu_mem_usage=True,
|
||
)
|
||
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
|
||
# dimensions under a text_config subobject. get_text_config()
|
||
# returns that sub-config when present, else the top-level config.
|
||
text_config = (
|
||
model.config.get_text_config()
|
||
if hasattr(model.config, "get_text_config")
|
||
else model.config
|
||
)
|
||
hidden_dim = text_config.hidden_size
|
||
n_model_layers = text_config.num_hidden_layers
|
||
print(
|
||
f"Model loaded. hidden_dim={hidden_dim}, "
|
||
f"n_model_layers={n_model_layers} "
|
||
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
|
||
)
|
||
|
||
for layer_idx in target_layers:
|
||
if layer_idx < 0 or layer_idx >= n_model_layers:
|
||
raise ValueError(
|
||
f"target layer {layer_idx} out of range "
|
||
f"[0, {n_model_layers})"
|
||
)
|
||
print(
|
||
"Target layers (relative depth): "
|
||
+ ", ".join(
|
||
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
|
||
for l in target_layers
|
||
)
|
||
)
|
||
|
||
positives_by_emotion, baselines = _load_corpus(
|
||
Path(args.stories_dir),
|
||
Path(args.paired_dir) if args.paired_dir else None,
|
||
)
|
||
emotions = sorted(
|
||
e for e, ps in positives_by_emotion.items()
|
||
if len(ps) >= args.min_positives
|
||
)
|
||
if not emotions:
|
||
raise RuntimeError(
|
||
f"No emotions with >= {args.min_positives} positive examples"
|
||
)
|
||
print(
|
||
f"Training {len(emotions)} emotions; "
|
||
f"{len(baselines)} baseline scenarios"
|
||
)
|
||
|
||
# Cache all positive-text activations once so we can reuse them as
|
||
# negatives for other emotions. Keyed by the text itself to dedup
|
||
# across emotion lists.
|
||
device = torch.device(args.device)
|
||
text_to_emotion: dict[str, str] = {}
|
||
for emotion, texts in positives_by_emotion.items():
|
||
for t in texts:
|
||
text_to_emotion[t] = emotion
|
||
|
||
unique_positive_texts = list(text_to_emotion.keys())
|
||
print(
|
||
f"Collecting activations for {len(unique_positive_texts)} unique "
|
||
f"positive texts + {len(baselines)} baselines..."
|
||
)
|
||
|
||
positive_acts = _collect_activations(
|
||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||
args.batch_size, args.max_length, label="positives",
|
||
)
|
||
# positive_acts[i] corresponds to unique_positive_texts[i]
|
||
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
||
|
||
baseline_acts = (
|
||
_collect_activations(
|
||
model, tokenizer, baselines, target_layers, device,
|
||
args.batch_size, args.max_length, label="baselines",
|
||
)
|
||
if baselines
|
||
else torch.zeros(0, len(target_layers), hidden_dim)
|
||
)
|
||
|
||
n_concepts = len(emotions)
|
||
n_layers = len(target_layers)
|
||
|
||
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
|
||
per_layer_vectors = torch.zeros(
|
||
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
||
)
|
||
|
||
# --- Subspace method: collect per-story right-singular-vector subspaces
|
||
# and use sum-of-projection-operators per concept. --------------------
|
||
pos_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||
base_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||
# Per (concept, layer): top-20 eigenvalues of (M_pos - M_base), descending.
|
||
# Populated only when --method subspace.
|
||
subspace_eigvals: dict[str, dict[int, list[float]]] = {}
|
||
if args.method == "subspace":
|
||
print("\nCollecting per-story subspaces (SVD, top-k right singular "
|
||
f"vectors, k={args.subspace_k})...")
|
||
pos_subspaces = _collect_per_story_subspaces(
|
||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||
args.batch_size, args.max_length, k=args.subspace_k,
|
||
label="subsp-pos",
|
||
)
|
||
if baselines:
|
||
base_subspaces = _collect_per_story_subspaces(
|
||
model, tokenizer, baselines, target_layers, device,
|
||
args.batch_size, args.max_length, k=args.subspace_k,
|
||
label="subsp-base",
|
||
)
|
||
else:
|
||
base_subspaces = []
|
||
|
||
for e_idx, emotion in enumerate(emotions):
|
||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||
# Negatives: every OTHER emotion's positives + baselines.
|
||
neg_rows = [
|
||
i
|
||
for i, t in enumerate(unique_positive_texts)
|
||
if text_to_emotion[t] != emotion
|
||
]
|
||
|
||
if args.method == "subspace":
|
||
# For each layer, build M_pos = Σ V V^T / n_pos, baseline same
|
||
# (using all other concepts' positive subspaces + baseline
|
||
# subspaces as the contrast set), top eigenvector of difference.
|
||
for l_idx, target_l in enumerate(target_layers):
|
||
pos_V = [pos_subspaces[j][target_l] for j in pos_rows]
|
||
base_V = [pos_subspaces[j][target_l] for j in neg_rows]
|
||
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
||
top_vec, eigvals = _subspace_concept_direction(
|
||
pos_V, base_V, hidden=hidden_dim,
|
||
top_k=args.subspace_eigen_k,
|
||
device=device,
|
||
)
|
||
top_vec = top_vec.cpu()
|
||
eigvals = eigvals.cpu()
|
||
per_layer_vectors[l_idx, e_idx] = top_vec
|
||
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
||
subspace_eigvals.setdefault(emotion, {})[target_l] = (
|
||
eigvals[-20:].flip(0).tolist()
|
||
)
|
||
else:
|
||
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
||
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
||
if baseline_acts.shape[0] > 0:
|
||
neg = torch.cat([neg, baseline_acts], dim=0)
|
||
|
||
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
|
||
neg_mean = neg.mean(dim=0)
|
||
diff = pos_mean - neg_mean
|
||
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||
diff = diff / norms
|
||
|
||
# diff[layer] -> per_layer_vectors[layer, e_idx]
|
||
for l_idx in range(n_layers):
|
||
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
|
||
|
||
if e_idx < 5 or e_idx == len(emotions) - 1:
|
||
print(
|
||
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
||
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
||
f" (method={args.method})"
|
||
)
|
||
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
tensors = {
|
||
f"layer_{target_layers[l_idx]}.vectors": (
|
||
per_layer_vectors[l_idx].to(torch.float16)
|
||
)
|
||
for l_idx in range(n_layers)
|
||
}
|
||
safetensors.torch.save_file(
|
||
tensors,
|
||
str(output_dir / "readout.safetensors"),
|
||
)
|
||
manifest = {
|
||
"concepts": emotions,
|
||
"layers": target_layers,
|
||
"hidden_size": hidden_dim,
|
||
"dtype": "float16",
|
||
}
|
||
(output_dir / "readout.json").write_text(
|
||
json.dumps(manifest, indent=2) + "\n"
|
||
)
|
||
|
||
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
|
||
print(
|
||
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
|
||
f" {n_concepts} concepts x {n_layers} layers x "
|
||
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
|
||
)
|
||
|
||
if args.quality_report:
|
||
print("\nComputing quality report...")
|
||
report = _compute_quality_report(
|
||
emotions=emotions,
|
||
positive_acts=positive_acts,
|
||
baseline_acts=baseline_acts,
|
||
positives_by_emotion=positives_by_emotion,
|
||
text_to_row=text_to_row,
|
||
per_layer_vectors=per_layer_vectors,
|
||
target_layers=target_layers,
|
||
model=model,
|
||
positive_texts=unique_positive_texts,
|
||
text_to_emotion=text_to_emotion,
|
||
)
|
||
|
||
# Per-head attention decomposition — second pass, captures
|
||
# o_proj's input at each target layer and ranks heads per concept
|
||
# by selectivity. Meta-relational concepts often live in specific
|
||
# attention heads rather than the residual-stream sum; this
|
||
# diagnostic surfaces that.
|
||
print("\nCollecting o_proj inputs for per-head analysis...")
|
||
attn_inputs, active_layers = _collect_attention_inputs(
|
||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||
args.batch_size, args.max_length, label="attn-pos",
|
||
)
|
||
if active_layers and baselines:
|
||
baseline_attn_inputs, _ = _collect_attention_inputs(
|
||
model, tokenizer, baselines, active_layers, device,
|
||
args.batch_size, args.max_length, label="attn-base",
|
||
)
|
||
else:
|
||
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
|
||
|
||
if active_layers:
|
||
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
|
||
per_head = _compute_per_head_ranking(
|
||
emotions=emotions,
|
||
attn_inputs=attn_inputs,
|
||
baseline_attn_inputs=baseline_attn_inputs,
|
||
positives_by_emotion=positives_by_emotion,
|
||
text_to_row=text_to_row,
|
||
active_layers=active_layers,
|
||
n_heads_per_layer=n_heads_per_layer,
|
||
text_to_emotion=text_to_emotion,
|
||
unique_positive_texts=unique_positive_texts,
|
||
)
|
||
# Fold per-head into the main report under each concept.
|
||
for emotion, ph in per_head.items():
|
||
if emotion in report:
|
||
report[emotion]["per_head"] = ph["per_layer"]
|
||
print(f"Per-head analysis done on layers {active_layers}")
|
||
else:
|
||
print(
|
||
"No layer exposed a recognisable o_proj module path — "
|
||
"per-head analysis skipped."
|
||
)
|
||
|
||
# Eigenvalue spectrum from the subspace method — if populated, report
|
||
# the top-20 eigenvalues per concept per layer. Tells us whether the
|
||
# concept direction lives in a single dominant dimension (λ_0 >> λ_1)
|
||
# or a spread of common directions (λ_0 ≈ λ_1 ≈ ...).
|
||
if subspace_eigvals:
|
||
for emotion, per_l in subspace_eigvals.items():
|
||
if emotion in report:
|
||
report[emotion]["subspace_eigvals"] = {
|
||
str(l): vals for l, vals in per_l.items()
|
||
}
|
||
|
||
# Linear combinations — for each concept, how much of its direction
|
||
# is explained by a ridge regression on the others. R² > 0.9 flags
|
||
# concepts that are essentially linear combinations of their peers
|
||
# (useful for teasing apart near-duplicate clusters).
|
||
print("\nComputing linear-combination analysis...")
|
||
lincomb = _compute_linear_combinations(
|
||
emotions, per_layer_vectors, target_layers
|
||
)
|
||
for emotion, lc in lincomb.items():
|
||
if emotion in report:
|
||
report[emotion]["linear_combination"] = lc["per_layer"]
|
||
|
||
(output_dir / "quality.json").write_text(
|
||
json.dumps(report, indent=2) + "\n"
|
||
)
|
||
|
||
# Short summary: concepts in each triage bucket.
|
||
clean_single_neuron = []
|
||
clean_circuit = []
|
||
fragmented = []
|
||
contaminated = []
|
||
redundant = [] # R² > 0.9 — concept is near-linear combo of others
|
||
mid = n_layers // 2
|
||
mid_layer = target_layers[mid]
|
||
for emotion in emotions:
|
||
per_l = report[emotion]["per_layer"][str(mid_layer)]
|
||
v = per_l["first_pc_variance_ratio"]
|
||
nb = per_l.get("best_neuron_cosine") or 0.0
|
||
top_near = report[emotion]["nearest_concepts"]
|
||
nearest_cos = top_near[0][1] if top_near else 0.0
|
||
lc_r2 = 0.0
|
||
lc_entry = report[emotion].get("linear_combination", {})
|
||
if str(mid_layer) in lc_entry:
|
||
lc_r2 = lc_entry[str(mid_layer)]["r_squared"]
|
||
if lc_r2 > 0.9:
|
||
redundant.append(emotion)
|
||
if nearest_cos > 0.8:
|
||
contaminated.append(emotion)
|
||
elif v > 0.7 and abs(nb) > 0.6:
|
||
clean_single_neuron.append(emotion)
|
||
elif v > 0.7:
|
||
clean_circuit.append(emotion)
|
||
elif v < 0.4:
|
||
fragmented.append(emotion)
|
||
print(
|
||
f"\nQuality summary @ layer {mid_layer}:\n"
|
||
f" clean (single-neuron): {len(clean_single_neuron)}\n"
|
||
f" clean (low-dim circuit): {len(clean_circuit)}\n"
|
||
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
|
||
f" contaminated (nearest > 0.8): {len(contaminated)}\n"
|
||
f" redundant (R² > 0.9 vs. others): {len(redundant)}"
|
||
)
|
||
if fragmented:
|
||
print(f" fragmented sample: {fragmented[:5]}")
|
||
if contaminated:
|
||
print(f" contaminated sample: {contaminated[:5]}")
|
||
if redundant:
|
||
print(f" redundant sample: {redundant[:5]}")
|
||
print(f"\nWrote quality.json to {output_dir}")
|
||
|
||
del model
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|