training: add preflight checks + progress logging to trainer

Review pass before running on b200. 27B model + 100+ story corpus
means any misconfiguration costs real time; better to fail before
model load and give visible progress during forwards.

* Pre-load-model validation: stories-dir and paired-dir exist,
  corpus has >= min_positives emotions.
* Per-batch progress log every 5 batches with elapsed + ETA.
* Relative depth printed for target layers (e.g. "layer 40 (51%)").
* Skip empty .txt files with a warning rather than feeding the
  tokenizer an empty string.
* Assert non-empty strings in _collect_activations.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 00:40:32 -04:00
parent 15737dfd92
commit 047da10123

View file

@ -95,10 +95,18 @@ def _collect_activations(
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> torch.Tensor:
"""Run texts through the model, capture residual stream at target
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
"""
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_activations: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
@ -114,10 +122,12 @@ def _collect_activations(
]
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for i in range(0, len(texts), batch_size):
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
@ -137,8 +147,17 @@ def _collect_activations(
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if (i // batch_size) % 10 == 0:
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
@ -156,13 +175,24 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
Cross-concept negatives are computed at training time from
``positives_by_emotion`` each emotion's negative set is the
union of all other emotions' positives plus the baseline texts.
Empty .txt files are skipped with a warning.
"""
def _read_nonempty(path: Path) -> str | None:
text = path.read_text().strip()
if not text:
print(
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
)
return None
return text
positives: dict[str, list[str]] = {}
for story_path in sorted(stories_dir.glob("*.txt")):
text = _read_nonempty(story_path)
if text is None:
continue
emotion = story_path.stem
positives.setdefault(emotion, []).append(
story_path.read_text().strip()
)
positives.setdefault(emotion, []).append(text)
baselines: list[str] = []
if paired_dir is not None and paired_dir.exists():
@ -171,14 +201,17 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
continue
baseline_path = scenario_dir / "baseline.txt"
if baseline_path.exists():
baselines.append(baseline_path.read_text().strip())
text = _read_nonempty(baseline_path)
if text is not None:
baselines.append(text)
for framing_path in sorted(scenario_dir.glob("*.txt")):
if framing_path.stem == "baseline":
continue
text = _read_nonempty(framing_path)
if text is None:
continue
emotion = framing_path.stem
positives.setdefault(emotion, []).append(
framing_path.read_text().strip()
)
positives.setdefault(emotion, []).append(text)
return positives, baselines
@ -225,6 +258,38 @@ def main() -> None:
"fp32": torch.float32,
}[args.dtype]
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
stories_dir = Path(args.stories_dir)
if not stories_dir.is_dir():
raise FileNotFoundError(
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
)
if args.paired_dir is not None:
pd = Path(args.paired_dir)
if not pd.is_dir():
raise FileNotFoundError(
f"--paired-dir {pd!s} does not exist or is not a dir"
)
# Quick corpus pre-scan so failures show up before we load the model.
positives_preview, baselines_preview = _load_corpus(
stories_dir,
Path(args.paired_dir) if args.paired_dir else None,
)
n_emotions_preview = sum(
1 for ps in positives_preview.values()
if len(ps) >= args.min_positives
)
if n_emotions_preview == 0:
raise RuntimeError(
f"corpus has 0 emotions with >= {args.min_positives} positive "
f"examples. Check {stories_dir} — is it the right directory?"
)
print(
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
f"{args.min_positives}), {len(baselines_preview)} baselines"
)
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None:
@ -235,11 +300,20 @@ def main() -> None:
device_map=args.device,
low_cpu_mem_usage=True,
)
hidden_dim = model.config.hidden_size
n_model_layers = model.config.num_hidden_layers
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
# dimensions under a text_config subobject. get_text_config()
# returns that sub-config when present, else the top-level config.
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
hidden_dim = text_config.hidden_size
n_model_layers = text_config.num_hidden_layers
print(
f"Model loaded. hidden_dim={hidden_dim}, "
f"n_model_layers={n_model_layers}"
f"n_model_layers={n_model_layers} "
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
)
for layer_idx in target_layers:
@ -248,6 +322,13 @@ def main() -> None:
f"target layer {layer_idx} out of range "
f"[0, {n_model_layers})"
)
print(
"Target layers (relative depth): "
+ ", ".join(
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
for l in target_layers
)
)
positives_by_emotion, baselines = _load_corpus(
Path(args.stories_dir),
@ -283,7 +364,7 @@ def main() -> None:
positive_acts = _collect_activations(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length,
args.batch_size, args.max_length, label="positives",
)
# positive_acts[i] corresponds to unique_positive_texts[i]
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
@ -291,7 +372,7 @@ def main() -> None:
baseline_acts = (
_collect_activations(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length,
args.batch_size, args.max_length, label="baselines",
)
if baselines
else torch.zeros(0, len(target_layers), hidden_dim)