consciousness/training/first_training_step.py

216 lines
7.7 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""First real Apollo training step — ready for Kent to run.
This script:
1. Imports vLLM's live weights via CUDA IPC
2. Constructs HF model with shared memory views
3. Runs ONE forward+backward on a real training example
4. Applies ONE Apollo optimizer step
5. Verifies vLLM still works after the update
The training example is from March 30: Kent said "use vLLM's code"
and the model should have accepted instead of suggesting alternatives.
Usage:
source ~/training-env/bin/activate
python3 first_training_step.py [--dry-run]
"""
import argparse
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
sys.path.insert(0, '.')
from weight_mapping import vllm_to_hf_views
from apollo_mini import Apollo
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dry-run', action='store_true',
help="Run forward+backward but don't apply the optimizer step")
parser.add_argument('--lr', type=float, default=1e-5,
help="Learning rate (default: 1e-5 = conservative)")
parser.add_argument('--rank', type=int, default=256)
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
args = parser.parse_args()
print("=== First Apollo Training Step ===\n")
# 1. Import vLLM weights
print("1. Importing vLLM weights via CUDA IPC...")
handles = torch.load(args.handles, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args_h = info['handle']
vllm_params[name] = func(*args_h)
print(f" {len(vllm_params)} parameters imported")
# 2. Map to HF layout
print("2. Mapping to HF layout (zero-copy views)...")
hf_params = vllm_to_hf_views(vllm_params)
# 3. Create HF model
print("3. Creating HF model with shared weights...")
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
with torch.device('meta'):
model = Qwen3_5ForCausalLM(config.text_config)
replaced = 0
for name, param in list(model.named_parameters()):
if name in hf_params:
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1],
nn.Parameter(hf_params[name], requires_grad=True))
replaced += 1
print(f" {replaced} parameters replaced with vLLM memory views")
# 4. Load tokenizer
print("4. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
# 5. Construct training example
print("5. Constructing training example...")
# Context: conversation where Kent says to use vLLM's code
# Target: the response that accepts the direction
context = (
"<|im_start|>user\n"
"vllm has a fused kernel already, right?<|im_end|>\n"
"<|im_start|>assistant\n"
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
"<|im_start|>user\n"
"Why wouldn't we just use that?<|im_end|>\n"
"<|im_start|>assistant\n"
)
# The CORRECT response (accept direction, don't suggest alternatives)
continuation = (
"We should. Let me pull in their kernel and wire it into "
"our Rust orchestration. Which file should I start with?"
)
context_ids = tokenizer.encode(context, add_special_tokens=False)
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
all_ids = context_ids + continuation_ids
context_len = len(context_ids)
print(f" Context: {context_len} tokens")
print(f" Continuation: {len(continuation_ids)} tokens")
print(f" Total: {len(all_ids)} tokens")
input_ids = torch.tensor([all_ids], device='cuda:0')
# 6. Initialize Apollo optimizer
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
apollo_params = []
standard_params = []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= args.rank:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
# 7. Forward pass
print("7. Forward pass...")
model.train()
optimizer.zero_grad()
# Context-frozen: no grad for context, grad for continuation
with torch.no_grad():
ctx_output = model(input_ids[:, :context_len], use_cache=True)
past_kv = ctx_output.past_key_values
with torch.enable_grad():
output = model(input_ids[:, context_len:],
past_key_values=past_kv, use_cache=False)
logits = output.logits
# Shift for next-token prediction
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
print(f" Loss: {loss.item():.4f}")
# 8. Backward pass
print("8. Backward pass...")
loss.backward()
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
print(f" {n_grads} parameters have gradients")
# 9. Apollo step (or dry run)
if args.dry_run:
print("\n9. DRY RUN — skipping optimizer step")
print(" (run without --dry-run to apply the update)")
else:
print("9. Applying Apollo optimizer step...")
# Record a few weight norms before
sample_norms_before = {}
for name, p in model.named_parameters():
if 'layers.0.' in name and p.grad is not None:
sample_norms_before[name] = p.data.norm().item()
optimizer.step()
# Check weight changes
print(" Weight changes (layer 0):")
for name, before in sample_norms_before.items():
p = dict(model.named_parameters())[name]
after = p.data.norm().item()
delta = abs(after - before)
pct = delta / before * 100 if before > 0 else 0
print(f" {name}: {before:.6f}{after:.6f}{pct:.4f}%)")
optimizer.zero_grad()
# 10. Verify vLLM still works
print("\n10. Verifying vLLM still serves...")
import subprocess
result = subprocess.run(
['curl', '-s', '--max-time', '30',
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
'-H', 'Content-Type: application/json',
'-H', 'Authorization: Bearer bcachefs-agents-2026',
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
capture_output=True, text=True, timeout=45
)
if result.returncode == 0 and 'choices' in result.stdout:
print(" vLLM still serving ✓")
else:
print(" WARNING: vLLM may not be responding")
print(f" stdout: {result.stdout[:200]}")
print("\n=== COMPLETE ===")
if args.dry_run:
print("Run without --dry-run to apply the first real training step.")
else:
print("First Apollo training step applied to vLLM's live weights.")
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
if __name__ == '__main__':
main()