training: rewrite trainer for readout pipeline + story corpus

The old script was written for the AmygdalaConnector's expected
format ([n_emotions, n_target_layers, hidden_dim] in a single
tensor, plus a JSONL input format from extract_training_pairs.py).
Neither matches our current state: the runtime side is now
ReadoutManager loading per-layer safetensors keyed layer_<idx>.vectors,
and the data side is hand-written prose stories under
amygdala_stories/{stories,paired}/.

Changes:

* Input loader reads stories/<emotion>.txt and
  paired/<scenario>/<emotion>.txt directly. Each emotion's positive
  set is {its unpaired story} union {its within-scenario framings};
  its negative set is {all other emotions' positives} union {all
  scenario baselines}.
* Paired scenarios' baseline.txt files become shared negatives
  (scenario-neutral prose that doesn't frame any particular
  emotion), providing anchor points for within-scenario contrasts.
* Output writes readout.safetensors with per-layer tensors keyed
  layer_<idx>.vectors shape (n_concepts, hidden_size), plus a
  sidecar readout.json manifest with {concepts, layers, hidden_size,
  dtype} that ReadoutManager.from_file consumes directly.
* Dedup: activations are computed once per unique text (an emotion's
  own positive is another emotion's negative — we'd otherwise do N×
  the forwards needed).

Preserved:
* _pool_last (last non-pad residual) — matches how readout is read
  at decode time from the sampler's query-last position.
* register_forward_hook on target layer modules — correct approach
  for transformer blocks.
* _find_layers_module traversal — mirrors ReadoutManager's.
* bf16 + low_cpu_mem_usage model load — sensible for 27B on B200.

Verified locally (CPU, fake activations):
* Loader finds 89 emotions from the current corpus (80 unpaired +
  9 emotions that appear only in paired scenarios) and 6 baselines.
* Per-(layer, concept) vectors are unit-normalized.
* Output reloads cleanly through ReadoutManager.from_file with
  matching concepts / layers / shapes.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 00:32:50 -04:00
parent 34bd122590
commit 15737dfd92

View file

