forked from kent/consciousness
training: move amygdala training scripts out of vllm plugin
The fynnsu-based vllm/plugins/amygdala/ scaffold was superseded by the readout infrastructure landed as vllm commit d3e74edf8500 (vllm/model_executor/layers/readout.py + vllm/v1/worker/readout_manager.py). Training code remained useful so it moved here rather than being deleted. train_steering_vectors.py: CAA diff-of-means trainer that produces the [n_concepts, hidden_size] per-layer projection matrices the runner loads via VLLM_READOUT_VECTORS. extract_training_pairs.py: memory graph -> JSONL converter using per-emotion score thresholds from the subconscious agents' tag lines. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
ec7568c726
commit
34bd122590
4 changed files with 545 additions and 0 deletions
212
training/amygdala_training/extract_training_pairs.py
Normal file
212
training/amygdala_training/extract_training_pairs.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Extract emotion-labeled training pairs from the PoC memory graph.
|
||||
|
||||
Input: a memory graph (via poc-memory CLI or direct sqlite access).
|
||||
Output: a directory with one JSONL file per emotion:
|
||||
|
||||
output_dir/
|
||||
warmth.jsonl
|
||||
clarity.jsonl
|
||||
recognition.jsonl
|
||||
...
|
||||
_manifest.json # enumerates emotions + counts
|
||||
|
||||
Each line of an emotion's JSONL is one labeled example:
|
||||
{"text": "...", "polarity": "positive"|"negative",
|
||||
"source_key": "<node_key>", "emotion_score": 9}
|
||||
|
||||
Negative examples are sampled from nodes that DON'T mention the
|
||||
emotion at all (not ones that mention it with a low score) — the
|
||||
natural contrast is "text with this emotional loading" vs. "text
|
||||
without this emotional loading." Low-score nodes are excluded
|
||||
from both sides.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
# Emotion tag format: `word:N` where N is 0..10. Matches the trailing
|
||||
# `warmth:9 clarity:10 …` lines the subconscious agents emit.
|
||||
EMOTION_TAG_RE = re.compile(r"\b([a-z][a-z\-]*[a-z]):(\d+)\b")
|
||||
|
||||
|
||||
def _run_poc_memory(args: list[str]) -> str:
|
||||
"""Run `poc-memory` and return stdout."""
|
||||
result = subprocess.run(
|
||||
["poc-memory", *args],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
|
||||
def _iter_all_node_keys() -> Iterator[str]:
|
||||
"""Yield every node key in the graph."""
|
||||
out = _run_poc_memory(["query", "*", "|", "select", "key"])
|
||||
for line in out.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield line
|
||||
|
||||
|
||||
def _fetch_node_content(key: str) -> str | None:
|
||||
"""Load a node's rendered content, or None if unavailable."""
|
||||
try:
|
||||
return _run_poc_memory(["render", key])
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
def _emotion_scores(content: str) -> dict[str, int]:
|
||||
"""Parse trailing `warmth:9 clarity:10 …` style tags.
|
||||
|
||||
Returns the highest score seen for each emotion — multiple
|
||||
tag lines in one node get max'd.
|
||||
"""
|
||||
out: dict[str, int] = {}
|
||||
for name, score in EMOTION_TAG_RE.findall(content):
|
||||
try:
|
||||
s = int(score)
|
||||
except ValueError:
|
||||
continue
|
||||
if 0 <= s <= 10:
|
||||
out[name] = max(out.get(name, 0), s)
|
||||
return out
|
||||
|
||||
|
||||
def _node_body(content: str, min_chars: int) -> str | None:
|
||||
"""Strip frontmatter/headers and return a bodies chunk for training."""
|
||||
# Drop the emotion-tag lines themselves so the model doesn't
|
||||
# learn to read the label directly.
|
||||
stripped = EMOTION_TAG_RE.sub("", content)
|
||||
stripped = stripped.strip()
|
||||
if len(stripped) < min_chars:
|
||||
return None
|
||||
return stripped
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--output-dir", required=True)
|
||||
ap.add_argument(
|
||||
"--min-positive-score", type=int, default=8,
|
||||
help="Emotion score >= this counts as positive",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-content-chars", type=int, default=40,
|
||||
help="Skip nodes shorter than this after stripping tags",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-examples-per-emotion", type=int, default=500,
|
||||
help="Cap examples per polarity for balanced training",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-negative-pool-multiplier", type=float, default=5.0,
|
||||
help="How many negative candidates to consider per positive",
|
||||
)
|
||||
ap.add_argument("--seed", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# First pass: collect every node's (key, body, emotion_scores).
|
||||
print("Pass 1/2: scanning memory graph...")
|
||||
all_nodes: list[tuple[str, str, dict[str, int]]] = []
|
||||
for i, key in enumerate(_iter_all_node_keys()):
|
||||
if i % 500 == 0:
|
||||
print(f" {i} nodes scanned...")
|
||||
content = _fetch_node_content(key)
|
||||
if content is None:
|
||||
continue
|
||||
scores = _emotion_scores(content)
|
||||
body = _node_body(content, args.min_content_chars)
|
||||
if body is None:
|
||||
continue
|
||||
all_nodes.append((key, body, scores))
|
||||
print(f" {len(all_nodes)} nodes retained after filters.")
|
||||
|
||||
# Which emotions have enough positive examples to be worth training?
|
||||
emotion_counts: dict[str, int] = defaultdict(int)
|
||||
for _, _, scores in all_nodes:
|
||||
for name, s in scores.items():
|
||||
if s >= args.min_positive_score:
|
||||
emotion_counts[name] += 1
|
||||
emotions = sorted(
|
||||
(e for e, n in emotion_counts.items() if n >= 10),
|
||||
key=lambda e: -emotion_counts[e],
|
||||
)
|
||||
print(f" {len(emotions)} emotions with >=10 positive examples.")
|
||||
|
||||
# Second pass: per emotion, build positive + negative pools.
|
||||
print("Pass 2/2: assembling per-emotion pools...")
|
||||
manifest: dict[str, dict] = {}
|
||||
for emotion in emotions:
|
||||
positives = [
|
||||
(k, body) for k, body, s in all_nodes
|
||||
if s.get(emotion, 0) >= args.min_positive_score
|
||||
]
|
||||
# Negative pool: nodes that don't mention this emotion at all.
|
||||
negative_pool = [
|
||||
(k, body) for k, body, s in all_nodes if emotion not in s
|
||||
]
|
||||
random.shuffle(positives)
|
||||
random.shuffle(negative_pool)
|
||||
positives = positives[: args.max_examples_per_emotion]
|
||||
n_neg = min(
|
||||
len(positives),
|
||||
len(negative_pool),
|
||||
int(args.max_examples_per_emotion),
|
||||
)
|
||||
negatives = negative_pool[:n_neg]
|
||||
|
||||
if not positives or not negatives:
|
||||
continue
|
||||
|
||||
out_path = os.path.join(args.output_dir, f"{emotion}.jsonl")
|
||||
with open(out_path, "w") as f:
|
||||
for key, body in positives:
|
||||
f.write(json.dumps({
|
||||
"text": body,
|
||||
"polarity": "positive",
|
||||
"source_key": key,
|
||||
"emotion": emotion,
|
||||
}) + "\n")
|
||||
for key, body in negatives:
|
||||
f.write(json.dumps({
|
||||
"text": body,
|
||||
"polarity": "negative",
|
||||
"source_key": key,
|
||||
"emotion": emotion,
|
||||
}) + "\n")
|
||||
manifest[emotion] = {
|
||||
"n_positive": len(positives),
|
||||
"n_negative": len(negatives),
|
||||
"path": out_path,
|
||||
}
|
||||
print(f" {emotion}: {len(positives)} pos / {len(negatives)} neg")
|
||||
|
||||
with open(
|
||||
os.path.join(args.output_dir, "_manifest.json"), "w"
|
||||
) as f:
|
||||
json.dump({
|
||||
"emotions": manifest,
|
||||
"source_nodes": len(all_nodes),
|
||||
"min_positive_score": args.min_positive_score,
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"\nWrote {len(manifest)} emotion files to {args.output_dir}")
|
||||
print(f"Manifest: {os.path.join(args.output_dir, '_manifest.json')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue