Review pass before running on b200. 27B model + 100+ story corpus means any misconfiguration costs real time; better to fail before model load and give visible progress during forwards. * Pre-load-model validation: stories-dir and paired-dir exist, corpus has >= min_positives emotions. * Per-batch progress log every 5 batches with elapsed + ETA. * Relative depth printed for target layers (e.g. "layer 40 (51%)"). * Skip empty .txt files with a warning rather than feeding the tokenizer an empty string. * Assert non-empty strings in _collect_activations. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
454 lines
15 KiB
Python
454 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Train concept-readout vectors via Contrastive Activation Addition.
|
|
|
|
Reads the hand-written story corpus at
|
|
``amygdala_stories/{stories,paired}/`` and produces the per-layer
|
|
safetensors file + sidecar JSON manifest that vLLM's ReadoutManager
|
|
loads at startup (``VLLM_READOUT_VECTORS`` / ``VLLM_READOUT_MANIFEST``).
|
|
|
|
Training data (cross-concept contrast):
|
|
|
|
positive for emotion E:
|
|
stories/E.txt
|
|
paired/<scenario>/E.txt (for each scenario that covers E)
|
|
|
|
negative for emotion E:
|
|
stories/<all other emotions>.txt
|
|
paired/<scenario>/baseline.txt (for each scenario)
|
|
|
|
Within-scenario paired stories are the highest-signal pairs (same
|
|
content, different concept framing); unpaired stories provide bulk
|
|
contrast across the 80 emotions we have written so far.
|
|
|
|
Pooling: last non-pad token. Matches how readout is consumed at decode
|
|
time (residual read at the sampler's query position).
|
|
|
|
Output:
|
|
|
|
readout.safetensors
|
|
layer_<idx>.vectors : fp16 (n_concepts, hidden_size) one per layer
|
|
readout.json
|
|
{
|
|
"concepts": [...],
|
|
"layers": [...],
|
|
"hidden_size": int,
|
|
"dtype": "float16"
|
|
}
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import gc
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
"""Pick the last non-pad token's hidden state per example.
|
|
|
|
hidden: [batch, seq, hidden_dim]
|
|
attention_mask: [batch, seq]
|
|
returns: [batch, hidden_dim]
|
|
"""
|
|
last_idx = attention_mask.sum(dim=1) - 1
|
|
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
|
return hidden[batch_idx, last_idx]
|
|
|
|
|
|
def _find_layers_module(model) -> torch.nn.ModuleList:
|
|
"""Walk a few likely paths to find the transformer-block list."""
|
|
candidates = [
|
|
"model.layers",
|
|
"model.model.layers",
|
|
"model.language_model.layers",
|
|
"model.language_model.model.layers",
|
|
"language_model.model.layers",
|
|
"transformer.h",
|
|
]
|
|
for path in candidates:
|
|
obj = model
|
|
ok = True
|
|
for part in path.split("."):
|
|
if not hasattr(obj, part):
|
|
ok = False
|
|
break
|
|
obj = getattr(obj, part)
|
|
if ok and isinstance(obj, torch.nn.ModuleList):
|
|
return obj
|
|
raise RuntimeError(
|
|
f"Couldn't find transformer layer list. Tried: {candidates}"
|
|
)
|
|
|
|
|
|
def _collect_activations(
|
|
model,
|
|
tokenizer,
|
|
texts: list[str],
|
|
target_layers: list[int],
|
|
device: torch.device,
|
|
batch_size: int,
|
|
max_length: int,
|
|
*,
|
|
label: str = "",
|
|
) -> torch.Tensor:
|
|
"""Run texts through the model, capture residual stream at target
|
|
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
|
"""
|
|
import time
|
|
|
|
assert all(isinstance(t, str) and t for t in texts), (
|
|
f"_collect_activations: empty or non-string text in {label!r}"
|
|
)
|
|
|
|
captures: dict[int, torch.Tensor] = {}
|
|
|
|
def make_hook(idx: int):
|
|
def hook(_mod, _inp, output):
|
|
hs = output[0] if isinstance(output, tuple) else output
|
|
captures[idx] = hs.detach()
|
|
return hook
|
|
|
|
layers_module = _find_layers_module(model)
|
|
handles = [
|
|
layers_module[idx].register_forward_hook(make_hook(idx))
|
|
for idx in target_layers
|
|
]
|
|
|
|
out_rows: list[torch.Tensor] = []
|
|
n_batches = (len(texts) + batch_size - 1) // batch_size
|
|
start = time.time()
|
|
try:
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
|
batch = texts[i : i + batch_size]
|
|
tok = tokenizer(
|
|
batch,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
).to(device)
|
|
captures.clear()
|
|
model(**tok)
|
|
|
|
per_layer = [
|
|
_pool_last(captures[idx], tok["attention_mask"])
|
|
.to(torch.float32)
|
|
.cpu()
|
|
for idx in target_layers
|
|
]
|
|
out_rows.append(torch.stack(per_layer, dim=1))
|
|
del tok, captures
|
|
if b_idx % 10 == 0:
|
|
torch.cuda.empty_cache()
|
|
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
|
elapsed = time.time() - start
|
|
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
|
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
|
print(
|
|
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
|
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
|
flush=True,
|
|
)
|
|
captures = {}
|
|
finally:
|
|
for h in handles:
|
|
h.remove()
|
|
|
|
return torch.cat(out_rows, dim=0)
|
|
|
|
|
|
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
|
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
|
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
|
]:
|
|
"""Return ``(positives_by_emotion, baselines)``.
|
|
|
|
Cross-concept negatives are computed at training time from
|
|
``positives_by_emotion`` — each emotion's negative set is the
|
|
union of all other emotions' positives plus the baseline texts.
|
|
Empty .txt files are skipped with a warning.
|
|
"""
|
|
def _read_nonempty(path: Path) -> str | None:
|
|
text = path.read_text().strip()
|
|
if not text:
|
|
print(
|
|
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
|
|
)
|
|
return None
|
|
return text
|
|
|
|
positives: dict[str, list[str]] = {}
|
|
for story_path in sorted(stories_dir.glob("*.txt")):
|
|
text = _read_nonempty(story_path)
|
|
if text is None:
|
|
continue
|
|
emotion = story_path.stem
|
|
positives.setdefault(emotion, []).append(text)
|
|
|
|
baselines: list[str] = []
|
|
if paired_dir is not None and paired_dir.exists():
|
|
for scenario_dir in sorted(paired_dir.iterdir()):
|
|
if not scenario_dir.is_dir():
|
|
continue
|
|
baseline_path = scenario_dir / "baseline.txt"
|
|
if baseline_path.exists():
|
|
text = _read_nonempty(baseline_path)
|
|
if text is not None:
|
|
baselines.append(text)
|
|
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
|
if framing_path.stem == "baseline":
|
|
continue
|
|
text = _read_nonempty(framing_path)
|
|
if text is None:
|
|
continue
|
|
emotion = framing_path.stem
|
|
positives.setdefault(emotion, []).append(text)
|
|
|
|
return positives, baselines
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("--model", required=True, help="HF model id or path")
|
|
ap.add_argument(
|
|
"--stories-dir",
|
|
required=True,
|
|
help="Path to amygdala_stories/stories/",
|
|
)
|
|
ap.add_argument(
|
|
"--paired-dir",
|
|
default=None,
|
|
help="Path to amygdala_stories/paired/ (optional)",
|
|
)
|
|
ap.add_argument(
|
|
"--target-layers",
|
|
required=True,
|
|
help="Comma-separated layer indices, e.g. 40,50,60,70",
|
|
)
|
|
ap.add_argument(
|
|
"--output-dir",
|
|
required=True,
|
|
help="Directory to write readout.safetensors + readout.json",
|
|
)
|
|
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
|
ap.add_argument("--batch-size", type=int, default=2)
|
|
ap.add_argument("--max-length", type=int, default=512)
|
|
ap.add_argument("--device", default="cuda:0")
|
|
ap.add_argument(
|
|
"--min-positives",
|
|
type=int,
|
|
default=1,
|
|
help="Skip emotions with fewer positive examples than this",
|
|
)
|
|
args = ap.parse_args()
|
|
|
|
target_layers = [int(x) for x in args.target_layers.split(",")]
|
|
dtype = {
|
|
"bf16": torch.bfloat16,
|
|
"fp16": torch.float16,
|
|
"fp32": torch.float32,
|
|
}[args.dtype]
|
|
|
|
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
|
|
stories_dir = Path(args.stories_dir)
|
|
if not stories_dir.is_dir():
|
|
raise FileNotFoundError(
|
|
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
|
|
)
|
|
if args.paired_dir is not None:
|
|
pd = Path(args.paired_dir)
|
|
if not pd.is_dir():
|
|
raise FileNotFoundError(
|
|
f"--paired-dir {pd!s} does not exist or is not a dir"
|
|
)
|
|
|
|
# Quick corpus pre-scan so failures show up before we load the model.
|
|
positives_preview, baselines_preview = _load_corpus(
|
|
stories_dir,
|
|
Path(args.paired_dir) if args.paired_dir else None,
|
|
)
|
|
n_emotions_preview = sum(
|
|
1 for ps in positives_preview.values()
|
|
if len(ps) >= args.min_positives
|
|
)
|
|
if n_emotions_preview == 0:
|
|
raise RuntimeError(
|
|
f"corpus has 0 emotions with >= {args.min_positives} positive "
|
|
f"examples. Check {stories_dir} — is it the right directory?"
|
|
)
|
|
print(
|
|
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
|
|
f"{args.min_positives}), {len(baselines_preview)} baselines"
|
|
)
|
|
|
|
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model,
|
|
torch_dtype=dtype,
|
|
device_map=args.device,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
|
|
# dimensions under a text_config subobject. get_text_config()
|
|
# returns that sub-config when present, else the top-level config.
|
|
text_config = (
|
|
model.config.get_text_config()
|
|
if hasattr(model.config, "get_text_config")
|
|
else model.config
|
|
)
|
|
hidden_dim = text_config.hidden_size
|
|
n_model_layers = text_config.num_hidden_layers
|
|
print(
|
|
f"Model loaded. hidden_dim={hidden_dim}, "
|
|
f"n_model_layers={n_model_layers} "
|
|
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
|
|
)
|
|
|
|
for layer_idx in target_layers:
|
|
if layer_idx < 0 or layer_idx >= n_model_layers:
|
|
raise ValueError(
|
|
f"target layer {layer_idx} out of range "
|
|
f"[0, {n_model_layers})"
|
|
)
|
|
print(
|
|
"Target layers (relative depth): "
|
|
+ ", ".join(
|
|
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
|
|
for l in target_layers
|
|
)
|
|
)
|
|
|
|
positives_by_emotion, baselines = _load_corpus(
|
|
Path(args.stories_dir),
|
|
Path(args.paired_dir) if args.paired_dir else None,
|
|
)
|
|
emotions = sorted(
|
|
e for e, ps in positives_by_emotion.items()
|
|
if len(ps) >= args.min_positives
|
|
)
|
|
if not emotions:
|
|
raise RuntimeError(
|
|
f"No emotions with >= {args.min_positives} positive examples"
|
|
)
|
|
print(
|
|
f"Training {len(emotions)} emotions; "
|
|
f"{len(baselines)} baseline scenarios"
|
|
)
|
|
|
|
# Cache all positive-text activations once so we can reuse them as
|
|
# negatives for other emotions. Keyed by the text itself to dedup
|
|
# across emotion lists.
|
|
device = torch.device(args.device)
|
|
text_to_emotion: dict[str, str] = {}
|
|
for emotion, texts in positives_by_emotion.items():
|
|
for t in texts:
|
|
text_to_emotion[t] = emotion
|
|
|
|
unique_positive_texts = list(text_to_emotion.keys())
|
|
print(
|
|
f"Collecting activations for {len(unique_positive_texts)} unique "
|
|
f"positive texts + {len(baselines)} baselines..."
|
|
)
|
|
|
|
positive_acts = _collect_activations(
|
|
model, tokenizer, unique_positive_texts, target_layers, device,
|
|
args.batch_size, args.max_length, label="positives",
|
|
)
|
|
# positive_acts[i] corresponds to unique_positive_texts[i]
|
|
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
|
|
|
baseline_acts = (
|
|
_collect_activations(
|
|
model, tokenizer, baselines, target_layers, device,
|
|
args.batch_size, args.max_length, label="baselines",
|
|
)
|
|
if baselines
|
|
else torch.zeros(0, len(target_layers), hidden_dim)
|
|
)
|
|
|
|
n_concepts = len(emotions)
|
|
n_layers = len(target_layers)
|
|
|
|
# Per-layer output matrices. Shape (n_concepts, hidden_size) each.
|
|
per_layer_vectors = torch.zeros(
|
|
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
|
)
|
|
|
|
for e_idx, emotion in enumerate(emotions):
|
|
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
|
# Negatives: every OTHER emotion's positives + baselines.
|
|
neg_rows = [
|
|
i
|
|
for i, t in enumerate(unique_positive_texts)
|
|
if text_to_emotion[t] != emotion
|
|
]
|
|
|
|
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
|
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
|
if baseline_acts.shape[0] > 0:
|
|
neg = torch.cat([neg, baseline_acts], dim=0)
|
|
|
|
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
|
|
neg_mean = neg.mean(dim=0)
|
|
diff = pos_mean - neg_mean
|
|
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
|
diff = diff / norms
|
|
|
|
# diff[layer] -> per_layer_vectors[layer, e_idx]
|
|
for l_idx in range(n_layers):
|
|
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
|
|
|
|
if e_idx < 5 or e_idx == len(emotions) - 1:
|
|
print(
|
|
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
|
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
|
)
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
tensors = {
|
|
f"layer_{target_layers[l_idx]}.vectors": (
|
|
per_layer_vectors[l_idx].to(torch.float16)
|
|
)
|
|
for l_idx in range(n_layers)
|
|
}
|
|
safetensors.torch.save_file(
|
|
tensors,
|
|
str(output_dir / "readout.safetensors"),
|
|
)
|
|
manifest = {
|
|
"concepts": emotions,
|
|
"layers": target_layers,
|
|
"hidden_size": hidden_dim,
|
|
"dtype": "float16",
|
|
}
|
|
(output_dir / "readout.json").write_text(
|
|
json.dumps(manifest, indent=2) + "\n"
|
|
)
|
|
|
|
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
|
|
print(
|
|
f"\nWrote readout.safetensors + readout.json to {output_dir}\n"
|
|
f" {n_concepts} concepts x {n_layers} layers x "
|
|
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
|
|
)
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|