diff --git a/training/amygdala_training/train_with_library.py b/training/amygdala_training/train_with_library.py new file mode 100644 index 0000000..a349310 --- /dev/null +++ b/training/amygdala_training/train_with_library.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Train concept-readout vectors using the steering-vectors library. + +Alternative to train_steering_vectors.py that uses the pip-installable +steering-vectors library (github.com/steering-vectors/steering-vectors) +instead of our hand-rolled diff-of-means + subspace machinery. The +library ships multiple aggregators out of the box: + + mean — pos_mean - neg_mean, unit-normed. Equivalent to our + default 'pooled' method. + pca — concatenates [pos-neg, neg-pos] and takes the top PC. + Implicit denoising: direction of maximum variance in the + paired deltas, less sensitive to per-pair noise than + plain mean. + logistic — trains a logistic-regression classifier on centered + activations; concept direction is the weight vector. + L1 penalty gives an explicit sparse vector (zeroes out + noise coords); L2 shrinks low-magnitude coords. + linear — same, with linear regression. + +Output is the same readout.safetensors + readout.json format the +trainer and vLLM plugin already understand. +""" + +from __future__ import annotations + +import argparse +import json +import random +from pathlib import Path + +import safetensors.torch +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from steering_vectors import ( + SteeringVectorTrainingSample, + train_steering_vector, +) +from steering_vectors.aggregators import ( + mean_aggregator, + pca_aggregator, + logistic_aggregator, + linear_aggregator, +) + +# Reuse corpus loader from the hand-rolled trainer. +from training.amygdala_training.train_steering_vectors import _load_corpus + + +def _samples_for_concept( + emotion: str, + positives_by_emotion: dict[str, list[str]], + baselines: list[str], + *, + max_negatives_per_positive: int = 3, + seed: int = 0, +) -> list[SteeringVectorTrainingSample]: + """Build paired (pos, neg) training samples for one concept. + + For each positive story of ``emotion``, pair it with up to + ``max_negatives_per_positive`` randomly-sampled negatives drawn + from: (a) other emotions' positive stories, (b) scenario baselines. + + The library expects paired samples; we don't have true + counterfactual pairs for all concepts, so we approximate with + random cross-concept / baseline negatives. + """ + rng = random.Random(hash((emotion, seed)) & 0xFFFFFFFF) + neg_pool: list[str] = list(baselines) + for other, texts in positives_by_emotion.items(): + if other == emotion: + continue + neg_pool.extend(texts) + + samples: list[SteeringVectorTrainingSample] = [] + for pos in positives_by_emotion[emotion]: + if not neg_pool: + continue + picks = rng.sample(neg_pool, min(max_negatives_per_positive, len(neg_pool))) + for neg in picks: + samples.append( + SteeringVectorTrainingSample(positive_str=pos, negative_str=neg) + ) + return samples + + +def _aggregator_from_name(name: str): + if name == "mean": + return mean_aggregator() + if name == "pca": + return pca_aggregator() + if name == "logistic": + return logistic_aggregator() + if name == "logistic_l1": + return logistic_aggregator( + sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1} + ) + if name == "linear": + return linear_aggregator() + raise ValueError(f"unknown aggregator: {name}") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--model", required=True) + ap.add_argument("--stories-dir", required=True) + ap.add_argument("--paired-dir", default=None) + ap.add_argument("--target-layers", required=True, help="Comma-separated layer indices") + ap.add_argument("--output-dir", required=True) + ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) + ap.add_argument("--batch-size", type=int, default=2) + ap.add_argument("--max-length", type=int, default=512) + ap.add_argument("--device", default="cuda:0") + ap.add_argument("--min-positives", type=int, default=1) + ap.add_argument( + "--aggregator", + default="mean", + choices=["mean", "pca", "logistic", "logistic_l1", "linear"], + ) + ap.add_argument("--max-negatives-per-positive", type=int, default=3) + args = ap.parse_args() + + target_layers = [int(x) for x in args.target_layers.split(",")] + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ + args.dtype + ] + + stories_dir = Path(args.stories_dir) + paired_dir = Path(args.paired_dir) if args.paired_dir else None + positives_by_emotion, baselines = _load_corpus(stories_dir, paired_dir) + + emotions = sorted( + e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives + ) + if not emotions: + raise RuntimeError( + f"no emotions with >= {args.min_positives} positives in {stories_dir}" + ) + + print( + f"Training {len(emotions)} concepts via steering-vectors " + f"aggregator={args.aggregator!r} on layers={target_layers}" + ) + + print(f"Loading {args.model} ({args.dtype}) on {args.device}...") + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True + ) + model.eval() + + text_config = ( + model.config.get_text_config() + if hasattr(model.config, "get_text_config") + else model.config + ) + hidden_dim = getattr(text_config, "hidden_size", None) or getattr( + text_config, "hidden_dim", None + ) + assert hidden_dim, "couldn't infer hidden_dim from model config" + + # Per-layer output: [n_concepts, hidden] + per_layer_vectors = torch.zeros( + (len(target_layers), len(emotions), hidden_dim), dtype=torch.float32 + ) + + aggregator = _aggregator_from_name(args.aggregator) + + for e_idx, emotion in enumerate(emotions): + samples = _samples_for_concept( + emotion, + positives_by_emotion, + baselines, + max_negatives_per_positive=args.max_negatives_per_positive, + ) + if not samples: + print(f" [{e_idx + 1}/{len(emotions)}] {emotion}: NO SAMPLES, skipping") + continue + + sv = train_steering_vector( + model, + tokenizer, + samples, + layers=target_layers, + aggregator=aggregator, + batch_size=args.batch_size, + show_progress=False, + ) + # sv.layer_activations is a dict {layer_idx: tensor[hidden]} + for l_idx, layer in enumerate(target_layers): + vec = sv.layer_activations.get(layer) + if vec is None: + print(f" WARN: no vector returned for layer {layer} on {emotion}") + continue + vec = vec.detach().to(torch.float32).cpu() + vec = vec / vec.norm().clamp_min(1e-6) + per_layer_vectors[l_idx, e_idx] = vec + + if e_idx < 5 or e_idx == len(emotions) - 1 or e_idx % 10 == 0: + print( + f" [{e_idx + 1}/{len(emotions)}] {emotion}: " + f"n_samples={len(samples)} layers={target_layers}" + ) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + tensors = { + f"layer_{target_layers[l_idx]}.vectors": per_layer_vectors[l_idx].to( + torch.float16 + ) + for l_idx in range(len(target_layers)) + } + safetensors.torch.save_file(tensors, str(output_dir / "readout.safetensors")) + (output_dir / "readout.json").write_text( + json.dumps( + { + "concepts": emotions, + "layers": target_layers, + "hidden_size": hidden_dim, + "dtype": "float16", + "aggregator": args.aggregator, + }, + indent=2, + ) + + "\n" + ) + + total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) + print( + f"\nWrote readout.safetensors + readout.json to {output_dir} " + f"({len(emotions)} concepts x {len(target_layers)} layers, {total_mb:.1f} MiB)" + ) + + +if __name__ == "__main__": + main()