diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index a722298..21e5ed1 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -1,30 +1,48 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Train amygdala steering vectors via Contrastive Activation Addition. +"""Train concept-readout vectors via Contrastive Activation Addition. -Reads the per-emotion JSONL files produced by extract_training_pairs.py, -runs the target model over each example, captures the residual-stream -hidden state at the configured target layers, and computes -`mean(positive) - mean(negative)` as the steering direction per layer -per emotion. +Reads the hand-written story corpus at +``amygdala_stories/{stories,paired}/`` and produces the per-layer +safetensors file + sidecar JSON manifest that vLLM's ReadoutManager +loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``). -Output: a safetensors file matching the format AmygdalaConnector -expects: +Training data (cross-concept contrast): - vectors: [n_emotions, n_target_layers, hidden_dim] fp16 - emotion_names: [n_emotions] uint8 + positive for emotion E: + stories/E.txt + paired//E.txt (for each scenario that covers E) -Pooling: last-token residual-stream per example (CAA convention — -the final token has seen the whole context and is where the model's -"decision" lives). Alternative: mean across all tokens. The LAST -convention is more common for steering vector work. + negative for emotion E: + stories/.txt + paired//baseline.txt (for each scenario) + +Within-scenario paired stories are the highest-signal pairs (same +content, different concept framing); unpaired stories provide bulk +contrast across the 80 emotions we have written so far. + +Pooling: last non-pad token. Matches how readout is consumed at decode +time (residual read at the sampler's query position). + +Output: + + readout.safetensors + layer_.vectors : fp16 (n_concepts, hidden_size) one per layer + readout.json + { + "concepts": [...], + "layers": [...], + "hidden_size": int, + "dtype": "float16" + } """ +from __future__ import annotations + import argparse import gc import json import os -from collections import defaultdict from pathlib import Path import safetensors.torch @@ -39,81 +57,11 @@ def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tens attention_mask: [batch, seq] returns: [batch, hidden_dim] """ - # last non-pad token index per row last_idx = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(hidden.size(0), device=hidden.device) return hidden[batch_idx, last_idx] -def _collect_activations( - model, - tokenizer, - texts: list[str], - target_layers: list[int], - device: torch.device, - batch_size: int, - max_length: int, -) -> torch.Tensor: - """Run texts through the model, capture residual stream at target - layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU. - """ - # Register hooks on the target layers' outputs. We want the - # residual stream AFTER each layer, which is the output of the - # transformer block (hidden_states[layer_idx+1] in HF land). - captures: dict[int, torch.Tensor] = {} - - def make_hook(idx): - def hook(_mod, _inp, output): - # output is typically (hidden_states, ...) — take the first - hs = output[0] if isinstance(output, tuple) else output - captures[idx] = hs.detach() - return hook - - handles = [] - # Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's - # language_model.model.layers follows the same convention. - # Resolve the layer list by walking common paths. - layers_module = _find_layers_module(model) - for idx in target_layers: - handles.append( - layers_module[idx].register_forward_hook(make_hook(idx)) - ) - - out_rows: list[torch.Tensor] = [] - try: - model.eval() - with torch.no_grad(): - for i in range(0, len(texts), batch_size): - batch = texts[i : i + batch_size] - tok = tokenizer( - batch, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_length, - ).to(device) - captures.clear() - model(**tok) - - per_layer = [] - for idx in target_layers: - hs = captures[idx] # [batch, seq, hidden] - pooled = _pool_last(hs, tok["attention_mask"]) - per_layer.append(pooled.to(torch.float32).cpu()) - # Stack to [batch, n_layers, hidden_dim] - batched = torch.stack(per_layer, dim=1) - out_rows.append(batched) - - del tok, captures - if (i // batch_size) % 10 == 0: - torch.cuda.empty_cache() - finally: - for h in handles: - h.remove() - - return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden] - - def _find_layers_module(model) -> torch.nn.ModuleList: """Walk a few likely paths to find the transformer-block list.""" candidates = [ @@ -139,25 +87,143 @@ def _find_layers_module(model) -> torch.nn.ModuleList: ) +def _collect_activations( + model, + tokenizer, + texts: list[str], + target_layers: list[int], + device: torch.device, + batch_size: int, + max_length: int, +) -> torch.Tensor: + """Run texts through the model, capture residual stream at target + layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. + """ + captures: dict[int, torch.Tensor] = {} + + def make_hook(idx: int): + def hook(_mod, _inp, output): + hs = output[0] if isinstance(output, tuple) else output + captures[idx] = hs.detach() + return hook + + layers_module = _find_layers_module(model) + handles = [ + layers_module[idx].register_forward_hook(make_hook(idx)) + for idx in target_layers + ] + + out_rows: list[torch.Tensor] = [] + try: + model.eval() + with torch.no_grad(): + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + tok = tokenizer( + batch, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ).to(device) + captures.clear() + model(**tok) + + per_layer = [ + _pool_last(captures[idx], tok["attention_mask"]) + .to(torch.float32) + .cpu() + for idx in target_layers + ] + out_rows.append(torch.stack(per_layer, dim=1)) + del tok, captures + if (i // batch_size) % 10 == 0: + torch.cuda.empty_cache() + captures = {} + finally: + for h in handles: + h.remove() + + return torch.cat(out_rows, dim=0) + + +def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ + dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings) + list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives +]: + """Return ``(positives_by_emotion, baselines)``. + + Cross-concept negatives are computed at training time from + ``positives_by_emotion`` — each emotion's negative set is the + union of all other emotions' positives plus the baseline texts. + """ + positives: dict[str, list[str]] = {} + for story_path in sorted(stories_dir.glob("*.txt")): + emotion = story_path.stem + positives.setdefault(emotion, []).append( + story_path.read_text().strip() + ) + + baselines: list[str] = [] + if paired_dir is not None and paired_dir.exists(): + for scenario_dir in sorted(paired_dir.iterdir()): + if not scenario_dir.is_dir(): + continue + baseline_path = scenario_dir / "baseline.txt" + if baseline_path.exists(): + baselines.append(baseline_path.read_text().strip()) + for framing_path in sorted(scenario_dir.glob("*.txt")): + if framing_path.stem == "baseline": + continue + emotion = framing_path.stem + positives.setdefault(emotion, []).append( + framing_path.read_text().strip() + ) + + return positives, baselines + + def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") - ap.add_argument("--training-data-dir", required=True) ap.add_argument( - "--target-layers", required=True, - help="Comma-separated layer indices, e.g. 3,18,33,36", + "--stories-dir", + required=True, + help="Path to amygdala_stories/stories/", + ) + ap.add_argument( + "--paired-dir", + default=None, + help="Path to amygdala_stories/paired/ (optional)", + ) + ap.add_argument( + "--target-layers", + required=True, + help="Comma-separated layer indices, e.g. 40,50,60,70", + ) + ap.add_argument( + "--output-dir", + required=True, + help="Directory to write readout.safetensors + readout.json", ) - ap.add_argument("--output", required=True) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) - ap.add_argument("--batch-size", type=int, default=4) + ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") + ap.add_argument( + "--min-positives", + type=int, + default=1, + help="Skip emotions with fewer positive examples than this", + ) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] - dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ - args.dtype - ] + dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + }[args.dtype] print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) @@ -170,78 +236,137 @@ def main() -> None: low_cpu_mem_usage=True, ) hidden_dim = model.config.hidden_size - print(f"Model loaded. hidden_dim={hidden_dim}, " - f"n_layers={model.config.num_hidden_layers}") - - manifest_path = Path(args.training_data_dir) / "_manifest.json" - manifest = json.loads(manifest_path.read_text()) - - emotions = sorted(manifest["emotions"].keys()) - print(f"Training {len(emotions)} emotions: {emotions}") - - n_emotions = len(emotions) - n_layers = len(target_layers) - vectors = torch.zeros( - (n_emotions, n_layers, hidden_dim), dtype=torch.float32 + n_model_layers = model.config.num_hidden_layers + print( + f"Model loaded. hidden_dim={hidden_dim}, " + f"n_model_layers={n_model_layers}" ) + + for layer_idx in target_layers: + if layer_idx < 0 or layer_idx >= n_model_layers: + raise ValueError( + f"target layer {layer_idx} out of range " + f"[0, {n_model_layers})" + ) + + positives_by_emotion, baselines = _load_corpus( + Path(args.stories_dir), + Path(args.paired_dir) if args.paired_dir else None, + ) + emotions = sorted( + e for e, ps in positives_by_emotion.items() + if len(ps) >= args.min_positives + ) + if not emotions: + raise RuntimeError( + f"No emotions with >= {args.min_positives} positive examples" + ) + print( + f"Training {len(emotions)} emotions; " + f"{len(baselines)} baseline scenarios" + ) + + # Cache all positive-text activations once so we can reuse them as + # negatives for other emotions. Keyed by the text itself to dedup + # across emotion lists. device = torch.device(args.device) + text_to_emotion: dict[str, str] = {} + for emotion, texts in positives_by_emotion.items(): + for t in texts: + text_to_emotion[t] = emotion + + unique_positive_texts = list(text_to_emotion.keys()) + print( + f"Collecting activations for {len(unique_positive_texts)} unique " + f"positive texts + {len(baselines)} baselines..." + ) + + positive_acts = _collect_activations( + model, tokenizer, unique_positive_texts, target_layers, device, + args.batch_size, args.max_length, + ) + # positive_acts[i] corresponds to unique_positive_texts[i] + text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} + + baseline_acts = ( + _collect_activations( + model, tokenizer, baselines, target_layers, device, + args.batch_size, args.max_length, + ) + if baselines + else torch.zeros(0, len(target_layers), hidden_dim) + ) + + n_concepts = len(emotions) + n_layers = len(target_layers) + + # Per-layer output matrices. Shape (n_concepts, hidden_size) each. + per_layer_vectors = torch.zeros( + (n_layers, n_concepts, hidden_dim), dtype=torch.float32 + ) for e_idx, emotion in enumerate(emotions): - path = Path(args.training_data_dir) / f"{emotion}.jsonl" - pos_texts, neg_texts = [], [] - with open(path) as f: - for line in f: - ex = json.loads(line) - if ex["polarity"] == "positive": - pos_texts.append(ex["text"]) - else: - neg_texts.append(ex["text"]) - print(f"[{e_idx+1}/{n_emotions}] {emotion}: " - f"{len(pos_texts)} pos / {len(neg_texts)} neg") + pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] + # Negatives: every OTHER emotion's positives + baselines. + neg_rows = [ + i + for i, t in enumerate(unique_positive_texts) + if text_to_emotion[t] != emotion + ] - pos_acts = _collect_activations( - model, tokenizer, pos_texts, target_layers, device, - args.batch_size, args.max_length, - ) - neg_acts = _collect_activations( - model, tokenizer, neg_texts, target_layers, device, - args.batch_size, args.max_length, - ) + pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] + neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] + if baseline_acts.shape[0] > 0: + neg = torch.cat([neg, baseline_acts], dim=0) - # Difference of means per layer - pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden] - neg_mean = neg_acts.mean(dim=0) + pos_mean = pos.mean(dim=0) # [n_layers, hidden] + neg_mean = neg.mean(dim=0) diff = pos_mean - neg_mean - - # Normalize per layer so projections are scale-comparable norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) diff = diff / norms - vectors[e_idx] = diff - del pos_acts, neg_acts - gc.collect() - torch.cuda.empty_cache() + # diff[layer] -> per_layer_vectors[layer, e_idx] + for l_idx in range(n_layers): + per_layer_vectors[l_idx, e_idx] = diff[l_idx] - # Save in AmygdalaConnector format. - # emotion_names as padded uint8 tensor - names_bytes = [e.encode("utf-8") for e in emotions] - max_len = max(len(b) for b in names_bytes) - padded = torch.tensor( - [list(b.ljust(max_len, b"\x00")) for b in names_bytes], - dtype=torch.uint8, - ) + if e_idx < 5 or e_idx == len(emotions) - 1: + print( + f" [{e_idx + 1}/{len(emotions)}] {emotion}: " + f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}" + ) - os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + tensors = { + f"layer_{target_layers[l_idx]}.vectors": ( + per_layer_vectors[l_idx].to(torch.float16) + ) + for l_idx in range(n_layers) + } safetensors.torch.save_file( - { - "vectors": vectors.to(torch.float16), - "emotion_names": padded, - "target_layers": torch.tensor(target_layers, dtype=torch.int32), - }, - args.output, + tensors, + str(output_dir / "readout.safetensors"), ) - print(f"\nWrote steering vectors to {args.output}: " - f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)") + manifest = { + "concepts": emotions, + "layers": target_layers, + "hidden_size": hidden_dim, + "dtype": "float16", + } + (output_dir / "readout.json").write_text( + json.dumps(manifest, indent=2) + "\n" + ) + + total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) + print( + f"\nWrote readout.safetensors + readout.json to {output_dir}\n" + f" {n_concepts} concepts x {n_layers} layers x " + f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB" + ) + del model + gc.collect() + torch.cuda.empty_cache() if __name__ == "__main__":