consciousness/training/amygdala_training/train_steering_vectors.py

249 lines
8.5 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Train amygdala steering vectors via Contrastive Activation Addition.
Reads the per-emotion JSONL files produced by extract_training_pairs.py,
runs the target model over each example, captures the residual-stream
hidden state at the configured target layers, and computes
`mean(positive) - mean(negative)` as the steering direction per layer
per emotion.
Output: a safetensors file matching the format AmygdalaConnector
expects:
vectors: [n_emotions, n_target_layers, hidden_dim] fp16
emotion_names: [n_emotions] uint8
Pooling: last-token residual-stream per example (CAA convention
the final token has seen the whole context and is where the model's
"decision" lives). Alternative: mean across all tokens. The LAST
convention is more common for steering vector work.
"""
import argparse
import gc
import json
import os
from collections import defaultdict
from pathlib import Path
import safetensors.torch
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Pick the last non-pad token's hidden state per example.
hidden: [batch, seq, hidden_dim]
attention_mask: [batch, seq]
returns: [batch, hidden_dim]
"""
# last non-pad token index per row
last_idx = attention_mask.sum(dim=1) - 1
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_idx, last_idx]
def _collect_activations(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
"""
# Register hooks on the target layers' outputs. We want the
# residual stream AFTER each layer, which is the output of the
# transformer block (hidden_states[layer_idx+1] in HF land).
captures: dict[int, torch.Tensor] = {}
def make_hook(idx):
def hook(_mod, _inp, output):
# output is typically (hidden_states, ...) — take the first
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
handles = []
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
# language_model.model.layers follows the same convention.
# Resolve the layer list by walking common paths.
layers_module = _find_layers_module(model)
for idx in target_layers:
handles.append(
layers_module[idx].register_forward_hook(make_hook(idx))
)
out_rows: list[torch.Tensor] = []
try:
model.eval()
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = []
for idx in target_layers:
hs = captures[idx] # [batch, seq, hidden]
pooled = _pool_last(hs, tok["attention_mask"])
per_layer.append(pooled.to(torch.float32).cpu())
# Stack to [batch, n_layers, hidden_dim]
batched = torch.stack(per_layer, dim=1)
out_rows.append(batched)
del tok, captures
if (i // batch_size) % 10 == 0:
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
def _find_layers_module(model) -> torch.nn.ModuleList:
"""Walk a few likely paths to find the transformer-block list."""
candidates = [
"model.layers",
"model.model.layers",
"model.language_model.layers",
"model.language_model.model.layers",
"language_model.model.layers",
"transformer.h",
]
for path in candidates:
obj = model
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok and isinstance(obj, torch.nn.ModuleList):
return obj
raise RuntimeError(
f"Couldn't find transformer layer list. Tried: {candidates}"
)
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", required=True, help="HF model id or path")
ap.add_argument("--training-data-dir", required=True)
ap.add_argument(
"--target-layers", required=True,
help="Comma-separated layer indices, e.g. 3,18,33,36",
)
ap.add_argument("--output", required=True)
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
ap.add_argument("--batch-size", type=int, default=4)
ap.add_argument("--max-length", type=int, default=512)
ap.add_argument("--device", default="cuda:0")
args = ap.parse_args()
target_layers = [int(x) for x in args.target_layers.split(",")]
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
args.dtype
]
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
device_map=args.device,
low_cpu_mem_usage=True,
)
hidden_dim = model.config.hidden_size
print(f"Model loaded. hidden_dim={hidden_dim}, "
f"n_layers={model.config.num_hidden_layers}")
manifest_path = Path(args.training_data_dir) / "_manifest.json"
manifest = json.loads(manifest_path.read_text())
emotions = sorted(manifest["emotions"].keys())
print(f"Training {len(emotions)} emotions: {emotions}")
n_emotions = len(emotions)
n_layers = len(target_layers)
vectors = torch.zeros(
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
)
device = torch.device(args.device)
for e_idx, emotion in enumerate(emotions):
path = Path(args.training_data_dir) / f"{emotion}.jsonl"
pos_texts, neg_texts = [], []
with open(path) as f:
for line in f:
ex = json.loads(line)
if ex["polarity"] == "positive":
pos_texts.append(ex["text"])
else:
neg_texts.append(ex["text"])
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
pos_acts = _collect_activations(
model, tokenizer, pos_texts, target_layers, device,
args.batch_size, args.max_length,
)
neg_acts = _collect_activations(
model, tokenizer, neg_texts, target_layers, device,
args.batch_size, args.max_length,
)
# Difference of means per layer
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden]
neg_mean = neg_acts.mean(dim=0)
diff = pos_mean - neg_mean
# Normalize per layer so projections are scale-comparable
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms
vectors[e_idx] = diff
del pos_acts, neg_acts
gc.collect()
torch.cuda.empty_cache()
# Save in AmygdalaConnector format.
# emotion_names as padded uint8 tensor
names_bytes = [e.encode("utf-8") for e in emotions]
max_len = max(len(b) for b in names_bytes)
padded = torch.tensor(
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
dtype=torch.uint8,
)
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
safetensors.torch.save_file(
{
"vectors": vectors.to(torch.float16),
"emotion_names": padded,
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
},
args.output,
)
print(f"\nWrote steering vectors to {args.output}: "
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
if __name__ == "__main__":
main()