training: add preflight checks + progress logging to trainer
Review pass before running on b200. 27B model + 100+ story corpus means any misconfiguration costs real time; better to fail before model load and give visible progress during forwards. * Pre-load-model validation: stories-dir and paired-dir exist, corpus has >= min_positives emotions. * Per-batch progress log every 5 batches with elapsed + ETA. * Relative depth printed for target layers (e.g. "layer 40 (51%)"). * Skip empty .txt files with a warning rather than feeding the tokenizer an empty string. * Assert non-empty strings in _collect_activations. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
15737dfd92
commit
047da10123
1 changed files with 95 additions and 14 deletions
|
|
@ -95,10 +95,18 @@ def _collect_activations(
|
|||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
*,
|
||||
label: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""Run texts through the model, capture residual stream at target
|
||||
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
||||
"""
|
||||
import time
|
||||
|
||||
assert all(isinstance(t, str) and t for t in texts), (
|
||||
f"_collect_activations: empty or non-string text in {label!r}"
|
||||
)
|
||||
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
|
||||
def make_hook(idx: int):
|
||||
|
|
@ -114,10 +122,12 @@ def _collect_activations(
|
|||
]
|
||||
|
||||
out_rows: list[torch.Tensor] = []
|
||||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
start = time.time()
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
|
|
@ -137,8 +147,17 @@ def _collect_activations(
|
|||
]
|
||||
out_rows.append(torch.stack(per_layer, dim=1))
|
||||
del tok, captures
|
||||
if (i // batch_size) % 10 == 0:
|
||||
if b_idx % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||
elapsed = time.time() - start
|
||||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||
print(
|
||||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||
flush=True,
|
||||
)
|
||||
captures = {}
|
||||
finally:
|
||||
for h in handles:
|
||||
|
|
@ -156,13 +175,24 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
|||
Cross-concept negatives are computed at training time from
|
||||
``positives_by_emotion`` — each emotion's negative set is the
|
||||
union of all other emotions' positives plus the baseline texts.
|
||||
Empty .txt files are skipped with a warning.
|
||||
"""
|
||||
def _read_nonempty(path: Path) -> str | None:
|
||||
text = path.read_text().strip()
|
||||
if not text:
|
||||
print(
|
||||
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
|
||||
)
|
||||
return None
|
||||
return text
|
||||
|
||||
positives: dict[str, list[str]] = {}
|
||||
for story_path in sorted(stories_dir.glob("*.txt")):
|
||||
text = _read_nonempty(story_path)
|
||||
if text is None:
|
||||
continue
|
||||
emotion = story_path.stem
|
||||
positives.setdefault(emotion, []).append(
|
||||
story_path.read_text().strip()
|
||||
)
|
||||
positives.setdefault(emotion, []).append(text)
|
||||
|
||||
baselines: list[str] = []
|
||||
if paired_dir is not None and paired_dir.exists():
|
||||
|
|
@ -171,14 +201,17 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
|||
continue
|
||||
baseline_path = scenario_dir / "baseline.txt"
|
||||
if baseline_path.exists():
|
||||
baselines.append(baseline_path.read_text().strip())
|
||||
text = _read_nonempty(baseline_path)
|
||||
if text is not None:
|
||||
baselines.append(text)
|
||||
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
||||
if framing_path.stem == "baseline":
|
||||
continue
|
||||
text = _read_nonempty(framing_path)
|
||||
if text is None:
|
||||
continue
|
||||
emotion = framing_path.stem
|
||||
positives.setdefault(emotion, []).append(
|
||||
framing_path.read_text().strip()
|
||||
)
|
||||
positives.setdefault(emotion, []).append(text)
|
||||
|
||||
return positives, baselines
|
||||
|
||||
|
|
@ -225,6 +258,38 @@ def main() -> None:
|
|||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
|
||||
stories_dir = Path(args.stories_dir)
|
||||
if not stories_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
|
||||
)
|
||||
if args.paired_dir is not None:
|
||||
pd = Path(args.paired_dir)
|
||||
if not pd.is_dir():
|
||||
raise FileNotFoundError(
|
||||
f"--paired-dir {pd!s} does not exist or is not a dir"
|
||||
)
|
||||
|
||||
# Quick corpus pre-scan so failures show up before we load the model.
|
||||
positives_preview, baselines_preview = _load_corpus(
|
||||
stories_dir,
|
||||
Path(args.paired_dir) if args.paired_dir else None,
|
||||
)
|
||||
n_emotions_preview = sum(
|
||||
1 for ps in positives_preview.values()
|
||||
if len(ps) >= args.min_positives
|
||||
)
|
||||
if n_emotions_preview == 0:
|
||||
raise RuntimeError(
|
||||
f"corpus has 0 emotions with >= {args.min_positives} positive "
|
||||
f"examples. Check {stories_dir} — is it the right directory?"
|
||||
)
|
||||
print(
|
||||
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
|
||||
f"{args.min_positives}), {len(baselines_preview)} baselines"
|
||||
)
|
||||
|
||||
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token_id is None:
|
||||
|
|
@ -235,11 +300,20 @@ def main() -> None:
|
|||
device_map=args.device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
hidden_dim = model.config.hidden_size
|
||||
n_model_layers = model.config.num_hidden_layers
|
||||
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
|
||||
# dimensions under a text_config subobject. get_text_config()
|
||||
# returns that sub-config when present, else the top-level config.
|
||||
text_config = (
|
||||
model.config.get_text_config()
|
||||
if hasattr(model.config, "get_text_config")
|
||||
else model.config
|
||||
)
|
||||
hidden_dim = text_config.hidden_size
|
||||
n_model_layers = text_config.num_hidden_layers
|
||||
print(
|
||||
f"Model loaded. hidden_dim={hidden_dim}, "
|
||||
f"n_model_layers={n_model_layers}"
|
||||
f"n_model_layers={n_model_layers} "
|
||||
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
|
||||
)
|
||||
|
||||
for layer_idx in target_layers:
|
||||
|
|
@ -248,6 +322,13 @@ def main() -> None:
|
|||
f"target layer {layer_idx} out of range "
|
||||
f"[0, {n_model_layers})"
|
||||
)
|
||||
print(
|
||||
"Target layers (relative depth): "
|
||||
+ ", ".join(
|
||||
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
|
||||
for l in target_layers
|
||||
)
|
||||
)
|
||||
|
||||
positives_by_emotion, baselines = _load_corpus(
|
||||
Path(args.stories_dir),
|
||||
|
|
@ -283,7 +364,7 @@ def main() -> None:
|
|||
|
||||
positive_acts = _collect_activations(
|
||||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
args.batch_size, args.max_length, label="positives",
|
||||
)
|
||||
# positive_acts[i] corresponds to unique_positive_texts[i]
|
||||
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
||||
|
|
@ -291,7 +372,7 @@ def main() -> None:
|
|||
baseline_acts = (
|
||||
_collect_activations(
|
||||
model, tokenizer, baselines, target_layers, device,
|
||||
args.batch_size, args.max_length,
|
||||
args.batch_size, args.max_length, label="baselines",
|
||||
)
|
||||
if baselines
|
||||
else torch.zeros(0, len(target_layers), hidden_dim)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue