apollo-mini training system: initial implementation

Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC
shared weight memory between vLLM and the training process:

- apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality)
- apollo_worker.py: HTTP daemon coordinating training with vLLM
- weight_mapping.py: vLLM merged → HF separate layout (zero-copy views)
- training_example.py: tokenization with chat template
- export_weights.py: CUDA IPC handle export from vLLM
- train.py: standalone training script (alternative to daemon)
- DESIGN.md: architecture and protocol documentation

Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200).
Apollo-Mini rank-1 projection + scaling + in-place update confirmed.

Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
This commit is contained in:
ProofOfConcept 2026-03-30 22:02:37 -04:00
parent 13453606ae
commit c5d7d8cb5d
7 changed files with 1484 additions and 0 deletions

269
training/train.py Normal file
View file

@ -0,0 +1,269 @@
#!/usr/bin/env python3
"""Nightly training process for Apollo-Mini fine-tuning.
Imports vLLM's model weights via CUDA IPC, runs context-frozen
training on flagged conversation segments, saves updated checkpoint.
Usage:
python3 train.py \
--weights /tmp/vllm_weight_handles.pt \
--examples training-examples.jsonl \
--checkpoint-dir checkpoints/ \
--lr 1e-5
"""
import argparse
import json
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import torch
from safetensors.torch import save_file
from apollo_mini import ApolloMini
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
"""Import weight tensors from CUDA IPC handles."""
handles = torch.load(handle_path, weights_only=False)
params = {}
for name, info in handles.items():
func, args = info['handle']
tensor = func(*args)
params[name] = tensor
return params
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
"""Split parameters into Apollo-Mini and standard groups.
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
"""
apollo_params = []
standard_params = []
for name, p in params.items():
p.requires_grad_(True)
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({
'params': apollo_params,
'name': 'apollo',
})
if standard_params:
groups.append({
'params': standard_params,
'name': 'standard',
})
n_apollo = sum(p.nelement() for p in apollo_params)
n_standard = sum(p.nelement() for p in standard_params)
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
return groups
def forward_pass(params, input_ids, context_len, device):
"""Run context-frozen forward pass.
Args:
params: dict of name -> tensor (shared with vLLM)
input_ids: full sequence [1, seq_len]
context_len: number of context tokens (no gradient)
device: CUDA device
Returns:
logits for decision tokens, target ids for loss
"""
# TODO: Build proper forward model matching vLLM's weight layout.
# For now this is a placeholder — the real implementation needs
# to replicate vLLM's model architecture (merged projections,
# GDN recurrence, full attention, MLP) using the shared weights.
raise NotImplementedError(
"Forward model not yet implemented. "
"Need to build a model that matches vLLM's merged weight layout "
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
"RowParallelLinear for out_proj/down) and computes the same "
"forward pass with autograd enabled."
)
def save_checkpoint(params: dict[str, torch.Tensor],
checkpoint_dir: str,
config_path: str = None):
"""Save model checkpoint in HuggingFace safetensors format.
Saves weights split across shards matching the original model layout,
archives the previous checkpoint, and updates the 'latest' symlink.
"""
date_str = datetime.now().strftime("%Y-%m-%d")
out_dir = Path(checkpoint_dir) / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save all weights in a single safetensors file for now.
# TODO: split across shards matching HF model index for large models.
tensors = {}
for name, param in params.items():
tensors[name] = param.data.contiguous().cpu()
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
# Copy config files if provided
if config_path:
import shutil
config_dir = Path(config_path)
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json', 'generation_config.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Update latest symlink
latest = Path(checkpoint_dir) / "latest"
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
print(f"Updated {latest} -> {date_str}")
return str(out_dir)
def train_step(params, example, optimizer, device, log_entries):
"""Run one training step on a single example.
Args:
params: dict of name -> tensor
example: dict with 'input_ids', 'context_len', 'target_ids'
optimizer: ApolloMini instance
device: CUDA device
log_entries: list to append log dicts to
Returns:
loss value
"""
optimizer.zero_grad()
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
context_len = example['context_len']
# Forward pass (context frozen, decision tokens with grad)
logits, targets = forward_pass(params, input_ids, context_len, device)
# Cross-entropy loss on decision tokens
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]),
targets.view(-1),
)
# Backward
loss.backward()
# Compute gradient stats before optimizer step
total_grad_norm = 0.0
for p in params.values():
if p.grad is not None:
total_grad_norm += p.grad.norm().item() ** 2
total_grad_norm = total_grad_norm ** 0.5
# Optimizer step
optimizer.step()
# Log
log_entries.append({
'example_id': example.get('id', 'unknown'),
'loss': loss.item(),
'grad_norm': total_grad_norm,
'timestamp': datetime.now().isoformat(),
})
return loss.item()
def main():
parser = argparse.ArgumentParser(description="Apollo-Mini training")
parser.add_argument("--weights", required=True,
help="Path to exported weight IPC handles")
parser.add_argument("--examples", required=True,
help="Path to training examples JSONL")
parser.add_argument("--checkpoint-dir", default="checkpoints",
help="Directory for saving checkpoints")
parser.add_argument("--config-path", default=None,
help="Path to model config files (for checkpoint)")
parser.add_argument("--lr", type=float, default=1e-5,
help="Learning rate")
parser.add_argument("--warmup-steps", type=int, default=10,
help="Learning rate warmup steps")
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--dry-run", action="store_true",
help="Load weights and validate, don't train")
args = parser.parse_args()
print(f"Apollo-Mini Training")
print(f" weights: {args.weights}")
print(f" examples: {args.examples}")
print(f" lr: {args.lr}")
print()
# Import weights
print("Importing weights via CUDA IPC...")
params = import_weights(args.weights)
print(f" {len(params)} parameters imported")
# Make parameter groups
param_groups = make_param_groups(params)
# Initialize optimizer
optimizer = ApolloMini(param_groups, lr=args.lr,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps)
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
if args.dry_run:
print("\nDry run — weights imported and validated successfully.")
return
# Load training examples
examples = []
with open(args.examples) as f:
for line in f:
examples.append(json.loads(line))
print(f" {len(examples)} training examples")
# Training loop
log_entries = []
print(f"\nTraining...")
t0 = time.time()
for i, example in enumerate(examples):
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
elapsed = time.time() - t0
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
# Save checkpoint
print("\nSaving checkpoint...")
save_checkpoint(params, args.checkpoint_dir, args.config_path)
# Save training log
date_str = datetime.now().strftime("%Y-%m-%d")
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
with open(log_path, 'w') as f:
for entry in log_entries:
f.write(json.dumps(entry) + '\n')
print(f"Training log: {log_path}")
if __name__ == '__main__':
main()