first_training_step.py: ready for Kent to run
Real training example from March 30 (listening reflex). Context-frozen forward+backward with Apollo rank-256. Supports --dry-run to test without modifying weights. Verifies vLLM still works after update. The button is ready. Kent pushes it.
This commit is contained in:
parent
0b835ddfb9
commit
d7a0fccdcc
1 changed files with 215 additions and 0 deletions
215
training/first_training_step.py
Normal file
215
training/first_training_step.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
#!/usr/bin/env python3
|
||||
"""First real Apollo training step — ready for Kent to run.
|
||||
|
||||
This script:
|
||||
1. Imports vLLM's live weights via CUDA IPC
|
||||
2. Constructs HF model with shared memory views
|
||||
3. Runs ONE forward+backward on a real training example
|
||||
4. Applies ONE Apollo optimizer step
|
||||
5. Verifies vLLM still works after the update
|
||||
|
||||
The training example is from March 30: Kent said "use vLLM's code"
|
||||
and the model should have accepted instead of suggesting alternatives.
|
||||
|
||||
Usage:
|
||||
source ~/training-env/bin/activate
|
||||
python3 first_training_step.py [--dry-run]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
||||
|
||||
sys.path.insert(0, '.')
|
||||
from weight_mapping import vllm_to_hf_views
|
||||
from apollo_mini import Apollo
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help="Run forward+backward but don't apply the optimizer step")
|
||||
parser.add_argument('--lr', type=float, default=1e-5,
|
||||
help="Learning rate (default: 1e-5 = conservative)")
|
||||
parser.add_argument('--rank', type=int, default=256)
|
||||
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
|
||||
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=== First Apollo Training Step ===\n")
|
||||
|
||||
# 1. Import vLLM weights
|
||||
print("1. Importing vLLM weights via CUDA IPC...")
|
||||
handles = torch.load(args.handles, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args_h = info['handle']
|
||||
vllm_params[name] = func(*args_h)
|
||||
print(f" {len(vllm_params)} parameters imported")
|
||||
|
||||
# 2. Map to HF layout
|
||||
print("2. Mapping to HF layout (zero-copy views)...")
|
||||
hf_params = vllm_to_hf_views(vllm_params)
|
||||
|
||||
# 3. Create HF model
|
||||
print("3. Creating HF model with shared weights...")
|
||||
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
with torch.device('meta'):
|
||||
model = Qwen3_5ForCausalLM(config.text_config)
|
||||
|
||||
replaced = 0
|
||||
for name, param in list(model.named_parameters()):
|
||||
if name in hf_params:
|
||||
parts = name.split('.')
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[-1],
|
||||
nn.Parameter(hf_params[name], requires_grad=True))
|
||||
replaced += 1
|
||||
print(f" {replaced} parameters replaced with vLLM memory views")
|
||||
|
||||
# 4. Load tokenizer
|
||||
print("4. Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
# 5. Construct training example
|
||||
print("5. Constructing training example...")
|
||||
|
||||
# Context: conversation where Kent says to use vLLM's code
|
||||
# Target: the response that accepts the direction
|
||||
context = (
|
||||
"<|im_start|>user\n"
|
||||
"vllm has a fused kernel already, right?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
|
||||
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"Why wouldn't we just use that?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
# The CORRECT response (accept direction, don't suggest alternatives)
|
||||
continuation = (
|
||||
"We should. Let me pull in their kernel and wire it into "
|
||||
"our Rust orchestration. Which file should I start with?"
|
||||
)
|
||||
|
||||
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
||||
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
||||
all_ids = context_ids + continuation_ids
|
||||
context_len = len(context_ids)
|
||||
|
||||
print(f" Context: {context_len} tokens")
|
||||
print(f" Continuation: {len(continuation_ids)} tokens")
|
||||
print(f" Total: {len(all_ids)} tokens")
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
# 6. Initialize Apollo optimizer
|
||||
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
|
||||
apollo_params = []
|
||||
standard_params = []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= args.rank:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
|
||||
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
|
||||
|
||||
# 7. Forward pass
|
||||
print("7. Forward pass...")
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen: no grad for context, grad for continuation
|
||||
with torch.no_grad():
|
||||
ctx_output = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = ctx_output.past_key_values
|
||||
|
||||
with torch.enable_grad():
|
||||
output = model(input_ids[:, context_len:],
|
||||
past_key_values=past_kv, use_cache=False)
|
||||
logits = output.logits
|
||||
# Shift for next-token prediction
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
print(f" Loss: {loss.item():.4f}")
|
||||
|
||||
# 8. Backward pass
|
||||
print("8. Backward pass...")
|
||||
loss.backward()
|
||||
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
||||
print(f" {n_grads} parameters have gradients")
|
||||
|
||||
# 9. Apollo step (or dry run)
|
||||
if args.dry_run:
|
||||
print("\n9. DRY RUN — skipping optimizer step")
|
||||
print(" (run without --dry-run to apply the update)")
|
||||
else:
|
||||
print("9. Applying Apollo optimizer step...")
|
||||
# Record a few weight norms before
|
||||
sample_norms_before = {}
|
||||
for name, p in model.named_parameters():
|
||||
if 'layers.0.' in name and p.grad is not None:
|
||||
sample_norms_before[name] = p.data.norm().item()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# Check weight changes
|
||||
print(" Weight changes (layer 0):")
|
||||
for name, before in sample_norms_before.items():
|
||||
p = dict(model.named_parameters())[name]
|
||||
after = p.data.norm().item()
|
||||
delta = abs(after - before)
|
||||
pct = delta / before * 100 if before > 0 else 0
|
||||
print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 10. Verify vLLM still works
|
||||
print("\n10. Verifying vLLM still serves...")
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
['curl', '-s', '--max-time', '30',
|
||||
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
|
||||
'-H', 'Content-Type: application/json',
|
||||
'-H', 'Authorization: Bearer bcachefs-agents-2026',
|
||||
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
|
||||
capture_output=True, text=True, timeout=45
|
||||
)
|
||||
if result.returncode == 0 and 'choices' in result.stdout:
|
||||
print(" vLLM still serving ✓")
|
||||
else:
|
||||
print(" WARNING: vLLM may not be responding")
|
||||
print(f" stdout: {result.stdout[:200]}")
|
||||
|
||||
print("\n=== COMPLETE ===")
|
||||
if args.dry_run:
|
||||
print("Run without --dry-run to apply the first real training step.")
|
||||
else:
|
||||
print("First Apollo training step applied to vLLM's live weights.")
|
||||
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue