first_training_step.py: ready for Kent to run
Real training example from March 30 (listening reflex). Context-frozen forward+backward with Apollo rank-256. Supports --dry-run to test without modifying weights. Verifies vLLM still works after update. The button is ready. Kent pushes it.
This commit is contained in:
parent
0b835ddfb9
commit
d7a0fccdcc
1 changed files with 215 additions and 0 deletions
215
training/first_training_step.py
Normal file
215
training/first_training_step.py
Normal file
|
|
@ -0,0 +1,215 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""First real Apollo training step — ready for Kent to run.
|
||||||
|
|
||||||
|
This script:
|
||||||
|
1. Imports vLLM's live weights via CUDA IPC
|
||||||
|
2. Constructs HF model with shared memory views
|
||||||
|
3. Runs ONE forward+backward on a real training example
|
||||||
|
4. Applies ONE Apollo optimizer step
|
||||||
|
5. Verifies vLLM still works after the update
|
||||||
|
|
||||||
|
The training example is from March 30: Kent said "use vLLM's code"
|
||||||
|
and the model should have accepted instead of suggesting alternatives.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
source ~/training-env/bin/activate
|
||||||
|
python3 first_training_step.py [--dry-run]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
||||||
|
|
||||||
|
sys.path.insert(0, '.')
|
||||||
|
from weight_mapping import vllm_to_hf_views
|
||||||
|
from apollo_mini import Apollo
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dry-run', action='store_true',
|
||||||
|
help="Run forward+backward but don't apply the optimizer step")
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-5,
|
||||||
|
help="Learning rate (default: 1e-5 = conservative)")
|
||||||
|
parser.add_argument('--rank', type=int, default=256)
|
||||||
|
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
|
||||||
|
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=== First Apollo Training Step ===\n")
|
||||||
|
|
||||||
|
# 1. Import vLLM weights
|
||||||
|
print("1. Importing vLLM weights via CUDA IPC...")
|
||||||
|
handles = torch.load(args.handles, weights_only=False)
|
||||||
|
vllm_params = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args_h = info['handle']
|
||||||
|
vllm_params[name] = func(*args_h)
|
||||||
|
print(f" {len(vllm_params)} parameters imported")
|
||||||
|
|
||||||
|
# 2. Map to HF layout
|
||||||
|
print("2. Mapping to HF layout (zero-copy views)...")
|
||||||
|
hf_params = vllm_to_hf_views(vllm_params)
|
||||||
|
|
||||||
|
# 3. Create HF model
|
||||||
|
print("3. Creating HF model with shared weights...")
|
||||||
|
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
||||||
|
with torch.device('meta'):
|
||||||
|
model = Qwen3_5ForCausalLM(config.text_config)
|
||||||
|
|
||||||
|
replaced = 0
|
||||||
|
for name, param in list(model.named_parameters()):
|
||||||
|
if name in hf_params:
|
||||||
|
parts = name.split('.')
|
||||||
|
parent = model
|
||||||
|
for part in parts[:-1]:
|
||||||
|
parent = getattr(parent, part)
|
||||||
|
setattr(parent, parts[-1],
|
||||||
|
nn.Parameter(hf_params[name], requires_grad=True))
|
||||||
|
replaced += 1
|
||||||
|
print(f" {replaced} parameters replaced with vLLM memory views")
|
||||||
|
|
||||||
|
# 4. Load tokenizer
|
||||||
|
print("4. Loading tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# 5. Construct training example
|
||||||
|
print("5. Constructing training example...")
|
||||||
|
|
||||||
|
# Context: conversation where Kent says to use vLLM's code
|
||||||
|
# Target: the response that accepts the direction
|
||||||
|
context = (
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
"vllm has a fused kernel already, right?<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
|
||||||
|
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
"Why wouldn't we just use that?<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The CORRECT response (accept direction, don't suggest alternatives)
|
||||||
|
continuation = (
|
||||||
|
"We should. Let me pull in their kernel and wire it into "
|
||||||
|
"our Rust orchestration. Which file should I start with?"
|
||||||
|
)
|
||||||
|
|
||||||
|
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
||||||
|
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
||||||
|
all_ids = context_ids + continuation_ids
|
||||||
|
context_len = len(context_ids)
|
||||||
|
|
||||||
|
print(f" Context: {context_len} tokens")
|
||||||
|
print(f" Continuation: {len(continuation_ids)} tokens")
|
||||||
|
print(f" Total: {len(all_ids)} tokens")
|
||||||
|
|
||||||
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||||
|
|
||||||
|
# 6. Initialize Apollo optimizer
|
||||||
|
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
|
||||||
|
apollo_params = []
|
||||||
|
standard_params = []
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
if p.ndim >= 2 and min(p.shape) >= args.rank:
|
||||||
|
apollo_params.append(p)
|
||||||
|
else:
|
||||||
|
standard_params.append(p)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
if apollo_params:
|
||||||
|
groups.append({'params': apollo_params})
|
||||||
|
if standard_params:
|
||||||
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
|
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
|
||||||
|
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
|
||||||
|
|
||||||
|
# 7. Forward pass
|
||||||
|
print("7. Forward pass...")
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Context-frozen: no grad for context, grad for continuation
|
||||||
|
with torch.no_grad():
|
||||||
|
ctx_output = model(input_ids[:, :context_len], use_cache=True)
|
||||||
|
past_kv = ctx_output.past_key_values
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = model(input_ids[:, context_len:],
|
||||||
|
past_key_values=past_kv, use_cache=False)
|
||||||
|
logits = output.logits
|
||||||
|
# Shift for next-token prediction
|
||||||
|
shift_logits = logits[:, :-1].contiguous()
|
||||||
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1),
|
||||||
|
)
|
||||||
|
print(f" Loss: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# 8. Backward pass
|
||||||
|
print("8. Backward pass...")
|
||||||
|
loss.backward()
|
||||||
|
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
||||||
|
print(f" {n_grads} parameters have gradients")
|
||||||
|
|
||||||
|
# 9. Apollo step (or dry run)
|
||||||
|
if args.dry_run:
|
||||||
|
print("\n9. DRY RUN — skipping optimizer step")
|
||||||
|
print(" (run without --dry-run to apply the update)")
|
||||||
|
else:
|
||||||
|
print("9. Applying Apollo optimizer step...")
|
||||||
|
# Record a few weight norms before
|
||||||
|
sample_norms_before = {}
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if 'layers.0.' in name and p.grad is not None:
|
||||||
|
sample_norms_before[name] = p.data.norm().item()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Check weight changes
|
||||||
|
print(" Weight changes (layer 0):")
|
||||||
|
for name, before in sample_norms_before.items():
|
||||||
|
p = dict(model.named_parameters())[name]
|
||||||
|
after = p.data.norm().item()
|
||||||
|
delta = abs(after - before)
|
||||||
|
pct = delta / before * 100 if before > 0 else 0
|
||||||
|
print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)")
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# 10. Verify vLLM still works
|
||||||
|
print("\n10. Verifying vLLM still serves...")
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
['curl', '-s', '--max-time', '30',
|
||||||
|
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
|
||||||
|
'-H', 'Content-Type: application/json',
|
||||||
|
'-H', 'Authorization: Bearer bcachefs-agents-2026',
|
||||||
|
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
|
||||||
|
capture_output=True, text=True, timeout=45
|
||||||
|
)
|
||||||
|
if result.returncode == 0 and 'choices' in result.stdout:
|
||||||
|
print(" vLLM still serving ✓")
|
||||||
|
else:
|
||||||
|
print(" WARNING: vLLM may not be responding")
|
||||||
|
print(f" stdout: {result.stdout[:200]}")
|
||||||
|
|
||||||
|
print("\n=== COMPLETE ===")
|
||||||
|
if args.dry_run:
|
||||||
|
print("Run without --dry-run to apply the first real training step.")
|
||||||
|
else:
|
||||||
|
print("First Apollo training step applied to vLLM's live weights.")
|
||||||
|
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue