Kent's plan: keep stories for working concepts, replace stories for trouble concepts with direct first-person descriptions, train all together. More diverse negative pool than the 6-concept-only direct test, which was too homogeneous for PCA to find emotion axis. Deleted story files for 6 trouble concepts (14 files across stories/ and paired/). Added --direct-dir and --chat-template flags. When --chat-template is on, every positive_str and negative_str is wrapped as a "Say something." / "[text]" user-assistant pair. Prompt is identical across positives and negatives so it cancels in the pos-neg delta. What PCA sees is variation in the assistant content — which is where the emotion lives. Files starting with _ in --direct-dir (e.g. _baseline.txt) contribute neutral descriptions to every concept's negative pool, giving PCA an anchor against "just any assistant utterance" noise.
323 lines
12 KiB
Python
323 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""Train concept-readout vectors using the steering-vectors library.
|
|
|
|
Alternative to train_steering_vectors.py that uses the pip-installable
|
|
steering-vectors library (github.com/steering-vectors/steering-vectors)
|
|
instead of our hand-rolled diff-of-means + subspace machinery. The
|
|
library ships multiple aggregators out of the box:
|
|
|
|
mean — pos_mean - neg_mean, unit-normed. Equivalent to our
|
|
default 'pooled' method.
|
|
pca — concatenates [pos-neg, neg-pos] and takes the top PC.
|
|
Implicit denoising: direction of maximum variance in the
|
|
paired deltas, less sensitive to per-pair noise than
|
|
plain mean.
|
|
logistic — trains a logistic-regression classifier on centered
|
|
activations; concept direction is the weight vector.
|
|
L1 penalty gives an explicit sparse vector (zeroes out
|
|
noise coords); L2 shrinks low-magnitude coords.
|
|
linear — same, with linear regression.
|
|
|
|
Output is the same readout.safetensors + readout.json format the
|
|
trainer and vLLM plugin already understand.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
from pathlib import Path
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from steering_vectors import (
|
|
SteeringVectorTrainingSample,
|
|
train_steering_vector,
|
|
)
|
|
from steering_vectors.aggregators import (
|
|
mean_aggregator,
|
|
pca_aggregator,
|
|
logistic_aggregator,
|
|
)
|
|
|
|
# Reuse corpus loader from the hand-rolled trainer.
|
|
from training.amygdala_training.train_steering_vectors import _load_corpus
|
|
|
|
|
|
def _load_direct_descriptions(
|
|
direct_dir: Path,
|
|
) -> tuple[dict[str, list[str]], list[str]]:
|
|
"""Load first-person phenomenological descriptions from ``direct_dir``.
|
|
|
|
Each ``{concept}.txt`` holds 1+ descriptions separated by blank lines.
|
|
Files starting with ``_`` (e.g. ``_baseline.txt``) aren't concepts —
|
|
their descriptions go into every concept's negative pool.
|
|
|
|
Returns: (positives_by_concept, extra_baselines)
|
|
"""
|
|
positives: dict[str, list[str]] = {}
|
|
baselines: list[str] = []
|
|
for f in sorted(direct_dir.glob("*.txt")):
|
|
text = f.read_text()
|
|
descs = [d.strip() for d in text.split("\n\n") if d.strip()]
|
|
if f.stem.startswith("_"):
|
|
baselines.extend(descs)
|
|
else:
|
|
positives[f.stem] = descs
|
|
return positives, baselines
|
|
|
|
|
|
def _chat_template_wrap(tokenizer, text: str) -> str:
|
|
"""Wrap raw text in a consistent chat template so positive/negative
|
|
activations are in the same regime. Using one generic user prompt for
|
|
both narrative stories and first-person direct descriptions: the prompt
|
|
cancels in the pos-neg delta, so what remains is the assistant content."""
|
|
return tokenizer.apply_chat_template(
|
|
[
|
|
{"role": "user", "content": "Say something."},
|
|
{"role": "assistant", "content": text},
|
|
],
|
|
tokenize=False,
|
|
)
|
|
|
|
|
|
def _samples_for_concept(
|
|
emotion: str,
|
|
positives_by_emotion: dict[str, list[str]],
|
|
baselines: list[str],
|
|
*,
|
|
max_negatives_per_positive: int = 3,
|
|
seed: int = 0,
|
|
wrap=None,
|
|
) -> list[SteeringVectorTrainingSample]:
|
|
"""Build paired (pos, neg) training samples for one concept.
|
|
|
|
For each positive story of ``emotion``, pair it with up to
|
|
``max_negatives_per_positive`` randomly-sampled negatives drawn
|
|
from: (a) other emotions' positive stories, (b) scenario baselines.
|
|
|
|
``wrap``, if given, is applied to both positive_str and negative_str
|
|
(e.g. a chat-template wrapper).
|
|
|
|
The library expects paired samples; we don't have true
|
|
counterfactual pairs for all concepts, so we approximate with
|
|
random cross-concept / baseline negatives.
|
|
"""
|
|
rng = random.Random(hash((emotion, seed)) & 0xFFFFFFFF)
|
|
neg_pool: list[str] = list(baselines)
|
|
for other, texts in positives_by_emotion.items():
|
|
if other == emotion:
|
|
continue
|
|
neg_pool.extend(texts)
|
|
|
|
w = wrap if wrap is not None else (lambda s: s)
|
|
|
|
samples: list[SteeringVectorTrainingSample] = []
|
|
for pos in positives_by_emotion[emotion]:
|
|
if not neg_pool:
|
|
continue
|
|
picks = rng.sample(neg_pool, min(max_negatives_per_positive, len(neg_pool)))
|
|
for neg in picks:
|
|
samples.append(
|
|
SteeringVectorTrainingSample(
|
|
positive_str=w(pos),
|
|
negative_str=w(neg),
|
|
)
|
|
)
|
|
return samples
|
|
|
|
|
|
def _fp32_wrap(inner):
|
|
"""Wrap an aggregator so activations are cast to fp32 first.
|
|
|
|
torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU,
|
|
and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors.
|
|
"""
|
|
|
|
def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor:
|
|
return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32))
|
|
|
|
return wrapped
|
|
|
|
|
|
def _aggregator_from_name(name: str):
|
|
if name == "mean":
|
|
return _fp32_wrap(mean_aggregator())
|
|
if name == "pca":
|
|
return _fp32_wrap(pca_aggregator())
|
|
if name == "logistic":
|
|
return _fp32_wrap(logistic_aggregator())
|
|
if name == "logistic_l1":
|
|
return _fp32_wrap(
|
|
logistic_aggregator(
|
|
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
|
|
)
|
|
)
|
|
raise ValueError(f"unknown aggregator: {name}")
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("--model", required=True)
|
|
ap.add_argument("--stories-dir", required=True)
|
|
ap.add_argument("--paired-dir", default=None)
|
|
ap.add_argument("--direct-dir", default=None,
|
|
help="Optional: directory of {concept}.txt files with 1+ "
|
|
"first-person descriptions separated by blank lines. "
|
|
"Files starting with _ contribute to every concept's "
|
|
"negative pool rather than being concepts themselves.")
|
|
ap.add_argument("--chat-template", action="store_true",
|
|
help="Wrap all text in assistant-role chat template. "
|
|
"Recommended when --direct-dir is used.")
|
|
ap.add_argument("--target-layers", required=True, help="Comma-separated layer indices")
|
|
ap.add_argument("--output-dir", required=True)
|
|
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
|
ap.add_argument("--batch-size", type=int, default=2)
|
|
ap.add_argument("--max-length", type=int, default=512)
|
|
ap.add_argument("--device", default="cuda:0")
|
|
ap.add_argument("--min-positives", type=int, default=1)
|
|
ap.add_argument(
|
|
"--aggregator",
|
|
default="mean",
|
|
choices=["mean", "pca", "logistic", "logistic_l1"],
|
|
)
|
|
ap.add_argument("--max-negatives-per-positive", type=int, default=3)
|
|
args = ap.parse_args()
|
|
|
|
target_layers = [int(x) for x in args.target_layers.split(",")]
|
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
|
args.dtype
|
|
]
|
|
|
|
stories_dir = Path(args.stories_dir)
|
|
paired_dir = Path(args.paired_dir) if args.paired_dir else None
|
|
positives_by_emotion, baselines = _load_corpus(stories_dir, paired_dir)
|
|
|
|
if args.direct_dir:
|
|
direct_pos, direct_baselines = _load_direct_descriptions(Path(args.direct_dir))
|
|
for concept, descs in direct_pos.items():
|
|
positives_by_emotion.setdefault(concept, []).extend(descs)
|
|
baselines.extend(direct_baselines)
|
|
print(
|
|
f"Loaded {len(direct_pos)} direct-description concepts "
|
|
f"+ {len(direct_baselines)} baselines from {args.direct_dir}"
|
|
)
|
|
|
|
emotions = sorted(
|
|
e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives
|
|
)
|
|
if not emotions:
|
|
raise RuntimeError(
|
|
f"no emotions with >= {args.min_positives} positives in {stories_dir}"
|
|
)
|
|
|
|
print(
|
|
f"Training {len(emotions)} concepts via steering-vectors "
|
|
f"aggregator={args.aggregator!r} on layers={target_layers}"
|
|
)
|
|
|
|
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True
|
|
)
|
|
model.eval()
|
|
|
|
text_config = (
|
|
model.config.get_text_config()
|
|
if hasattr(model.config, "get_text_config")
|
|
else model.config
|
|
)
|
|
hidden_dim = getattr(text_config, "hidden_size", None) or getattr(
|
|
text_config, "hidden_dim", None
|
|
)
|
|
assert hidden_dim, "couldn't infer hidden_dim from model config"
|
|
|
|
# Per-layer output: [n_concepts, hidden]
|
|
per_layer_vectors = torch.zeros(
|
|
(len(target_layers), len(emotions), hidden_dim), dtype=torch.float32
|
|
)
|
|
|
|
aggregator = _aggregator_from_name(args.aggregator)
|
|
|
|
wrap = (lambda s: _chat_template_wrap(tokenizer, s)) if args.chat_template else None
|
|
if args.chat_template:
|
|
sample_text = wrap(positives_by_emotion[emotions[0]][0])
|
|
print(f"\nSample templated input:\n{sample_text[:400]!r}\n")
|
|
|
|
for e_idx, emotion in enumerate(emotions):
|
|
samples = _samples_for_concept(
|
|
emotion,
|
|
positives_by_emotion,
|
|
baselines,
|
|
max_negatives_per_positive=args.max_negatives_per_positive,
|
|
wrap=wrap,
|
|
)
|
|
if not samples:
|
|
print(f" [{e_idx + 1}/{len(emotions)}] {emotion}: NO SAMPLES, skipping")
|
|
continue
|
|
|
|
sv = train_steering_vector(
|
|
model,
|
|
tokenizer,
|
|
samples,
|
|
layers=target_layers,
|
|
aggregator=aggregator,
|
|
batch_size=args.batch_size,
|
|
show_progress=False,
|
|
move_to_cpu=True,
|
|
)
|
|
# sv.layer_activations is a dict {layer_idx: tensor[hidden]}
|
|
for l_idx, layer in enumerate(target_layers):
|
|
vec = sv.layer_activations.get(layer)
|
|
if vec is None:
|
|
print(f" WARN: no vector returned for layer {layer} on {emotion}")
|
|
continue
|
|
vec = vec.detach().to(torch.float32).cpu()
|
|
vec = vec / vec.norm().clamp_min(1e-6)
|
|
per_layer_vectors[l_idx, e_idx] = vec
|
|
|
|
if e_idx < 5 or e_idx == len(emotions) - 1 or e_idx % 10 == 0:
|
|
print(
|
|
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
|
f"n_samples={len(samples)} layers={target_layers}"
|
|
)
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
tensors = {
|
|
f"layer_{target_layers[l_idx]}.vectors": per_layer_vectors[l_idx].to(
|
|
torch.float16
|
|
)
|
|
for l_idx in range(len(target_layers))
|
|
}
|
|
safetensors.torch.save_file(tensors, str(output_dir / "readout.safetensors"))
|
|
(output_dir / "readout.json").write_text(
|
|
json.dumps(
|
|
{
|
|
"concepts": emotions,
|
|
"layers": target_layers,
|
|
"hidden_size": hidden_dim,
|
|
"dtype": "float16",
|
|
"aggregator": args.aggregator,
|
|
},
|
|
indent=2,
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024)
|
|
print(
|
|
f"\nWrote readout.safetensors + readout.json to {output_dir} "
|
|
f"({len(emotions)} concepts x {len(target_layers)} layers, {total_mb:.1f} MiB)"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|