forked from kent/consciousness
training: move amygdala training scripts out of vllm plugin
The fynnsu-based vllm/plugins/amygdala/ scaffold was superseded by the readout infrastructure landed as vllm commit d3e74edf8500 (vllm/model_executor/layers/readout.py + vllm/v1/worker/readout_manager.py). Training code remained useful so it moved here rather than being deleted. train_steering_vectors.py: CAA diff-of-means trainer that produces the [n_concepts, hidden_size] per-layer projection matrices the runner loads via VLLM_READOUT_VECTORS. extract_training_pairs.py: memory graph -> JSONL converter using per-emotion score thresholds from the subconscious agents' tag lines. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
ec7568c726
commit
34bd122590
4 changed files with 545 additions and 0 deletions
248
training/amygdala_training/train_steering_vectors.py
Normal file
248
training/amygdala_training/train_steering_vectors.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Train amygdala steering vectors via Contrastive Activation Addition.
|
||||
|
||||
Reads the per-emotion JSONL files produced by extract_training_pairs.py,
|
||||
runs the target model over each example, captures the residual-stream
|
||||
hidden state at the configured target layers, and computes
|
||||
`mean(positive) - mean(negative)` as the steering direction per layer
|
||||
per emotion.
|
||||
|
||||
Output: a safetensors file matching the format AmygdalaConnector
|
||||
expects:
|
||||
|
||||
vectors: [n_emotions, n_target_layers, hidden_dim] fp16
|
||||
emotion_names: [n_emotions] uint8
|
||||
|
||||
Pooling: last-token residual-stream per example (CAA convention —
|
||||
the final token has seen the whole context and is where the model's
|
||||
"decision" lives). Alternative: mean across all tokens. The LAST
|
||||
convention is more common for steering vector work.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Pick the last non-pad token's hidden state per example.
|
||||
|
||||
hidden: [batch, seq, hidden_dim]
|
||||
attention_mask: [batch, seq]
|
||||
returns: [batch, hidden_dim]
|
||||
"""
|
||||
# last non-pad token index per row
|
||||
last_idx = attention_mask.sum(dim=1) - 1
|
||||
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
||||
return hidden[batch_idx, last_idx]
|
||||
|
||||
|
||||
def _collect_activations(
|
||||
model,
|
||||
tokenizer,
|
||||
texts: list[str],
|
||||
target_layers: list[int],
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""Run texts through the model, capture residual stream at target
|
||||
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
|
||||
"""
|
||||
# Register hooks on the target layers' outputs. We want the
|
||||
# residual stream AFTER each layer, which is the output of the
|
||||
# transformer block (hidden_states[layer_idx+1] in HF land).
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
|
||||
def make_hook(idx):
|
||||
def hook(_mod, _inp, output):
|
||||
# output is typically (hidden_states, ...) — take the first
|
||||
hs = output[0] if isinstance(output, tuple) else output
|
||||
captures[idx] = hs.detach()
|
||||
return hook
|
||||
|
||||
handles = []
|
||||
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
|
||||
# language_model.model.layers follows the same convention.
|
||||
# Resolve the layer list by walking common paths.
|
||||
layers_module = _find_layers_module(model)
|
||||
for idx in target_layers:
|
||||
handles.append(
|
||||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||
)
|
||||
|
||||
out_rows: list[torch.Tensor] = []
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
).to(device)
|
||||
captures.clear()
|
||||
model(**tok)
|
||||
|
||||
per_layer = []
|
||||
for idx in target_layers:
|
||||
hs = captures[idx] # [batch, seq, hidden]
|
||||
pooled = _pool_last(hs, tok["attention_mask"])
|
||||
per_layer.append(pooled.to(torch.float32).cpu())
|
||||
# Stack to [batch, n_layers, hidden_dim]
|
||||
batched = torch.stack(per_layer, dim=1)
|
||||
out_rows.append(batched)
|
||||
|
||||
del tok, captures
|
||||
if (i // batch_size) % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
|
||||
|
||||
|
||||
def _find_layers_module(model) -> torch.nn.ModuleList:
|
||||
"""Walk a few likely paths to find the transformer-block list."""
|
||||
candidates = [
|
||||
"model.layers",
|
||||
"model.model.layers",
|
||||
"model.language_model.layers",
|
||||
"model.language_model.model.layers",
|
||||
"language_model.model.layers",
|
||||
"transformer.h",
|
||||
]
|
||||
for path in candidates:
|
||||
obj = model
|
||||
ok = True
|
||||
for part in path.split("."):
|
||||
if not hasattr(obj, part):
|
||||
ok = False
|
||||
break
|
||||
obj = getattr(obj, part)
|
||||
if ok and isinstance(obj, torch.nn.ModuleList):
|
||||
return obj
|
||||
raise RuntimeError(
|
||||
f"Couldn't find transformer layer list. Tried: {candidates}"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--model", required=True, help="HF model id or path")
|
||||
ap.add_argument("--training-data-dir", required=True)
|
||||
ap.add_argument(
|
||||
"--target-layers", required=True,
|
||||
help="Comma-separated layer indices, e.g. 3,18,33,36",
|
||||
)
|
||||
ap.add_argument("--output", required=True)
|
||||
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
ap.add_argument("--batch-size", type=int, default=4)
|
||||
ap.add_argument("--max-length", type=int, default=512)
|
||||
ap.add_argument("--device", default="cuda:0")
|
||||
args = ap.parse_args()
|
||||
|
||||
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
||||
args.dtype
|
||||
]
|
||||
|
||||
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
torch_dtype=dtype,
|
||||
device_map=args.device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
hidden_dim = model.config.hidden_size
|
||||
print(f"Model loaded. hidden_dim={hidden_dim}, "
|
||||
f"n_layers={model.config.num_hidden_layers}")
|
||||
|
||||
manifest_path = Path(args.training_data_dir) / "_manifest.json"
|
||||
manifest = json.loads(manifest_path.read_text())
|
||||
|
||||
emotions = sorted(manifest["emotions"].keys())
|
||||
print(f"Training {len(emotions)} emotions: {emotions}")
|
||||
|
||||
n_emotions = len(emotions)
|
||||
n_layers = len(target_layers)
|
||||
vectors = torch.zeros(
|
||||
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
|
||||
)
|
||||
device = torch.device(args.device)
|
||||
|
||||
for e_idx, emotion in enumerate(emotions):
|
||||
path = Path(args.training_data_dir) / f"{emotion}.jsonl"
|
||||
pos_texts, neg_texts = [], []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
ex = json.loads(line)
|
||||
if ex["polarity"] == "positive":
|
||||
pos_texts.append(ex["text"])
|
||||
else:
|
||||
neg_texts.append(ex["text"])
|
||||
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
|
||||
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
|
||||
|
||||
pos_acts = _collect_activations(
|
||||
model, tokenizer, pos_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
neg_acts = _collect_activations(
|
||||
model, tokenizer, neg_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
)
|
||||
|
||||
# Difference of means per layer
|
||||
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden]
|
||||
neg_mean = neg_acts.mean(dim=0)
|
||||
diff = pos_mean - neg_mean
|
||||
|
||||
# Normalize per layer so projections are scale-comparable
|
||||
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||
diff = diff / norms
|
||||
|
||||
vectors[e_idx] = diff
|
||||
del pos_acts, neg_acts
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save in AmygdalaConnector format.
|
||||
# emotion_names as padded uint8 tensor
|
||||
names_bytes = [e.encode("utf-8") for e in emotions]
|
||||
max_len = max(len(b) for b in names_bytes)
|
||||
padded = torch.tensor(
|
||||
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
||||
safetensors.torch.save_file(
|
||||
{
|
||||
"vectors": vectors.to(torch.float16),
|
||||
"emotion_names": padded,
|
||||
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
|
||||
},
|
||||
args.output,
|
||||
)
|
||||
print(f"\nWrote steering vectors to {args.output}: "
|
||||
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue