# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Train concept-readout vectors via Contrastive Activation Addition. Reads the hand-written story corpus at ``amygdala_stories/{stories,paired}/`` and produces the per-layer safetensors file + sidecar JSON manifest that vLLM's ReadoutManager loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``). Training data (cross-concept contrast): positive for emotion E: stories/E.txt paired//E.txt (for each scenario that covers E) negative for emotion E: stories/.txt paired//baseline.txt (for each scenario) Within-scenario paired stories are the highest-signal pairs (same content, different concept framing); unpaired stories provide bulk contrast across the 80 emotions we have written so far. Pooling: last non-pad token. Matches how readout is consumed at decode time (residual read at the sampler's query position). Output: readout.safetensors layer_.vectors : fp16 (n_concepts, hidden_size) one per layer readout.json { "concepts": [...], "layers": [...], "hidden_size": int, "dtype": "float16" } """ from __future__ import annotations import argparse import gc import json import os from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Pick the last non-pad token's hidden state per example. hidden: [batch, seq, hidden_dim] attention_mask: [batch, seq] returns: [batch, hidden_dim] """ last_idx = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(hidden.size(0), device=hidden.device) return hidden[batch_idx, last_idx] def _find_layers_module(model) -> torch.nn.ModuleList: """Walk a few likely paths to find the transformer-block list.""" candidates = [ "model.layers", "model.model.layers", "model.language_model.layers", "model.language_model.model.layers", "language_model.model.layers", "transformer.h", ] for path in candidates: obj = model ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and isinstance(obj, torch.nn.ModuleList): return obj raise RuntimeError( f"Couldn't find transformer layer list. Tried: {candidates}" ) def _collect_activations( model, tokenizer, texts: list[str], target_layers: list[int], device: torch.device, batch_size: int, max_length: int, ) -> torch.Tensor: """Run texts through the model, capture residual stream at target layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. """ captures: dict[int, torch.Tensor] = {} def make_hook(idx: int): def hook(_mod, _inp, output): hs = output[0] if isinstance(output, tuple) else output captures[idx] = hs.detach() return hook layers_module = _find_layers_module(model) handles = [ layers_module[idx].register_forward_hook(make_hook(idx)) for idx in target_layers ] out_rows: list[torch.Tensor] = [] try: model.eval() with torch.no_grad(): for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] tok = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(device) captures.clear() model(**tok) per_layer = [ _pool_last(captures[idx], tok["attention_mask"]) .to(torch.float32) .cpu() for idx in target_layers ] out_rows.append(torch.stack(per_layer, dim=1)) del tok, captures if (i // batch_size) % 10 == 0: torch.cuda.empty_cache() captures = {} finally: for h in handles: h.remove() return torch.cat(out_rows, dim=0) def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings) list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives ]: """Return ``(positives_by_emotion, baselines)``. Cross-concept negatives are computed at training time from ``positives_by_emotion`` — each emotion's negative set is the union of all other emotions' positives plus the baseline texts. """ positives: dict[str, list[str]] = {} for story_path in sorted(stories_dir.glob("*.txt")): emotion = story_path.stem positives.setdefault(emotion, []).append( story_path.read_text().strip() ) baselines: list[str] = [] if paired_dir is not None and paired_dir.exists(): for scenario_dir in sorted(paired_dir.iterdir()): if not scenario_dir.is_dir(): continue baseline_path = scenario_dir / "baseline.txt" if baseline_path.exists(): baselines.append(baseline_path.read_text().strip()) for framing_path in sorted(scenario_dir.glob("*.txt")): if framing_path.stem == "baseline": continue emotion = framing_path.stem positives.setdefault(emotion, []).append( framing_path.read_text().strip() ) return positives, baselines def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") ap.add_argument( "--stories-dir", required=True, help="Path to amygdala_stories/stories/", ) ap.add_argument( "--paired-dir", default=None, help="Path to amygdala_stories/paired/ (optional)", ) ap.add_argument( "--target-layers", required=True, help="Comma-separated layer indices, e.g. 40,50,60,70", ) ap.add_argument( "--output-dir", required=True, help="Directory to write readout.safetensors + readout.json", ) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") ap.add_argument( "--min-positives", type=int, default=1, help="Skip emotions with fewer positive examples than this", ) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }[args.dtype] print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True, ) hidden_dim = model.config.hidden_size n_model_layers = model.config.num_hidden_layers print( f"Model loaded. hidden_dim={hidden_dim}, " f"n_model_layers={n_model_layers}" ) for layer_idx in target_layers: if layer_idx < 0 or layer_idx >= n_model_layers: raise ValueError( f"target layer {layer_idx} out of range " f"[0, {n_model_layers})" ) positives_by_emotion, baselines = _load_corpus( Path(args.stories_dir), Path(args.paired_dir) if args.paired_dir else None, ) emotions = sorted( e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives ) if not emotions: raise RuntimeError( f"No emotions with >= {args.min_positives} positive examples" ) print( f"Training {len(emotions)} emotions; " f"{len(baselines)} baseline scenarios" ) # Cache all positive-text activations once so we can reuse them as # negatives for other emotions. Keyed by the text itself to dedup # across emotion lists. device = torch.device(args.device) text_to_emotion: dict[str, str] = {} for emotion, texts in positives_by_emotion.items(): for t in texts: text_to_emotion[t] = emotion unique_positive_texts = list(text_to_emotion.keys()) print( f"Collecting activations for {len(unique_positive_texts)} unique " f"positive texts + {len(baselines)} baselines..." ) positive_acts = _collect_activations( model, tokenizer, unique_positive_texts, target_layers, device, args.batch_size, args.max_length, ) # positive_acts[i] corresponds to unique_positive_texts[i] text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} baseline_acts = ( _collect_activations( model, tokenizer, baselines, target_layers, device, args.batch_size, args.max_length, ) if baselines else torch.zeros(0, len(target_layers), hidden_dim) ) n_concepts = len(emotions) n_layers = len(target_layers) # Per-layer output matrices. Shape (n_concepts, hidden_size) each. per_layer_vectors = torch.zeros( (n_layers, n_concepts, hidden_dim), dtype=torch.float32 ) for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] # Negatives: every OTHER emotion's positives + baselines. neg_rows = [ i for i, t in enumerate(unique_positive_texts) if text_to_emotion[t] != emotion ] pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] if baseline_acts.shape[0] > 0: neg = torch.cat([neg, baseline_acts], dim=0) pos_mean = pos.mean(dim=0) # [n_layers, hidden] neg_mean = neg.mean(dim=0) diff = pos_mean - neg_mean norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) diff = diff / norms # diff[layer] -> per_layer_vectors[layer, e_idx] for l_idx in range(n_layers): per_layer_vectors[l_idx, e_idx] = diff[l_idx] if e_idx < 5 or e_idx == len(emotions) - 1: print( f" [{e_idx + 1}/{len(emotions)}] {emotion}: " f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}" ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tensors = { f"layer_{target_layers[l_idx]}.vectors": ( per_layer_vectors[l_idx].to(torch.float16) ) for l_idx in range(n_layers) } safetensors.torch.save_file( tensors, str(output_dir / "readout.safetensors"), ) manifest = { "concepts": emotions, "layers": target_layers, "hidden_size": hidden_dim, "dtype": "float16", } (output_dir / "readout.json").write_text( json.dumps(manifest, indent=2) + "\n" ) total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) print( f"\nWrote readout.safetensors + readout.json to {output_dir}\n" f" {n_concepts} concepts x {n_layers} layers x " f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB" ) del model gc.collect() torch.cuda.empty_cache() if __name__ == "__main__": main()