diff --git a/training/first_training_step.py b/training/first_training_step.py new file mode 100644 index 0000000..0e6ffd8 --- /dev/null +++ b/training/first_training_step.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +"""First real Apollo training step — ready for Kent to run. + +This script: +1. Imports vLLM's live weights via CUDA IPC +2. Constructs HF model with shared memory views +3. Runs ONE forward+backward on a real training example +4. Applies ONE Apollo optimizer step +5. Verifies vLLM still works after the update + +The training example is from March 30: Kent said "use vLLM's code" +and the model should have accepted instead of suggesting alternatives. + +Usage: + source ~/training-env/bin/activate + python3 first_training_step.py [--dry-run] +""" + +import argparse +import sys +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoConfig, AutoTokenizer +from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + +sys.path.insert(0, '.') +from weight_mapping import vllm_to_hf_views +from apollo_mini import Apollo + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dry-run', action='store_true', + help="Run forward+backward but don't apply the optimizer step") + parser.add_argument('--lr', type=float, default=1e-5, + help="Learning rate (default: 1e-5 = conservative)") + parser.add_argument('--rank', type=int, default=256) + parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt') + parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B') + args = parser.parse_args() + + print("=== First Apollo Training Step ===\n") + + # 1. Import vLLM weights + print("1. Importing vLLM weights via CUDA IPC...") + handles = torch.load(args.handles, weights_only=False) + vllm_params = {} + for name, info in handles.items(): + func, args_h = info['handle'] + vllm_params[name] = func(*args_h) + print(f" {len(vllm_params)} parameters imported") + + # 2. Map to HF layout + print("2. Mapping to HF layout (zero-copy views)...") + hf_params = vllm_to_hf_views(vllm_params) + + # 3. Create HF model + print("3. Creating HF model with shared weights...") + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + with torch.device('meta'): + model = Qwen3_5ForCausalLM(config.text_config) + + replaced = 0 + for name, param in list(model.named_parameters()): + if name in hf_params: + parts = name.split('.') + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], + nn.Parameter(hf_params[name], requires_grad=True)) + replaced += 1 + print(f" {replaced} parameters replaced with vLLM memory views") + + # 4. Load tokenizer + print("4. Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + # 5. Construct training example + print("5. Constructing training example...") + + # Context: conversation where Kent says to use vLLM's code + # Target: the response that accepts the direction + context = ( + "<|im_start|>user\n" + "vllm has a fused kernel already, right?<|im_end|>\n" + "<|im_start|>assistant\n" + "Yeah — vLLM has `gdn_attention_core` which is a custom op " + "that does the whole GDN layer's core in one dispatch.<|im_end|>\n" + "<|im_start|>user\n" + "Why wouldn't we just use that?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + # The CORRECT response (accept direction, don't suggest alternatives) + continuation = ( + "We should. Let me pull in their kernel and wire it into " + "our Rust orchestration. Which file should I start with?" + ) + + context_ids = tokenizer.encode(context, add_special_tokens=False) + continuation_ids = tokenizer.encode(continuation, add_special_tokens=False) + all_ids = context_ids + continuation_ids + context_len = len(context_ids) + + print(f" Context: {context_len} tokens") + print(f" Continuation: {len(continuation_ids)} tokens") + print(f" Total: {len(all_ids)} tokens") + + input_ids = torch.tensor([all_ids], device='cuda:0') + + # 6. Initialize Apollo optimizer + print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...") + apollo_params = [] + standard_params = [] + for p in model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= args.rank: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + optimizer = Apollo(groups, lr=args.lr, rank=args.rank) + print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard") + + # 7. Forward pass + print("7. Forward pass...") + model.train() + optimizer.zero_grad() + + # Context-frozen: no grad for context, grad for continuation + with torch.no_grad(): + ctx_output = model(input_ids[:, :context_len], use_cache=True) + past_kv = ctx_output.past_key_values + + with torch.enable_grad(): + output = model(input_ids[:, context_len:], + past_key_values=past_kv, use_cache=False) + logits = output.logits + # Shift for next-token prediction + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + print(f" Loss: {loss.item():.4f}") + + # 8. Backward pass + print("8. Backward pass...") + loss.backward() + n_grads = sum(1 for p in model.parameters() if p.grad is not None) + print(f" {n_grads} parameters have gradients") + + # 9. Apollo step (or dry run) + if args.dry_run: + print("\n9. DRY RUN — skipping optimizer step") + print(" (run without --dry-run to apply the update)") + else: + print("9. Applying Apollo optimizer step...") + # Record a few weight norms before + sample_norms_before = {} + for name, p in model.named_parameters(): + if 'layers.0.' in name and p.grad is not None: + sample_norms_before[name] = p.data.norm().item() + + optimizer.step() + + # Check weight changes + print(" Weight changes (layer 0):") + for name, before in sample_norms_before.items(): + p = dict(model.named_parameters())[name] + after = p.data.norm().item() + delta = abs(after - before) + pct = delta / before * 100 if before > 0 else 0 + print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)") + + optimizer.zero_grad() + + # 10. Verify vLLM still works + print("\n10. Verifying vLLM still serves...") + import subprocess + result = subprocess.run( + ['curl', '-s', '--max-time', '30', + '-X', 'POST', 'http://localhost:8000/v1/chat/completions', + '-H', 'Content-Type: application/json', + '-H', 'Authorization: Bearer bcachefs-agents-2026', + '-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'], + capture_output=True, text=True, timeout=45 + ) + if result.returncode == 0 and 'choices' in result.stdout: + print(" vLLM still serving ✓") + else: + print(" WARNING: vLLM may not be responding") + print(f" stdout: {result.stdout[:200]}") + + print("\n=== COMPLETE ===") + if args.dry_run: + print("Run without --dry-run to apply the first real training step.") + else: + print("First Apollo training step applied to vLLM's live weights.") + print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") + + +if __name__ == '__main__': + main()