# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Train concept-readout vectors via Contrastive Activation Addition. Reads the hand-written story corpus at ``amygdala_stories/{stories,paired}/`` and produces the per-layer safetensors file + sidecar JSON manifest that vLLM's ReadoutManager loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``). Training data (cross-concept contrast): positive for emotion E: stories/E.txt paired//E.txt (for each scenario that covers E) negative for emotion E: stories/.txt paired//baseline.txt (for each scenario) Within-scenario paired stories are the highest-signal pairs (same content, different concept framing); unpaired stories provide bulk contrast across the 80 emotions we have written so far. Pooling: last non-pad token. Matches how readout is consumed at decode time (residual read at the sampler's query position). Output: readout.safetensors layer_.vectors : fp16 (n_concepts, hidden_size) one per layer readout.json { "concepts": [...], "layers": [...], "hidden_size": int, "dtype": "float16" } """ from __future__ import annotations import argparse import gc import json import os from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Pick the last non-pad token's hidden state per example. hidden: [batch, seq, hidden_dim] attention_mask: [batch, seq] returns: [batch, hidden_dim] """ last_idx = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(hidden.size(0), device=hidden.device) return hidden[batch_idx, last_idx] def _find_layers_module(model) -> torch.nn.ModuleList: """Walk a few likely paths to find the transformer-block list.""" candidates = [ "model.layers", "model.model.layers", "model.language_model.layers", "model.language_model.model.layers", "language_model.model.layers", "transformer.h", ] for path in candidates: obj = model ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and isinstance(obj, torch.nn.ModuleList): return obj raise RuntimeError( f"Couldn't find transformer layer list. Tried: {candidates}" ) def _collect_activations( model, tokenizer, texts: list[str], target_layers: list[int], device: torch.device, batch_size: int, max_length: int, *, label: str = "", ) -> torch.Tensor: """Run texts through the model, capture residual stream at target layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. """ import time assert all(isinstance(t, str) and t for t in texts), ( f"_collect_activations: empty or non-string text in {label!r}" ) captures: dict[int, torch.Tensor] = {} def make_hook(idx: int): def hook(_mod, _inp, output): hs = output[0] if isinstance(output, tuple) else output captures[idx] = hs.detach() return hook layers_module = _find_layers_module(model) handles = [ layers_module[idx].register_forward_hook(make_hook(idx)) for idx in target_layers ] out_rows: list[torch.Tensor] = [] n_batches = (len(texts) + batch_size - 1) // batch_size start = time.time() try: model.eval() with torch.no_grad(): for b_idx, i in enumerate(range(0, len(texts), batch_size)): batch = texts[i : i + batch_size] tok = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(device) captures.clear() model(**tok) per_layer = [ _pool_last(captures[idx], tok["attention_mask"]) .to(torch.float32) .cpu() for idx in target_layers ] out_rows.append(torch.stack(per_layer, dim=1)) del tok, captures if b_idx % 10 == 0: torch.cuda.empty_cache() if b_idx % 5 == 0 or b_idx == n_batches - 1: elapsed = time.time() - start rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 print( f" [{label}] batch {b_idx + 1}/{n_batches} " f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", flush=True, ) captures = {} finally: for h in handles: h.remove() return torch.cat(out_rows, dim=0) def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings) list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives ]: """Return ``(positives_by_emotion, baselines)``. Cross-concept negatives are computed at training time from ``positives_by_emotion`` — each emotion's negative set is the union of all other emotions' positives plus the baseline texts. Empty .txt files are skipped with a warning. """ def _read_nonempty(path: Path) -> str | None: text = path.read_text().strip() if not text: print( f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}" ) return None return text positives: dict[str, list[str]] = {} for story_path in sorted(stories_dir.glob("*.txt")): text = _read_nonempty(story_path) if text is None: continue emotion = story_path.stem positives.setdefault(emotion, []).append(text) baselines: list[str] = [] if paired_dir is not None and paired_dir.exists(): for scenario_dir in sorted(paired_dir.iterdir()): if not scenario_dir.is_dir(): continue baseline_path = scenario_dir / "baseline.txt" if baseline_path.exists(): text = _read_nonempty(baseline_path) if text is not None: baselines.append(text) for framing_path in sorted(scenario_dir.glob("*.txt")): if framing_path.stem == "baseline": continue text = _read_nonempty(framing_path) if text is None: continue emotion = framing_path.stem positives.setdefault(emotion, []).append(text) return positives, baselines def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None: """Return the W_down weight for the MLP at the given transformer layer. Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj). Returns None if nothing matches — downstream code skips the single-neuron alignment check in that case rather than failing. """ layers = _find_layers_module(model) layer = layers[layer_idx] for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"): obj = layer ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and hasattr(obj, "weight"): # Shape convention: [hidden, mlp_inner] — each column is one # MLP neuron's contribution direction into the residual stream. return obj.weight.detach() return None def _compute_quality_report( emotions: list[str], positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden] baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden] positives_by_emotion: dict[str, list[str]], text_to_row: dict[str, int], per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed target_layers: list[int], model, positive_texts: list[str], text_to_emotion: dict[str, str], ) -> dict: """Per-concept quality metrics: - first_pc_variance_ratio: SVD on centered positive activations. >0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree). - story_projection_*: how each positive story projects onto the concept direction. Low std = tight agreement. - best_neuron_cosine: alignment of the residual-space direction with the nearest W_down column (= single MLP neuron). >0.6 = essentially single-neuron. - nearest_concepts: top-5 concept directions most parallel to this one. Cosine >0.8 means the vector is confused with a neighbor. """ report: dict = {} n_layers = per_layer_vectors.shape[0] # Pre-compute per-layer W_down for single-neuron alignment. w_down: dict[int, torch.Tensor] = {} for target_l in target_layers: w = _find_mlp_down_proj(model, target_l) if w is not None: # Unit-normalize each column (one per MLP neuron). w = w.to(torch.float32) norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6) w_down[target_l] = w / norms # [hidden, mlp_inner] # Pre-compute unit-normed concept vectors (for cross-concept cosines). vec_norm = per_layer_vectors / per_layer_vectors.norm( dim=-1, keepdim=True ).clamp_min(1e-6) for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden] per_layer: dict = {} for l_idx, target_l in enumerate(target_layers): pos_l = pos[:, l_idx, :] # [n_pos, hidden] diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed pos_mean_l = pos_l.mean(dim=0) # SVD for rank analysis — if first PC dominates, stories agree. centered = pos_l - pos_mean_l # svdvals errors on 1-row; handle that. if centered.shape[0] >= 2: S = torch.linalg.svdvals(centered) var = S ** 2 var_total = var.sum().clamp_min(1e-12) var_ratios = (var / var_total).tolist() else: var_ratios = [1.0] # Per-story projection onto the concept direction. projections = pos_l @ diff_l # [n_pos] # Per-story alignment: cosine(story_dir, concept_dir) where # story_dir = pos_i - pos_mean (centered, pointing away from center). if centered.shape[0] >= 2: centered_norm = centered / centered.norm( dim=-1, keepdim=True ).clamp_min(1e-6) alignments = centered_norm @ diff_l else: alignments = torch.zeros(1) # Single-neuron alignment: is the direction close to any # W_down column? nb_best_idx = None nb_best_cos = None nb_top5 = None if target_l in w_down: W = w_down[target_l] cos = W.t() @ diff_l # [mlp_inner] abs_cos = cos.abs() k = min(5, abs_cos.shape[0]) top_vals, top_idxs = abs_cos.topk(k) nb_best_idx = int(top_idxs[0]) nb_best_cos = float(cos[top_idxs[0]]) nb_top5 = [[int(i), float(cos[i])] for i in top_idxs] per_layer[str(target_l)] = { "top3_variance_ratios": [ float(v) for v in var_ratios[:3] ], "first_pc_variance_ratio": float(var_ratios[0]), "story_projection_mean": float(projections.mean()), "story_projection_std": float(projections.std()), "story_projection_min": float(projections.min()), "story_projection_max": float(projections.max()), "story_alignment_mean": float(alignments.mean()), "story_alignment_std": float(alignments.std()), "best_neuron_idx": nb_best_idx, "best_neuron_cosine": nb_best_cos, "top5_neurons": nb_top5, } # Outlier stories: lowest-aligned on the middle target layer. mid = n_layers // 2 pos_l_mid = pos[:, mid, :] mid_mean = pos_l_mid.mean(dim=0) mid_diff = per_layer_vectors[mid, e_idx] centered_mid = pos_l_mid - mid_mean if centered_mid.shape[0] >= 2: centered_mid_norm = centered_mid / centered_mid.norm( dim=-1, keepdim=True ).clamp_min(1e-6) mid_aligns = centered_mid_norm @ mid_diff # [n_pos] # Lowest two alignments = candidate outliers. k = min(2, mid_aligns.shape[0]) low_vals, low_idxs = mid_aligns.topk(k, largest=False) outliers = [ [ positives_by_emotion[emotion][int(i)], float(mid_aligns[i]), ] for i in low_idxs ] else: outliers = [] # Nearest other concepts at the middle target layer. this_norm = vec_norm[mid, e_idx] all_cos = vec_norm[mid] @ this_norm # [n_concepts] all_cos[e_idx] = -2.0 # mask self k = min(5, all_cos.shape[0] - 1) top_vals, top_idxs = all_cos.topk(k) nearest = [ [emotions[int(i)], float(v)] for i, v in zip(top_idxs, top_vals) ] report[emotion] = { "n_positive_stories": len(pos_rows), "per_layer": per_layer, "outlier_stories": outliers, "nearest_concepts": nearest, } return report def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") ap.add_argument( "--stories-dir", required=True, help="Path to amygdala_stories/stories/", ) ap.add_argument( "--paired-dir", default=None, help="Path to amygdala_stories/paired/ (optional)", ) ap.add_argument( "--target-layers", required=True, help="Comma-separated layer indices, e.g. 40,50,60,70", ) ap.add_argument( "--output-dir", required=True, help="Directory to write readout.safetensors + readout.json", ) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") ap.add_argument( "--min-positives", type=int, default=1, help="Skip emotions with fewer positive examples than this", ) ap.add_argument( "--quality-report", action="store_true", help="After training, compute a per-concept quality report " "(SVD rank, per-story alignment, single-neuron alignment, " "nearest-concept contamination) and write quality.json", ) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }[args.dtype] # Preflight: corpus dirs exist before we pay the cost of loading a 27B model stories_dir = Path(args.stories_dir) if not stories_dir.is_dir(): raise FileNotFoundError( f"--stories-dir {stories_dir!s} does not exist or is not a dir" ) if args.paired_dir is not None: pd = Path(args.paired_dir) if not pd.is_dir(): raise FileNotFoundError( f"--paired-dir {pd!s} does not exist or is not a dir" ) # Quick corpus pre-scan so failures show up before we load the model. positives_preview, baselines_preview = _load_corpus( stories_dir, Path(args.paired_dir) if args.paired_dir else None, ) n_emotions_preview = sum( 1 for ps in positives_preview.values() if len(ps) >= args.min_positives ) if n_emotions_preview == 0: raise RuntimeError( f"corpus has 0 emotions with >= {args.min_positives} positive " f"examples. Check {stories_dir} — is it the right directory?" ) print( f"Corpus preflight: {n_emotions_preview} emotions (min_positives=" f"{args.min_positives}), {len(baselines_preview)} baselines" ) print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True, ) # Multimodal configs (Qwen3.5-27B, etc.) nest the text-model # dimensions under a text_config subobject. get_text_config() # returns that sub-config when present, else the top-level config. text_config = ( model.config.get_text_config() if hasattr(model.config, "get_text_config") else model.config ) hidden_dim = text_config.hidden_size n_model_layers = text_config.num_hidden_layers print( f"Model loaded. hidden_dim={hidden_dim}, " f"n_model_layers={n_model_layers} " f"(text_config.model_type={getattr(text_config, 'model_type', '?')})" ) for layer_idx in target_layers: if layer_idx < 0 or layer_idx >= n_model_layers: raise ValueError( f"target layer {layer_idx} out of range " f"[0, {n_model_layers})" ) print( "Target layers (relative depth): " + ", ".join( f"{l} ({100 * l / (n_model_layers - 1):.0f}%)" for l in target_layers ) ) positives_by_emotion, baselines = _load_corpus( Path(args.stories_dir), Path(args.paired_dir) if args.paired_dir else None, ) emotions = sorted( e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives ) if not emotions: raise RuntimeError( f"No emotions with >= {args.min_positives} positive examples" ) print( f"Training {len(emotions)} emotions; " f"{len(baselines)} baseline scenarios" ) # Cache all positive-text activations once so we can reuse them as # negatives for other emotions. Keyed by the text itself to dedup # across emotion lists. device = torch.device(args.device) text_to_emotion: dict[str, str] = {} for emotion, texts in positives_by_emotion.items(): for t in texts: text_to_emotion[t] = emotion unique_positive_texts = list(text_to_emotion.keys()) print( f"Collecting activations for {len(unique_positive_texts)} unique " f"positive texts + {len(baselines)} baselines..." ) positive_acts = _collect_activations( model, tokenizer, unique_positive_texts, target_layers, device, args.batch_size, args.max_length, label="positives", ) # positive_acts[i] corresponds to unique_positive_texts[i] text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} baseline_acts = ( _collect_activations( model, tokenizer, baselines, target_layers, device, args.batch_size, args.max_length, label="baselines", ) if baselines else torch.zeros(0, len(target_layers), hidden_dim) ) n_concepts = len(emotions) n_layers = len(target_layers) # Per-layer output matrices. Shape (n_concepts, hidden_size) each. per_layer_vectors = torch.zeros( (n_layers, n_concepts, hidden_dim), dtype=torch.float32 ) for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] # Negatives: every OTHER emotion's positives + baselines. neg_rows = [ i for i, t in enumerate(unique_positive_texts) if text_to_emotion[t] != emotion ] pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] if baseline_acts.shape[0] > 0: neg = torch.cat([neg, baseline_acts], dim=0) pos_mean = pos.mean(dim=0) # [n_layers, hidden] neg_mean = neg.mean(dim=0) diff = pos_mean - neg_mean norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) diff = diff / norms # diff[layer] -> per_layer_vectors[layer, e_idx] for l_idx in range(n_layers): per_layer_vectors[l_idx, e_idx] = diff[l_idx] if e_idx < 5 or e_idx == len(emotions) - 1: print( f" [{e_idx + 1}/{len(emotions)}] {emotion}: " f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}" ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tensors = { f"layer_{target_layers[l_idx]}.vectors": ( per_layer_vectors[l_idx].to(torch.float16) ) for l_idx in range(n_layers) } safetensors.torch.save_file( tensors, str(output_dir / "readout.safetensors"), ) manifest = { "concepts": emotions, "layers": target_layers, "hidden_size": hidden_dim, "dtype": "float16", } (output_dir / "readout.json").write_text( json.dumps(manifest, indent=2) + "\n" ) total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) print( f"\nWrote readout.safetensors + readout.json to {output_dir}\n" f" {n_concepts} concepts x {n_layers} layers x " f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB" ) if args.quality_report: print("\nComputing quality report...") report = _compute_quality_report( emotions=emotions, positive_acts=positive_acts, baseline_acts=baseline_acts, positives_by_emotion=positives_by_emotion, text_to_row=text_to_row, per_layer_vectors=per_layer_vectors, target_layers=target_layers, model=model, positive_texts=unique_positive_texts, text_to_emotion=text_to_emotion, ) (output_dir / "quality.json").write_text( json.dumps(report, indent=2) + "\n" ) # Short summary: concepts in each triage bucket. clean_single_neuron = [] clean_circuit = [] fragmented = [] contaminated = [] mid = n_layers // 2 mid_layer = target_layers[mid] for emotion in emotions: per_l = report[emotion]["per_layer"][str(mid_layer)] v = per_l["first_pc_variance_ratio"] nb = per_l.get("best_neuron_cosine") or 0.0 top_near = report[emotion]["nearest_concepts"] nearest_cos = top_near[0][1] if top_near else 0.0 if nearest_cos > 0.8: contaminated.append(emotion) elif v > 0.7 and abs(nb) > 0.6: clean_single_neuron.append(emotion) elif v > 0.7: clean_circuit.append(emotion) elif v < 0.4: fragmented.append(emotion) print( f"\nQuality summary @ layer {mid_layer}:\n" f" clean (single-neuron): {len(clean_single_neuron)}\n" f" clean (low-dim circuit): {len(clean_circuit)}\n" f" fragmented (first-PC < 0.4): {len(fragmented)}\n" f" contaminated (nearest > 0.8): {len(contaminated)}" ) if fragmented: print(f" fragmented sample: {fragmented[:5]}") if contaminated: print(f" contaminated sample: {contaminated[:5]}") print(f"\nWrote quality.json to {output_dir}") del model gc.collect() torch.cuda.empty_cache() if __name__ == "__main__": main()