apollo-mini training system: initial implementation
Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC shared weight memory between vLLM and the training process: - apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality) - apollo_worker.py: HTTP daemon coordinating training with vLLM - weight_mapping.py: vLLM merged → HF separate layout (zero-copy views) - training_example.py: tokenization with chat template - export_weights.py: CUDA IPC handle export from vLLM - train.py: standalone training script (alternative to daemon) - DESIGN.md: architecture and protocol documentation Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200). Apollo-Mini rank-1 projection + scaling + in-place update confirmed. Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
This commit is contained in:
parent
13453606ae
commit
c5d7d8cb5d
7 changed files with 1484 additions and 0 deletions
453
training/apollo_worker.py
Executable file
453
training/apollo_worker.py
Executable file
|
|
@ -0,0 +1,453 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apollo Mini Training Daemon
|
||||
|
||||
This daemon:
|
||||
1. Listens over HTTPS for training requests from poc-agent
|
||||
2. Pauses vLLM inference
|
||||
3. Runs APOLLO-Mini training with torch.enable_grad()
|
||||
4. Saves checkpoints and training metadata
|
||||
5. Resumes vLLM inference
|
||||
|
||||
Communication protocol:
|
||||
- POST /train: Start a training job
|
||||
- GET /status/{job_id}: Check training status
|
||||
- GET /checkpoints: List available checkpoints
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from aiohttp import web
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('apollo_worker')
|
||||
|
||||
class TrainingStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PAUSING_VLLM = "pausing_vllm"
|
||||
TRAINING = "training"
|
||||
SAVING_CHECKPOINT = "saving_checkpoint"
|
||||
RESUMING_VLLM = "resuming_vllm"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class TrainingJob:
|
||||
job_id: str
|
||||
status: TrainingStatus
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
model_path: Optional[str] = None
|
||||
checkpoint_path: Optional[str] = None
|
||||
training_samples: int = 0
|
||||
loss_history: List[float] = field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'job_id': self.job_id,
|
||||
'status': self.status.value,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'model_path': self.model_path,
|
||||
'checkpoint_path': self.checkpoint_path,
|
||||
'training_samples': self.training_samples,
|
||||
'loss_history': self.loss_history,
|
||||
'error': self.error,
|
||||
}
|
||||
|
||||
class ApolloWorker:
|
||||
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
||||
self.config = self._load_config(config_path)
|
||||
self.jobs: Dict[str, TrainingJob] = {}
|
||||
self.vllm_paused = False
|
||||
self.app = web.Application()
|
||||
self._setup_routes()
|
||||
|
||||
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
||||
"""Load configuration from file or use defaults."""
|
||||
default_config = {
|
||||
'host': '0.0.0.0',
|
||||
'port': 8080,
|
||||
'vllm_socket': '/tmp/vllm_control.sock',
|
||||
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
|
||||
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
|
||||
'max_training_samples': 100,
|
||||
'learning_rate': 1e-5,
|
||||
'batch_size': 1,
|
||||
}
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
user_config = json.load(f)
|
||||
default_config.update(user_config)
|
||||
|
||||
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
|
||||
return default_config
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup HTTP routes."""
|
||||
self.app.router.add_post('/train', self.handle_train_request)
|
||||
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
|
||||
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
|
||||
self.app.router.add_get('/health', self.handle_health_check)
|
||||
|
||||
async def handle_health_check(self, request: web.Request) -> web.Response:
|
||||
"""Health check endpoint."""
|
||||
return web.json_response({
|
||||
'status': 'healthy',
|
||||
'vllm_paused': self.vllm_paused,
|
||||
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
|
||||
})
|
||||
|
||||
async def handle_train_request(self, request: web.Request) -> web.Response:
|
||||
"""Handle training request from poc-agent."""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate required fields
|
||||
if 'training_data' not in data:
|
||||
return web.json_response(
|
||||
{'error': 'Missing training_data field'},
|
||||
status=400
|
||||
)
|
||||
|
||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
|
||||
job = TrainingJob(
|
||||
job_id=job_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
model_path=self.config['model_path']
|
||||
)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
# Start training in background
|
||||
asyncio.create_task(self.execute_training(job, data))
|
||||
|
||||
return web.json_response({
|
||||
'job_id': job_id,
|
||||
'status': 'accepted',
|
||||
'message': 'Training job started'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling train request: {e}")
|
||||
return web.json_response(
|
||||
{'error': str(e)},
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_status_request(self, request: web.Request) -> web.Response:
|
||||
"""Get training job status."""
|
||||
job_id = request.match_info['job_id']
|
||||
|
||||
if job_id not in self.jobs:
|
||||
return web.json_response(
|
||||
{'error': 'Job not found'},
|
||||
status=404
|
||||
)
|
||||
|
||||
job = self.jobs[job_id]
|
||||
return web.json_response(job.to_dict())
|
||||
|
||||
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""List available checkpoints."""
|
||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||
checkpoints = []
|
||||
|
||||
if checkpoint_dir.exists():
|
||||
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
|
||||
checkpoints.append({
|
||||
'filename': checkpoint_file.name,
|
||||
'path': str(checkpoint_file),
|
||||
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
|
||||
'size': checkpoint_file.stat().st_size
|
||||
})
|
||||
|
||||
return web.json_response({'checkpoints': checkpoints})
|
||||
|
||||
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
|
||||
"""Execute the training pipeline."""
|
||||
try:
|
||||
logger.info(f"Starting training job {job.job_id}")
|
||||
job.started_at = datetime.now()
|
||||
|
||||
# Step 1: Pause vLLM
|
||||
job.status = TrainingStatus.PAUSING_VLLM
|
||||
logger.info("Pausing vLLM...")
|
||||
await self.pause_vllm()
|
||||
self.vllm_paused = True
|
||||
|
||||
# Step 2: Load model and prepare for training
|
||||
job.status = TrainingStatus.TRAINING
|
||||
logger.info("Loading model and preparing for training...")
|
||||
|
||||
# Load model (this would be the actual Qwen3.5-27B model)
|
||||
# For now, we'll use a placeholder
|
||||
model = await self.load_model_for_training()
|
||||
|
||||
# Step 3: Run APOLLO-Mini training
|
||||
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
|
||||
|
||||
# Extract training samples
|
||||
samples = training_data['samples']
|
||||
job.training_samples = len(samples)
|
||||
|
||||
# Run training loop
|
||||
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
|
||||
job.loss_history = loss_history
|
||||
|
||||
# Step 4: Save checkpoint
|
||||
job.status = TrainingStatus.SAVING_CHECKPOINT
|
||||
logger.info("Saving checkpoint...")
|
||||
checkpoint_path = await self.save_checkpoint(model, job)
|
||||
job.checkpoint_path = checkpoint_path
|
||||
|
||||
# Step 5: Resume vLLM
|
||||
job.status = TrainingStatus.RESUMING_VLLM
|
||||
logger.info("Resuming vLLM...")
|
||||
await self.resume_vllm()
|
||||
self.vllm_paused = False
|
||||
|
||||
# Mark job as completed
|
||||
job.status = TrainingStatus.COMPLETED
|
||||
job.completed_at = datetime.now()
|
||||
|
||||
logger.info(f"Training job {job.job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job.job_id} failed: {e}")
|
||||
job.status = TrainingStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.now()
|
||||
|
||||
# Try to resume vLLM if it was paused
|
||||
if self.vllm_paused:
|
||||
try:
|
||||
await self.resume_vllm()
|
||||
self.vllm_paused = False
|
||||
except Exception as resume_error:
|
||||
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
|
||||
|
||||
async def pause_vllm(self):
|
||||
"""Pause vLLM inference via HTTP API."""
|
||||
import aiohttp as aio
|
||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||
try:
|
||||
async with aio.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{url}/pause_generation",
|
||||
json={"mode": "keep", "clear_cache": False},
|
||||
timeout=aio.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
logger.info("vLLM paused")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pause vLLM: {e}")
|
||||
|
||||
async def resume_vllm(self):
|
||||
"""Resume vLLM inference via HTTP API."""
|
||||
import aiohttp as aio
|
||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||
try:
|
||||
async with aio.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{url}/resume_generation",
|
||||
timeout=aio.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
logger.info("vLLM resumed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resume vLLM: {e}")
|
||||
|
||||
async def load_model_for_training(self) -> nn.Module:
|
||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
|
||||
views (narrowing merged weights into separate q/k/v/z etc.), and
|
||||
constructs the HF model around those views. No weight copying —
|
||||
all parameters share vLLM's GPU memory.
|
||||
"""
|
||||
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
|
||||
model_path = self.config['model_path']
|
||||
|
||||
# Import vLLM weights via CUDA IPC
|
||||
logger.info(f"Importing vLLM weights from {handle_path}")
|
||||
handles = torch.load(handle_path, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
logger.info(f"Imported {len(vllm_params)} parameters")
|
||||
|
||||
# Map vLLM merged layout → HF separate layout (views, no copies)
|
||||
from weight_mapping import load_hf_model_with_vllm_weights
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
||||
logger.info("HF model constructed with vLLM weight views")
|
||||
|
||||
return model
|
||||
|
||||
async def run_apollo_training(self, model: nn.Module,
|
||||
samples: List[Dict[str, str]],
|
||||
config: Dict[str, Any]) -> List[float]:
|
||||
"""Run Apollo-Mini training on conversation decision points."""
|
||||
from apollo_mini import ApolloMini
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config['model_path'], trust_remote_code=True)
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
optimizer = ApolloMini(groups, lr=lr)
|
||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
context = sample.get('context', '')
|
||||
continuation = sample.get('continuation', '')
|
||||
|
||||
# Tokenize
|
||||
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
|
||||
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
# Forward through context (no gradients)
|
||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits # [1, cont_len, vocab]
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||
|
||||
logger.info(f"Training done: {len(samples)} examples, "
|
||||
f"final loss={loss_history[-1]:.4f}")
|
||||
return loss_history
|
||||
|
||||
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
|
||||
"""Save model checkpoint in HuggingFace safetensors format."""
|
||||
from safetensors.torch import save_file
|
||||
import shutil
|
||||
|
||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||
date_str = datetime.now().strftime('%Y-%m-%d')
|
||||
out_dir = checkpoint_dir / date_str
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
tensors = {name: p.data.contiguous().cpu()
|
||||
for name, p in model.named_parameters()}
|
||||
save_path = out_dir / "model.safetensors"
|
||||
save_file(tensors, str(save_path))
|
||||
|
||||
# Copy config files
|
||||
config_dir = Path(self.config['model_path'])
|
||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
||||
'special_tokens_map.json']:
|
||||
src = config_dir / f
|
||||
if src.exists():
|
||||
shutil.copy2(src, out_dir / f)
|
||||
|
||||
# Save training metadata
|
||||
meta = {
|
||||
'job_id': job.job_id,
|
||||
'training_samples': job.training_samples,
|
||||
'loss_history': job.loss_history,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
with open(out_dir / 'training-meta.json', 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Update latest symlink
|
||||
latest = checkpoint_dir / 'latest'
|
||||
if latest.is_symlink():
|
||||
latest.unlink()
|
||||
latest.symlink_to(date_str)
|
||||
|
||||
size_gb = save_path.stat().st_size / 1e9
|
||||
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
|
||||
return str(out_dir)
|
||||
|
||||
async def run(self):
|
||||
"""Run the daemon."""
|
||||
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, self.config['host'], self.config['port'])
|
||||
await site.start()
|
||||
logger.info("Apollo Worker is running")
|
||||
|
||||
# Keep running
|
||||
while True:
|
||||
await asyncio.sleep(3600) # Sleep for an hour
|
||||
|
||||
def main():
|
||||
worker = ApolloWorker()
|
||||
asyncio.run(worker.run())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue