training: rewrite trainer for readout pipeline + story corpus
The old script was written for the AmygdalaConnector's expected
format ([n_emotions, n_target_layers, hidden_dim] in a single
tensor, plus a JSONL input format from extract_training_pairs.py).
Neither matches our current state: the runtime side is now
ReadoutManager loading per-layer safetensors keyed layer_<idx>.vectors,
and the data side is hand-written prose stories under
amygdala_stories/{stories,paired}/.
Changes:
* Input loader reads stories/<emotion>.txt and
paired/<scenario>/<emotion>.txt directly. Each emotion's positive
set is {its unpaired story} union {its within-scenario framings};
its negative set is {all other emotions' positives} union {all
scenario baselines}.
* Paired scenarios' baseline.txt files become shared negatives
(scenario-neutral prose that doesn't frame any particular
emotion), providing anchor points for within-scenario contrasts.
* Output writes readout.safetensors with per-layer tensors keyed
layer_<idx>.vectors shape (n_concepts, hidden_size), plus a
sidecar readout.json manifest with {concepts, layers, hidden_size,
dtype} that ReadoutManager.from_file consumes directly.
* Dedup: activations are computed once per unique text (an emotion's
own positive is another emotion's negative — we'd otherwise do N×
the forwards needed).
Preserved:
* _pool_last (last non-pad residual) — matches how readout is read
at decode time from the sampler's query-last position.
* register_forward_hook on target layer modules — correct approach
for transformer blocks.
* _find_layers_module traversal — mirrors ReadoutManager's.
* bf16 + low_cpu_mem_usage model load — sensible for 27B on B200.
Verified locally (CPU, fake activations):
* Loader finds 89 emotions from the current corpus (80 unpaired +
9 emotions that appear only in paired scenarios) and 6 baselines.
* Per-(layer, concept) vectors are unit-normalized.
* Output reloads cleanly through ReadoutManager.from_file with
matching concepts / layers / shapes.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
34bd122590
commit
15737dfd92
1 changed files with 276 additions and 151 deletions
|
|
@ -1,30 +1,48 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Train amygdala steering vectors via Contrastive Activation Addition.
|
||||
"""Train concept-readout vectors via Contrastive Activation Addition.
|
||||
|
||||
Reads the per-emotion JSONL files produced by extract_training_pairs.py,
|
||||
runs the target model over each example, captures the residual-stream
|
||||
hidden state at the configured target layers, and computes
|
||||
`mean(positive) - mean(negative)` as the steering direction per layer
|
||||
per emotion.
|
||||
Reads the hand-written story corpus at
|
||||
``amygdala_stories/{stories,paired}/`` and produces the per-layer
|
||||
safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
|
||||
loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
|
||||
|
||||
Output: a safetensors file matching the format AmygdalaConnector
|
||||
expects:
|
||||
Training data (cross-concept contrast):
|
||||
|
||||
vectors: [n_emotions, n_target_layers, hidden_dim] fp16
|
||||
emotion_names: [n_emotions] uint8
|
||||
positive for emotion E:
|
||||
stories/E.txt
|
||||
paired/<scenario>/E.txt (for each scenario that covers E)
|
||||
|
||||
Pooling: last-token residual-stream per example (CAA convention —
|
||||
the final token has seen the whole context and is where the model's
|
||||
"decision" lives). Alternative: mean across all tokens. The LAST
|
||||
convention is more common for steering vector work.
|
||||
negative for emotion E:
|
||||
stories/<all other emotions>.txt
|
||||
paired/<scenario>/baseline.txt (for each scenario)
|
||||
|
||||
Within-scenario paired stories are the highest-signal pairs (same
|
||||
content, different concept framing); unpaired stories provide bulk
|
||||
contrast across the 80 emotions we have written so far.
|
||||
|
||||
Pooling: last non-pad token. Matches how readout is consumed at decode
|
||||
time (residual read at the sampler's query position).
|
||||
|
||||
Output:
|
||||
|
||||
readout.safetensors
|
||||
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
|
||||
readout.json
|
||||
{
|
||||
"concepts": [...],
|
||||
"layers": [...],
|
||||
"hidden_size": int,
|
||||
"dtype": "float16"
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors.torch
|
||||
|
|
@ -39,81 +57,11 @@ def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tens
|
|||
attention_mask: [batch, seq]
|
||||
returns: [batch, hidden_dim]
|
||||
"""
|
||||
# last non-pad token index per row
|
||||
last_idx = attention_mask.sum(dim=1) - 1
|
||||
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
||||
return hidden[batch_idx, last_idx]
|
||||
|
||||
|
||||
def _collect_activations(
|
||||
model,
|
||||
tokenizer,
|
||||
texts: list[str],
|
||||
target_layers: list[int],
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""Run texts through the model, capture residual stream at target
|
||||
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
|
||||
"""
|
||||
# Register hooks on the target layers' outputs. We want the
|
||||
# residual stream AFTER each layer, which is the output of the
|
||||
# transformer block (hidden_states[layer_idx+1] in HF land).
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
|
||||
def make_hook(idx):
|
||||
def hook(_mod, _inp, output):
|
||||
# output is typically (hidden_states, ...) — take the first
|
||||
hs = output[0] if isinstance(output, tuple) else output
|
||||
captures[idx] = hs.detach()
|
||||
return hook
|
||||
|
||||
handles = []
|
||||
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
|
||||
# language_model.model.layers follows the same convention.
|
||||
# Resolve the layer list by walking common paths.
|
||||
layers_module = _find_layers_module(model)
|
||||
for idx in target_layers:
|
||||
handles.append(
|
||||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||
)
|
||||
|
||||
out_rows: list[torch.Tensor] = []
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
).to(device)
|
||||
captures.clear()
|
||||
model(**tok)
|
||||
|
||||
per_layer = []
|
||||
for idx in target_layers:
|
||||
hs = captures[idx] # [batch, seq, hidden]
|
||||
pooled = _pool_last(hs, tok["attention_mask"])
|
||||
per_layer.append(pooled.to(torch.float32).cpu())
|
||||
# Stack to [batch, n_layers, hidden_dim]
|
||||
batched = torch.stack(per_layer, dim=1)
|
||||
out_rows.append(batched)
|
||||
|
||||
del tok, captures
|
||||
if (i // batch_size) % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
|
||||
|
||||
|
||||
def _find_layers_module(model) -> torch.nn.ModuleList:
|
||||
"""Walk a few likely paths to find the transformer-block list."""
|
||||
candidates = [
|
||||
|
|
@ -139,25 +87,143 @@ def _find_layers_module(model) -> torch.nn.ModuleList:
|
|||
)
|
||||
|
||||
|
||||
def _collect_activations(
|
||||
model,
|
||||
tokenizer,
|
||||
texts: list[str],
|
||||
target_layers: list[int],
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""Run texts through the model, capture residual stream at target
|
||||
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
||||
"""
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
|
||||
def make_hook(idx: int):
|
||||
def hook(_mod, _inp, output):
|
||||
hs = output[0] if isinstance(output, tuple) else output
|
||||
captures[idx] = hs.detach()
|
||||
return hook
|
||||
|
||||
layers_module = _find_layers_module(model)
|
||||
handles = [
|
||||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||
for idx in target_layers
|
||||
]
|
||||
|
||||
out_rows: list[torch.Tensor] = []
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
).to(device)
|
||||
captures.clear()
|
||||
model(**tok)
|
||||
|
||||
per_layer = [
|
||||
_pool_last(captures[idx], tok["attention_mask"])
|
||||
.to(torch.float32)
|
||||
.cpu()
|
||||
for idx in target_layers
|
||||
]
|
||||
out_rows.append(torch.stack(per_layer, dim=1))
|
||||
del tok, captures
|
||||
if (i // batch_size) % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
captures = {}
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
return torch.cat(out_rows, dim=0)
|
||||
|
||||
|
||||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
||||
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
||||
]:
|
||||
"""Return ``(positives_by_emotion, baselines)``.
|
||||
|
||||
Cross-concept negatives are computed at training time from
|
||||
``positives_by_emotion`` — each emotion's negative set is the
|
||||
union of all other emotions' positives plus the baseline texts.
|
||||
"""
|
||||
positives: dict[str, list[str]] = {}
|
||||
for story_path in sorted(stories_dir.glob("*.txt")):
|
||||
emotion = story_path.stem
|
||||
positives.setdefault(emotion, []).append(
|
||||
story_path.read_text().strip()
|
||||
)
|
||||
|
||||
baselines: list[str] = []
|
||||
if paired_dir is not None and paired_dir.exists():
|
||||
for scenario_dir in sorted(paired_dir.iterdir()):
|
||||
if not scenario_dir.is_dir():
|
||||
continue
|
||||
baseline_path = scenario_dir / "baseline.txt"
|
||||
if baseline_path.exists():
|
||||
baselines.append(baseline_path.read_text().strip())
|
||||
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
||||
if framing_path.stem == "baseline":
|
||||
continue
|
||||
emotion = framing_path.stem
|
||||
positives.setdefault(emotion, []).append(
|
||||
framing_path.read_text().strip()
|
||||
)
|
||||
|
||||
return positives, baselines
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--model", required=True, help="HF model id or path")
|
||||
ap.add_argument("--training-data-dir", required=True)
|
||||
ap.add_argument(
|
||||
"--target-layers", required=True,
|
||||
help="Comma-separated layer indices, e.g. 3,18,33,36",
|
||||
"--stories-dir",
|
||||
required=True,
|
||||
help="Path to amygdala_stories/stories/",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--paired-dir",
|
||||
default=None,
|
||||
help="Path to amygdala_stories/paired/ (optional)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--target-layers",
|
||||
required=True,
|
||||
help="Comma-separated layer indices, e.g. 40,50,60,70",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--output-dir",
|
||||
required=True,
|
||||
help="Directory to write readout.safetensors + readout.json",
|
||||
)
|
||||
ap.add_argument("--output", required=True)
|
||||
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
ap.add_argument("--batch-size", type=int, default=4)
|
||||
ap.add_argument("--batch-size", type=int, default=2)
|
||||
ap.add_argument("--max-length", type=int, default=512)
|
||||
ap.add_argument("--device", default="cuda:0")
|
||||
ap.add_argument(
|
||||
"--min-positives",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Skip emotions with fewer positive examples than this",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
||||
args.dtype
|
||||
]
|
||||
dtype = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
|
|
@ -170,78 +236,137 @@ def main() -> None:
|
|||
low_cpu_mem_usage=True,
|
||||
)
|
||||
hidden_dim = model.config.hidden_size
|
||||
print(f"Model loaded. hidden_dim={hidden_dim}, "
|
||||
f"n_layers={model.config.num_hidden_layers}")
|
||||
|
||||
manifest_path = Path(args.training_data_dir) / "_manifest.json"
|
||||
manifest = json.loads(manifest_path.read_text())
|
||||
|
||||
emotions = sorted(manifest["emotions"].keys())
|
||||
print(f"Training {len(emotions)} emotions: {emotions}")
|
||||
|
||||
n_emotions = len(emotions)
|
||||
n_layers = len(target_layers)
|
||||
vectors = torch.zeros(
|
||||
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
|
||||
n_model_layers = model.config.num_hidden_layers
|
||||
print(
|
||||
f"Model loaded. hidden_dim={hidden_dim}, "
|
||||
f"n_model_layers={n_model_layers}"
|
||||
)
|
||||
|
||||
for layer_idx in target_layers:
|
||||
if layer_idx < 0 or layer_idx >= n_model_layers:
|
||||
raise ValueError(
|
||||
f"target layer {layer_idx} out of range "
|
||||
f"[0, {n_model_layers})"
|
||||
)
|
||||
|
||||
positives_by_emotion, baselines = _load_corpus(
|
||||
Path(args.stories_dir),
|
||||
Path(args.paired_dir) if args.paired_dir else None,
|
||||
)
|
||||
emotions = sorted(
|
||||
e for e, ps in positives_by_emotion.items()
|
||||
if len(ps) >= args.min_positives
|
||||
)
|
||||
if not emotions:
|
||||
raise RuntimeError(
|
||||
f"No emotions with >= {args.min_positives} positive examples"
|
||||
)
|
||||
print(
|
||||
f"Training {len(emotions)} emotions; "
|
||||
f"{len(baselines)} baseline scenarios"
|
||||
)
|
||||
|
||||
# Cache all positive-text activations once so we can reuse them as
|
||||
# negatives for other emotions. Keyed by the text itself to dedup
|
||||
# across emotion lists.
|
||||
device = torch.device(args.device)
|
||||
text_to_emotion: dict[str, str] = {}
|
||||
for emotion, texts in positives_by_emotion.items():
|
||||
for t in texts:
|
||||
text_to_emotion[t] = emotion
|
||||
|
||||
unique_positive_texts = list(text_to_emotion.keys())
|
||||
print(
|
||||
f"Collecting activations for {len(unique_positive_texts)} unique "
|
||||
f"positive texts + {len(baselines)} baselines..."
|
||||
)
|
||||
|
||||
positive_acts = _collect_activations(
|
||||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
# positive_acts[i] corresponds to unique_positive_texts[i]
|
||||
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
||||
|
||||
baseline_acts = (
|
||||
_collect_activations(
|
||||
model, tokenizer, baselines, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
if baselines
|
||||
else torch.zeros(0, len(target_layers), hidden_dim)
|
||||
)
|
||||
|
||||
n_concepts = len(emotions)
|
||||
n_layers = len(target_layers)
|
||||
|
||||
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
|
||||
per_layer_vectors = torch.zeros(
|
||||
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
||||
)
|
||||
|
||||
for e_idx, emotion in enumerate(emotions):
|
||||
path = Path(args.training_data_dir) / f"{emotion}.jsonl"
|
||||
pos_texts, neg_texts = [], []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
ex = json.loads(line)
|
||||
if ex["polarity"] == "positive":
|
||||
pos_texts.append(ex["text"])
|
||||
else:
|
||||
neg_texts.append(ex["text"])
|
||||
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
|
||||
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
|
||||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||
# Negatives: every OTHER emotion's positives + baselines.
|
||||
neg_rows = [
|
||||
i
|
||||
for i, t in enumerate(unique_positive_texts)
|
||||
if text_to_emotion[t] != emotion
|
||||
]
|
||||
|
||||
pos_acts = _collect_activations(
|
||||
model, tokenizer, pos_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
neg_acts = _collect_activations(
|
||||
model, tokenizer, neg_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
||||
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
||||
if baseline_acts.shape[0] > 0:
|
||||
neg = torch.cat([neg, baseline_acts], dim=0)
|
||||
|
||||
# Difference of means per layer
|
||||
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden]
|
||||
neg_mean = neg_acts.mean(dim=0)
|
||||
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
|
||||
neg_mean = neg.mean(dim=0)
|
||||
diff = pos_mean - neg_mean
|
||||
|
||||
# Normalize per layer so projections are scale-comparable
|
||||
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||
diff = diff / norms
|
||||
|
||||
vectors[e_idx] = diff
|
||||
del pos_acts, neg_acts
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# diff[layer] -> per_layer_vectors[layer, e_idx]
|
||||
for l_idx in range(n_layers):
|
||||
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
|
||||
|
||||
# Save in AmygdalaConnector format.
|
||||
# emotion_names as padded uint8 tensor
|
||||
names_bytes = [e.encode("utf-8") for e in emotions]
|
||||
max_len = max(len(b) for b in names_bytes)
|
||||
padded = torch.tensor(
|
||||
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
if e_idx < 5 or e_idx == len(emotions) - 1:
|
||||
print(
|
||||
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
||||
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tensors = {
|
||||
f"layer_{target_layers[l_idx]}.vectors": (
|
||||
per_layer_vectors[l_idx].to(torch.float16)
|
||||
)
|
||||
for l_idx in range(n_layers)
|
||||
}
|
||||
safetensors.torch.save_file(
|
||||
{
|
||||
"vectors": vectors.to(torch.float16),
|
||||
"emotion_names": padded,
|
||||
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
|
||||
},
|
||||
args.output,
|
||||
tensors,
|
||||
str(output_dir / "readout.safetensors"),
|
||||
)
|
||||
print(f"\nWrote steering vectors to {args.output}: "
|
||||
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
|
||||
manifest = {
|
||||
"concepts": emotions,
|
||||
"layers": target_layers,
|
||||
"hidden_size": hidden_dim,
|
||||
"dtype": "float16",
|
||||
}
|
||||
(output_dir / "readout.json").write_text(
|
||||
json.dumps(manifest, indent=2) + "\n"
|
||||
)
|
||||
|
||||
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
|
||||
print(
|
||||
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
|
||||
f" {n_concepts} concepts x {n_layers} layers x "
|
||||
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
|
||||
)
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue