diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 21e5ed1..d06a35a 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -95,10 +95,18 @@ def _collect_activations( device: torch.device, batch_size: int, max_length: int, + *, + label: str = "", ) -> torch.Tensor: """Run texts through the model, capture residual stream at target layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. """ + import time + + assert all(isinstance(t, str) and t for t in texts), ( + f"_collect_activations: empty or non-string text in {label!r}" + ) + captures: dict[int, torch.Tensor] = {} def make_hook(idx: int): @@ -114,10 +122,12 @@ def _collect_activations( ] out_rows: list[torch.Tensor] = [] + n_batches = (len(texts) + batch_size - 1) // batch_size + start = time.time() try: model.eval() with torch.no_grad(): - for i in range(0, len(texts), batch_size): + for b_idx, i in enumerate(range(0, len(texts), batch_size)): batch = texts[i : i + batch_size] tok = tokenizer( batch, @@ -137,8 +147,17 @@ def _collect_activations( ] out_rows.append(torch.stack(per_layer, dim=1)) del tok, captures - if (i // batch_size) % 10 == 0: + if b_idx % 10 == 0: torch.cuda.empty_cache() + if b_idx % 5 == 0 or b_idx == n_batches - 1: + elapsed = time.time() - start + rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 + eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 + print( + f" [{label}] batch {b_idx + 1}/{n_batches} " + f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", + flush=True, + ) captures = {} finally: for h in handles: @@ -156,13 +175,24 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ Cross-concept negatives are computed at training time from ``positives_by_emotion`` — each emotion's negative set is the union of all other emotions' positives plus the baseline texts. + Empty .txt files are skipped with a warning. """ + def _read_nonempty(path: Path) -> str | None: + text = path.read_text().strip() + if not text: + print( + f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}" + ) + return None + return text + positives: dict[str, list[str]] = {} for story_path in sorted(stories_dir.glob("*.txt")): + text = _read_nonempty(story_path) + if text is None: + continue emotion = story_path.stem - positives.setdefault(emotion, []).append( - story_path.read_text().strip() - ) + positives.setdefault(emotion, []).append(text) baselines: list[str] = [] if paired_dir is not None and paired_dir.exists(): @@ -171,14 +201,17 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ continue baseline_path = scenario_dir / "baseline.txt" if baseline_path.exists(): - baselines.append(baseline_path.read_text().strip()) + text = _read_nonempty(baseline_path) + if text is not None: + baselines.append(text) for framing_path in sorted(scenario_dir.glob("*.txt")): if framing_path.stem == "baseline": continue + text = _read_nonempty(framing_path) + if text is None: + continue emotion = framing_path.stem - positives.setdefault(emotion, []).append( - framing_path.read_text().strip() - ) + positives.setdefault(emotion, []).append(text) return positives, baselines @@ -225,6 +258,38 @@ def main() -> None: "fp32": torch.float32, }[args.dtype] + # Preflight: corpus dirs exist before we pay the cost of loading a 27B model + stories_dir = Path(args.stories_dir) + if not stories_dir.is_dir(): + raise FileNotFoundError( + f"--stories-dir {stories_dir!s} does not exist or is not a dir" + ) + if args.paired_dir is not None: + pd = Path(args.paired_dir) + if not pd.is_dir(): + raise FileNotFoundError( + f"--paired-dir {pd!s} does not exist or is not a dir" + ) + + # Quick corpus pre-scan so failures show up before we load the model. + positives_preview, baselines_preview = _load_corpus( + stories_dir, + Path(args.paired_dir) if args.paired_dir else None, + ) + n_emotions_preview = sum( + 1 for ps in positives_preview.values() + if len(ps) >= args.min_positives + ) + if n_emotions_preview == 0: + raise RuntimeError( + f"corpus has 0 emotions with >= {args.min_positives} positive " + f"examples. Check {stories_dir} — is it the right directory?" + ) + print( + f"Corpus preflight: {n_emotions_preview} emotions (min_positives=" + f"{args.min_positives}), {len(baselines_preview)} baselines" + ) + print(f"Loading {args.model} ({args.dtype}) on {args.device}...") tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token_id is None: @@ -235,11 +300,20 @@ def main() -> None: device_map=args.device, low_cpu_mem_usage=True, ) - hidden_dim = model.config.hidden_size - n_model_layers = model.config.num_hidden_layers + # Multimodal configs (Qwen3.5-27B, etc.) nest the text-model + # dimensions under a text_config subobject. get_text_config() + # returns that sub-config when present, else the top-level config. + text_config = ( + model.config.get_text_config() + if hasattr(model.config, "get_text_config") + else model.config + ) + hidden_dim = text_config.hidden_size + n_model_layers = text_config.num_hidden_layers print( f"Model loaded. hidden_dim={hidden_dim}, " - f"n_model_layers={n_model_layers}" + f"n_model_layers={n_model_layers} " + f"(text_config.model_type={getattr(text_config, 'model_type', '?')})" ) for layer_idx in target_layers: @@ -248,6 +322,13 @@ def main() -> None: f"target layer {layer_idx} out of range " f"[0, {n_model_layers})" ) + print( + "Target layers (relative depth): " + + ", ".join( + f"{l} ({100 * l / (n_model_layers - 1):.0f}%)" + for l in target_layers + ) + ) positives_by_emotion, baselines = _load_corpus( Path(args.stories_dir), @@ -283,7 +364,7 @@ def main() -> None: positive_acts = _collect_activations( model, tokenizer, unique_positive_texts, target_layers, device, - args.batch_size, args.max_length, + args.batch_size, args.max_length, label="positives", ) # positive_acts[i] corresponds to unique_positive_texts[i] text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} @@ -291,7 +372,7 @@ def main() -> None: baseline_acts = ( _collect_activations( model, tokenizer, baselines, target_layers, device, - args.batch_size, args.max_length, + args.batch_size, args.max_length, label="baselines", ) if baselines else torch.zeros(0, len(target_layers), hidden_dim)