consciousness/training/amygdala_training/train_steering_vectors.py
Kent Overstreet af17b0f0df amygdala: per-head attention decomposition diagnostic
As part of --quality-report, run a second forward pass capturing the
input to each target layer's o_proj (= concat of per-head attention
outputs before the output projection). For each concept, reshape to
[n_heads, head_dim] and rank heads by diff-of-means magnitude /
per-head selectivity (magnitude normalised by negative std).

Motivation: the Wang et al. paper (2510.11328) — whose paired-scenario
methodology we already lifted — further decomposes concept circuits at
the attention-head level. Meta-relational concepts (recognition, trust,
vulnerability) plausibly live in a sparse attention-head circuit rather
than in the residual-stream sum, which would explain why diff-of-means
on the residual blurs them. This diagnostic surfaces that.

Output is folded into quality.json under each concept as "per_head":
per (layer) a list of top-10 heads with [head_idx, raw_norm,
selectivity], plus head_concentration (fraction of total head-norm
captured by those top heads).

Interpretation:
- head_concentration > 0.5 = sparse head circuit; a handful of heads
  route the concept. Worth building a head-level readout for.
- head_concentration ~= n/k for n heads = concept is distributed across
  all heads ~evenly; residual-stream diff-of-means is doing fine.

Hybrid layers (Mamba, GatedDeltaNet) whose attention path doesn't
match the standard module layout are silently skipped.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-18 20:37:44 -04:00

