# SPDX-License-Identifier: Apache-2.0 """Train concept-readout vectors using the steering-vectors library. Alternative to train_steering_vectors.py that uses the pip-installable steering-vectors library (github.com/steering-vectors/steering-vectors) instead of our hand-rolled diff-of-means + subspace machinery. The library ships multiple aggregators out of the box: mean — pos_mean - neg_mean, unit-normed. Equivalent to our default 'pooled' method. pca — concatenates [pos-neg, neg-pos] and takes the top PC. Implicit denoising: direction of maximum variance in the paired deltas, less sensitive to per-pair noise than plain mean. logistic — trains a logistic-regression classifier on centered activations; concept direction is the weight vector. L1 penalty gives an explicit sparse vector (zeroes out noise coords); L2 shrinks low-magnitude coords. linear — same, with linear regression. Output is the same readout.safetensors + readout.json format the trainer and vLLM plugin already understand. """ from __future__ import annotations import argparse import json import random from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer from steering_vectors import ( SteeringVectorTrainingSample, train_steering_vector, ) from steering_vectors.aggregators import ( mean_aggregator, pca_aggregator, logistic_aggregator, ) # Reuse corpus loader from the hand-rolled trainer. from training.amygdala_training.train_steering_vectors import _load_corpus def _load_direct_descriptions( direct_dir: Path, ) -> tuple[dict[str, list[str]], list[str]]: """Load first-person phenomenological descriptions from ``direct_dir``. Each ``{concept}.txt`` holds 1+ descriptions separated by blank lines. Files starting with ``_`` (e.g. ``_baseline.txt``) aren't concepts — their descriptions go into every concept's negative pool. Returns: (positives_by_concept, extra_baselines) """ positives: dict[str, list[str]] = {} baselines: list[str] = [] for f in sorted(direct_dir.glob("*.txt")): text = f.read_text() descs = [d.strip() for d in text.split("\n\n") if d.strip()] if f.stem.startswith("_"): baselines.extend(descs) else: positives[f.stem] = descs return positives, baselines def _chat_template_wrap(tokenizer, text: str) -> str: """Wrap raw text in a consistent chat template so positive/negative activations are in the same regime. Using one generic user prompt for both narrative stories and first-person direct descriptions: the prompt cancels in the pos-neg delta, so what remains is the assistant content.""" return tokenizer.apply_chat_template( [ {"role": "user", "content": "Say something."}, {"role": "assistant", "content": text}, ], tokenize=False, ) def _samples_for_concept( emotion: str, positives_by_emotion: dict[str, list[str]], baselines: list[str], *, max_negatives_per_positive: int = 3, seed: int = 0, wrap=None, ) -> list[SteeringVectorTrainingSample]: """Build paired (pos, neg) training samples for one concept. For each positive story of ``emotion``, pair it with up to ``max_negatives_per_positive`` randomly-sampled negatives drawn from: (a) other emotions' positive stories, (b) scenario baselines. ``wrap``, if given, is applied to both positive_str and negative_str (e.g. a chat-template wrapper). The library expects paired samples; we don't have true counterfactual pairs for all concepts, so we approximate with random cross-concept / baseline negatives. """ rng = random.Random(hash((emotion, seed)) & 0xFFFFFFFF) neg_pool: list[str] = list(baselines) for other, texts in positives_by_emotion.items(): if other == emotion: continue neg_pool.extend(texts) w = wrap if wrap is not None else (lambda s: s) samples: list[SteeringVectorTrainingSample] = [] for pos in positives_by_emotion[emotion]: if not neg_pool: continue picks = rng.sample(neg_pool, min(max_negatives_per_positive, len(neg_pool))) for neg in picks: samples.append( SteeringVectorTrainingSample( positive_str=w(pos), negative_str=w(neg), ) ) return samples def _fp32_wrap(inner): """Wrap an aggregator so activations are cast to fp32 first. torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU, and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors. """ def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor: return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32)) return wrapped def _pca_with_spectrum(spectrum_log: dict, concept_key: list[str]): """PCA aggregator that also records the eigenvalue spectrum of the pos-neg deltas under ``concept_key[0]`` in ``spectrum_log``. The key is passed by reference (a 1-element list) so we can rebind it per concept without recreating the aggregator closure.""" @torch.no_grad() def agg(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor: pos = pos_acts.to(torch.float32) neg = neg_acts.to(torch.float32) deltas = pos - neg # Uncentered PCA: concatenate deltas and -deltas (library convention). X = torch.cat([deltas, -deltas]) # Eigenvalues via SVD: sigma^2 are the variances along each PC. # torch.linalg.svd returns U, S, Vh where columns of Vh.T are PCs. _, s, vh = torch.linalg.svd(X, full_matrices=False) variances = (s ** 2) total = variances.sum().item() var_list = variances.tolist() first_pc_ratio = var_list[0] / total if total > 0 else 0.0 # Participation ratio over the FULL spectrum — includes noise tail. eff_dim_full = (total ** 2) / float((variances ** 2).sum().item() or 1.0) # Signal/noise split: find smallest k with cumulative variance ≥ 0.9, # then compute PR over just those top-k eigenvalues. If PCA denoising # is clean, eff_dim_signal should ≈ k_signal (the retained dims carry # roughly equal variance, with the noise tail dropped). cum = 0.0 k_signal = len(var_list) for i, v in enumerate(var_list): cum += v if cum / total >= 0.9: k_signal = i + 1 break top_vars = variances[:k_signal] top_total = top_vars.sum().item() eff_dim_signal = (top_total ** 2) / float((top_vars ** 2).sum().item() or 1.0) spectrum_log[concept_key[0]] = { "first_pc_ratio": round(first_pc_ratio, 4), "effective_dim_full": round(eff_dim_full, 3), "k_signal_at_90pct": k_signal, "effective_dim_signal": round(eff_dim_signal, 3), "top10_eigenvalues": [round(v, 4) for v in var_list[:10]], "total_variance": round(total, 4), } # Top-1 PC vec = vh[0] # Sign-flip so the direction aligns with most deltas (library convention). sign = torch.sign(torch.mean(deltas @ vec)) return sign * vec return agg def _aggregator_from_name(name: str): if name == "mean": return _fp32_wrap(mean_aggregator()) if name == "pca": return _fp32_wrap(pca_aggregator()) if name == "logistic": return _fp32_wrap(logistic_aggregator()) if name == "logistic_l1": return _fp32_wrap( logistic_aggregator( sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1} ) ) raise ValueError(f"unknown aggregator: {name}") def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True) ap.add_argument("--stories-dir", required=True) ap.add_argument("--paired-dir", default=None) ap.add_argument("--direct-dir", default=None, help="Optional: directory of {concept}.txt files with 1+ " "first-person descriptions separated by blank lines. " "Files starting with _ contribute to every concept's " "negative pool rather than being concepts themselves.") ap.add_argument("--chat-template", action="store_true", help="Wrap all text in assistant-role chat template. " "Recommended when --direct-dir is used.") ap.add_argument("--target-layers", required=True, help="Comma-separated layer indices") ap.add_argument("--output-dir", required=True) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=512) ap.add_argument("--device", default="cuda:0") ap.add_argument("--min-positives", type=int, default=1) ap.add_argument( "--aggregator", default="mean", choices=["mean", "pca", "logistic", "logistic_l1"], ) ap.add_argument("--max-negatives-per-positive", type=int, default=3) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ args.dtype ] stories_dir = Path(args.stories_dir) paired_dir = Path(args.paired_dir) if args.paired_dir else None positives_by_emotion, baselines = _load_corpus(stories_dir, paired_dir) if args.direct_dir: direct_pos, direct_baselines = _load_direct_descriptions(Path(args.direct_dir)) for concept, descs in direct_pos.items(): positives_by_emotion.setdefault(concept, []).extend(descs) baselines.extend(direct_baselines) print( f"Loaded {len(direct_pos)} direct-description concepts " f"+ {len(direct_baselines)} baselines from {args.direct_dir}" ) emotions = sorted( e for e, ps in positives_by_emotion.items() if len(ps) >= args.min_positives ) if not emotions: raise RuntimeError( f"no emotions with >= {args.min_positives} positives in {stories_dir}" ) print( f"Training {len(emotions)} concepts via steering-vectors " f"aggregator={args.aggregator!r} on layers={target_layers}" ) print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True ) model.eval() text_config = ( model.config.get_text_config() if hasattr(model.config, "get_text_config") else model.config ) hidden_dim = getattr(text_config, "hidden_size", None) or getattr( text_config, "hidden_dim", None ) assert hidden_dim, "couldn't infer hidden_dim from model config" # Per-layer output: [n_concepts, hidden] per_layer_vectors = torch.zeros( (len(target_layers), len(emotions), hidden_dim), dtype=torch.float32 ) # Optional spectrum-logging aggregator (only for --aggregator pca). spectrum_log: dict = {} concept_key = [""] if args.aggregator == "pca": aggregator = _pca_with_spectrum(spectrum_log, concept_key) else: aggregator = _aggregator_from_name(args.aggregator) wrap = (lambda s: _chat_template_wrap(tokenizer, s)) if args.chat_template else None if args.chat_template: sample_text = wrap(positives_by_emotion[emotions[0]][0]) print(f"\nSample templated input:\n{sample_text[:400]!r}\n") for e_idx, emotion in enumerate(emotions): samples = _samples_for_concept( emotion, positives_by_emotion, baselines, max_negatives_per_positive=args.max_negatives_per_positive, wrap=wrap, ) if not samples: print(f" [{e_idx + 1}/{len(emotions)}] {emotion}: NO SAMPLES, skipping") continue concept_key[0] = emotion # tell the aggregator which concept is being trained sv = train_steering_vector( model, tokenizer, samples, layers=target_layers, aggregator=aggregator, batch_size=args.batch_size, show_progress=False, move_to_cpu=True, ) # sv.layer_activations is a dict {layer_idx: tensor[hidden]} for l_idx, layer in enumerate(target_layers): vec = sv.layer_activations.get(layer) if vec is None: print(f" WARN: no vector returned for layer {layer} on {emotion}") continue vec = vec.detach().to(torch.float32).cpu() vec = vec / vec.norm().clamp_min(1e-6) per_layer_vectors[l_idx, e_idx] = vec if e_idx < 5 or e_idx == len(emotions) - 1 or e_idx % 10 == 0: print( f" [{e_idx + 1}/{len(emotions)}] {emotion}: " f"n_samples={len(samples)} layers={target_layers}" ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tensors = { f"layer_{target_layers[l_idx]}.vectors": per_layer_vectors[l_idx].to( torch.float16 ) for l_idx in range(len(target_layers)) } safetensors.torch.save_file(tensors, str(output_dir / "readout.safetensors")) (output_dir / "readout.json").write_text( json.dumps( { "concepts": emotions, "layers": target_layers, "hidden_size": hidden_dim, "dtype": "float16", "aggregator": args.aggregator, }, indent=2, ) + "\n" ) if spectrum_log: (output_dir / "spectrum.json").write_text(json.dumps(spectrum_log, indent=2) + "\n") print("\n=== eigenvalue spectrum per concept ===") print( " concept first_pc k_90pct " "eff_dim_signal eff_dim_full (signal/k ratio)" ) items = sorted(spectrum_log.items(), key=lambda kv: -kv[1]["first_pc_ratio"]) for concept, stats in items: k = stats["k_signal_at_90pct"] eff_sig = stats["effective_dim_signal"] ratio = eff_sig / k if k else 0.0 print( f" {concept:22s} " f"{stats['first_pc_ratio']:>8.3f} " f"{k:>7d} " f"{eff_sig:>14.2f} " f"{stats['effective_dim_full']:>12.2f} " f"({ratio:.2f})" ) total_mb = sum(t.numel() * 2 for t in tensors.values()) / (1024 * 1024) print( f"\nWrote readout.safetensors + readout.json to {output_dir} " f"({len(emotions)} concepts x {len(target_layers)} layers, {total_mb:.1f} MiB)" ) if __name__ == "__main__": main()