# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Train concept-readout vectors via Contrastive Activation Addition. Reads the hand-written story corpus at ``amygdala_stories/{stories,paired}/`` and produces the per-layer safetensors file + sidecar JSON manifest that vLLM's ReadoutManager loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``). Training data (cross-concept contrast): positive for emotion E: stories/E.txt paired//E.txt (for each scenario that covers E) negative for emotion E: stories/.txt paired//baseline.txt (for each scenario) Within-scenario paired stories are the highest-signal pairs (same content, different concept framing); unpaired stories provide bulk contrast across the 80 emotions we have written so far. Pooling: last non-pad token. Matches how readout is consumed at decode time (residual read at the sampler's query position). Output: readout.safetensors layer_.vectors : fp16 (n_concepts, hidden_size) one per layer readout.json { "concepts": [...], "layers": [...], "hidden_size": int, "dtype": "float16" } """ from __future__ import annotations import argparse import gc import json import os from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Pick the last non-pad token's hidden state per example. hidden: [batch, seq, hidden_dim] attention_mask: [batch, seq] returns: [batch, hidden_dim] """ last_idx = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(hidden.size(0), device=hidden.device) return hidden[batch_idx, last_idx] def _find_layers_module(model) -> torch.nn.ModuleList: """Walk a few likely paths to find the transformer-block list.""" candidates = [ "model.layers", "model.model.layers", "model.language_model.layers", "model.language_model.model.layers", "language_model.model.layers", "transformer.h", ] for path in candidates: obj = model ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and isinstance(obj, torch.nn.ModuleList): return obj raise RuntimeError( f"Couldn't find transformer layer list. Tried: {candidates}" ) def _collect_activations( model, tokenizer, texts: list[str], target_layers: list[int], device: torch.device, batch_size: int, max_length: int, *, label: str = "", ) -> torch.Tensor: """Run texts through the model, capture residual stream at target layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. """ import time assert all(isinstance(t, str) and t for t in texts), ( f"_collect_activations: empty or non-string text in {label!r}" ) captures: dict[int, torch.Tensor] = {} def make_hook(idx: int): def hook(_mod, _inp, output): hs = output[0] if isinstance(output, tuple) else output captures[idx] = hs.detach() return hook layers_module = _find_layers_module(model) handles = [ layers_module[idx].register_forward_hook(make_hook(idx)) for idx in target_layers ] out_rows: list[torch.Tensor] = [] n_batches = (len(texts) + batch_size - 1) // batch_size start = time.time() try: model.eval() with torch.no_grad(): for b_idx, i in enumerate(range(0, len(texts), batch_size)): batch = texts[i : i + batch_size] tok = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(device) captures.clear() model(**tok) per_layer = [ _pool_last(captures[idx], tok["attention_mask"]) .to(torch.float32) .cpu() for idx in target_layers ] out_rows.append(torch.stack(per_layer, dim=1)) del tok, captures if b_idx % 10 == 0: torch.cuda.empty_cache() if b_idx % 5 == 0 or b_idx == n_batches - 1: elapsed = time.time() - start rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 print( f" [{label}] batch {b_idx + 1}/{n_batches} " f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", flush=True, ) captures = {} finally: for h in handles: h.remove() return torch.cat(out_rows, dim=0) def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings) list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives ]: """Return ``(positives_by_emotion, baselines)``. Cross-concept negatives are computed at training time from ``positives_by_emotion`` — each emotion's negative set is the union of all other emotions' positives plus the baseline texts. Empty .txt files are skipped with a warning. """ def _read_nonempty(path: Path) -> str | None: text = path.read_text().strip() if not text: print( f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}" ) return None return text positives: dict[str, list[str]] = {} for story_path in sorted(stories_dir.glob("*.txt")): text = _read_nonempty(story_path) if text is None: continue emotion = story_path.stem positives.setdefault(emotion, []).append(text) baselines: list[str] = [] if paired_dir is not None and paired_dir.exists(): for scenario_dir in sorted(paired_dir.iterdir()): if not scenario_dir.is_dir(): continue baseline_path = scenario_dir / "baseline.txt" if baseline_path.exists(): text = _read_nonempty(baseline_path) if text is not None: baselines.append(text) for framing_path in sorted(scenario_dir.glob("*.txt")): if framing_path.stem == "baseline": continue text = _read_nonempty(framing_path) if text is None: continue emotion = framing_path.stem positives.setdefault(emotion, []).append(text) return positives, baselines def _find_o_proj(layer) -> torch.nn.Module | None: """Locate the attention output projection within a transformer layer.""" for path in ( "self_attn.o_proj", "self_attn.out_proj", "attention.o_proj", "attn.out_proj", ): obj = layer ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok: return obj return None def _collect_attention_inputs( model, tokenizer, texts: list[str], target_layers: list[int], device: torch.device, batch_size: int, max_length: int, *, label: str = "", ) -> tuple[torch.Tensor, list[int]]: """Capture the INPUT to o_proj at each target layer (= concat of per-head attention outputs right before the output projection). Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers). The active_layers list is the subset of target_layers whose attention module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.) may be silently skipped. """ import time layers_module = _find_layers_module(model) captures: dict[int, torch.Tensor] = {} handles = [] active_layers: list[int] = [] def make_hook(idx: int): def hook(_mod, inputs): x = inputs[0] if isinstance(inputs, tuple) else inputs captures[idx] = x.detach() return hook for idx in target_layers: o_proj = _find_o_proj(layers_module[idx]) if o_proj is not None: handles.append(o_proj.register_forward_pre_hook(make_hook(idx))) active_layers.append(idx) if not active_layers: return torch.zeros(0, 0, 0), [] out_rows: list[torch.Tensor] = [] n_batches = (len(texts) + batch_size - 1) // batch_size start = time.time() try: model.eval() with torch.no_grad(): for b_idx, i in enumerate(range(0, len(texts), batch_size)): batch = texts[i : i + batch_size] tok = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(device) captures.clear() model(**tok) per_layer = [ _pool_last(captures[idx], tok["attention_mask"]) .to(torch.float32) .cpu() for idx in active_layers ] out_rows.append(torch.stack(per_layer, dim=1)) del tok, captures if b_idx % 10 == 0: torch.cuda.empty_cache() if b_idx % 5 == 0 or b_idx == n_batches - 1: elapsed = time.time() - start rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 print( f" [{label}] batch {b_idx + 1}/{n_batches} " f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", flush=True, ) captures = {} finally: for h in handles: h.remove() return torch.cat(out_rows, dim=0), active_layers def _compute_per_head_ranking( emotions: list[str], attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden] baseline_attn_inputs: torch.Tensor, positives_by_emotion: dict[str, list[str]], text_to_row: dict[str, int], active_layers: list[int], n_heads_per_layer: dict[int, int], text_to_emotion: dict[str, str], unique_positive_texts: list[str], ) -> dict: """For each concept, rank attention heads by contribution magnitude. Per (concept, layer): reshape o_proj input to [n_heads, head_dim], compute diff-of-means between positives and negatives per head, rank heads by the L2 norm of that diff. The top heads are the ones most strongly implicated in the concept circuit. Why this matters: meta-relational concepts (trust, recognition, "seen") often don't give a strong residual-stream diff-of-means but DO give a strong per-head signal — the concept lives in a small attention circuit rather than in the residual-stream sum. """ result: dict[str, dict] = {} for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] neg_rows = [ i for i, t in enumerate(unique_positive_texts) if text_to_emotion[t] != emotion ] pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden] neg = attn_inputs[neg_rows] if baseline_attn_inputs.shape[0] > 0: neg = torch.cat([neg, baseline_attn_inputs], dim=0) per_layer: dict[str, list] = {} for l_idx, target_l in enumerate(active_layers): n_heads = n_heads_per_layer.get(target_l) if not n_heads: continue hidden = pos.shape[-1] if hidden % n_heads != 0: continue head_dim = hidden // n_heads pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim) neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim) diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim] head_norms = diff.norm(dim=-1) # [n_heads] # Normalise by neg variance per head so different-scale heads # don't dominate purely on activation magnitude. neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6) head_selectivity = head_norms / neg_std # [n_heads] k = min(10, n_heads) top_vals, top_idxs = head_selectivity.topk(k) top_heads = [ [int(i), float(head_norms[i]), float(head_selectivity[i])] for i in top_idxs ] per_layer[str(target_l)] = { "n_heads": n_heads, "head_dim": head_dim, "top_heads": top_heads, # [head_idx, raw_norm, selectivity] "head_concentration": float( # fraction of total head-norm captured by top-k head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6) ), } result[emotion] = {"per_layer": per_layer} return result def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]: """Best-effort read of num_attention_heads per layer. Qwen uses the top-level config; falls back to config.num_attention_heads. """ cfg = model.config if hasattr(cfg, "get_text_config"): cfg = cfg.get_text_config() n = getattr(cfg, "num_attention_heads", None) if n is None: return {} return {l: n for l in target_layers} def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None: """Return the W_down weight for the MLP at the given transformer layer. Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj). Returns None if nothing matches — downstream code skips the single-neuron alignment check in that case rather than failing. """ layers = _find_layers_module(model) layer = layers[layer_idx] for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"): obj = layer ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and hasattr(obj, "weight"): # Shape convention: [hidden, mlp_inner] — each column is one # MLP neuron's contribution direction into the residual stream. return obj.weight.detach() return None def _compute_quality_report( emotions: list[str], positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden] baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden] positives_by_emotion: dict[str, list[str]], text_to_row: dict[str, int], per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed target_layers: list[int], model, positive_texts: list[str], text_to_emotion: dict[str, str], ) -> dict: """Per-concept quality metrics: - first_pc_variance_ratio: SVD on centered positive activations. >0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree). - story_projection_*: how each positive story projects onto the concept direction. Low std = tight agreement. - best_neuron_cosine: alignment of the residual-space direction with the nearest W_down column (= single MLP neuron). >0.6 = essentially single-neuron. - nearest_concepts: top-5 concept directions most parallel to this one. Cosine >0.8 means the vector is confused with a neighbor. """ report: dict = {} n_layers = per_layer_vectors.shape[0] # Pre-compute per-layer W_down for single-neuron alignment. Keep on # CPU to match the per_layer_vectors tensor. w_down: dict[int, torch.Tensor] = {} for target_l in target_layers: w = _find_mlp_down_proj(model, target_l) if w is not None: # Unit-normalize each column (one per MLP neuron). w = w.to(torch.float32).cpu() norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6) w_down[target_l] = w / norms # [hidden, mlp_inner] # Pre-compute unit-normed concept vectors (for cross-concept cosines). vec_norm = per_layer_vectors / per_layer_vectors.norm( dim=-1, keepdim=True ).clamp_min(1e-6) for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden] per_layer: dict = {} for l_idx, target_l in enumerate(target_layers): pos_l = pos[:, l_idx, :] # [n_pos, hidden] diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed pos_mean_l = pos_l.mean(dim=0) # SVD for rank analysis — if first PC dominates, stories agree. centered = pos_l - pos_mean_l # svdvals errors on 1-row; handle that. if centered.shape[0] >= 2: S = torch.linalg.svdvals(centered) var = S ** 2 var_total = var.sum().clamp_min(1e-12) var_ratios = (var / var_total).tolist() else: var_ratios = [1.0] # Per-story projection onto the concept direction. projections = pos_l @ diff_l # [n_pos] # Per-story alignment: cosine(story_dir, concept_dir) where # story_dir = pos_i - pos_mean (centered, pointing away from center). if centered.shape[0] >= 2: centered_norm = centered / centered.norm( dim=-1, keepdim=True ).clamp_min(1e-6) alignments = centered_norm @ diff_l else: alignments = torch.zeros(1) # Single-neuron alignment: is the direction close to any # W_down column? nb_best_idx = None nb_best_cos = None nb_top5 = None if target_l in w_down: W = w_down[target_l] cos = W.t() @ diff_l # [mlp_inner] abs_cos = cos.abs() k = min(5, abs_cos.shape[0]) top_vals, top_idxs = abs_cos.topk(k) nb_best_idx = int(top_idxs[0]) nb_best_cos = float(cos[top_idxs[0]]) nb_top5 = [[int(i), float(cos[i])] for i in top_idxs] per_layer[str(target_l)] = { "top3_variance_ratios": [ float(v) for v in var_ratios[:3] ], "first_pc_variance_ratio": float(var_ratios[0]), "story_projection_mean": float(projections.mean()), "story_projection_std": float(projections.std()), "story_projection_min": float(projections.min()), "story_projection_max": float(projections.max()), "story_alignment_mean": float(alignments.mean()), "story_alignment_std": float(alignments.std()), "best_neuron_idx": nb_best_idx, "best_neuron_cosine": nb_best_cos, "top5_neurons": nb_top5, } # Outlier stories: lowest-aligned on the middle target layer. mid = n_layers // 2 pos_l_mid = pos[:, mid, :] mid_mean = pos_l_mid.mean(dim=0) mid_diff = per_layer_vectors[mid, e_idx] centered_mid = pos_l_mid - mid_mean if centered_mid.shape[0] >= 2: centered_mid_norm = centered_mid / centered_mid.norm( dim=-1, keepdim=True ).clamp_min(1e-6) mid_aligns = centered_mid_norm @ mid_diff # [n_pos] # Lowest two alignments = candidate outliers. k = min(2, mid_aligns.shape[0]) low_vals, low_idxs = mid_aligns.topk(k, largest=False) outliers = [ [ positives_by_emotion[emotion][int(i)], float(mid_aligns[i]), ] for i in low_idxs ] else: outliers = [] # Nearest other concepts at the middle target layer. this_norm = vec_norm[mid, e_idx] all_cos = vec_norm[mid] @ this_norm # [n_concepts] all_cos[e_idx] = -2.0 # mask self k = min(5, all_cos.shape[0] - 1) top_vals, top_idxs = all_cos.topk(k) nearest = [ [emotions[int(i)], float(v)] for i, v in zip(top_idxs, top_vals) ] report[emotion] = { "n_positive_stories": len(pos_rows), "per_layer": per_layer, "outlier_stories": outliers, "nearest_concepts": nearest, } return report def _compute_linear_combinations( emotions: list[str], per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed target_layers: list[int], *, ridge_lambda: float = 0.01, top_k: int = 5, ) -> dict: """For each concept, ridge-regress its direction against all other concept directions. Report R² (how much of the target direction is explained by a linear combination of others) + top contributors. R² > 0.9 = concept is essentially a linear combination of others (redundant, or part of a cluster that needs disambiguating) R² < 0.5 = concept has a substantial unique component ridge_lambda keeps the coefficients stable when concepts are near-collinear. """ n_layers, n_concepts, hidden = per_layer_vectors.shape result: dict[str, dict] = {} # Middle layer for summary — same convention as nearest_concepts. mid = n_layers // 2 for l_idx, target_l in enumerate(target_layers): V = per_layer_vectors[l_idx] # [n_concepts, hidden] for i, name in enumerate(emotions): target = V[i] # [hidden] mask = torch.arange(n_concepts) != i others = V[mask] # [n-1, hidden] # Ridge: solve (O O^T + lam I) alpha = O t OOt = others @ others.t() # [n-1, n-1] b = others @ target # [n-1] A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype) alpha = torch.linalg.solve(A, b) recon = others.t() @ alpha # [hidden] resid = target - recon t_sq = (target * target).sum().clamp_min(1e-12) r2 = 1.0 - (resid * resid).sum() / t_sq abs_alpha = alpha.abs() k = min(top_k, n_concepts - 1) top_vals, top_idxs = abs_alpha.topk(k) other_names = [emotions[j] for j in range(n_concepts) if j != i] top = [ [other_names[int(j)], float(alpha[j])] for j in top_idxs ] entry = result.setdefault(name, {}) entry.setdefault("per_layer", {})[str(target_l)] = { "r_squared": float(r2), "residual_norm": float(resid.norm()), "top_contributors": top, } return result def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") ap.add_argument( "--stories-dir", required=True, help="Path to amygdala_stories/stories/", ) ap.add_argument( "--paired-dir", default=None, help="Path to amygdala_stories/paired/ (optional)", ) ap.add_argument( "--target-layers", required=True, help="Comma-separated layer indices, e.g. 40,50,60,70", ) ap.add_argument( "--output-dir", required=True, help="Directory to write readout.safetensors + readout.json", ) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") ap.add_argument( "--min-positives", type=int, default=1, help="Skip emotions with fewer positive examples than this", ) ap.add_argument( "--quality-report", action="store_true", help="After training, compute a per-concept quality report " "(SVD rank, per-story alignment, single-neuron alignment, " "nearest-concept contamination) and write quality.json", ) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }[args.dtype] # Preflight: corpus dirs exist before we pay the cost of loading a 27B model stories_dir = Path(args.stories_dir) if not stories_dir.is_dir(): raise FileNotFoundError( f"--stories-dir {stories_dir!s} does not exist or is not a dir" ) if args.paired_dir is not None: pd = Path(args.paired_dir) if not pd.is_dir(): raise FileNotFoundError( f"--paired-dir {pd!s} does not exist or is not a dir" ) # Quick corpus pre-scan so failures show up before we load the model. positives_preview, baselines_preview = _load_corpus( stories_dir, Path(args.paired_dir) if args.paired_dir else None, ) n_emotions_preview = sum( 1 for ps in positives_preview.values() if len(ps) >= args.min_positives ) if n_emotions_preview == 0: raise RuntimeError( f"corpus has 0 emotions with >= {args.min_positives} positive " f"examples. Check {stories_dir} — is it the right directory?" ) print( f"Corpus preflight: {n_emotions_preview} emotions (min_positives=" f"{args.min_positives}), {len(baselines_preview)} baselines" ) print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True, ) # Multimodal configs (Qwen3.5-27B, etc.) nest the text-model # dimensions under a text_config subobject. get_text_config() # returns that sub-config when present, else the top-level config. text_config = ( model.config.get_text_config() if hasattr(model.config, "get_text_config") else model.config ) hidden_dim = text_config.hidden_size n_model_layers = text_config.num_hidden_layers print( f"Model loaded. hidden_dim={hidden_dim}, " f"n_model_layers={n_model_layers} " f"(text_config.model_type={getattr(text_config, 'model_type', '?')})" ) for layer_idx in target_layers: if layer_idx < 0 or layer_idx >= n_model_layers: raise ValueError( f"target layer {layer_idx} out of range " f"[0, {n_model_layers})" ) print( "Target layers (relative depth): " + ", ".join( f"{l} ({100 * l / (n_model_layers - 1):.0f}%)" for l in target_layers ) ) positives_by_emotion, baselines = _load_corpus( Path(args.stories_dir), Path(args.paired_dir) if args.paired_dir else None, ) emotions = sorted( e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives ) if not emotions: raise RuntimeError( f"No emotions with >= {args.min_positives} positive examples" ) print( f"Training {len(emotions)} emotions; " f"{len(baselines)} baseline scenarios" ) # Cache all positive-text activations once so we can reuse them as # negatives for other emotions. Keyed by the text itself to dedup # across emotion lists. device = torch.device(args.device) text_to_emotion: dict[str, str] = {} for emotion, texts in positives_by_emotion.items(): for t in texts: text_to_emotion[t] = emotion unique_positive_texts = list(text_to_emotion.keys()) print( f"Collecting activations for {len(unique_positive_texts)} unique " f"positive texts + {len(baselines)} baselines..." ) positive_acts = _collect_activations( model, tokenizer, unique_positive_texts, target_layers, device, args.batch_size, args.max_length, label="positives", ) # positive_acts[i] corresponds to unique_positive_texts[i] text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} baseline_acts = ( _collect_activations( model, tokenizer, baselines, target_layers, device, args.batch_size, args.max_length, label="baselines", ) if baselines else torch.zeros(0, len(target_layers), hidden_dim) ) n_concepts = len(emotions) n_layers = len(target_layers) # Per-layer output matrices. Shape (n_concepts, hidden_size) each. per_layer_vectors = torch.zeros( (n_layers, n_concepts, hidden_dim), dtype=torch.float32 ) for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] # Negatives: every OTHER emotion's positives + baselines. neg_rows = [ i for i, t in enumerate(unique_positive_texts) if text_to_emotion[t] != emotion ] pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] if baseline_acts.shape[0] > 0: neg = torch.cat([neg, baseline_acts], dim=0) pos_mean = pos.mean(dim=0) # [n_layers, hidden] neg_mean = neg.mean(dim=0) diff = pos_mean - neg_mean norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) diff = diff / norms # diff[layer] -> per_layer_vectors[layer, e_idx] for l_idx in range(n_layers): per_layer_vectors[l_idx, e_idx] = diff[l_idx] if e_idx < 5 or e_idx == len(emotions) - 1: print( f" [{e_idx + 1}/{len(emotions)}] {emotion}: " f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}" ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tensors = { f"layer_{target_layers[l_idx]}.vectors": ( per_layer_vectors[l_idx].to(torch.float16) ) for l_idx in range(n_layers) } safetensors.torch.save_file( tensors, str(output_dir / "readout.safetensors"), ) manifest = { "concepts": emotions, "layers": target_layers, "hidden_size": hidden_dim, "dtype": "float16", } (output_dir / "readout.json").write_text( json.dumps(manifest, indent=2) + "\n" ) total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) print( f"\nWrote readout.safetensors + readout.json to {output_dir}\n" f" {n_concepts} concepts x {n_layers} layers x " f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB" ) if args.quality_report: print("\nComputing quality report...") report = _compute_quality_report( emotions=emotions, positive_acts=positive_acts, baseline_acts=baseline_acts, positives_by_emotion=positives_by_emotion, text_to_row=text_to_row, per_layer_vectors=per_layer_vectors, target_layers=target_layers, model=model, positive_texts=unique_positive_texts, text_to_emotion=text_to_emotion, ) # Per-head attention decomposition — second pass, captures # o_proj's input at each target layer and ranks heads per concept # by selectivity. Meta-relational concepts often live in specific # attention heads rather than the residual-stream sum; this # diagnostic surfaces that. print("\nCollecting o_proj inputs for per-head analysis...") attn_inputs, active_layers = _collect_attention_inputs( model, tokenizer, unique_positive_texts, target_layers, device, args.batch_size, args.max_length, label="attn-pos", ) if active_layers and baselines: baseline_attn_inputs, _ = _collect_attention_inputs( model, tokenizer, baselines, active_layers, device, args.batch_size, args.max_length, label="attn-base", ) else: baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim) if active_layers: n_heads_per_layer = _get_n_heads_per_layer(model, active_layers) per_head = _compute_per_head_ranking( emotions=emotions, attn_inputs=attn_inputs, baseline_attn_inputs=baseline_attn_inputs, positives_by_emotion=positives_by_emotion, text_to_row=text_to_row, active_layers=active_layers, n_heads_per_layer=n_heads_per_layer, text_to_emotion=text_to_emotion, unique_positive_texts=unique_positive_texts, ) # Fold per-head into the main report under each concept. for emotion, ph in per_head.items(): if emotion in report: report[emotion]["per_head"] = ph["per_layer"] print(f"Per-head analysis done on layers {active_layers}") else: print( "No layer exposed a recognisable o_proj module path — " "per-head analysis skipped." ) # Linear combinations — for each concept, how much of its direction # is explained by a ridge regression on the others. R² > 0.9 flags # concepts that are essentially linear combinations of their peers # (useful for teasing apart near-duplicate clusters). print("\nComputing linear-combination analysis...") lincomb = _compute_linear_combinations( emotions, per_layer_vectors, target_layers ) for emotion, lc in lincomb.items(): if emotion in report: report[emotion]["linear_combination"] = lc["per_layer"] (output_dir / "quality.json").write_text( json.dumps(report, indent=2) + "\n" ) # Short summary: concepts in each triage bucket. clean_single_neuron = [] clean_circuit = [] fragmented = [] contaminated = [] redundant = [] # R² > 0.9 — concept is near-linear combo of others mid = n_layers // 2 mid_layer = target_layers[mid] for emotion in emotions: per_l = report[emotion]["per_layer"][str(mid_layer)] v = per_l["first_pc_variance_ratio"] nb = per_l.get("best_neuron_cosine") or 0.0 top_near = report[emotion]["nearest_concepts"] nearest_cos = top_near[0][1] if top_near else 0.0 lc_r2 = 0.0 lc_entry = report[emotion].get("linear_combination", {}) if str(mid_layer) in lc_entry: lc_r2 = lc_entry[str(mid_layer)]["r_squared"] if lc_r2 > 0.9: redundant.append(emotion) if nearest_cos > 0.8: contaminated.append(emotion) elif v > 0.7 and abs(nb) > 0.6: clean_single_neuron.append(emotion) elif v > 0.7: clean_circuit.append(emotion) elif v < 0.4: fragmented.append(emotion) print( f"\nQuality summary @ layer {mid_layer}:\n" f" clean (single-neuron): {len(clean_single_neuron)}\n" f" clean (low-dim circuit): {len(clean_circuit)}\n" f" fragmented (first-PC < 0.4): {len(fragmented)}\n" f" contaminated (nearest > 0.8): {len(contaminated)}\n" f" redundant (R² > 0.9 vs. others): {len(redundant)}" ) if fragmented: print(f" fragmented sample: {fragmented[:5]}") if contaminated: print(f" contaminated sample: {contaminated[:5]}") if redundant: print(f" redundant sample: {redundant[:5]}") print(f"\nWrote quality.json to {output_dir}") del model gc.collect() torch.cuda.empty_cache() if __name__ == "__main__": main()