Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC shared weight memory between vLLM and the training process: - apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality) - apollo_worker.py: HTTP daemon coordinating training with vLLM - weight_mapping.py: vLLM merged → HF separate layout (zero-copy views) - training_example.py: tokenization with chat template - export_weights.py: CUDA IPC handle export from vLLM - train.py: standalone training script (alternative to daemon) - DESIGN.md: architecture and protocol documentation Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200). Apollo-Mini rank-1 projection + scaling + in-place update confirmed. Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
269 lines
8.6 KiB
Python
269 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Nightly training process for Apollo-Mini fine-tuning.
|
|
|
|
Imports vLLM's model weights via CUDA IPC, runs context-frozen
|
|
training on flagged conversation segments, saves updated checkpoint.
|
|
|
|
Usage:
|
|
python3 train.py \
|
|
--weights /tmp/vllm_weight_handles.pt \
|
|
--examples training-examples.jsonl \
|
|
--checkpoint-dir checkpoints/ \
|
|
--lr 1e-5
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from safetensors.torch import save_file
|
|
|
|
from apollo_mini import ApolloMini
|
|
|
|
|
|
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
|
|
"""Import weight tensors from CUDA IPC handles."""
|
|
handles = torch.load(handle_path, weights_only=False)
|
|
params = {}
|
|
for name, info in handles.items():
|
|
func, args = info['handle']
|
|
tensor = func(*args)
|
|
params[name] = tensor
|
|
return params
|
|
|
|
|
|
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
|
|
"""Split parameters into Apollo-Mini and standard groups.
|
|
|
|
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
|
|
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
|
|
"""
|
|
apollo_params = []
|
|
standard_params = []
|
|
|
|
for name, p in params.items():
|
|
p.requires_grad_(True)
|
|
if p.ndim >= 2 and min(p.shape) >= 2:
|
|
apollo_params.append(p)
|
|
else:
|
|
standard_params.append(p)
|
|
|
|
groups = []
|
|
if apollo_params:
|
|
groups.append({
|
|
'params': apollo_params,
|
|
'name': 'apollo',
|
|
})
|
|
if standard_params:
|
|
groups.append({
|
|
'params': standard_params,
|
|
'name': 'standard',
|
|
})
|
|
|
|
n_apollo = sum(p.nelement() for p in apollo_params)
|
|
n_standard = sum(p.nelement() for p in standard_params)
|
|
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
|
|
return groups
|
|
|
|
|
|
def forward_pass(params, input_ids, context_len, device):
|
|
"""Run context-frozen forward pass.
|
|
|
|
Args:
|
|
params: dict of name -> tensor (shared with vLLM)
|
|
input_ids: full sequence [1, seq_len]
|
|
context_len: number of context tokens (no gradient)
|
|
device: CUDA device
|
|
|
|
Returns:
|
|
logits for decision tokens, target ids for loss
|
|
"""
|
|
# TODO: Build proper forward model matching vLLM's weight layout.
|
|
# For now this is a placeholder — the real implementation needs
|
|
# to replicate vLLM's model architecture (merged projections,
|
|
# GDN recurrence, full attention, MLP) using the shared weights.
|
|
raise NotImplementedError(
|
|
"Forward model not yet implemented. "
|
|
"Need to build a model that matches vLLM's merged weight layout "
|
|
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
|
|
"RowParallelLinear for out_proj/down) and computes the same "
|
|
"forward pass with autograd enabled."
|
|
)
|
|
|
|
|
|
def save_checkpoint(params: dict[str, torch.Tensor],
|
|
checkpoint_dir: str,
|
|
config_path: str = None):
|
|
"""Save model checkpoint in HuggingFace safetensors format.
|
|
|
|
Saves weights split across shards matching the original model layout,
|
|
archives the previous checkpoint, and updates the 'latest' symlink.
|
|
"""
|
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
out_dir = Path(checkpoint_dir) / date_str
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save all weights in a single safetensors file for now.
|
|
# TODO: split across shards matching HF model index for large models.
|
|
tensors = {}
|
|
for name, param in params.items():
|
|
tensors[name] = param.data.contiguous().cpu()
|
|
|
|
save_path = out_dir / "model.safetensors"
|
|
save_file(tensors, str(save_path))
|
|
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
|
|
|
|
# Copy config files if provided
|
|
if config_path:
|
|
import shutil
|
|
config_dir = Path(config_path)
|
|
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
|
'special_tokens_map.json', 'generation_config.json']:
|
|
src = config_dir / f
|
|
if src.exists():
|
|
shutil.copy2(src, out_dir / f)
|
|
|
|
# Update latest symlink
|
|
latest = Path(checkpoint_dir) / "latest"
|
|
if latest.is_symlink():
|
|
latest.unlink()
|
|
latest.symlink_to(date_str)
|
|
print(f"Updated {latest} -> {date_str}")
|
|
|
|
return str(out_dir)
|
|
|
|
|
|
def train_step(params, example, optimizer, device, log_entries):
|
|
"""Run one training step on a single example.
|
|
|
|
Args:
|
|
params: dict of name -> tensor
|
|
example: dict with 'input_ids', 'context_len', 'target_ids'
|
|
optimizer: ApolloMini instance
|
|
device: CUDA device
|
|
log_entries: list to append log dicts to
|
|
|
|
Returns:
|
|
loss value
|
|
"""
|
|
optimizer.zero_grad()
|
|
|
|
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
|
|
context_len = example['context_len']
|
|
|
|
# Forward pass (context frozen, decision tokens with grad)
|
|
logits, targets = forward_pass(params, input_ids, context_len, device)
|
|
|
|
# Cross-entropy loss on decision tokens
|
|
loss = torch.nn.functional.cross_entropy(
|
|
logits.view(-1, logits.shape[-1]),
|
|
targets.view(-1),
|
|
)
|
|
|
|
# Backward
|
|
loss.backward()
|
|
|
|
# Compute gradient stats before optimizer step
|
|
total_grad_norm = 0.0
|
|
for p in params.values():
|
|
if p.grad is not None:
|
|
total_grad_norm += p.grad.norm().item() ** 2
|
|
total_grad_norm = total_grad_norm ** 0.5
|
|
|
|
# Optimizer step
|
|
optimizer.step()
|
|
|
|
# Log
|
|
log_entries.append({
|
|
'example_id': example.get('id', 'unknown'),
|
|
'loss': loss.item(),
|
|
'grad_norm': total_grad_norm,
|
|
'timestamp': datetime.now().isoformat(),
|
|
})
|
|
|
|
return loss.item()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Apollo-Mini training")
|
|
parser.add_argument("--weights", required=True,
|
|
help="Path to exported weight IPC handles")
|
|
parser.add_argument("--examples", required=True,
|
|
help="Path to training examples JSONL")
|
|
parser.add_argument("--checkpoint-dir", default="checkpoints",
|
|
help="Directory for saving checkpoints")
|
|
parser.add_argument("--config-path", default=None,
|
|
help="Path to model config files (for checkpoint)")
|
|
parser.add_argument("--lr", type=float, default=1e-5,
|
|
help="Learning rate")
|
|
parser.add_argument("--warmup-steps", type=int, default=10,
|
|
help="Learning rate warmup steps")
|
|
parser.add_argument("--weight-decay", type=float, default=0.01)
|
|
parser.add_argument("--dry-run", action="store_true",
|
|
help="Load weights and validate, don't train")
|
|
args = parser.parse_args()
|
|
|
|
print(f"Apollo-Mini Training")
|
|
print(f" weights: {args.weights}")
|
|
print(f" examples: {args.examples}")
|
|
print(f" lr: {args.lr}")
|
|
print()
|
|
|
|
# Import weights
|
|
print("Importing weights via CUDA IPC...")
|
|
params = import_weights(args.weights)
|
|
print(f" {len(params)} parameters imported")
|
|
|
|
# Make parameter groups
|
|
param_groups = make_param_groups(params)
|
|
|
|
# Initialize optimizer
|
|
optimizer = ApolloMini(param_groups, lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
warmup_steps=args.warmup_steps)
|
|
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
|
|
if args.dry_run:
|
|
print("\nDry run — weights imported and validated successfully.")
|
|
return
|
|
|
|
# Load training examples
|
|
examples = []
|
|
with open(args.examples) as f:
|
|
for line in f:
|
|
examples.append(json.loads(line))
|
|
print(f" {len(examples)} training examples")
|
|
|
|
# Training loop
|
|
log_entries = []
|
|
print(f"\nTraining...")
|
|
t0 = time.time()
|
|
|
|
for i, example in enumerate(examples):
|
|
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
|
|
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
|
|
|
|
elapsed = time.time() - t0
|
|
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
|
|
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
|
|
# Save checkpoint
|
|
print("\nSaving checkpoint...")
|
|
save_checkpoint(params, args.checkpoint_dir, args.config_path)
|
|
|
|
# Save training log
|
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
|
|
with open(log_path, 'w') as f:
|
|
for entry in log_entries:
|
|
f.write(json.dumps(entry) + '\n')
|
|
print(f"Training log: {log_path}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|