# SPDX-License-Identifier: Apache-2.0 """Train concept-readout vectors from direct phenomenological descriptions. Alternative to story-based training (train_with_library.py). Each concept has a handful of 2-3 sentence first-person descriptions of the form "I feel X. [phenomenological detail]". The emotion word is the anchor; the description is the internal texture. Text is wrapped in the assistant-role chat template before being fed to the model, so we're training on "model-producing-this-utterance" hidden states — closer to the inhabited-state representation we want for readout. This avoids the scenario-contamination problem we saw with narrative stories: when concept X's training data all share "on a couch" setup features, PCA finds the couch-direction as the concept direction. """ from __future__ import annotations import argparse import json import random from pathlib import Path import safetensors.torch import torch from transformers import AutoModelForCausalLM, AutoTokenizer from steering_vectors import ( SteeringVectorTrainingSample, train_steering_vector, ) from steering_vectors.aggregators import pca_aggregator def _load_descriptions(direct_dir: Path) -> dict[str, list[str]]: """Each file in direct_dir is `{concept}.txt`. Descriptions are separated by blank lines within the file. Files starting with `_` are not concepts but are included in negative pools (e.g. _baseline.txt).""" out: dict[str, list[str]] = {} for f in sorted(direct_dir.glob("*.txt")): concept = f.stem # underscore-prefixed names keep their prefix text = f.read_text() descs = [d.strip() for d in text.split("\n\n") if d.strip()] out[concept] = descs return out def _fp32_wrap(inner): def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor: return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32)) return wrapped def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True) ap.add_argument("--direct-dir", required=True) ap.add_argument("--target-layers", required=True) ap.add_argument("--output-dir", required=True) ap.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--max-length", type=int, default=256) ap.add_argument("--device", default="cuda:0") ap.add_argument("--max-negatives-per-positive", type=int, default=20) args = ap.parse_args() target_layers = [int(x) for x in args.target_layers.split(",")] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype] all_descriptions = _load_descriptions(Path(args.direct_dir)) # Files starting with `_` are neg-pool helpers (e.g. _baseline.txt), not concepts. concepts = sorted(k for k in all_descriptions if not k.startswith("_")) neg_pool_extra: list[str] = [] for k, ds in all_descriptions.items(): if k.startswith("_"): neg_pool_extra.extend(ds) descriptions = {k: all_descriptions[k] for k in concepts} print(f"Loaded {len(concepts)} concepts with direct descriptions:") for c in concepts: print(f" {c}: {len(descriptions[c])} descriptions") if neg_pool_extra: print(f"Plus {len(neg_pool_extra)} neutral/baseline descriptions added to every concept's negative pool") print(f"\nLoading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, device_map=args.device, low_cpu_mem_usage=True ) model.eval() def apply_template(text: str) -> str: return tokenizer.apply_chat_template( [ {"role": "user", "content": "How do you feel right now?"}, {"role": "assistant", "content": text}, ], tokenize=False, ) text_config = ( model.config.get_text_config() if hasattr(model.config, "get_text_config") else model.config ) hidden_dim = getattr(text_config, "hidden_size", None) or getattr(text_config, "hidden_dim", None) assert hidden_dim, "couldn't infer hidden_dim from model config" per_layer_vectors = torch.zeros( (len(target_layers), len(concepts), hidden_dim), dtype=torch.float32 ) aggregator = _fp32_wrap(pca_aggregator()) # Preview a templated sample so we can eyeball what the model is seeing. sample_text = apply_template(descriptions[concepts[0]][0]) print(f"\nSample templated input (truncated):\n{sample_text[:400]!r}\n") for c_idx, concept in enumerate(concepts): pos_descs = descriptions[concept] neg_pool: list[str] = [] for other, other_descs in descriptions.items(): if other != concept: neg_pool.extend(other_descs) # Underscore-prefixed files (e.g. _baseline.txt) contribute to # every concept's negative pool, independent of the other- # concept negatives. neg_pool.extend(neg_pool_extra) rng = random.Random(hash(concept) & 0xFFFFFFFF) samples: list[SteeringVectorTrainingSample] = [] for pos in pos_descs: picks = rng.sample( neg_pool, min(args.max_negatives_per_positive, len(neg_pool)) ) for neg in picks: samples.append( SteeringVectorTrainingSample( positive_str=apply_template(pos), negative_str=apply_template(neg), ) ) sv = train_steering_vector( model, tokenizer, samples, layers=target_layers, aggregator=aggregator, batch_size=args.batch_size, show_progress=False, move_to_cpu=True, ) for l_idx, layer in enumerate(target_layers): vec = sv.layer_activations.get(layer) if vec is None: print(f" WARN: no vector for layer {layer} on {concept}") continue vec = vec.detach().to(torch.float32).cpu() vec = vec / vec.norm().clamp_min(1e-6) per_layer_vectors[l_idx, c_idx] = vec print(f" [{c_idx + 1}/{len(concepts)}] {concept}: n_samples={len(samples)}") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tensors = { f"layer_{target_layers[l_idx]}.vectors": per_layer_vectors[l_idx].to(torch.float16) for l_idx in range(len(target_layers)) } safetensors.torch.save_file(tensors, str(output_dir / "readout.safetensors")) (output_dir / "readout.json").write_text( json.dumps( { "concepts": concepts, "layers": target_layers, "hidden_size": hidden_dim, "dtype": "float16", "aggregator": "pca", "format": "direct_first_person_assistant_role", }, indent=2, ) + "\n" ) print(f"\nWrote readout to {output_dir}") if __name__ == "__main__": main()