training: add preflight checks + progress logging to trainer
Review pass before running on b200. 27B model + 100+ story corpus means any misconfiguration costs real time; better to fail before model load and give visible progress during forwards. * Pre-load-model validation: stories-dir and paired-dir exist, corpus has >= min_positives emotions. * Per-batch progress log every 5 batches with elapsed + ETA. * Relative depth printed for target layers (e.g. "layer 40 (51%)"). * Skip empty .txt files with a warning rather than feeding the tokenizer an empty string. * Assert non-empty strings in _collect_activations. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
15737dfd92
commit
047da10123
1 changed files with 95 additions and 14 deletions
|
|
@ -95,10 +95,18 @@ def _collect_activations(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
|
*,
|
||||||
|
label: str = "",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Run texts through the model, capture residual stream at target
|
"""Run texts through the model, capture residual stream at target
|
||||||
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
layers, return ``[n_texts, n_target_layers, hidden_dim]`` fp32 on CPU.
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
assert all(isinstance(t, str) and t for t in texts), (
|
||||||
|
f"_collect_activations: empty or non-string text in {label!r}"
|
||||||
|
)
|
||||||
|
|
||||||
captures: dict[int, torch.Tensor] = {}
|
captures: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
def make_hook(idx: int):
|
def make_hook(idx: int):
|
||||||
|
|
@ -114,10 +122,12 @@ def _collect_activations(
|
||||||
]
|
]
|
||||||
|
|
||||||
out_rows: list[torch.Tensor] = []
|
out_rows: list[torch.Tensor] = []
|
||||||
|
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(0, len(texts), batch_size):
|
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||||
batch = texts[i : i + batch_size]
|
batch = texts[i : i + batch_size]
|
||||||
tok = tokenizer(
|
tok = tokenizer(
|
||||||
batch,
|
batch,
|
||||||
|
|
@ -137,8 +147,17 @@ def _collect_activations(
|
||||||
]
|
]
|
||||||
out_rows.append(torch.stack(per_layer, dim=1))
|
out_rows.append(torch.stack(per_layer, dim=1))
|
||||||
del tok, captures
|
del tok, captures
|
||||||
if (i // batch_size) % 10 == 0:
|
if b_idx % 10 == 0:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||||
|
print(
|
||||||
|
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||||
|
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
captures = {}
|
captures = {}
|
||||||
finally:
|
finally:
|
||||||
for h in handles:
|
for h in handles:
|
||||||
|
|
@ -156,13 +175,24 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
Cross-concept negatives are computed at training time from
|
Cross-concept negatives are computed at training time from
|
||||||
``positives_by_emotion`` — each emotion's negative set is the
|
``positives_by_emotion`` — each emotion's negative set is the
|
||||||
union of all other emotions' positives plus the baseline texts.
|
union of all other emotions' positives plus the baseline texts.
|
||||||
|
Empty .txt files are skipped with a warning.
|
||||||
"""
|
"""
|
||||||
|
def _read_nonempty(path: Path) -> str | None:
|
||||||
|
text = path.read_text().strip()
|
||||||
|
if not text:
|
||||||
|
print(
|
||||||
|
f" WARN: skipping empty story file {path.relative_to(path.parents[1]) if len(path.parents) >= 2 else path}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return text
|
||||||
|
|
||||||
positives: dict[str, list[str]] = {}
|
positives: dict[str, list[str]] = {}
|
||||||
for story_path in sorted(stories_dir.glob("*.txt")):
|
for story_path in sorted(stories_dir.glob("*.txt")):
|
||||||
|
text = _read_nonempty(story_path)
|
||||||
|
if text is None:
|
||||||
|
continue
|
||||||
emotion = story_path.stem
|
emotion = story_path.stem
|
||||||
positives.setdefault(emotion, []).append(
|
positives.setdefault(emotion, []).append(text)
|
||||||
story_path.read_text().strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
baselines: list[str] = []
|
baselines: list[str] = []
|
||||||
if paired_dir is not None and paired_dir.exists():
|
if paired_dir is not None and paired_dir.exists():
|
||||||
|
|
@ -171,14 +201,17 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
continue
|
continue
|
||||||
baseline_path = scenario_dir / "baseline.txt"
|
baseline_path = scenario_dir / "baseline.txt"
|
||||||
if baseline_path.exists():
|
if baseline_path.exists():
|
||||||
baselines.append(baseline_path.read_text().strip())
|
text = _read_nonempty(baseline_path)
|
||||||
|
if text is not None:
|
||||||
|
baselines.append(text)
|
||||||
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
for framing_path in sorted(scenario_dir.glob("*.txt")):
|
||||||
if framing_path.stem == "baseline":
|
if framing_path.stem == "baseline":
|
||||||
continue
|
continue
|
||||||
|
text = _read_nonempty(framing_path)
|
||||||
|
if text is None:
|
||||||
|
continue
|
||||||
emotion = framing_path.stem
|
emotion = framing_path.stem
|
||||||
positives.setdefault(emotion, []).append(
|
positives.setdefault(emotion, []).append(text)
|
||||||
framing_path.read_text().strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return positives, baselines
|
return positives, baselines
|
||||||
|
|
||||||
|
|
@ -225,6 +258,38 @@ def main() -> None:
|
||||||
"fp32": torch.float32,
|
"fp32": torch.float32,
|
||||||
}[args.dtype]
|
}[args.dtype]
|
||||||
|
|
||||||
|
# Preflight: corpus dirs exist before we pay the cost of loading a 27B model
|
||||||
|
stories_dir = Path(args.stories_dir)
|
||||||
|
if not stories_dir.is_dir():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"--stories-dir {stories_dir!s} does not exist or is not a dir"
|
||||||
|
)
|
||||||
|
if args.paired_dir is not None:
|
||||||
|
pd = Path(args.paired_dir)
|
||||||
|
if not pd.is_dir():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"--paired-dir {pd!s} does not exist or is not a dir"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quick corpus pre-scan so failures show up before we load the model.
|
||||||
|
positives_preview, baselines_preview = _load_corpus(
|
||||||
|
stories_dir,
|
||||||
|
Path(args.paired_dir) if args.paired_dir else None,
|
||||||
|
)
|
||||||
|
n_emotions_preview = sum(
|
||||||
|
1 for ps in positives_preview.values()
|
||||||
|
if len(ps) >= args.min_positives
|
||||||
|
)
|
||||||
|
if n_emotions_preview == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"corpus has 0 emotions with >= {args.min_positives} positive "
|
||||||
|
f"examples. Check {stories_dir} — is it the right directory?"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Corpus preflight: {n_emotions_preview} emotions (min_positives="
|
||||||
|
f"{args.min_positives}), {len(baselines_preview)} baselines"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
print(f"Loading {args.model} ({args.dtype}) on {args.device}...")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
|
|
@ -235,11 +300,20 @@ def main() -> None:
|
||||||
device_map=args.device,
|
device_map=args.device,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
hidden_dim = model.config.hidden_size
|
# Multimodal configs (Qwen3.5-27B, etc.) nest the text-model
|
||||||
n_model_layers = model.config.num_hidden_layers
|
# dimensions under a text_config subobject. get_text_config()
|
||||||
|
# returns that sub-config when present, else the top-level config.
|
||||||
|
text_config = (
|
||||||
|
model.config.get_text_config()
|
||||||
|
if hasattr(model.config, "get_text_config")
|
||||||
|
else model.config
|
||||||
|
)
|
||||||
|
hidden_dim = text_config.hidden_size
|
||||||
|
n_model_layers = text_config.num_hidden_layers
|
||||||
print(
|
print(
|
||||||
f"Model loaded. hidden_dim={hidden_dim}, "
|
f"Model loaded. hidden_dim={hidden_dim}, "
|
||||||
f"n_model_layers={n_model_layers} "
|
f"n_model_layers={n_model_layers} "
|
||||||
|
f"(text_config.model_type={getattr(text_config, 'model_type', '?')})"
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer_idx in target_layers:
|
for layer_idx in target_layers:
|
||||||
|
|
@ -248,6 +322,13 @@ def main() -> None:
|
||||||
f"target layer {layer_idx} out of range "
|
f"target layer {layer_idx} out of range "
|
||||||
f"[0, {n_model_layers})"
|
f"[0, {n_model_layers})"
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
"Target layers (relative depth): "
|
||||||
|
+ ", ".join(
|
||||||
|
f"{l} ({100 * l / (n_model_layers - 1):.0f}%)"
|
||||||
|
for l in target_layers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
positives_by_emotion, baselines = _load_corpus(
|
positives_by_emotion, baselines = _load_corpus(
|
||||||
Path(args.stories_dir),
|
Path(args.stories_dir),
|
||||||
|
|
@ -283,7 +364,7 @@ def main() -> None:
|
||||||
|
|
||||||
positive_acts = _collect_activations(
|
positive_acts = _collect_activations(
|
||||||
model, tokenizer, unique_positive_texts, target_layers, device,
|
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||||
args.batch_size, args.max_length,
|
args.batch_size, args.max_length, label="positives",
|
||||||
)
|
)
|
||||||
# positive_acts[i] corresponds to unique_positive_texts[i]
|
# positive_acts[i] corresponds to unique_positive_texts[i]
|
||||||
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
text_to_row = {t: i for i, t in enumerate(unique_positive_texts)}
|
||||||
|
|
@ -291,7 +372,7 @@ def main() -> None:
|
||||||
baseline_acts = (
|
baseline_acts = (
|
||||||
_collect_activations(
|
_collect_activations(
|
||||||
model, tokenizer, baselines, target_layers, device,
|
model, tokenizer, baselines, target_layers, device,
|
||||||
args.batch_size, args.max_length,
|
args.batch_size, args.max_length, label="baselines",
|
||||||
)
|
)
|
||||||
if baselines
|
if baselines
|
||||||
else torch.zeros(0, len(target_layers), hidden_dim)
|
else torch.zeros(0, len(target_layers), hidden_dim)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue