diff --git a/training/amygdala_training/train_direct.py b/training/amygdala_training/train_direct.py index 8749e37..2ad2a30 100644 --- a/training/amygdala_training/train_direct.py +++ b/training/amygdala_training/train_direct.py @@ -35,12 +35,11 @@ from steering_vectors.aggregators import pca_aggregator def _load_descriptions(direct_dir: Path) -> dict[str, list[str]]: """Each file in direct_dir is `{concept}.txt`. Descriptions are - separated by blank lines within the file.""" + separated by blank lines within the file. Files starting with `_` + are not concepts but are included in negative pools (e.g. _baseline.txt).""" out: dict[str, list[str]] = {} for f in sorted(direct_dir.glob("*.txt")): - if f.name.startswith("_"): - continue - concept = f.stem + concept = f.stem # underscore-prefixed names keep their prefix text = f.read_text() descs = [d.strip() for d in text.split("\n\n") if d.strip()] out[concept] = descs @@ -69,11 +68,19 @@ def main() -> None: target_layers = [int(x) for x in args.target_layers.split(",")] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype] - descriptions = _load_descriptions(Path(args.direct_dir)) - concepts = sorted(descriptions.keys()) + all_descriptions = _load_descriptions(Path(args.direct_dir)) + # Files starting with `_` are neg-pool helpers (e.g. _baseline.txt), not concepts. + concepts = sorted(k for k in all_descriptions if not k.startswith("_")) + neg_pool_extra: list[str] = [] + for k, ds in all_descriptions.items(): + if k.startswith("_"): + neg_pool_extra.extend(ds) + descriptions = {k: all_descriptions[k] for k in concepts} print(f"Loaded {len(concepts)} concepts with direct descriptions:") for c in concepts: print(f" {c}: {len(descriptions[c])} descriptions") + if neg_pool_extra: + print(f"Plus {len(neg_pool_extra)} neutral/baseline descriptions added to every concept's negative pool") print(f"\nLoading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) @@ -117,6 +124,10 @@ def main() -> None: for other, other_descs in descriptions.items(): if other != concept: neg_pool.extend(other_descs) + # Underscore-prefixed files (e.g. _baseline.txt) contribute to + # every concept's negative pool, independent of the other- + # concept negatives. + neg_pool.extend(neg_pool_extra) rng = random.Random(hash(concept) & 0xFFFFFFFF) samples: list[SteeringVectorTrainingSample] = []