#!/usr/bin/env python3 """Nightly training process for Apollo-Mini fine-tuning. Imports vLLM's model weights via CUDA IPC, runs context-frozen training on flagged conversation segments, saves updated checkpoint. Usage: python3 train.py \ --weights /tmp/vllm_weight_handles.pt \ --examples training-examples.jsonl \ --checkpoint-dir checkpoints/ \ --lr 1e-5 """ import argparse import json import os import sys import time from datetime import datetime from pathlib import Path import torch from safetensors.torch import save_file from apollo_mini import ApolloMini def import_weights(handle_path: str) -> dict[str, torch.Tensor]: """Import weight tensors from CUDA IPC handles.""" handles = torch.load(handle_path, weights_only=False) params = {} for name, info in handles.items(): func, args = info['handle'] tensor = func(*args) params[name] = tensor return params def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]: """Split parameters into Apollo-Mini and standard groups. Apollo-Mini needs 2D+ matrices with min dimension >= 2. Small tensors (norms, biases, conv1d 3D weights) use standard Adam. """ apollo_params = [] standard_params = [] for name, p in params.items(): p.requires_grad_(True) if p.ndim >= 2 and min(p.shape) >= 2: apollo_params.append(p) else: standard_params.append(p) groups = [] if apollo_params: groups.append({ 'params': apollo_params, 'name': 'apollo', }) if standard_params: groups.append({ 'params': standard_params, 'name': 'standard', }) n_apollo = sum(p.nelement() for p in apollo_params) n_standard = sum(p.nelement() for p in standard_params) print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M") return groups def forward_pass(params, input_ids, context_len, device): """Run context-frozen forward pass. Args: params: dict of name -> tensor (shared with vLLM) input_ids: full sequence [1, seq_len] context_len: number of context tokens (no gradient) device: CUDA device Returns: logits for decision tokens, target ids for loss """ # TODO: Build proper forward model matching vLLM's weight layout. # For now this is a placeholder — the real implementation needs # to replicate vLLM's model architecture (merged projections, # GDN recurrence, full attention, MLP) using the shared weights. raise NotImplementedError( "Forward model not yet implemented. " "Need to build a model that matches vLLM's merged weight layout " "(MergedColumnParallelLinear for qkvz/ba/gate_up, " "RowParallelLinear for out_proj/down) and computes the same " "forward pass with autograd enabled." ) def save_checkpoint(params: dict[str, torch.Tensor], checkpoint_dir: str, config_path: str = None): """Save model checkpoint in HuggingFace safetensors format. Saves weights split across shards matching the original model layout, archives the previous checkpoint, and updates the 'latest' symlink. """ date_str = datetime.now().strftime("%Y-%m-%d") out_dir = Path(checkpoint_dir) / date_str out_dir.mkdir(parents=True, exist_ok=True) # Save all weights in a single safetensors file for now. # TODO: split across shards matching HF model index for large models. tensors = {} for name, param in params.items(): tensors[name] = param.data.contiguous().cpu() save_path = out_dir / "model.safetensors" save_file(tensors, str(save_path)) print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)") # Copy config files if provided if config_path: import shutil config_dir = Path(config_path) for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json', 'generation_config.json']: src = config_dir / f if src.exists(): shutil.copy2(src, out_dir / f) # Update latest symlink latest = Path(checkpoint_dir) / "latest" if latest.is_symlink(): latest.unlink() latest.symlink_to(date_str) print(f"Updated {latest} -> {date_str}") return str(out_dir) def train_step(params, example, optimizer, device, log_entries): """Run one training step on a single example. Args: params: dict of name -> tensor example: dict with 'input_ids', 'context_len', 'target_ids' optimizer: ApolloMini instance device: CUDA device log_entries: list to append log dicts to Returns: loss value """ optimizer.zero_grad() input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0) context_len = example['context_len'] # Forward pass (context frozen, decision tokens with grad) logits, targets = forward_pass(params, input_ids, context_len, device) # Cross-entropy loss on decision tokens loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.shape[-1]), targets.view(-1), ) # Backward loss.backward() # Compute gradient stats before optimizer step total_grad_norm = 0.0 for p in params.values(): if p.grad is not None: total_grad_norm += p.grad.norm().item() ** 2 total_grad_norm = total_grad_norm ** 0.5 # Optimizer step optimizer.step() # Log log_entries.append({ 'example_id': example.get('id', 'unknown'), 'loss': loss.item(), 'grad_norm': total_grad_norm, 'timestamp': datetime.now().isoformat(), }) return loss.item() def main(): parser = argparse.ArgumentParser(description="Apollo-Mini training") parser.add_argument("--weights", required=True, help="Path to exported weight IPC handles") parser.add_argument("--examples", required=True, help="Path to training examples JSONL") parser.add_argument("--checkpoint-dir", default="checkpoints", help="Directory for saving checkpoints") parser.add_argument("--config-path", default=None, help="Path to model config files (for checkpoint)") parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") parser.add_argument("--warmup-steps", type=int, default=10, help="Learning rate warmup steps") parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--dry-run", action="store_true", help="Load weights and validate, don't train") args = parser.parse_args() print(f"Apollo-Mini Training") print(f" weights: {args.weights}") print(f" examples: {args.examples}") print(f" lr: {args.lr}") print() # Import weights print("Importing weights via CUDA IPC...") params = import_weights(args.weights) print(f" {len(params)} parameters imported") # Make parameter groups param_groups = make_param_groups(params) # Initialize optimizer optimizer = ApolloMini(param_groups, lr=args.lr, weight_decay=args.weight_decay, warmup_steps=args.warmup_steps) print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") if args.dry_run: print("\nDry run — weights imported and validated successfully.") return # Load training examples examples = [] with open(args.examples) as f: for line in f: examples.append(json.loads(line)) print(f" {len(examples)} training examples") # Training loop log_entries = [] print(f"\nTraining...") t0 = time.time() for i, example in enumerate(examples): loss = train_step(params, example, optimizer, 'cuda:0', log_entries) print(f" [{i+1}/{len(examples)}] loss={loss:.4f}") elapsed = time.time() - t0 print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s") print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") # Save checkpoint print("\nSaving checkpoint...") save_checkpoint(params, args.checkpoint_dir, args.config_path) # Save training log date_str = datetime.now().strftime("%Y-%m-%d") log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl" with open(log_path, 'w') as f: for entry in log_entries: f.write(json.dumps(entry) + '\n') print(f"Training log: {log_path}") if __name__ == '__main__': main()