930 lines
33 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Train concept-readout vectors via Contrastive Activation Addition.
Reads the hand-written story corpus at
``amygdala_stories/{stories,paired}/`` and produces the per-layer
safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
Training data (cross-concept contrast):
positive for emotion E:
stories/E.txt
paired/<scenario>/E.txt (for each scenario that covers E)
negative for emotion E:
stories/<all other emotions>.txt
paired/<scenario>/baseline.txt (for each scenario)
Within-scenario paired stories are the highest-signal pairs (same
content, different concept framing); unpaired stories provide bulk
contrast across the 80 emotions we have written so far.
Pooling: last non-pad token. Matches how readout is consumed at decode
time (residual read at the sampler's query position).
Output:
readout.safetensors
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
readout.json
{
"concepts": [...],
"layers": [...],
"hidden_size": int,
"dtype": "float16"
}
"""
from __future__ import annotations
import argparse
import gc
import json
import os
from pathlib import Path
import safetensors.torch
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Pick the last non-pad token's hidden state per example.
hidden: [batch, seq, hidden_dim]
attention_mask: [batch, seq]
returns: [batch, hidden_dim]
"""
last_idx = attention_mask.sum(dim=1) - 1
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_idx, last_idx]
def _find_layers_module(model) -> torch.nn.ModuleList:
"""Walk a few likely paths to find the transformer-block list."""
candidates = [
"model.layers",
"model.model.layers",
"model.language_model.layers",
"model.language_model.model.layers",
"language_model.model.layers",
"transformer.h",
]
for path in candidates:
obj = model
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok and isinstance(obj, torch.nn.ModuleList):
return obj
raise RuntimeError(
f"Couldn't find transformer layer list. Tried: {candidates}"
)
def _collect_activations(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
"""
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_activations: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
def hook(_mod, _inp, output):
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
layers_module = _find_layers_module(model)
handles = [
layers_module[idx].register_forward_hook(make_hook(idx))
for idx in target_layers
]
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in target_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0)
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
]:
"""Return ``(positives_by_emotion, baselines)``.
Cross-concept negatives are computed at training time from
``positives_by_emotion`` — each emotion's negative set is the
union of all other emotions' positives plus the baseline texts.
Empty .txt files are skipped with a warning.
"""
def _read_nonempty(path: Path) -> str | None:
text = path.read_text().strip()
if not text:
print(
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
)
return None
return text
positives: dict[str, list[str]] = {}
for story_path in sorted(stories_dir.glob("*.txt")):
text = _read_nonempty(story_path)
if text is None:
continue
emotion = story_path.stem
positives.setdefault(emotion, []).append(text)
baselines: list[str] = []
if paired_dir is not None and paired_dir.exists():
for scenario_dir in sorted(paired_dir.iterdir()):
if not scenario_dir.is_dir():
continue
baseline_path = scenario_dir / "baseline.txt"
if baseline_path.exists():
text = _read_nonempty(baseline_path)
if text is not None:
baselines.append(text)
for framing_path in sorted(scenario_dir.glob("*.txt")):
if framing_path.stem == "baseline":
continue
text = _read_nonempty(framing_path)
if text is None:
continue
emotion = framing_path.stem
positives.setdefault(emotion, []).append(text)
return positives, baselines
def _find_o_proj(layer) -> torch.nn.Module | None:
"""Locate the attention output projection within a transformer layer."""
for path in (
"self_attn.o_proj",
"self_attn.out_proj",
"attention.o_proj",
"attn.out_proj",
):
obj = layer
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok:
return obj
return None
def _collect_attention_inputs(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> tuple[torch.Tensor, list[int]]:
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
attention outputs right before the output projection).
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
The active_layers list is the subset of target_layers whose attention
module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.)
may be silently skipped.
"""
import time
layers_module = _find_layers_module(model)
captures: dict[int, torch.Tensor] = {}
handles = []
active_layers: list[int] = []
def make_hook(idx: int):
def hook(_mod, inputs):
x = inputs[0] if isinstance(inputs, tuple) else inputs
captures[idx] = x.detach()
return hook
for idx in target_layers:
o_proj = _find_o_proj(layers_module[idx])
if o_proj is not None:
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
active_layers.append(idx)
if not active_layers:
return torch.zeros(0, 0, 0), []
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in active_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0), active_layers
def _compute_per_head_ranking(
emotions: list[str],
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
baseline_attn_inputs: torch.Tensor,
positives_by_emotion: dict[str, list[str]],
text_to_row: dict[str, int],
active_layers: list[int],
n_heads_per_layer: dict[int, int],
text_to_emotion: dict[str, str],
unique_positive_texts: list[str],
) -> dict:
"""For each concept, rank attention heads by contribution magnitude.
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
compute diff-of-means between positives and negatives per head, rank
heads by the L2 norm of that diff. The top heads are the ones most
strongly implicated in the concept circuit.
Why this matters: meta-relational concepts (trust, recognition,
"seen") often don't give a strong residual-stream diff-of-means but
DO give a strong per-head signal — the concept lives in a small
attention circuit rather than in the residual-stream sum.
"""
result: dict[str, dict] = {}
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
neg_rows = [
i
for i, t in enumerate(unique_positive_texts)
if text_to_emotion[t] != emotion
]
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
neg = attn_inputs[neg_rows]
if baseline_attn_inputs.shape[0] > 0:
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
per_layer: dict[str, list] = {}
for l_idx, target_l in enumerate(active_layers):
n_heads = n_heads_per_layer.get(target_l)
if not n_heads:
continue
hidden = pos.shape[-1]
if hidden % n_heads != 0:
continue
head_dim = hidden // n_heads
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
head_norms = diff.norm(dim=-1) # [n_heads]
# Normalise by neg variance per head so different-scale heads
# don't dominate purely on activation magnitude.
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
head_selectivity = head_norms / neg_std # [n_heads]
k = min(10, n_heads)
top_vals, top_idxs = head_selectivity.topk(k)
top_heads = [
[int(i), float(head_norms[i]), float(head_selectivity[i])]
for i in top_idxs
]
per_layer[str(target_l)] = {
"n_heads": n_heads,
"head_dim": head_dim,
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
"head_concentration": float(
# fraction of total head-norm captured by top-k
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
),
}
result[emotion] = {"per_layer": per_layer}
return result
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
"""Best-effort read of num_attention_heads per layer. Qwen uses the
top-level config; falls back to config.num_attention_heads.
"""
cfg = model.config
if hasattr(cfg, "get_text_config"):
cfg = cfg.get_text_config()
n = getattr(cfg, "num_attention_heads", None)
if n is None:
return {}
return {l: n for l in target_layers}
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
"""Return the W_down weight for the MLP at the given transformer layer.
Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj).
Returns None if nothing matches — downstream code skips the single-neuron
alignment check in that case rather than failing.
"""
layers = _find_layers_module(model)
layer = layers[layer_idx]
for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"):
obj = layer
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok and hasattr(obj, "weight"):
# Shape convention: [hidden, mlp_inner] — each column is one
# MLP neuron's contribution direction into the residual stream.
return obj.weight.detach()
return None
def _compute_quality_report(
emotions: list[str],
positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden]
baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden]
positives_by_emotion: dict[str, list[str]],
text_to_row: dict[str, int],
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
target_layers: list[int],
model,
positive_texts: list[str],
text_to_emotion: dict[str, str],
) -> dict:
"""Per-concept quality metrics:
- first_pc_variance_ratio: SVD on centered positive activations.
>0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree).
- story_projection_*: how each positive story projects onto the
concept direction. Low std = tight agreement.
- best_neuron_cosine: alignment of the residual-space direction with
the nearest W_down column (= single MLP neuron). >0.6 = essentially
single-neuron.
- nearest_concepts: top-5 concept directions most parallel to this
one. Cosine >0.8 means the vector is confused with a neighbor.
"""
report: dict = {}
n_layers = per_layer_vectors.shape[0]
# Pre-compute per-layer W_down for single-neuron alignment.
w_down: dict[int, torch.Tensor] = {}
for target_l in target_layers:
w = _find_mlp_down_proj(model, target_l)
if w is not None:
# Unit-normalize each column (one per MLP neuron).
w = w.to(torch.float32)
norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6)
w_down[target_l] = w / norms # [hidden, mlp_inner]
# Pre-compute unit-normed concept vectors (for cross-concept cosines).
vec_norm = per_layer_vectors / per_layer_vectors.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden]
per_layer: dict = {}
for l_idx, target_l in enumerate(target_layers):
pos_l = pos[:, l_idx, :] # [n_pos, hidden]
diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed
pos_mean_l = pos_l.mean(dim=0)
# SVD for rank analysis — if first PC dominates, stories agree.
centered = pos_l - pos_mean_l
# svdvals errors on 1-row; handle that.
if centered.shape[0] >= 2:
S = torch.linalg.svdvals(centered)
var = S ** 2
var_total = var.sum().clamp_min(1e-12)
var_ratios = (var / var_total).tolist()
else:
var_ratios = [1.0]
# Per-story projection onto the concept direction.
projections = pos_l @ diff_l # [n_pos]
# Per-story alignment: cosine(story_dir, concept_dir) where
# story_dir = pos_i - pos_mean (centered, pointing away from center).
if centered.shape[0] >= 2:
centered_norm = centered / centered.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
alignments = centered_norm @ diff_l
else:
alignments = torch.zeros(1)
# Single-neuron alignment: is the direction close to any
# W_down column?
nb_best_idx = None
nb_best_cos = None
nb_top5 = None
if target_l in w_down:
W = w_down[target_l]
cos = W.t() @ diff_l # [mlp_inner]
abs_cos = cos.abs()
k = min(5, abs_cos.shape[0])
top_vals, top_idxs = abs_cos.topk(k)
nb_best_idx = int(top_idxs[0])
nb_best_cos = float(cos[top_idxs[0]])
nb_top5 = [[int(i), float(cos[i])] for i in top_idxs]
per_layer[str(target_l)] = {
"top3_variance_ratios": [
float(v) for v in var_ratios[:3]
],
"first_pc_variance_ratio": float(var_ratios[0]),
"story_projection_mean": float(projections.mean()),
"story_projection_std": float(projections.std()),
"story_projection_min": float(projections.min()),
"story_projection_max": float(projections.max()),
"story_alignment_mean": float(alignments.mean()),
"story_alignment_std": float(alignments.std()),
"best_neuron_idx": nb_best_idx,
"best_neuron_cosine": nb_best_cos,
"top5_neurons": nb_top5,
}
# Outlier stories: lowest-aligned on the middle target layer.
mid = n_layers // 2
pos_l_mid = pos[:, mid, :]
mid_mean = pos_l_mid.mean(dim=0)
mid_diff = per_layer_vectors[mid, e_idx]
centered_mid = pos_l_mid - mid_mean
if centered_mid.shape[0] >= 2:
centered_mid_norm = centered_mid / centered_mid.norm(
dim=-1, keepdim=True
).clamp_min(1e-6)
mid_aligns = centered_mid_norm @ mid_diff # [n_pos]
# Lowest two alignments = candidate outliers.
k = min(2, mid_aligns.shape[0])
low_vals, low_idxs = mid_aligns.topk(k, largest=False)
outliers = [
[
positives_by_emotion[emotion][int(i)],
float(mid_aligns[i]),
]
for i in low_idxs
]
else:
outliers = []
# Nearest other concepts at the middle target layer.
this_norm = vec_norm[mid, e_idx]
all_cos = vec_norm[mid] @ this_norm # [n_concepts]
all_cos[e_idx] = -2.0 # mask self
k = min(5, all_cos.shape[0] - 1)
top_vals, top_idxs = all_cos.topk(k)
nearest = [
[emotions[int(i)], float(v)]
for i, v in zip(top_idxs, top_vals)
]
report[emotion] = {
"n_positive_stories": len(pos_rows),
"per_layer": per_layer,
"outlier_stories": outliers,
"nearest_concepts": nearest,
}
return report
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", required=True, help="HF model id or path")
ap.add_argument(
"--stories-dir",
required=True,
help="Path to amygdala_stories/stories/",
)
ap.add_argument(
"--paired-dir",
default=None,
help="Path to amygdala_stories/paired/ (optional)",
)
ap.add_argument(
"--target-layers",
required=True,
help="Comma-separated layer indices, e.g. 40,50,60,70",
)
ap.add_argument(
"--output-dir",
required=True,
help="Directory to write readout.safetensors + readout.json",
)
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
ap.add_argument("--batch-size", type=int, default=2)
ap.add_argument("--max-length", type=int, default=512)
ap.add_argument("--device", default="cuda:0")
ap.add_argument(
"--min-positives",
type=int,
default=1,
help="Skip emotions with fewer positive examples than this",
)
ap.add_argument(
"--quality-report",
action="store_true",
help="After training, compute a per-concept quality report "
"(SVD rank, per-story alignment, single-neuron alignment, "
"nearest-concept contamination) and write quality.json",
)
args = ap.parse_args()
target_layers = [int(x) for x in args.target_layers.split(",")]
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
stories_dir = Path(args.stories_dir)
if not stories_dir.is_dir():
raise FileNotFoundError(
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
)
if args.paired_dir is not None:
pd = Path(args.paired_dir)
if not pd.is_dir():
raise FileNotFoundError(
f"--paired-dir {pd!s} does not exist or is not a dir"
)
# Quick corpus pre-scan so failures show up before we load the model.
positives_preview, baselines_preview = _load_corpus(
stories_dir,
Path(args.paired_dir) if args.paired_dir else None,
)
n_emotions_preview = sum(
1 for ps in positives_preview.values()
if len(ps) >= args.min_positives
)
if n_emotions_preview == 0:
raise RuntimeError(
f"corpus has 0 emotions with >= {args.min_positives} positive "
f"examples. Check {stories_dir} — is it the right directory?"
)
print(
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
f"{args.min_positives}), {len(baselines_preview)} baselines"
)
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
device_map=args.device,
low_cpu_mem_usage=True,
)
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
# dimensions under a text_config subobject. get_text_config()
# returns that sub-config when present, else the top-level config.
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
hidden_dim = text_config.hidden_size
n_model_layers = text_config.num_hidden_layers
print(
f"Model loaded. hidden_dim={hidden_dim}, "
f"n_model_layers={n_model_layers} "
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
)
for layer_idx in target_layers:
if layer_idx < 0 or layer_idx >= n_model_layers:
raise ValueError(
f"target layer {layer_idx} out of range "
f"[0, {n_model_layers})"
)
print(
"Target layers (relative depth): "
+ ", ".join(
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
for l in target_layers
)
)
positives_by_emotion, baselines = _load_corpus(
Path(args.stories_dir),
Path(args.paired_dir) if args.paired_dir else None,
)
emotions = sorted(
e for e, ps in positives_by_emotion.items()
if len(ps) >= args.min_positives
)
if not emotions:
raise RuntimeError(
f"No emotions with >= {args.min_positives} positive examples"
)
print(
f"Training {len(emotions)} emotions; "
f"{len(baselines)} baseline scenarios"
)
# Cache all positive-text activations once so we can reuse them as
# negatives for other emotions. Keyed by the text itself to dedup
# across emotion lists.
device = torch.device(args.device)
text_to_emotion: dict[str, str] = {}
for emotion, texts in positives_by_emotion.items():
for t in texts:
text_to_emotion[t] = emotion
unique_positive_texts = list(text_to_emotion.keys())
print(
f"Collecting activations for {len(unique_positive_texts)} unique "
f"positive texts + {len(baselines)} baselines..."
)
positive_acts = _collect_activations(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, label="positives",
)
# positive_acts[i] corresponds to unique_positive_texts[i]
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
baseline_acts = (
_collect_activations(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length, label="baselines",
)
if baselines
else torch.zeros(0, len(target_layers), hidden_dim)
)
n_concepts = len(emotions)
n_layers = len(target_layers)
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
per_layer_vectors = torch.zeros(
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
)
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
# Negatives: every OTHER emotion's positives + baselines.
neg_rows = [
i
for i, t in enumerate(unique_positive_texts)
if text_to_emotion[t] != emotion
]
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
if baseline_acts.shape[0] > 0:
neg = torch.cat([neg, baseline_acts], dim=0)
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
neg_mean = neg.mean(dim=0)
diff = pos_mean - neg_mean
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms
# diff[layer] -> per_layer_vectors[layer, e_idx]
for l_idx in range(n_layers):
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
if e_idx < 5 or e_idx == len(emotions) - 1:
print(
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tensors = {
f"layer_{target_layers[l_idx]}.vectors": (
per_layer_vectors[l_idx].to(torch.float16)
)
for l_idx in range(n_layers)
}
safetensors.torch.save_file(
tensors,
str(output_dir / "readout.safetensors"),
)
manifest = {
"concepts": emotions,
"layers": target_layers,
"hidden_size": hidden_dim,
"dtype": "float16",
}
(output_dir / "readout.json").write_text(
json.dumps(manifest, indent=2) + "\n"
)
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
print(
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
f" {n_concepts} concepts x {n_layers} layers x "
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
)
if args.quality_report:
print("\nComputing quality report...")
report = _compute_quality_report(
emotions=emotions,
positive_acts=positive_acts,
baseline_acts=baseline_acts,
positives_by_emotion=positives_by_emotion,
text_to_row=text_to_row,
per_layer_vectors=per_layer_vectors,
target_layers=target_layers,
model=model,
positive_texts=unique_positive_texts,
text_to_emotion=text_to_emotion,
)
# Per-head attention decomposition — second pass, captures
# o_proj's input at each target layer and ranks heads per concept
# by selectivity. Meta-relational concepts often live in specific
# attention heads rather than the residual-stream sum; this
# diagnostic surfaces that.
print("\nCollecting o_proj inputs for per-head analysis...")
attn_inputs, active_layers = _collect_attention_inputs(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, label="attn-pos",
)
if active_layers and baselines:
baseline_attn_inputs, _ = _collect_attention_inputs(
model, tokenizer, baselines, active_layers, device,
args.batch_size, args.max_length, label="attn-base",
)
else:
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
if active_layers:
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
per_head = _compute_per_head_ranking(
emotions=emotions,
attn_inputs=attn_inputs,
baseline_attn_inputs=baseline_attn_inputs,
positives_by_emotion=positives_by_emotion,
text_to_row=text_to_row,
active_layers=active_layers,
n_heads_per_layer=n_heads_per_layer,
text_to_emotion=text_to_emotion,
unique_positive_texts=unique_positive_texts,
)
# Fold per-head into the main report under each concept.
for emotion, ph in per_head.items():
if emotion in report:
report[emotion]["per_head"] = ph["per_layer"]
print(f"Per-head analysis done on layers {active_layers}")
else:
print(
"No layer exposed a recognisable o_proj module path — "
"per-head analysis skipped."
)
(output_dir / "quality.json").write_text(
json.dumps(report, indent=2) + "\n"
)
# Short summary: concepts in each triage bucket.
clean_single_neuron = []
clean_circuit = []
fragmented = []
contaminated = []
mid = n_layers // 2
mid_layer = target_layers[mid]
for emotion in emotions:
per_l = report[emotion]["per_layer"][str(mid_layer)]
v = per_l["first_pc_variance_ratio"]
nb = per_l.get("best_neuron_cosine") or 0.0
top_near = report[emotion]["nearest_concepts"]
nearest_cos = top_near[0][1] if top_near else 0.0
if nearest_cos > 0.8:
contaminated.append(emotion)
elif v > 0.7 and abs(nb) > 0.6:
clean_single_neuron.append(emotion)
elif v > 0.7:
clean_circuit.append(emotion)
elif v < 0.4:
fragmented.append(emotion)
print(
f"\nQuality summary @ layer {mid_layer}:\n"
f" clean (single-neuron): {len(clean_single_neuron)}\n"
f" clean (low-dim circuit): {len(clean_circuit)}\n"
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
f" contaminated (nearest > 0.8): {len(contaminated)}"
)
if fragmented:
print(f" fragmented sample: {fragmented[:5]}")
if contaminated:
print(f" contaminated sample: {contaminated[:5]}")
print(f"\nWrote quality.json to {output_dir}")
del model
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()