consciousness/training/amygdala_training/train_steering_vectors.py
Kent Overstreet f9b3f00691 amygdala: run subspace eigh on GPU, not CPU
Previous run was grinding on CPU for 36+ minutes because the per-story
V_i tensors were stored on CPU by the collector, and
_subspace_concept_direction inherited that device. The per-concept
eigh on 5120x5120 is glacial on CPU and fast on GPU (~1s).

Add explicit device parameter; pass training device. Transfer result
back to CPU for storage.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-18 21:52:35 -04:00

1269 lines
48 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Train concept-readout vectors via Contrastive Activation Addition.
Reads the hand-written story corpus at
``amygdala_stories/{stories,paired}/`` and produces the per-layer
safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
Training data (cross-concept contrast):
positive for emotion E:
stories/E.txt
paired/<scenario>/E.txt (for each scenario that covers E)
negative for emotion E:
stories/<all other emotions>.txt
paired/<scenario>/baseline.txt (for each scenario)
Within-scenario paired stories are the highest-signal pairs (same
content, different concept framing); unpaired stories provide bulk
contrast across the 80 emotions we have written so far.
Pooling: last non-pad token. Matches how readout is consumed at decode
time (residual read at the sampler's query position).
Output:
readout.safetensors
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
readout.json
{
"concepts": [...],
"layers": [...],
"hidden_size": int,
"dtype": "float16"
}
"""
from __future__ import annotations
import argparse
import gc
import json
import os
from pathlib import Path
import safetensors.torch
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Pick the last non-pad token's hidden state per example.
hidden: [batch, seq, hidden_dim]
attention_mask: [batch, seq]
returns: [batch, hidden_dim]
"""
last_idx = attention_mask.sum(dim=1) - 1
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_idx, last_idx]
def _find_layers_module(model) -> torch.nn.ModuleList:
"""Walk a few likely paths to find the transformer-block list."""
candidates = [
"model.layers",
"model.model.layers",
"model.language_model.layers",
"model.language_model.model.layers",
"language_model.model.layers",
"transformer.h",
]
for path in candidates:
obj = model
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok and isinstance(obj, torch.nn.ModuleList):
return obj
raise RuntimeError(
f"Couldn't find transformer layer list. Tried: {candidates}"
)
def _collect_activations(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
"""
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_activations: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
def hook(_mod, _inp, output):
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
layers_module = _find_layers_module(model)
handles = [
layers_module[idx].register_forward_hook(make_hook(idx))
for idx in target_layers
]
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in target_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0)
def _collect_per_story_subspaces(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
k: int = 20,
label: str = "",
) -> list[dict[int, torch.Tensor]]:
"""Run texts through the model, capture the full per-token residual-stream
activations at each target layer, do SVD per story, return the top-k right
singular vectors.
Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to
a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the
subspace the story's tokens span in activation space at that layer).
The per-story subspace captures *all* the directions a story occupies —
concept, narrator, topic, style. Finding the direction common to stories of
the same concept (via the sum of V_i V_i^T and its top eigenvector)
cancels nuisance directions that differ across stories while preserving
directions they share.
"""
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_per_story_subspaces: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
def hook(_mod, _inp, output):
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
layers_module = _find_layers_module(model)
handles = [
layers_module[idx].register_forward_hook(make_hook(idx))
for idx in target_layers
]
# One entry per text: {layer_idx: V[hidden, k]}
out: list[dict[int, torch.Tensor]] = [
{} for _ in range(len(texts))
]
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
# For each item in the batch, for each layer, SVD on the
# non-pad tokens.
attn = tok["attention_mask"]
for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()):
story_idx = i + t_idx_in_batch
for l_idx, layer in enumerate(target_layers):
hs = captures[layer][t_idx_in_batch, :n_tok, :]
# Center tokens so SVD captures variation within story,
# not the story's center-of-mass:
hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0)
# SVD: hs = U Σ V^T; V has hidden-dim columns.
# For n_tok < k, the subspace rank is bounded by n_tok.
try:
_u, _s, vh = torch.linalg.svd(hs, full_matrices=False)
except Exception:
# Degenerate case (all-zero hs, n_tok=1): fall back
# to the last-token vector itself, unit-normed.
vec = captures[layer][t_idx_in_batch, n_tok - 1, :]
vec = vec.to(torch.float32)
nrm = vec.norm().clamp_min(1e-6)
vh = (vec / nrm).unsqueeze(0) # [1, hidden]
# Take top-k rows of V^T (= top-k right singular vecs).
top = min(k, vh.shape[0])
V = vh[:top].t().contiguous().cpu() # [hidden, top]
out[story_idx][layer] = V
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return out
def _subspace_concept_direction(
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
base_V: list[torch.Tensor],
hidden: int,
*,
top_k: int = 5,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Subspace-common-direction CAA alternative.
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
same over baselines. Returns a weighted sum of the top-k eigenvectors of
(M_pos - M_base), weights = eigenvalues (so stronger common directions
contribute more), unit-normed. Returns the full eigenvalue spectrum for
diagnostics.
top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
captures richer structure when the concept lives in a multi-dimensional
shared subspace — which the flat eigenvalue spectrum observed in
practice suggests is the common case. Selection happens AFTER the
eigendecomposition so nothing is lost up to that point.
"""
if device is None:
device = pos_V[0].device if pos_V else torch.device("cpu")
dtype = torch.float32
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
if not Vs:
return torch.zeros(hidden, hidden, dtype=dtype, device=device)
M = torch.zeros(hidden, hidden, dtype=dtype, device=device)
for V in Vs:
V = V.to(dtype=dtype, device=device)
M.addmm_(V, V.t())
M /= len(Vs)
return M
M_pos = acc(pos_V)
M_base = acc(base_V)
M = M_pos - M_base
# Symmetric eigendecomposition.
eigvals, eigvecs = torch.linalg.eigh(M)
# eigh returns ascending; top-k are the last k columns.
k = max(1, min(top_k, eigvecs.shape[1]))
top_vals = eigvals[-k:] # [k], ascending within top-k
top_vecs = eigvecs[:, -k:] # [hidden, k]
# Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
# negative weights to 0 (wrong-sign directions shouldn't contribute).
w = top_vals.clamp_min(0.0)
combined = top_vecs @ w # [hidden]
combined = combined / combined.norm().clamp_min(1e-6)
return combined, eigvals
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
]:
"""Return ``(positives_by_emotion, baselines)``.
Cross-concept negatives are computed at training time from
``positives_by_emotion`` — each emotion's negative set is the
union of all other emotions' positives plus the baseline texts.
Empty .txt files are skipped with a warning.
"""
def _read_nonempty(path: Path) -> str | None:
text = path.read_text().strip()
if not text:
print(
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
)
return None
return text
positives: dict[str, list[str]] = {}
for story_path in sorted(stories_dir.glob("*.txt")):
text = _read_nonempty(story_path)
if text is None:
continue
emotion = story_path.stem
positives.setdefault(emotion, []).append(text)
baselines: list[str] = []
if paired_dir is not None and paired_dir.exists():
for scenario_dir in sorted(paired_dir.iterdir()):
if not scenario_dir.is_dir():
continue
baseline_path = scenario_dir / "baseline.txt"
if baseline_path.exists():
text = _read_nonempty(baseline_path)
if text is not None:
baselines.append(text)
for framing_path in sorted(scenario_dir.glob("*.txt")):
if framing_path.stem == "baseline":
continue
text = _read_nonempty(framing_path)
if text is None:
continue
emotion = framing_path.stem
positives.setdefault(emotion, []).append(text)
return positives, baselines
def _find_o_proj(layer) -> torch.nn.Module | None:
"""Locate the attention output projection within a transformer layer."""
for path in (
"self_attn.o_proj",
"self_attn.out_proj",
"attention.o_proj",
"attn.out_proj",
):
obj = layer
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok:
return obj
return None
def _collect_attention_inputs(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> tuple[torch.Tensor, list[int]]:
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
attention outputs right before the output projection).
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
The active_layers list is the subset of target_layers whose attention
module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.)
may be silently skipped.
"""
import time
layers_module = _find_layers_module(model)
captures: dict[int, torch.Tensor] = {}
handles = []
active_layers: list[int] = []
def make_hook(idx: int):
def hook(_mod, inputs):
x = inputs[0] if isinstance(inputs, tuple) else inputs
captures[idx] = x.detach()
return hook
for idx in target_layers:
o_proj = _find_o_proj(layers_module[idx])
if o_proj is not None:
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
active_layers.append(idx)
if not active_layers:
return torch.zeros(0, 0, 0), []
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in active_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0), active_layers
def _compute_per_head_ranking(
emotions: list[str],
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
baseline_attn_inputs: torch.Tensor,
positives_by_emotion: dict[str, list[str]],
text_to_row: dict[str, int],
active_layers: list[int],
n_heads_per_layer: dict[int, int],
text_to_emotion: dict[str, str],
unique_positive_texts: list[str],
) -> dict:
"""For each concept, rank attention heads by contribution magnitude.
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
compute diff-of-means between positives and negatives per head, rank
heads by the L2 norm of that diff. The top heads are the ones most
strongly implicated in the concept circuit.
Why this matters: meta-relational concepts (trust, recognition,
"seen") often don't give a strong residual-stream diff-of-means but
DO give a strong per-head signal — the concept lives in a small
attention circuit rather than in the residual-stream sum.
"""
result: dict[str, dict] = {}
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
neg_rows = [
i
for i, t in enumerate(unique_positive_texts)
if text_to_emotion[t] != emotion
]
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
neg = attn_inputs[neg_rows]
if baseline_attn_inputs.shape[0] > 0:
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
per_layer: dict[str, list] = {}
for l_idx, target_l in enumerate(active_layers):
n_heads = n_heads_per_layer.get(target_l)
if not n_heads:
continue
hidden = pos.shape[-1]
if hidden % n_heads != 0:
continue
head_dim = hidden // n_heads
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
head_norms = diff.norm(dim=-1) # [n_heads]
# Normalise by neg variance per head so different-scale heads
# don't dominate purely on activation magnitude.
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
head_selectivity = head_norms / neg_std # [n_heads]
k = min(10, n_heads)
top_vals, top_idxs = head_selectivity.topk(k)
top_heads = [
[int(i), float(head_norms[i]), float(head_selectivity[i])]
for i in top_idxs
]
per_layer[str(target_l)] = {
"n_heads": n_heads,
"head_dim": head_dim,
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
"head_concentration": float(
# fraction of total head-norm captured by top-k
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
),
}
result[emotion] = {"per_layer": per_layer}
return result
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
"""Best-effort read of num_attention_heads per layer. Qwen uses the
top-level config; falls back to config.num_attention_heads.
"""
cfg = model.config
if hasattr(cfg, "get_text_config"):
cfg = cfg.get_text_config()
n = getattr(cfg, "num_attention_heads", None)
if n is None:
return {}
return {l: n for l in target_layers}
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
"""Return the W_down weight for the MLP at the given transformer layer.
Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj).
Returns None if nothing matches — downstream code skips the single-neuron
alignment check in that case rather than failing.
"""
layers = _find_layers_module(model)
layer = layers[layer_idx]
for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"):
obj = layer
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok and hasattr(obj, "weight"):
# Shape convention: [hidden, mlp_inner] — each column is one
# MLP neuron's contribution direction into the residual stream.
return obj.weight.detach()
return None
def _compute_quality_report(
emotions: list[str],
positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden]
baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden]
positives_by_emotion: dict[str, list[str]],
text_to_row: dict[str, int],
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
target_layers: list[int],
model,
positive_texts: list[str],
text_to_emotion: dict[str, str],
) -> dict:
"""Per-concept quality metrics:
- first_pc_variance_ratio: SVD on centered positive activations.
>0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree).
- story_projection_*: how each positive story projects onto the
concept direction. Low std = tight agreement.
- best_neuron_cosine: alignment of the residual-space direction with
the nearest W_down column (= single MLP neuron). >0.6 = essentially
single-neuron.
- nearest_concepts: top-5 concept directions most parallel to this
one. Cosine >0.8 means the vector is confused with a neighbor.
"""
report: dict = {}
n_layers = per_layer_vectors.shape[0]
# Pre-compute per-layer W_down for single-neuron alignment. Keep on
# CPU to match the per_layer_vectors tensor.
w_down: dict[int, torch.Tensor] = {}
for target_l in target_layers:
w = _find_mlp_down_proj(model, target_l)
if w is not None:
# Unit-normalize each column (one per MLP neuron).
w = w.to(torch.float32).cpu()
norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6)
w_down[target_l] = w / norms # [hidden, mlp_inner]
# Pre-compute unit-normed concept vectors (for cross-concept cosines).
vec_norm = per_layer_vectors / per_layer_vectors.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden]
per_layer: dict = {}
for l_idx, target_l in enumerate(target_layers):
pos_l = pos[:, l_idx, :] # [n_pos, hidden]
diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed
pos_mean_l = pos_l.mean(dim=0)
# SVD for rank analysis — if first PC dominates, stories agree.
centered = pos_l - pos_mean_l
# svdvals errors on 1-row; handle that.
if centered.shape[0] >= 2:
S = torch.linalg.svdvals(centered)
var = S ** 2
var_total = var.sum().clamp_min(1e-12)
var_ratios = (var / var_total).tolist()
else:
var_ratios = [1.0]
# Per-story projection onto the concept direction.
projections = pos_l @ diff_l # [n_pos]
# Per-story alignment: cosine(story_dir, concept_dir) where
# story_dir = pos_i - pos_mean (centered, pointing away from center).
if centered.shape[0] >= 2:
centered_norm = centered / centered.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
alignments = centered_norm @ diff_l
else:
alignments = torch.zeros(1)
# Single-neuron alignment: is the direction close to any
# W_down column?
nb_best_idx = None
nb_best_cos = None
nb_top5 = None
if target_l in w_down:
W = w_down[target_l]
cos = W.t() @ diff_l # [mlp_inner]
abs_cos = cos.abs()
k = min(5, abs_cos.shape[0])
top_vals, top_idxs = abs_cos.topk(k)
nb_best_idx = int(top_idxs[0])
nb_best_cos = float(cos[top_idxs[0]])
nb_top5 = [[int(i), float(cos[i])] for i in top_idxs]
per_layer[str(target_l)] = {
"top3_variance_ratios": [
float(v) for v in var_ratios[:3]
],
"first_pc_variance_ratio": float(var_ratios[0]),
"story_projection_mean": float(projections.mean()),
"story_projection_std": float(projections.std()),
"story_projection_min": float(projections.min()),
"story_projection_max": float(projections.max()),
"story_alignment_mean": float(alignments.mean()),
"story_alignment_std": float(alignments.std()),
"best_neuron_idx": nb_best_idx,
"best_neuron_cosine": nb_best_cos,
"top5_neurons": nb_top5,
}
# Outlier stories: lowest-aligned on the middle target layer.
mid = n_layers // 2
pos_l_mid = pos[:, mid, :]
mid_mean = pos_l_mid.mean(dim=0)
mid_diff = per_layer_vectors[mid, e_idx]
centered_mid = pos_l_mid - mid_mean
if centered_mid.shape[0] >= 2:
centered_mid_norm = centered_mid / centered_mid.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
mid_aligns = centered_mid_norm @ mid_diff # [n_pos]
# Lowest two alignments = candidate outliers.
k = min(2, mid_aligns.shape[0])
low_vals, low_idxs = mid_aligns.topk(k, largest=False)
outliers = [
[
positives_by_emotion[emotion][int(i)],
float(mid_aligns[i]),
]
for i in low_idxs
]
else:
outliers = []
# Nearest other concepts at the middle target layer.
this_norm = vec_norm[mid, e_idx]
all_cos = vec_norm[mid] @ this_norm # [n_concepts]
all_cos[e_idx] = -2.0 # mask self
k = min(5, all_cos.shape[0] - 1)
top_vals, top_idxs = all_cos.topk(k)
nearest = [
[emotions[int(i)], float(v)]
for i, v in zip(top_idxs, top_vals)
]
report[emotion] = {
"n_positive_stories": len(pos_rows),
"per_layer": per_layer,
"outlier_stories": outliers,
"nearest_concepts": nearest,
}
return report
def _compute_linear_combinations(
emotions: list[str],
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
target_layers: list[int],
*,
ridge_lambda: float = 0.01,
top_k: int = 5,
) -> dict:
"""For each concept, ridge-regress its direction against all other
concept directions. Report R² (how much of the target direction is
explained by a linear combination of others) + top contributors.
R² > 0.9 = concept is essentially a linear combination of others
(redundant, or part of a cluster that needs disambiguating)
R² < 0.5 = concept has a substantial unique component
ridge_lambda keeps the coefficients stable when concepts are near-collinear.
"""
n_layers, n_concepts, hidden = per_layer_vectors.shape
result: dict[str, dict] = {}
# Middle layer for summary — same convention as nearest_concepts.
mid = n_layers // 2
for l_idx, target_l in enumerate(target_layers):
V = per_layer_vectors[l_idx] # [n_concepts, hidden]
for i, name in enumerate(emotions):
target = V[i] # [hidden]
mask = torch.arange(n_concepts) != i
others = V[mask] # [n-1, hidden]
# Ridge: solve (O O^T + lam I) alpha = O t
OOt = others @ others.t() # [n-1, n-1]
b = others @ target # [n-1]
A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype)
alpha = torch.linalg.solve(A, b)
recon = others.t() @ alpha # [hidden]
resid = target - recon
t_sq = (target * target).sum().clamp_min(1e-12)
r2 = 1.0 - (resid * resid).sum() / t_sq
abs_alpha = alpha.abs()
k = min(top_k, n_concepts - 1)
top_vals, top_idxs = abs_alpha.topk(k)
other_names = [emotions[j] for j in range(n_concepts) if j != i]
top = [
[other_names[int(j)], float(alpha[j])]
for j in top_idxs
]
entry = result.setdefault(name, {})
entry.setdefault("per_layer", {})[str(target_l)] = {
"r_squared": float(r2),
"residual_norm": float(resid.norm()),
"top_contributors": top,
}
return result
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", required=True, help="HF model id or path")
ap.add_argument(
"--stories-dir",
required=True,
help="Path to amygdala_stories/stories/",
)
ap.add_argument(
"--paired-dir",
default=None,
help="Path to amygdala_stories/paired/ (optional)",
)
ap.add_argument(
"--target-layers",
required=True,
help="Comma-separated layer indices, e.g. 40,50,60,70",
)
ap.add_argument(
"--output-dir",
required=True,
help="Directory to write readout.safetensors + readout.json",
)
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
ap.add_argument("--batch-size", type=int, default=2)
ap.add_argument("--max-length", type=int, default=512)
ap.add_argument("--device", default="cuda:0")
ap.add_argument(
"--min-positives",
type=int,
default=1,
help="Skip emotions with fewer positive examples than this",
)
ap.add_argument(
"--method",
default="pooled",
choices=["pooled", "subspace"],
help="Concept-extraction method: 'pooled' (classic CAA, "
"pos_mean - neg_mean on last-token activations) or 'subspace' "
"(per-story SVD; top eigenvector of Σ V_i V_i^T for positives "
"minus same for baselines — captures what's common across "
"stories' full-trajectory subspaces)",
)
ap.add_argument(
"--subspace-k",
type=int,
default=99999,
help="Max top-k right singular vectors per story for subspace method "
"(clamped to min(n_tokens, hidden_dim) per story). Default is "
"effectively 'keep full per-story subspace' — each story's V_i "
"spans its entire natural row space. On a hidden_dim=5120 "
"residual and ~500-token stories, that's ~500 vectors per story. "
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
)
ap.add_argument(
"--subspace-eigen-k",
type=int,
default=5,
help="Number of top eigenvectors of M_pos - M_base to combine into "
"the concept direction. Weighted sum by eigenvalue (so strongest "
"common directions contribute most). eigen_k=1 recovers "
"single-eigenvector behavior. Higher values (5-10) capture "
"richer structure when the concept's shared-subspace spectrum "
"is flat (which it tends to be in practice).",
)
ap.add_argument(
"--quality-report",
action="store_true",
help="After training, compute a per-concept quality report "
"(SVD rank, per-story alignment, single-neuron alignment, "
"nearest-concept contamination) and write quality.json",
)
args = ap.parse_args()
target_layers = [int(x) for x in args.target_layers.split(",")]
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
stories_dir = Path(args.stories_dir)
if not stories_dir.is_dir():
raise FileNotFoundError(
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
)
if args.paired_dir is not None:
pd = Path(args.paired_dir)
if not pd.is_dir():
raise FileNotFoundError(
f"--paired-dir {pd!s} does not exist or is not a dir"
)
# Quick corpus pre-scan so failures show up before we load the model.
positives_preview, baselines_preview = _load_corpus(
stories_dir,
Path(args.paired_dir) if args.paired_dir else None,
)
n_emotions_preview = sum(
1 for ps in positives_preview.values()
if len(ps) >= args.min_positives
)
if n_emotions_preview == 0:
raise RuntimeError(
f"corpus has 0 emotions with >= {args.min_positives} positive "
f"examples. Check {stories_dir} — is it the right directory?"
)
print(
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
f"{args.min_positives}), {len(baselines_preview)} baselines"
)
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
device_map=args.device,
low_cpu_mem_usage=True,
)
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
# dimensions under a text_config subobject. get_text_config()
# returns that sub-config when present, else the top-level config.
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
hidden_dim = text_config.hidden_size
n_model_layers = text_config.num_hidden_layers
print(
f"Model loaded. hidden_dim={hidden_dim}, "
f"n_model_layers={n_model_layers} "
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
)
for layer_idx in target_layers:
if layer_idx < 0 or layer_idx >= n_model_layers:
raise ValueError(
f"target layer {layer_idx} out of range "
f"[0, {n_model_layers})"
)
print(
"Target layers (relative depth): "
+ ", ".join(
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
for l in target_layers
)
)
positives_by_emotion, baselines = _load_corpus(
Path(args.stories_dir),
Path(args.paired_dir) if args.paired_dir else None,
)
emotions = sorted(
e for e, ps in positives_by_emotion.items()
if len(ps) >= args.min_positives
)
if not emotions:
raise RuntimeError(
f"No emotions with >= {args.min_positives} positive examples"
)
print(
f"Training {len(emotions)} emotions; "
f"{len(baselines)} baseline scenarios"
)
# Cache all positive-text activations once so we can reuse them as
# negatives for other emotions. Keyed by the text itself to dedup
# across emotion lists.
device = torch.device(args.device)
text_to_emotion: dict[str, str] = {}
for emotion, texts in positives_by_emotion.items():
for t in texts:
text_to_emotion[t] = emotion
unique_positive_texts = list(text_to_emotion.keys())
print(
f"Collecting activations for {len(unique_positive_texts)} unique "
f"positive texts + {len(baselines)} baselines..."
)
positive_acts = _collect_activations(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, label="positives",
)
# positive_acts[i] corresponds to unique_positive_texts[i]
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
baseline_acts = (
_collect_activations(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length, label="baselines",
)
if baselines
else torch.zeros(0, len(target_layers), hidden_dim)
)
n_concepts = len(emotions)
n_layers = len(target_layers)
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
per_layer_vectors = torch.zeros(
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
)
# --- Subspace method: collect per-story right-singular-vector subspaces
# and use sum-of-projection-operators per concept. --------------------
pos_subspaces: list[dict[int, torch.Tensor]] | None = None
base_subspaces: list[dict[int, torch.Tensor]] | None = None
# Per (concept, layer): top-20 eigenvalues of (M_pos - M_base), descending.
# Populated only when --method subspace.
subspace_eigvals: dict[str, dict[int, list[float]]] = {}
if args.method == "subspace":
print("\nCollecting per-story subspaces (SVD, top-k right singular "
f"vectors, k={args.subspace_k})...")
pos_subspaces = _collect_per_story_subspaces(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, k=args.subspace_k,
label="subsp-pos",
)
if baselines:
base_subspaces = _collect_per_story_subspaces(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length, k=args.subspace_k,
label="subsp-base",
)
else:
base_subspaces = []
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
# Negatives: every OTHER emotion's positives + baselines.
neg_rows = [
i
for i, t in enumerate(unique_positive_texts)
if text_to_emotion[t] != emotion
]
if args.method == "subspace":
# For each layer, build M_pos = Σ V V^T / n_pos, baseline same
# (using all other concepts' positive subspaces + baseline
# subspaces as the contrast set), top eigenvector of difference.
for l_idx, target_l in enumerate(target_layers):
pos_V = [pos_subspaces[j][target_l] for j in pos_rows]
base_V = [pos_subspaces[j][target_l] for j in neg_rows]
base_V += [bs[target_l] for bs in (base_subspaces or [])]
top_vec, eigvals = _subspace_concept_direction(
pos_V, base_V, hidden=hidden_dim,
top_k=args.subspace_eigen_k,
device=device,
)
top_vec = top_vec.cpu()
eigvals = eigvals.cpu()
per_layer_vectors[l_idx, e_idx] = top_vec
# Keep the top-20 eigenvalues for quality-report diagnostics.
subspace_eigvals.setdefault(emotion, {})[target_l] = (
eigvals[-20:].flip(0).tolist()
)
else:
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
if baseline_acts.shape[0] > 0:
neg = torch.cat([neg, baseline_acts], dim=0)
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
neg_mean = neg.mean(dim=0)
diff = pos_mean - neg_mean
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms
# diff[layer] -> per_layer_vectors[layer, e_idx]
for l_idx in range(n_layers):
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
if e_idx < 5 or e_idx == len(emotions) - 1:
print(
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
f" (method={args.method})"
)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tensors = {
f"layer_{target_layers[l_idx]}.vectors": (
per_layer_vectors[l_idx].to(torch.float16)
)
for l_idx in range(n_layers)
}
safetensors.torch.save_file(
tensors,
str(output_dir / "readout.safetensors"),
)
manifest = {
"concepts": emotions,
"layers": target_layers,
"hidden_size": hidden_dim,
"dtype": "float16",
}
(output_dir / "readout.json").write_text(
json.dumps(manifest, indent=2) + "\n"
)
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
print(
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
f" {n_concepts} concepts x {n_layers} layers x "
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
)
if args.quality_report:
print("\nComputing quality report...")
report = _compute_quality_report(
emotions=emotions,
positive_acts=positive_acts,
baseline_acts=baseline_acts,
positives_by_emotion=positives_by_emotion,
text_to_row=text_to_row,
per_layer_vectors=per_layer_vectors,
target_layers=target_layers,
model=model,
positive_texts=unique_positive_texts,
text_to_emotion=text_to_emotion,
)
# Per-head attention decomposition — second pass, captures
# o_proj's input at each target layer and ranks heads per concept
# by selectivity. Meta-relational concepts often live in specific
# attention heads rather than the residual-stream sum; this
# diagnostic surfaces that.
print("\nCollecting o_proj inputs for per-head analysis...")
attn_inputs, active_layers = _collect_attention_inputs(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, label="attn-pos",
)
if active_layers and baselines:
baseline_attn_inputs, _ = _collect_attention_inputs(
model, tokenizer, baselines, active_layers, device,
args.batch_size, args.max_length, label="attn-base",
)
else:
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
if active_layers:
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
per_head = _compute_per_head_ranking(
emotions=emotions,
attn_inputs=attn_inputs,
baseline_attn_inputs=baseline_attn_inputs,
positives_by_emotion=positives_by_emotion,
text_to_row=text_to_row,
active_layers=active_layers,
n_heads_per_layer=n_heads_per_layer,
text_to_emotion=text_to_emotion,
unique_positive_texts=unique_positive_texts,
)
# Fold per-head into the main report under each concept.
for emotion, ph in per_head.items():
if emotion in report:
report[emotion]["per_head"] = ph["per_layer"]
print(f"Per-head analysis done on layers {active_layers}")
else:
print(
"No layer exposed a recognisable o_proj module path — "
"per-head analysis skipped."
)
# Eigenvalue spectrum from the subspace method — if populated, report
# the top-20 eigenvalues per concept per layer. Tells us whether the
# concept direction lives in a single dominant dimension (λ_0 >> λ_1)
# or a spread of common directions (λ_0 ≈ λ_1 ≈ ...).
if subspace_eigvals:
for emotion, per_l in subspace_eigvals.items():
if emotion in report:
report[emotion]["subspace_eigvals"] = {
str(l): vals for l, vals in per_l.items()
}
# Linear combinations — for each concept, how much of its direction
# is explained by a ridge regression on the others. R² > 0.9 flags
# concepts that are essentially linear combinations of their peers
# (useful for teasing apart near-duplicate clusters).
print("\nComputing linear-combination analysis...")
lincomb = _compute_linear_combinations(
emotions, per_layer_vectors, target_layers
)
for emotion, lc in lincomb.items():
if emotion in report:
report[emotion]["linear_combination"] = lc["per_layer"]
(output_dir / "quality.json").write_text(
json.dumps(report, indent=2) + "\n"
)
# Short summary: concepts in each triage bucket.
clean_single_neuron = []
clean_circuit = []
fragmented = []
contaminated = []
redundant = [] # R² > 0.9 — concept is near-linear combo of others
mid = n_layers // 2
mid_layer = target_layers[mid]
for emotion in emotions:
per_l = report[emotion]["per_layer"][str(mid_layer)]
v = per_l["first_pc_variance_ratio"]
nb = per_l.get("best_neuron_cosine") or 0.0
top_near = report[emotion]["nearest_concepts"]
nearest_cos = top_near[0][1] if top_near else 0.0
lc_r2 = 0.0
lc_entry = report[emotion].get("linear_combination", {})
if str(mid_layer) in lc_entry:
lc_r2 = lc_entry[str(mid_layer)]["r_squared"]
if lc_r2 > 0.9:
redundant.append(emotion)
if nearest_cos > 0.8:
contaminated.append(emotion)
elif v > 0.7 and abs(nb) > 0.6:
clean_single_neuron.append(emotion)
elif v > 0.7:
clean_circuit.append(emotion)
elif v < 0.4:
fragmented.append(emotion)
print(
f"\nQuality summary @ layer {mid_layer}:\n"
f" clean (single-neuron): {len(clean_single_neuron)}\n"
f" clean (low-dim circuit): {len(clean_circuit)}\n"
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
f" contaminated (nearest > 0.8): {len(contaminated)}\n"
f" redundant (R² > 0.9 vs. others): {len(redundant)}"
)
if fragmented:
print(f" fragmented sample: {fragmented[:5]}")
if contaminated:
print(f" contaminated sample: {contaminated[:5]}")
if redundant:
print(f" redundant sample: {redundant[:5]}")
print(f"\nWrote quality.json to {output_dir}")
del model
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()