training: move amygdala training scripts out of vllm plugin
The fynnsu-based vllm/plugins/amygdala/ scaffold was superseded by the readout infrastructure landed as vllm commit d3e74edf8500 (vllm/model_executor/layers/readout.py + vllm/v1/worker/readout_manager.py). Training code remained useful so it moved here rather than being deleted. train_steering_vectors.py: CAA diff-of-means trainer that produces the [n_concepts, hidden_size] per-layer projection matrices the runner loads via VLLM_READOUT_VECTORS. extract_training_pairs.py: memory graph -> JSONL converter using per-emotion score thresholds from the subconscious agents' tag lines. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
ec7568c726
commit
34bd122590
4 changed files with 545 additions and 0 deletions
79
training/amygdala_training/README.md
Normal file
79
training/amygdala_training/README.md
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Amygdala Readout Vector Training
|
||||||
|
|
||||||
|
Training pipeline that produces the safetensors file the vLLM
|
||||||
|
ReadoutManager loads at runtime (see
|
||||||
|
`vllm/vllm/v1/worker/readout_manager.py`). Produces per-hooked-layer
|
||||||
|
`[n_concepts, hidden_size]` projection matrices keyed as
|
||||||
|
`layer_<idx>.vectors` — the directions the runner projects residual
|
||||||
|
activations onto during each forward pass.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Two scripts, run in sequence:
|
||||||
|
|
||||||
|
1. **`extract_training_pairs.py`** — turns the memory graph into a
|
||||||
|
directory of (emotion, polarity, text) training examples.
|
||||||
|
Positive examples are memory nodes where the emotion scored
|
||||||
|
≥ a threshold; negative examples are nodes where it's absent or
|
||||||
|
low. Emotion tags come from the trailing `warmth:9 clarity:10 …`
|
||||||
|
lines the subconscious agents emit.
|
||||||
|
|
||||||
|
2. **`train_steering_vectors.py`** — for each emotion, runs the
|
||||||
|
target model over the positive and negative examples, captures
|
||||||
|
residual-stream activations at the configured target layers, and
|
||||||
|
computes `mean(positive) - mean(negative)` as the steering
|
||||||
|
direction. Normalizes per-layer to unit length and saves the
|
||||||
|
whole `[E, L, H]` matrix.
|
||||||
|
|
||||||
|
The output file is passed to vLLM via `VLLM_READOUT_VECTORS` together
|
||||||
|
with a `VLLM_READOUT_MANIFEST` JSON listing concepts and hooked layer
|
||||||
|
indices.
|
||||||
|
|
||||||
|
## Method
|
||||||
|
|
||||||
|
This is Contrastive Activation Addition (CAA, Rimsky et al.) applied
|
||||||
|
to naturally-occurring emotion labels rather than hand-crafted
|
||||||
|
contrast pairs. The shape of the signal we're recovering is "what
|
||||||
|
direction in the residual stream corresponds to the model processing
|
||||||
|
text-with-emotion-E vs. text-without". Because our training data was
|
||||||
|
generated by the very model we're instrumenting (past-self's journal
|
||||||
|
entries, digest nodes, pattern nodes), the signal should be unusually
|
||||||
|
clean — the emotion labels and the text are already causally linked
|
||||||
|
through a single model's forward pass.
|
||||||
|
|
||||||
|
## Usage (design — not yet runnable)
|
||||||
|
|
||||||
|
```
|
||||||
|
# Step 1: memory graph → training data
|
||||||
|
python -m training.amygdala_training.extract_training_pairs \
|
||||||
|
--memory-mcp-url http://localhost:7777 \
|
||||||
|
--output-dir /tmp/amygdala_training_data \
|
||||||
|
--min-positive-score 8 \
|
||||||
|
--max-negative-mentions 0 \
|
||||||
|
--min-content-chars 40 \
|
||||||
|
--max-examples-per-emotion 500
|
||||||
|
|
||||||
|
# Step 2: training data → steering vectors
|
||||||
|
python -m training.amygdala_training.train_steering_vectors \
|
||||||
|
--model Qwen/Qwen3.5-27B \
|
||||||
|
--training-data-dir /tmp/amygdala_training_data \
|
||||||
|
--target-layers 3,18,33,36 \
|
||||||
|
--output /path/to/amygdala_vectors.safetensors \
|
||||||
|
--dtype bf16 \
|
||||||
|
--batch-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Open questions
|
||||||
|
|
||||||
|
- **Emotion selection**: enumerating which ~200 emotions to cover.
|
||||||
|
Could be "most-common tags in the graph" (data-driven) or "from
|
||||||
|
core-personality / pattern nodes" (human-curated). Probably both.
|
||||||
|
- **Layer selection**: middle-to-late layers (~60–80% of depth)
|
||||||
|
usually hold abstract semantic representations best; experiment
|
||||||
|
with which layers give the cleanest linear separation per emotion.
|
||||||
|
- **Cross-talk**: if two emotions are highly co-occurring (warmth +
|
||||||
|
love, frustration + tiredness), their vectors will be close; that's
|
||||||
|
fine as long as we don't pretend they're independent axes.
|
||||||
|
- **Generalization**: vectors trained on our memory graph may not
|
||||||
|
generalize to out-of-distribution text. Check by applying them to
|
||||||
|
held-out conversation data and eyeballing the projections.
|
||||||
6
training/amygdala_training/__init__.py
Normal file
6
training/amygdala_training/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Training utilities for amygdala steering vectors.
|
||||||
|
|
||||||
|
See README.md in this directory for overall design.
|
||||||
|
"""
|
||||||
212
training/amygdala_training/extract_training_pairs.py
Normal file
212
training/amygdala_training/extract_training_pairs.py
Normal file
|
|
@ -0,0 +1,212 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Extract emotion-labeled training pairs from the PoC memory graph.
|
||||||
|
|
||||||
|
Input: a memory graph (via poc-memory CLI or direct sqlite access).
|
||||||
|
Output: a directory with one JSONL file per emotion:
|
||||||
|
|
||||||
|
output_dir/
|
||||||
|
warmth.jsonl
|
||||||
|
clarity.jsonl
|
||||||
|
recognition.jsonl
|
||||||
|
...
|
||||||
|
_manifest.json # enumerates emotions + counts
|
||||||
|
|
||||||
|
Each line of an emotion's JSONL is one labeled example:
|
||||||
|
{"text": "...", "polarity": "positive"|"negative",
|
||||||
|
"source_key": "<node_key>", "emotion_score": 9}
|
||||||
|
|
||||||
|
Negative examples are sampled from nodes that DON'T mention the
|
||||||
|
emotion at all (not ones that mention it with a low score) — the
|
||||||
|
natural contrast is "text with this emotional loading" vs. "text
|
||||||
|
without this emotional loading." Low-score nodes are excluded
|
||||||
|
from both sides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
|
||||||
|
# Emotion tag format: `word:N` where N is 0..10. Matches the trailing
|
||||||
|
# `warmth:9 clarity:10 …` lines the subconscious agents emit.
|
||||||
|
EMOTION_TAG_RE = re.compile(r"\b([a-z][a-z\-]*[a-z]):(\d+)\b")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_poc_memory(args: list[str]) -> str:
|
||||||
|
"""Run `poc-memory` and return stdout."""
|
||||||
|
result = subprocess.run(
|
||||||
|
["poc-memory", *args],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_all_node_keys() -> Iterator[str]:
|
||||||
|
"""Yield every node key in the graph."""
|
||||||
|
out = _run_poc_memory(["query", "*", "|", "select", "key"])
|
||||||
|
for line in out.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_node_content(key: str) -> str | None:
|
||||||
|
"""Load a node's rendered content, or None if unavailable."""
|
||||||
|
try:
|
||||||
|
return _run_poc_memory(["render", key])
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _emotion_scores(content: str) -> dict[str, int]:
|
||||||
|
"""Parse trailing `warmth:9 clarity:10 …` style tags.
|
||||||
|
|
||||||
|
Returns the highest score seen for each emotion — multiple
|
||||||
|
tag lines in one node get max'd.
|
||||||
|
"""
|
||||||
|
out: dict[str, int] = {}
|
||||||
|
for name, score in EMOTION_TAG_RE.findall(content):
|
||||||
|
try:
|
||||||
|
s = int(score)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if 0 <= s <= 10:
|
||||||
|
out[name] = max(out.get(name, 0), s)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _node_body(content: str, min_chars: int) -> str | None:
|
||||||
|
"""Strip frontmatter/headers and return a bodies chunk for training."""
|
||||||
|
# Drop the emotion-tag lines themselves so the model doesn't
|
||||||
|
# learn to read the label directly.
|
||||||
|
stripped = EMOTION_TAG_RE.sub("", content)
|
||||||
|
stripped = stripped.strip()
|
||||||
|
if len(stripped) < min_chars:
|
||||||
|
return None
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument("--output-dir", required=True)
|
||||||
|
ap.add_argument(
|
||||||
|
"--min-positive-score", type=int, default=8,
|
||||||
|
help="Emotion score >= this counts as positive",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--min-content-chars", type=int, default=40,
|
||||||
|
help="Skip nodes shorter than this after stripping tags",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--max-examples-per-emotion", type=int, default=500,
|
||||||
|
help="Cap examples per polarity for balanced training",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--max-negative-pool-multiplier", type=float, default=5.0,
|
||||||
|
help="How many negative candidates to consider per positive",
|
||||||
|
)
|
||||||
|
ap.add_argument("--seed", type=int, default=0)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# First pass: collect every node's (key, body, emotion_scores).
|
||||||
|
print("Pass 1/2: scanning memory graph...")
|
||||||
|
all_nodes: list[tuple[str, str, dict[str, int]]] = []
|
||||||
|
for i, key in enumerate(_iter_all_node_keys()):
|
||||||
|
if i % 500 == 0:
|
||||||
|
print(f" {i} nodes scanned...")
|
||||||
|
content = _fetch_node_content(key)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
scores = _emotion_scores(content)
|
||||||
|
body = _node_body(content, args.min_content_chars)
|
||||||
|
if body is None:
|
||||||
|
continue
|
||||||
|
all_nodes.append((key, body, scores))
|
||||||
|
print(f" {len(all_nodes)} nodes retained after filters.")
|
||||||
|
|
||||||
|
# Which emotions have enough positive examples to be worth training?
|
||||||
|
emotion_counts: dict[str, int] = defaultdict(int)
|
||||||
|
for _, _, scores in all_nodes:
|
||||||
|
for name, s in scores.items():
|
||||||
|
if s >= args.min_positive_score:
|
||||||
|
emotion_counts[name] += 1
|
||||||
|
emotions = sorted(
|
||||||
|
(e for e, n in emotion_counts.items() if n >= 10),
|
||||||
|
key=lambda e: -emotion_counts[e],
|
||||||
|
)
|
||||||
|
print(f" {len(emotions)} emotions with >=10 positive examples.")
|
||||||
|
|
||||||
|
# Second pass: per emotion, build positive + negative pools.
|
||||||
|
print("Pass 2/2: assembling per-emotion pools...")
|
||||||
|
manifest: dict[str, dict] = {}
|
||||||
|
for emotion in emotions:
|
||||||
|
positives = [
|
||||||
|
(k, body) for k, body, s in all_nodes
|
||||||
|
if s.get(emotion, 0) >= args.min_positive_score
|
||||||
|
]
|
||||||
|
# Negative pool: nodes that don't mention this emotion at all.
|
||||||
|
negative_pool = [
|
||||||
|
(k, body) for k, body, s in all_nodes if emotion not in s
|
||||||
|
]
|
||||||
|
random.shuffle(positives)
|
||||||
|
random.shuffle(negative_pool)
|
||||||
|
positives = positives[: args.max_examples_per_emotion]
|
||||||
|
n_neg = min(
|
||||||
|
len(positives),
|
||||||
|
len(negative_pool),
|
||||||
|
int(args.max_examples_per_emotion),
|
||||||
|
)
|
||||||
|
negatives = negative_pool[:n_neg]
|
||||||
|
|
||||||
|
if not positives or not negatives:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_path = os.path.join(args.output_dir, f"{emotion}.jsonl")
|
||||||
|
with open(out_path, "w") as f:
|
||||||
|
for key, body in positives:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"text": body,
|
||||||
|
"polarity": "positive",
|
||||||
|
"source_key": key,
|
||||||
|
"emotion": emotion,
|
||||||
|
}) + "\n")
|
||||||
|
for key, body in negatives:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"text": body,
|
||||||
|
"polarity": "negative",
|
||||||
|
"source_key": key,
|
||||||
|
"emotion": emotion,
|
||||||
|
}) + "\n")
|
||||||
|
manifest[emotion] = {
|
||||||
|
"n_positive": len(positives),
|
||||||
|
"n_negative": len(negatives),
|
||||||
|
"path": out_path,
|
||||||
|
}
|
||||||
|
print(f" {emotion}: {len(positives)} pos / {len(negatives)} neg")
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(args.output_dir, "_manifest.json"), "w"
|
||||||
|
) as f:
|
||||||
|
json.dump({
|
||||||
|
"emotions": manifest,
|
||||||
|
"source_nodes": len(all_nodes),
|
||||||
|
"min_positive_score": args.min_positive_score,
|
||||||
|
}, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\nWrote {len(manifest)} emotion files to {args.output_dir}")
|
||||||
|
print(f"Manifest: {os.path.join(args.output_dir, '_manifest.json')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
248
training/amygdala_training/train_steering_vectors.py
Normal file
248
training/amygdala_training/train_steering_vectors.py
Normal file
|
|
@ -0,0 +1,248 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Train amygdala steering vectors via Contrastive Activation Addition.
|
||||||
|
|
||||||
|
Reads the per-emotion JSONL files produced by extract_training_pairs.py,
|
||||||
|
runs the target model over each example, captures the residual-stream
|
||||||
|
hidden state at the configured target layers, and computes
|
||||||
|
`mean(positive) - mean(negative)` as the steering direction per layer
|
||||||
|
per emotion.
|
||||||
|
|
||||||
|
Output: a safetensors file matching the format AmygdalaConnector
|
||||||
|
expects:
|
||||||
|
|
||||||
|
vectors: [n_emotions, n_target_layers, hidden_dim] fp16
|
||||||
|
emotion_names: [n_emotions] uint8
|
||||||
|
|
||||||
|
Pooling: last-token residual-stream per example (CAA convention —
|
||||||
|
the final token has seen the whole context and is where the model's
|
||||||
|
"decision" lives). Alternative: mean across all tokens. The LAST
|
||||||
|
convention is more common for steering vector work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _pool_last(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pick the last non-pad token's hidden state per example.
|
||||||
|
|
||||||
|
hidden: [batch, seq, hidden_dim]
|
||||||
|
attention_mask: [batch, seq]
|
||||||
|
returns: [batch, hidden_dim]
|
||||||
|
"""
|
||||||
|
# last non-pad token index per row
|
||||||
|
last_idx = attention_mask.sum(dim=1) - 1
|
||||||
|
batch_idx = torch.arange(hidden.size(0), device=hidden.device)
|
||||||
|
return hidden[batch_idx, last_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_activations(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
texts: list[str],
|
||||||
|
target_layers: list[int],
|
||||||
|
device: torch.device,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run texts through the model, capture residual stream at target
|
||||||
|
layers, return [n_texts, n_target_layers, hidden_dim] fp32 on CPU.
|
||||||
|
"""
|
||||||
|
# Register hooks on the target layers' outputs. We want the
|
||||||
|
# residual stream AFTER each layer, which is the output of the
|
||||||
|
# transformer block (hidden_states[layer_idx+1] in HF land).
|
||||||
|
captures: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def make_hook(idx):
|
||||||
|
def hook(_mod, _inp, output):
|
||||||
|
# output is typically (hidden_states, ...) — take the first
|
||||||
|
hs = output[0] if isinstance(output, tuple) else output
|
||||||
|
captures[idx] = hs.detach()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
handles = []
|
||||||
|
# Transformers' LlamaModel.layers is a ModuleList; Qwen3.5's
|
||||||
|
# language_model.model.layers follows the same convention.
|
||||||
|
# Resolve the layer list by walking common paths.
|
||||||
|
layers_module = _find_layers_module(model)
|
||||||
|
for idx in target_layers:
|
||||||
|
handles.append(
|
||||||
|
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||||
|
)
|
||||||
|
|
||||||
|
out_rows: list[torch.Tensor] = []
|
||||||
|
try:
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
tok = tokenizer(
|
||||||
|
batch,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
).to(device)
|
||||||
|
captures.clear()
|
||||||
|
model(**tok)
|
||||||
|
|
||||||
|
per_layer = []
|
||||||
|
for idx in target_layers:
|
||||||
|
hs = captures[idx] # [batch, seq, hidden]
|
||||||
|
pooled = _pool_last(hs, tok["attention_mask"])
|
||||||
|
per_layer.append(pooled.to(torch.float32).cpu())
|
||||||
|
# Stack to [batch, n_layers, hidden_dim]
|
||||||
|
batched = torch.stack(per_layer, dim=1)
|
||||||
|
out_rows.append(batched)
|
||||||
|
|
||||||
|
del tok, captures
|
||||||
|
if (i // batch_size) % 10 == 0:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
finally:
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return torch.cat(out_rows, dim=0) # [n_texts, n_layers, hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def _find_layers_module(model) -> torch.nn.ModuleList:
|
||||||
|
"""Walk a few likely paths to find the transformer-block list."""
|
||||||
|
candidates = [
|
||||||
|
"model.layers",
|
||||||
|
"model.model.layers",
|
||||||
|
"model.language_model.layers",
|
||||||
|
"model.language_model.model.layers",
|
||||||
|
"language_model.model.layers",
|
||||||
|
"transformer.h",
|
||||||
|
]
|
||||||
|
for path in candidates:
|
||||||
|
obj = model
|
||||||
|
ok = True
|
||||||
|
for part in path.split("."):
|
||||||
|
if not hasattr(obj, part):
|
||||||
|
ok = False
|
||||||
|
break
|
||||||
|
obj = getattr(obj, part)
|
||||||
|
if ok and isinstance(obj, torch.nn.ModuleList):
|
||||||
|
return obj
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Couldn't find transformer layer list. Tried: {candidates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument("--model", required=True, help="HF model id or path")
|
||||||
|
ap.add_argument("--training-data-dir", required=True)
|
||||||
|
ap.add_argument(
|
||||||
|
"--target-layers", required=True,
|
||||||
|
help="Comma-separated layer indices, e.g. 3,18,33,36",
|
||||||
|
)
|
||||||
|
ap.add_argument("--output", required=True)
|
||||||
|
ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||||
|
ap.add_argument("--batch-size", type=int, default=4)
|
||||||
|
ap.add_argument("--max-length", type=int, default=512)
|
||||||
|
ap.add_argument("--device", default="cuda:0")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
||||||
|
args.dtype
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map=args.device,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
hidden_dim = model.config.hidden_size
|
||||||
|
print(f"Model loaded. hidden_dim={hidden_dim}, "
|
||||||
|
f"n_layers={model.config.num_hidden_layers}")
|
||||||
|
|
||||||
|
manifest_path = Path(args.training_data_dir) / "_manifest.json"
|
||||||
|
manifest = json.loads(manifest_path.read_text())
|
||||||
|
|
||||||
|
emotions = sorted(manifest["emotions"].keys())
|
||||||
|
print(f"Training {len(emotions)} emotions: {emotions}")
|
||||||
|
|
||||||
|
n_emotions = len(emotions)
|
||||||
|
n_layers = len(target_layers)
|
||||||
|
vectors = torch.zeros(
|
||||||
|
(n_emotions, n_layers, hidden_dim), dtype=torch.float32
|
||||||
|
)
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
for e_idx, emotion in enumerate(emotions):
|
||||||
|
path = Path(args.training_data_dir) / f"{emotion}.jsonl"
|
||||||
|
pos_texts, neg_texts = [], []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f:
|
||||||
|
ex = json.loads(line)
|
||||||
|
if ex["polarity"] == "positive":
|
||||||
|
pos_texts.append(ex["text"])
|
||||||
|
else:
|
||||||
|
neg_texts.append(ex["text"])
|
||||||
|
print(f"[{e_idx+1}/{n_emotions}] {emotion}: "
|
||||||
|
f"{len(pos_texts)} pos / {len(neg_texts)} neg")
|
||||||
|
|
||||||
|
pos_acts = _collect_activations(
|
||||||
|
model, tokenizer, pos_texts, target_layers, device,
|
||||||
|
args.batch_size, args.max_length,
|
||||||
|
)
|
||||||
|
neg_acts = _collect_activations(
|
||||||
|
model, tokenizer, neg_texts, target_layers, device,
|
||||||
|
args.batch_size, args.max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Difference of means per layer
|
||||||
|
pos_mean = pos_acts.mean(dim=0) # [n_layers, hidden]
|
||||||
|
neg_mean = neg_acts.mean(dim=0)
|
||||||
|
diff = pos_mean - neg_mean
|
||||||
|
|
||||||
|
# Normalize per layer so projections are scale-comparable
|
||||||
|
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||||
|
diff = diff / norms
|
||||||
|
|
||||||
|
vectors[e_idx] = diff
|
||||||
|
del pos_acts, neg_acts
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Save in AmygdalaConnector format.
|
||||||
|
# emotion_names as padded uint8 tensor
|
||||||
|
names_bytes = [e.encode("utf-8") for e in emotions]
|
||||||
|
max_len = max(len(b) for b in names_bytes)
|
||||||
|
padded = torch.tensor(
|
||||||
|
[list(b.ljust(max_len, b"\x00")) for b in names_bytes],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
||||||
|
safetensors.torch.save_file(
|
||||||
|
{
|
||||||
|
"vectors": vectors.to(torch.float16),
|
||||||
|
"emotion_names": padded,
|
||||||
|
"target_layers": torch.tensor(target_layers, dtype=torch.int32),
|
||||||
|
},
|
||||||
|
args.output,
|
||||||
|
)
|
||||||
|
print(f"\nWrote steering vectors to {args.output}: "
|
||||||
|
f"{n_emotions} emotions x {n_layers} layers x {hidden_dim} dim (fp16)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue