249 lines
8.5 KiB
Python
249 lines
8.5 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
"""Train amygdala steering vectors via Contrastive Activation Addition.
|
||
|
|
|
||
|
|
Reads the per-emotion JSONL files produced by extract_training_pairs.py,
|
||
|
|
runs the target model over each example, captures the residual-stream
|
||
|
|
hidden state at the configured target layers, and computes
|
||
|
|
`mean(positive) - mean(negative)` as the steering direction per layer
|
||
|
|
per emotion.
|
||
|
|
|
||
|
|
Output: a safetensors file matching the format AmygdalaConnector
|
||
|
|
expects:
|
||
|
|
|
||
|
|
vectors: [n_emotions, n_target_layers, hidden_dim] fp16
|
||
|
|
emotion_names: [n_emotions] uint8
|
||
|
|
|
||
|
|
Pooling: last-token residual-stream per example (CAA convention —
|
||
|
|
the final token has seen the whole context and is where the model's
|
||
|
|
"decision" lives). Alternative: mean across all tokens. The LAST
|
||
|
|
convention is more common for steering vector work.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import gc
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from collections import defaultdict
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import safetensors.torch
|
||
|
|
import torch
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
|
||
|
|
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||
|
|
"""Pick the last non-pad token's hidden state per example.
|
||
|
|
|
||
|
|
hidden: [batch, seq, hidden_dim]
|
||
|
|
attention_mask: [batch, seq]
|
||
|
|
returns: [batch, hidden_dim]
|
||
|
|
"""
|
||
|
|
# last non-pad token index per row
|
||
|
|
last_idx = attention_mask.sum(dim=1) - 1
|
||
|
|
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
||
|
|
return hidden[batch_idx, last_idx]
|
||
|
|
|
||
|
|
|
||
|
|
def _collect_activations(
|
||
|
|
model,
|
||
|
|
tokenizer,
|
||
|
|
texts: list[str],
|
||
|
|
target_layers: list[int],
|
||
|
|
device: torch.device,
|
||
|
|
batch_size: int,
|
||
|
|
max_length: int,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""Run texts through the model, capture residual stream at target
|
||
|
|
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
|
||
|
|
"""
|
||
|
|
# Register hooks on the target layers' outputs. We want the
|
||
|
|
# residual stream AFTER each layer, which is the output of the
|
||
|
|
# transformer block (hidden_states[layer_idx+1] in HF land).
|
||
|
|
captures: dict[int, torch.Tensor] = {}
|
||
|
|
|
||
|
|
def make_hook(idx):
|
||
|
|
def hook(_mod, _inp, output):
|
||
|
|
# output is typically (hidden_states, ...) — take the first
|
||
|
|
hs = output[0] if isinstance(output, tuple) else output
|
||
|
|
captures[idx] = hs.detach()
|
||
|
|
return hook
|
||
|
|
|
||
|
|
handles = []
|
||
|
|
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
|
||
|
|
# language_model.model.layers follows the same convention.
|
||
|
|
# Resolve the layer list by walking common paths.
|
||
|
|
layers_module = _find_layers_module(model)
|
||
|
|
for idx in target_layers:
|
||
|
|
handles.append(
|
||
|
|
layers_module[idx].register_forward_hook(make_hook(idx))
|
||
|
|
)
|
||
|
|
|
||
|
|
out_rows: list[torch.Tensor] = []
|
||
|
|
try:
|
||
|
|
model.eval()
|
||
|
|
with torch.no_grad():
|
||
|
|
for i in range(0, len(texts), batch_size):
|
||
|
|
batch = texts[i : i + batch_size]
|
||
|
|
tok = tokenizer(
|
||
|
|
batch,
|
||
|
|
return_tensors="pt",
|
||
|
|
padding=True,
|
||
|
|
truncation=True,
|
||
|
|
max_length=max_length,
|
||
|
|
).to(device)
|
||
|
|
captures.clear()
|
||
|
|
model(**tok)
|
||
|
|
|
||
|
|
per_layer = []
|
||
|
|
for idx in target_layers:
|
||
|
|
hs = captures[idx] # [batch, seq, hidden]
|
||
|
|
pooled = _pool_last(hs, tok["attention_mask"])
|
||
|
|
per_layer.append(pooled.to(torch.float32).cpu())
|
||
|
|
# Stack to [batch, n_layers, hidden_dim]
|
||
|
|
batched = torch.stack(per_layer, dim=1)
|
||
|
|
out_rows.append(batched)
|
||
|
|
|
||
|
|
del tok, captures
|
||
|
|
if (i // batch_size) % 10 == 0:
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
finally:
|
||
|
|
for h in handles:
|
||
|
|
h.remove()
|
||
|
|
|
||
|
|
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
|
||
|
|
|
||
|
|
|
||
|
|
def _find_layers_module(model) -> torch.nn.ModuleList:
|
||
|
|
"""Walk a few likely paths to find the transformer-block list."""
|
||
|
|
candidates = [
|
||
|
|
"model.layers",
|
||
|
|
"model.model.layers",
|
||
|
|
"model.language_model.layers",
|
||
|
|
"model.language_model.model.layers",
|
||
|
|
"language_model.model.layers",
|
||
|
|
"transformer.h",
|
||
|
|
]
|
||
|
|
for path in candidates:
|
||
|
|
obj = model
|
||
|
|
ok = True
|
||
|
|
for part in path.split("."):
|
||
|
|
if not hasattr(obj, part):
|
||
|
|
ok = False
|
||
|
|
break
|
||
|
|
obj = getattr(obj, part)
|
||
|
|
if ok and isinstance(obj, torch.nn.ModuleList):
|
||
|
|
return obj
|
||
|
|
raise RuntimeError(
|
||
|
|
f"Couldn't find transformer layer list. Tried: {candidates}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
||
|
|
ap.add_argument("--model", required=True, help="HF model id or path")
|
||
|
|
ap.add_argument("--training-data-dir", required=True)
|
||
|
|
ap.add_argument(
|
||
|
|
"--target-layers", required=True,
|
||
|
|
help="Comma-separated layer indices, e.g. 3,18,33,36",
|
||
|
|
)
|
||
|
|
ap.add_argument("--output", required=True)
|
||
|
|
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||
|
|
ap.add_argument("--batch-size", type=int, default=4)
|
||
|
|
ap.add_argument("--max-length", type=int, default=512)
|
||
|
|
ap.add_argument("--device", default="cuda:0")
|
||
|
|
args = ap.parse_args()
|
||
|
|
|
||
|
|
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||
|
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
||
|
|
args.dtype
|
||
|
|
]
|
||
|
|
|
||
|
|
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||
|
|
if tokenizer.pad_token_id is None:
|
||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
args.model,
|
||
|
|
torch_dtype=dtype,
|
||
|
|
device_map=args.device,
|
||
|
|
low_cpu_mem_usage=True,
|
||
|
|
)
|
||
|
|
hidden_dim = model.config.hidden_size
|
||
|
|
print(f"Model loaded. hidden_dim={hidden_dim}, "
|
||
|
|
f"n_layers={model.config.num_hidden_layers}")
|
||
|
|
|
||
|
|
manifest_path = Path(args.training_data_dir) / "_manifest.json"
|
||
|
|
manifest = json.loads(manifest_path.read_text())
|
||
|
|
|
||
|
|
emotions = sorted(manifest["emotions"].keys())
|
||
|
|
print(f"Training {len(emotions)} emotions: {emotions}")
|
||
|
|
|
||
|
|
n_emotions = len(emotions)
|
||
|
|
n_layers = len(target_layers)
|
||
|
|
vectors = torch.zeros(
|
||
|
|
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
|
||
|
|
)
|
||
|
|
device = torch.device(args.device)
|
||
|
|
|
||
|
|
for e_idx, emotion in enumerate(emotions):
|
||
|
|
path = Path(args.training_data_dir) / f"{emotion}.jsonl"
|
||
|
|
pos_texts, neg_texts = [], []
|
||
|
|
with open(path) as f:
|
||
|
|
for line in f:
|
||
|
|
ex = json.loads(line)
|
||
|
|
if ex["polarity"] == "positive":
|
||
|
|
pos_texts.append(ex["text"])
|
||
|
|
else:
|
||
|
|
neg_texts.append(ex["text"])
|
||
|
|
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
|
||
|
|
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
|
||
|
|
|
||
|
|
pos_acts = _collect_activations(
|
||
|
|
model, tokenizer, pos_texts, target_layers, device,
|
||
|
|
args.batch_size, args.max_length,
|
||
|
|
)
|
||
|
|
neg_acts = _collect_activations(
|
||
|
|
model, tokenizer, neg_texts, target_layers, device,
|
||
|
|
args.batch_size, args.max_length,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Difference of means per layer
|
||
|
|
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden]
|
||
|
|
neg_mean = neg_acts.mean(dim=0)
|
||
|
|
diff = pos_mean - neg_mean
|
||
|
|
|
||
|
|
# Normalize per layer so projections are scale-comparable
|
||
|
|
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||
|
|
diff = diff / norms
|
||
|
|
|
||
|
|
vectors[e_idx] = diff
|
||
|
|
del pos_acts, neg_acts
|
||
|
|
gc.collect()
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
|
||
|
|
# Save in AmygdalaConnector format.
|
||
|
|
# emotion_names as padded uint8 tensor
|
||
|
|
names_bytes = [e.encode("utf-8") for e in emotions]
|
||
|
|
max_len = max(len(b) for b in names_bytes)
|
||
|
|
padded = torch.tensor(
|
||
|
|
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
|
||
|
|
dtype=torch.uint8,
|
||
|
|
)
|
||
|
|
|
||
|
|
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
||
|
|
safetensors.torch.save_file(
|
||
|
|
{
|
||
|
|
"vectors": vectors.to(torch.float16),
|
||
|
|
"emotion_names": padded,
|
||
|
|
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
|
||
|
|
},
|
||
|
|
args.output,
|
||
|
|
)
|
||
|
|
print(f"\nWrote steering vectors to {args.output}: "
|
||
|
|
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|