training: add preflight checks + progress logging to trainer

Review pass before running on b200. 27B model + 100+ story corpus
means any misconfiguration costs real time; better to fail before
model load and give visible progress during forwards.

* Pre-load-model validation: stories-dir and paired-dir exist,
  corpus has >= min_positives emotions.
* Per-batch progress log every 5 batches with elapsed + ETA.
* Relative depth printed for target layers (e.g. "layer 40 (51%)").
* Skip empty .txt files with a warning rather than feeding the
  tokenizer an empty string.
* Assert non-empty strings in _collect_activations.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 00:40:32 -04:00
parent 15737dfd92
commit 047da10123

View file

@ -95,10 +95,18 @@ def _collect_activations(
device: torch.device, device: torch.device,
batch_size: int, batch_size: int,
max_length: int, max_length: int,
*,
label: str = "",
) -> torch.Tensor: ) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target """Run texts through the model, capture residual stream at target
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU. layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
""" """
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_activations: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {} captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int): def make_hook(idx: int):
@ -114,10 +122,12 @@ def _collect_activations(
] ]
out_rows: list[torch.Tensor] = [] out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try: try:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for i in range(0, len(texts), batch_size): for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size] batch = texts[i : i + batch_size]
tok = tokenizer( tok = tokenizer(
batch, batch,
@ -137,8 +147,17 @@ def _collect_activations(
] ]
out_rows.append(torch.stack(per_layer, dim=1)) out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures del tok, captures
if (i // batch_size) % 10 == 0: if b_idx % 10 == 0:
torch.cuda.empty_cache() torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {} captures = {}
finally: finally:
for h in handles: for h in handles:
@ -156,13 +175,24 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
Cross-concept negatives are computed at training time from Cross-concept negatives are computed at training time from
``positives_by_emotion`` each emotion's negative set is the ``positives_by_emotion`` each emotion's negative set is the
union of all other emotions' positives plus the baseline texts. union of all other emotions' positives plus the baseline texts.
Empty .txt files are skipped with a warning.
""" """
def _read_nonempty(path: Path) -> str | None:
text = path.read_text().strip()
if not text:
print(
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
)
return None
return text
positives: dict[str, list[str]] = {} positives: dict[str, list[str]] = {}
for story_path in sorted(stories_dir.glob("*.txt")): for story_path in sorted(stories_dir.glob("*.txt")):
text = _read_nonempty(story_path)
if text is None:
continue
emotion = story_path.stem emotion = story_path.stem
positives.setdefault(emotion, []).append( positives.setdefault(emotion, []).append(text)
story_path.read_text().strip()
)
baselines: list[str] = [] baselines: list[str] = []
if paired_dir is not None and paired_dir.exists(): if paired_dir is not None and paired_dir.exists():
@ -171,14 +201,17 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
continue continue
baseline_path = scenario_dir / "baseline.txt" baseline_path = scenario_dir / "baseline.txt"
if baseline_path.exists(): if baseline_path.exists():
baselines.append(baseline_path.read_text().strip()) text = _read_nonempty(baseline_path)
if text is not None:
baselines.append(text)
for framing_path in sorted(scenario_dir.glob("*.txt")): for framing_path in sorted(scenario_dir.glob("*.txt")):
if framing_path.stem == "baseline": if framing_path.stem == "baseline":
continue continue
text = _read_nonempty(framing_path)
if text is None:
continue
emotion = framing_path.stem emotion = framing_path.stem
positives.setdefault(emotion, []).append( positives.setdefault(emotion, []).append(text)
framing_path.read_text().strip()
)
return positives, baselines return positives, baselines
@ -225,6 +258,38 @@ def main() -> None:
"fp32": torch.float32, "fp32": torch.float32,
}[args.dtype] }[args.dtype]
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
stories_dir = Path(args.stories_dir)
if not stories_dir.is_dir():
raise FileNotFoundError(
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
)
if args.paired_dir is not None:
pd = Path(args.paired_dir)
if not pd.is_dir():
raise FileNotFoundError(
f"--paired-dir {pd!s} does not exist or is not a dir"
)
# Quick corpus pre-scan so failures show up before we load the model.
positives_preview, baselines_preview = _load_corpus(
stories_dir,
Path(args.paired_dir) if args.paired_dir else None,
)
n_emotions_preview = sum(
1 for ps in positives_preview.values()
if len(ps) >= args.min_positives
)
if n_emotions_preview == 0:
raise RuntimeError(
f"corpus has 0 emotions with >= {args.min_positives} positive "
f"examples. Check {stories_dir} — is it the right directory?"
)
print(
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
f"{args.min_positives}), {len(baselines_preview)} baselines"
)
print(f"Loading {args.model} ({args.dtype}) on {args.device}...") print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
@ -235,11 +300,20 @@ def main() -> None:
device_map=args.device, device_map=args.device,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
hidden_dim = model.config.hidden_size # Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
n_model_layers = model.config.num_hidden_layers # dimensions under a text_config subobject. get_text_config()
# returns that sub-config when present, else the top-level config.
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
hidden_dim = text_config.hidden_size
n_model_layers = text_config.num_hidden_layers
print( print(
f"Model loaded. hidden_dim={hidden_dim}, " f"Model loaded. hidden_dim={hidden_dim}, "
f"n_model_layers={n_model_layers} " f"n_model_layers={n_model_layers} "
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
) )
for layer_idx in target_layers: for layer_idx in target_layers:
@ -248,6 +322,13 @@ def main() -> None:
f"target layer {layer_idx} out of range " f"target layer {layer_idx} out of range "
f"[0, {n_model_layers})" f"[0, {n_model_layers})"
) )
print(
"Target layers (relative depth): "
+ ", ".join(
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
for l in target_layers
)
)
positives_by_emotion, baselines = _load_corpus( positives_by_emotion, baselines = _load_corpus(
Path(args.stories_dir), Path(args.stories_dir),
@ -283,7 +364,7 @@ def main() -> None:
positive_acts = _collect_activations( positive_acts = _collect_activations(
model, tokenizer, unique_positive_texts, target_layers, device, model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, args.batch_size, args.max_length, label="positives",
) )
# positive_acts[i] corresponds to unique_positive_texts[i] # positive_acts[i] corresponds to unique_positive_texts[i]
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)} text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
@ -291,7 +372,7 @@ def main() -> None:
baseline_acts = ( baseline_acts = (
_collect_activations( _collect_activations(
model, tokenizer, baselines, target_layers, device, model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length, args.batch_size, args.max_length, label="baselines",
) )
if baselines if baselines
else torch.zeros(0, len(target_layers), hidden_dim) else torch.zeros(0, len(target_layers), hidden_dim)