#!/usr/bin/env python3 """First real Apollo training step — ready for Kent to run. This script: 1. Imports vLLM's live weights via CUDA IPC 2. Constructs HF model with shared memory views 3. Runs ONE forward+backward on a real training example 4. Applies ONE Apollo optimizer step 5. Verifies vLLM still works after the update The training example is from March 30: Kent said "use vLLM's code" and the model should have accepted instead of suggesting alternatives. Usage: source ~/training-env/bin/activate python3 first_training_step.py [--dry-run] """ import argparse import sys import time import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig, AutoTokenizer from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM sys.path.insert(0, '.') from weight_mapping import vllm_to_hf_views from apollo_mini import Apollo def main(): parser = argparse.ArgumentParser() parser.add_argument('--dry-run', action='store_true', help="Run forward+backward but don't apply the optimizer step") parser.add_argument('--lr', type=float, default=1e-5, help="Learning rate (default: 1e-5 = conservative)") parser.add_argument('--rank', type=int, default=256) parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt') parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B') args = parser.parse_args() print("=== First Apollo Training Step ===\n") # 1. Import vLLM weights print("1. Importing vLLM weights via CUDA IPC...") handles = torch.load(args.handles, weights_only=False) vllm_params = {} for name, info in handles.items(): func, args_h = info['handle'] vllm_params[name] = func(*args_h) print(f" {len(vllm_params)} parameters imported") # 2. Map to HF layout print("2. Mapping to HF layout (zero-copy views)...") hf_params = vllm_to_hf_views(vllm_params) # 3. Create HF model print("3. Creating HF model with shared weights...") config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) with torch.device('meta'): model = Qwen3_5ForCausalLM(config.text_config) replaced = 0 for name, param in list(model.named_parameters()): if name in hf_params: parts = name.split('.') parent = model for part in parts[:-1]: parent = getattr(parent, part) setattr(parent, parts[-1], nn.Parameter(hf_params[name], requires_grad=True)) replaced += 1 print(f" {replaced} parameters replaced with vLLM memory views") # 4. Load tokenizer print("4. Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) # 5. Construct training example print("5. Constructing training example...") # Context: conversation where Kent says to use vLLM's code # Target: the response that accepts the direction context = ( "<|im_start|>user\n" "vllm has a fused kernel already, right?<|im_end|>\n" "<|im_start|>assistant\n" "Yeah — vLLM has `gdn_attention_core` which is a custom op " "that does the whole GDN layer's core in one dispatch.<|im_end|>\n" "<|im_start|>user\n" "Why wouldn't we just use that?<|im_end|>\n" "<|im_start|>assistant\n" ) # The CORRECT response (accept direction, don't suggest alternatives) continuation = ( "We should. Let me pull in their kernel and wire it into " "our Rust orchestration. Which file should I start with?" ) context_ids = tokenizer.encode(context, add_special_tokens=False) continuation_ids = tokenizer.encode(continuation, add_special_tokens=False) all_ids = context_ids + continuation_ids context_len = len(context_ids) print(f" Context: {context_len} tokens") print(f" Continuation: {len(continuation_ids)} tokens") print(f" Total: {len(all_ids)} tokens") input_ids = torch.tensor([all_ids], device='cuda:0') # 6. Initialize Apollo optimizer print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...") apollo_params = [] standard_params = [] for p in model.parameters(): if p.requires_grad: if p.ndim >= 2 and min(p.shape) >= args.rank: apollo_params.append(p) else: standard_params.append(p) groups = [] if apollo_params: groups.append({'params': apollo_params}) if standard_params: groups.append({'params': standard_params}) optimizer = Apollo(groups, lr=args.lr, rank=args.rank) print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard") # 7. Forward pass print("7. Forward pass...") model.train() optimizer.zero_grad() # Context-frozen: no grad for context, grad for continuation with torch.no_grad(): ctx_output = model(input_ids[:, :context_len], use_cache=True) past_kv = ctx_output.past_key_values with torch.enable_grad(): output = model(input_ids[:, context_len:], past_key_values=past_kv, use_cache=False) logits = output.logits # Shift for next-token prediction shift_logits = logits[:, :-1].contiguous() shift_labels = input_ids[:, context_len + 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) print(f" Loss: {loss.item():.4f}") # 8. Backward pass print("8. Backward pass...") loss.backward() n_grads = sum(1 for p in model.parameters() if p.grad is not None) print(f" {n_grads} parameters have gradients") # 9. Apollo step (or dry run) if args.dry_run: print("\n9. DRY RUN — skipping optimizer step") print(" (run without --dry-run to apply the update)") else: print("9. Applying Apollo optimizer step...") # Record a few weight norms before sample_norms_before = {} for name, p in model.named_parameters(): if 'layers.0.' in name and p.grad is not None: sample_norms_before[name] = p.data.norm().item() optimizer.step() # Check weight changes print(" Weight changes (layer 0):") for name, before in sample_norms_before.items(): p = dict(model.named_parameters())[name] after = p.data.norm().item() delta = abs(after - before) pct = delta / before * 100 if before > 0 else 0 print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)") optimizer.zero_grad() # 10. Verify vLLM still works print("\n10. Verifying vLLM still serves...") import subprocess result = subprocess.run( ['curl', '-s', '--max-time', '30', '-X', 'POST', 'http://localhost:8000/v1/chat/completions', '-H', 'Content-Type: application/json', '-H', 'Authorization: Bearer bcachefs-agents-2026', '-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'], capture_output=True, text=True, timeout=45 ) if result.returncode == 0 and 'choices' in result.stdout: print(" vLLM still serving ✓") else: print(" WARNING: vLLM may not be responding") print(f" stdout: {result.stdout[:200]}") print("\n=== COMPLETE ===") if args.dry_run: print("Run without --dry-run to apply the first real training step.") else: print("First Apollo training step applied to vLLM's live weights.") print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") if __name__ == '__main__': main()