@ -1,30 +1,48 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Train amygdala steering vectors via Contrastive Activation Addition. """Train concept-readout vectors via Contrastive Activation Addition.
Reads the per-emotion JSONL files produced by extract_training_pairs.py, Reads the hand-written story corpus at
runs the target model over each example, captures the residual-stream ``amygdala_stories/{stories,paired}/`` and produces the per-layer
hidden state at the configured target layers, and computes safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
`mean(positive) - mean(negative)` as the steering direction per layer loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
per emotion.
Output: a safetensors file matching the format AmygdalaConnector Training data (cross-concept contrast):
expects:
vectors: [n_emotions, n_target_layers, hidden_dim] fp16 positive for emotion E:
emotion_names: [n_emotions] uint8 stories/E.txt
paired/<scenario>/E.txt (for each scenario that covers E)
Pooling: last-token residual-stream per example (CAA convention negative for emotion E:
the final token has seen the whole context and is where the model's stories/<all other emotions>.txt
"decision" lives). Alternative: mean across all tokens. The LAST paired/<scenario>/baseline.txt (for each scenario)
convention is more common for steering vector work.
Within-scenario paired stories are the highest-signal pairs (same
content, different concept framing); unpaired stories provide bulk
contrast across the 80 emotions we have written so far.
Pooling: last non-pad token. Matches how readout is consumed at decode
time (residual read at the sampler's query position).
Output:
readout.safetensors
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
readout.json
{
"concepts": [...],
"layers": [...],
"hidden_size": int,
"dtype": "float16"
}
""" """
from __future__ import annotations
import argparse import argparse
import gc import gc
import json import json
import os import os
from collections import defaultdict
from pathlib import Path from pathlib import Path
import safetensors.torch import safetensors.torch
@ -39,81 +57,11 @@ def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tens
attention_mask: [batch, seq] attention_mask: [batch, seq]
returns: [batch, hidden_dim] returns: [batch, hidden_dim]
""" """
# last non-pad token index per row
last_idx = attention_mask.sum(dim=1) - 1 last_idx = attention_mask.sum(dim=1) - 1
batch_idx = torch.arange(hidden.size(0), device=hidden.device) batch_idx = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_idx, last_idx] return hidden[batch_idx, last_idx]
def _collect_activations(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
"""
# Register hooks on the target layers' outputs. We want the
# residual stream AFTER each layer, which is the output of the
# transformer block (hidden_states[layer_idx+1] in HF land).
captures: dict[int, torch.Tensor] = {}
def make_hook(idx):
def hook(_mod, _inp, output):
# output is typically (hidden_states, ...) — take the first
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
handles = []
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
# language_model.model.layers follows the same convention.
# Resolve the layer list by walking common paths.
layers_module = _find_layers_module(model)
for idx in target_layers:
handles.append(
layers_module[idx].register_forward_hook(make_hook(idx))
)
out_rows: list[torch.Tensor] = []
try:
model.eval()
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = []
for idx in target_layers:
hs = captures[idx] # [batch, seq, hidden]
pooled = _pool_last(hs, tok["attention_mask"])
per_layer.append(pooled.to(torch.float32).cpu())
# Stack to [batch, n_layers, hidden_dim]
batched = torch.stack(per_layer, dim=1)
out_rows.append(batched)
del tok, captures
if (i // batch_size) % 10 == 0:
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
def _find_layers_module(model) -> torch.nn.ModuleList: def _find_layers_module(model) -> torch.nn.ModuleList:
"""Walk a few likely paths to find the transformer-block list.""" """Walk a few likely paths to find the transformer-block list."""
candidates = [ candidates = [
@ -139,25 +87,143 @@ def _find_layers_module(model) -> torch.nn.ModuleList:
) )
def _collect_activations(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
"""
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
def hook(_mod, _inp, output):
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
layers_module = _find_layers_module(model)
handles = [
layers_module[idx].register_forward_hook(make_hook(idx))
for idx in target_layers
]
out_rows: list[torch.Tensor] = []
try:
model.eval()
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in target_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if (i // batch_size) % 10 == 0:
torch.cuda.empty_cache()
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0)
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
]:
"""Return ``(positives_by_emotion, baselines)``.
Cross-concept negatives are computed at training time from
``positives_by_emotion`` each emotion's negative set is the
union of all other emotions' positives plus the baseline texts.
"""
positives: dict[str, list[str]] = {}
for story_path in sorted(stories_dir.glob("*.txt")):
emotion = story_path.stem
positives.setdefault(emotion, []).append(
story_path.read_text().strip()
)
baselines: list[str] = []
if paired_dir is not None and paired_dir.exists():
for scenario_dir in sorted(paired_dir.iterdir()):
if not scenario_dir.is_dir():
continue
baseline_path = scenario_dir / "baseline.txt"
if baseline_path.exists():
baselines.append(baseline_path.read_text().strip())
for framing_path in sorted(scenario_dir.glob("*.txt")):
if framing_path.stem == "baseline":
continue
emotion = framing_path.stem
positives.setdefault(emotion, []).append(
framing_path.read_text().strip()
)
return positives, baselines
def main() -> None: def main() -> None:
ap = argparse.ArgumentParser(description=__doc__) ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", required=True, help="HF model id or path") ap.add_argument("--model", required=True, help="HF model id or path")
ap.add_argument("--training-data-dir", required=True)
ap.add_argument( ap.add_argument(
"--target-layers", required=True, "--stories-dir",
help="Comma-separated layer indices, e.g. 3,18,33,36", required=True,
help="Path to amygdala_stories/stories/",
)
ap.add_argument(
"--paired-dir",
default=None,
help="Path to amygdala_stories/paired/ (optional)",
)
ap.add_argument(
"--target-layers",
required=True,
help="Comma-separated layer indices, e.g. 40,50,60,70",
)
ap.add_argument(
"--output-dir",
required=True,
help="Directory to write readout.safetensors + readout.json",
) )
ap.add_argument("--output", required=True)
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
ap.add_argument("--batch-size", type=int, default=4) ap.add_argument("--batch-size", type=int, default=2)
ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--max-length", type=int, default=512)
ap.add_argument("--device", default="cuda:0") ap.add_argument("--device", default="cuda:0")
ap.add_argument(
"--min-positives",
type=int,
default=1,
help="Skip emotions with fewer positive examples than this",
)
args = ap.parse_args() args = ap.parse_args()
target_layers = [int(x) for x in args.target_layers.split(",")] target_layers = [int(x) for x in args.target_layers.split(",")]
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ dtype = {
args.dtype "bf16": torch.bfloat16,
] "fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
print(f"Loading {args.model} ({args.dtype}) on {args.device}...") print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
@ -170,79 +236,138 @@ def main() -> None:
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
hidden_dim = model.config.hidden_size hidden_dim = model.config.hidden_size
print(f"Model loaded. hidden_dim={hidden_dim}, " n_model_layers = model.config.num_hidden_layers
f"n_layers={model.config.num_hidden_layers}") print(
f"Model loaded. hidden_dim={hidden_dim}, "
manifest_path = Path(args.training_data_dir) / "_manifest.json" f"n_model_layers={n_model_layers}"
manifest = json.loads(manifest_path.read_text())
emotions = sorted(manifest["emotions"].keys())
print(f"Training {len(emotions)} emotions: {emotions}")
n_emotions = len(emotions)
n_layers = len(target_layers)
vectors = torch.zeros(
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
) )
for layer_idx in target_layers:
if layer_idx < 0 or layer_idx >= n_model_layers:
raise ValueError(
f"target layer {layer_idx} out of range "
f"[0, {n_model_layers})"
)
positives_by_emotion, baselines = _load_corpus(
Path(args.stories_dir),
Path(args.paired_dir) if args.paired_dir else None,
)
emotions = sorted(
e for e, ps in positives_by_emotion.items()
if len(ps) >= args.min_positives
)
if not emotions:
raise RuntimeError(
f"No emotions with >= {args.min_positives} positive examples"
)
print(
f"Training {len(emotions)} emotions; "
f"{len(baselines)} baseline scenarios"
)
# Cache all positive-text activations once so we can reuse them as
# negatives for other emotions. Keyed by the text itself to dedup
# across emotion lists.
device = torch.device(args.device) device = torch.device(args.device)
text_to_emotion: dict[str, str] = {}
for emotion, texts in positives_by_emotion.items():
for t in texts:
text_to_emotion[t] = emotion
unique_positive_texts = list(text_to_emotion.keys())
print(
f"Collecting activations for {len(unique_positive_texts)} unique "
f"positive texts + {len(baselines)} baselines..."
)
positive_acts = _collect_activations(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length,
)
# positive_acts[i] corresponds to unique_positive_texts[i]
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
baseline_acts = (
_collect_activations(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length,
)
if baselines
else torch.zeros(0, len(target_layers), hidden_dim)
)
n_concepts = len(emotions)
n_layers = len(target_layers)
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
per_layer_vectors = torch.zeros(
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
)
for e_idx, emotion in enumerate(emotions): for e_idx, emotion in enumerate(emotions):
path = Path(args.training_data_dir) / f"{emotion}.jsonl" pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
pos_texts, neg_texts = [], [] # Negatives: every OTHER emotion's positives + baselines.
with open(path) as f: neg_rows = [
for line in f: i
ex = json.loads(line) for i, t in enumerate(unique_positive_texts)
if ex["polarity"] == "positive": if text_to_emotion[t] != emotion
pos_texts.append(ex["text"]) ]
else:
neg_texts.append(ex["text"])
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
pos_acts = _collect_activations( pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
model, tokenizer, pos_texts, target_layers, device, neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
args.batch_size, args.max_length, if baseline_acts.shape[0] > 0:
) neg = torch.cat([neg, baseline_acts], dim=0)
neg_acts = _collect_activations(
model, tokenizer, neg_texts, target_layers, device,
args.batch_size, args.max_length,
)
# Difference of means per layer pos_mean = pos.mean(dim=0) # [n_layers, hidden]
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden] neg_mean = neg.mean(dim=0)
neg_mean = neg_acts.mean(dim=0)
diff = pos_mean - neg_mean diff = pos_mean - neg_mean
# Normalize per layer so projections are scale-comparable
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms diff = diff / norms
vectors[e_idx] = diff # diff[layer] -> per_layer_vectors[layer, e_idx]
del pos_acts, neg_acts for l_idx in range(n_layers):
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
if e_idx < 5 or e_idx == len(emotions) - 1:
print(
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tensors = {
f"layer_{target_layers[l_idx]}.vectors": (
per_layer_vectors[l_idx].to(torch.float16)
)
for l_idx in range(n_layers)
}
safetensors.torch.save_file(
tensors,
str(output_dir / "readout.safetensors"),
)
manifest = {
"concepts": emotions,
"layers": target_layers,
"hidden_size": hidden_dim,
"dtype": "float16",
}
(output_dir / "readout.json").write_text(
json.dumps(manifest, indent=2) + "\n"
)
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
print(
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
f" {n_concepts} concepts x {n_layers} layers x "
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
)
del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Save in AmygdalaConnector format.
# emotion_names as padded uint8 tensor
names_bytes = [e.encode("utf-8") for e in emotions]
max_len = max(len(b) for b in names_bytes)
padded = torch.tensor(
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
dtype=torch.uint8,
)
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
safetensors.torch.save_file(
{
"vectors": vectors.to(torch.float16),
"emotion_names": padded,
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
},
args.output,
)
print(f"\nWrote steering vectors to {args.output}: "
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
if __name__ == "__main__": if __name__ == "__main__":
main() main()