# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Train amygdala steering vectors via Contrastive Activation Addition. Reads the per-emotion JSONL files produced by extract_training_pairs.py, runs the target model over each example, captures the residual-stream hidden state at the configured target layers, and computes `mean(positive) - mean(negative)` as the steering direction per layer per emotion. Output: a safetensors file matching the format AmygdalaConnector expects: vectors: [n_emotions, n_target_layers, hidden_dim] fp16 emotion_names: [n_emotions] uint8 Pooling: last-token residual-stream per example (CAA convention — the final token has seen the whole context and is where the model's "decision" lives). Alternative: mean across all tokens. The LAST convention is more common for steering vector work. """ import argparse import gc import json import os from collections import defaultdict from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Pick the last non-pad token's hidden state per example. hidden: [batch, seq, hidden_dim] attention_mask: [batch, seq] returns: [batch, hidden_dim] """ # last non-pad token index per row last_idx = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(hidden.size(0), device=hidden.device) return hidden[batch_idx, last_idx] def _collect_activations( model, tokenizer, texts: list[str], target_layers: list[int], device: torch.device, batch_size: int, max_length: int, ) -> torch.Tensor: """Run texts through the model, capture residual stream at target layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU. """ # Register hooks on the target layers' outputs. We want the # residual stream AFTER each layer, which is the output of the # transformer block (hidden_states[layer_idx+1] in HF land). captures: dict[int, torch.Tensor] = {} def make_hook(idx): def hook(_mod, _inp, output): # output is typically (hidden_states, ...) — take the first hs = output[0] if isinstance(output, tuple) else output captures[idx] = hs.detach() return hook handles = [] # Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's # language_model.model.layers follows the same convention. # Resolve the layer list by walking common paths. layers_module = _find_layers_module(model) for idx in target_layers: handles.append( layers_module[idx].register_forward_hook(make_hook(idx)) ) out_rows: list[torch.Tensor] = [] try: model.eval() with torch.no_grad(): for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] tok = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(device) captures.clear() model(**tok) per_layer = [] for idx in target_layers: hs = captures[idx] # [batch, seq, hidden] pooled = _pool_last(hs, tok["attention_mask"]) per_layer.append(pooled.to(torch.float32).cpu()) # Stack to [batch, n_layers, hidden_dim] batched = torch.stack(per_layer, dim=1) out_rows.append(batched) del tok, captures if (i // batch_size) % 10 == 0: torch.cuda.empty_cache() finally: for h in handles: h.remove() return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden] def _find_layers_module(model) -> torch.nn.ModuleList: """Walk a few likely paths to find the transformer-block list.""" candidates = [ "model.layers", "model.model.layers", "model.language_model.layers", "model.language_model.model.layers", "language_model.model.layers", "transformer.h", ] for path in candidates: obj = model ok = True for part in path.split("."): if not hasattr(obj, part): ok = False break obj = getattr(obj, part) if ok and isinstance(obj, torch.nn.ModuleList): return obj raise RuntimeError( f"Couldn't find transformer layer list. Tried: {candidates}" ) def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") ap.add_argument("--training-data-dir", required=True) ap.add_argument( "--target-layers", required=True, help="Comma-separated layer indices, e.g. 3,18,33,36", ) ap.add_argument("--output", required=True) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=4) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ args.dtype ] print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True, ) hidden_dim = model.config.hidden_size print(f"Model loaded. hidden_dim={hidden_dim}, " f"n_layers={model.config.num_hidden_layers}") manifest_path = Path(args.training_data_dir) / "_manifest.json" manifest = json.loads(manifest_path.read_text()) emotions = sorted(manifest["emotions"].keys()) print(f"Training {len(emotions)} emotions: {emotions}") n_emotions = len(emotions) n_layers = len(target_layers) vectors = torch.zeros( (n_emotions, n_layers, hidden_dim), dtype=torch.float32 ) device = torch.device(args.device) for e_idx, emotion in enumerate(emotions): path = Path(args.training_data_dir) / f"{emotion}.jsonl" pos_texts, neg_texts = [], [] with open(path) as f: for line in f: ex = json.loads(line) if ex["polarity"] == "positive": pos_texts.append(ex["text"]) else: neg_texts.append(ex["text"]) print(f"[{e_idx+1}/{n_emotions}] {emotion}: " f"{len(pos_texts)} pos / {len(neg_texts)} neg") pos_acts = _collect_activations( model, tokenizer, pos_texts, target_layers, device, args.batch_size, args.max_length, ) neg_acts = _collect_activations( model, tokenizer, neg_texts, target_layers, device, args.batch_size, args.max_length, ) # Difference of means per layer pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden] neg_mean = neg_acts.mean(dim=0) diff = pos_mean - neg_mean # Normalize per layer so projections are scale-comparable norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) diff = diff / norms vectors[e_idx] = diff del pos_acts, neg_acts gc.collect() torch.cuda.empty_cache() # Save in AmygdalaConnector format. # emotion_names as padded uint8 tensor names_bytes = [e.encode("utf-8") for e in emotions] max_len = max(len(b) for b in names_bytes) padded = torch.tensor( [list(b.ljust(max_len, b"\x00")) for b in names_bytes], dtype=torch.uint8, ) os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) safetensors.torch.save_file( { "vectors": vectors.to(torch.float16), "emotion_names": padded, "target_layers": torch.tensor(target_layers, dtype=torch.int32), }, args.output, ) print(f"\nWrote steering vectors to {args.output}: " f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)") if __name__ == "__main__": main()