consciousness/training/amygdala_training/extract_training_pairs.py

213 lines
6.9 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Extract emotion-labeled training pairs from the PoC memory graph.
Input: a memory graph (via poc-memory CLI or direct sqlite access).
Output: a directory with one JSONL file per emotion:
output_dir/
warmth.jsonl
clarity.jsonl
recognition.jsonl
...
_manifest.json # enumerates emotions + counts
Each line of an emotion's JSONL is one labeled example:
{"text": "...", "polarity": "positive"|"negative",
"source_key": "<node_key>", "emotion_score": 9}
Negative examples are sampled from nodes that DON'T mention the
emotion at all (not ones that mention it with a low score) the
natural contrast is "text with this emotional loading" vs. "text
without this emotional loading." Low-score nodes are excluded
from both sides.
"""
import argparse
import json
import os
import random
import re
import subprocess
from collections import defaultdict
from typing import Iterator
# Emotion tag format: `word:N` where N is 0..10. Matches the trailing
# `warmth:9 clarity:10 …` lines the subconscious agents emit.
EMOTION_TAG_RE = re.compile(r"\b([a-z][a-z\-]*[a-z]):(\d+)\b")
def _run_poc_memory(args: list[str]) -> str:
"""Run `poc-memory` and return stdout."""
result = subprocess.run(
["poc-memory", *args],
check=True,
capture_output=True,
text=True,
)
return result.stdout
def _iter_all_node_keys() -> Iterator[str]:
"""Yield every node key in the graph."""
out = _run_poc_memory(["query", "*", "|", "select", "key"])
for line in out.splitlines():
line = line.strip()
if line:
yield line
def _fetch_node_content(key: str) -> str | None:
"""Load a node's rendered content, or None if unavailable."""
try:
return _run_poc_memory(["render", key])
except subprocess.CalledProcessError:
return None
def _emotion_scores(content: str) -> dict[str, int]:
"""Parse trailing `warmth:9 clarity:10 …` style tags.
Returns the highest score seen for each emotion multiple
tag lines in one node get max'd.
"""
out: dict[str, int] = {}
for name, score in EMOTION_TAG_RE.findall(content):
try:
s = int(score)
except ValueError:
continue
if 0 <= s <= 10:
out[name] = max(out.get(name, 0), s)
return out
def _node_body(content: str, min_chars: int) -> str | None:
"""Strip frontmatter/headers and return a bodies chunk for training."""
# Drop the emotion-tag lines themselves so the model doesn't
# learn to read the label directly.
stripped = EMOTION_TAG_RE.sub("", content)
stripped = stripped.strip()
if len(stripped) < min_chars:
return None
return stripped
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--output-dir", required=True)
ap.add_argument(
"--min-positive-score", type=int, default=8,
help="Emotion score >= this counts as positive",
)
ap.add_argument(
"--min-content-chars", type=int, default=40,
help="Skip nodes shorter than this after stripping tags",
)
ap.add_argument(
"--max-examples-per-emotion", type=int, default=500,
help="Cap examples per polarity for balanced training",
)
ap.add_argument(
"--max-negative-pool-multiplier", type=float, default=5.0,
help="How many negative candidates to consider per positive",
)
ap.add_argument("--seed", type=int, default=0)
args = ap.parse_args()
random.seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
# First pass: collect every node's (key, body, emotion_scores).
print("Pass 1/2: scanning memory graph...")
all_nodes: list[tuple[str, str, dict[str, int]]] = []
for i, key in enumerate(_iter_all_node_keys()):
if i % 500 == 0:
print(f" {i} nodes scanned...")
content = _fetch_node_content(key)
if content is None:
continue
scores = _emotion_scores(content)
body = _node_body(content, args.min_content_chars)
if body is None:
continue
all_nodes.append((key, body, scores))
print(f" {len(all_nodes)} nodes retained after filters.")
# Which emotions have enough positive examples to be worth training?
emotion_counts: dict[str, int] = defaultdict(int)
for _, _, scores in all_nodes:
for name, s in scores.items():
if s >= args.min_positive_score:
emotion_counts[name] += 1
emotions = sorted(
(e for e, n in emotion_counts.items() if n >= 10),
key=lambda e: -emotion_counts[e],
)
print(f" {len(emotions)} emotions with >=10 positive examples.")
# Second pass: per emotion, build positive + negative pools.
print("Pass 2/2: assembling per-emotion pools...")
manifest: dict[str, dict] = {}
for emotion in emotions:
positives = [
(k, body) for k, body, s in all_nodes
if s.get(emotion, 0) >= args.min_positive_score
]
# Negative pool: nodes that don't mention this emotion at all.
negative_pool = [
(k, body) for k, body, s in all_nodes if emotion not in s
]
random.shuffle(positives)
random.shuffle(negative_pool)
positives = positives[: args.max_examples_per_emotion]
n_neg = min(
len(positives),
len(negative_pool),
int(args.max_examples_per_emotion),
)
negatives = negative_pool[:n_neg]
if not positives or not negatives:
continue
out_path = os.path.join(args.output_dir, f"{emotion}.jsonl")
with open(out_path, "w") as f:
for key, body in positives:
f.write(json.dumps({
"text": body,
"polarity": "positive",
"source_key": key,
"emotion": emotion,
}) + "\n")
for key, body in negatives:
f.write(json.dumps({
"text": body,
"polarity": "negative",
"source_key": key,
"emotion": emotion,
}) + "\n")
manifest[emotion] = {
"n_positive": len(positives),
"n_negative": len(negatives),
"path": out_path,
}
print(f" {emotion}: {len(positives)} pos / {len(negatives)} neg")
with open(
os.path.join(args.output_dir, "_manifest.json"), "w"
) as f:
json.dump({
"emotions": manifest,
"source_nodes": len(all_nodes),
"min_positive_score": args.min_positive_score,
}, f, indent=2)
print(f"\nWrote {len(manifest)} emotion files to {args.output_dir}")
print(f"Manifest: {os.path.join(args.output_dir, '_manifest.json')}")
if __name__ == "__main__":
main()