consciousness/training/apollo_worker.py

454 lines
17 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Apollo Mini Training Daemon
This daemon:
1. Listens over HTTPS for training requests from poc-agent
2. Pauses vLLM inference
3. Runs APOLLO-Mini training with torch.enable_grad()
4. Saves checkpoints and training metadata
5. Resumes vLLM inference
Communication protocol:
- POST /train: Start a training job
- GET /status/{job_id}: Check training status
- GET /checkpoints: List available checkpoints
"""
import asyncio
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, List
from enum import Enum
import torch
import torch.nn as nn
from aiohttp import web
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('apollo_worker')
class TrainingStatus(Enum):
PENDING = "pending"
PAUSING_VLLM = "pausing_vllm"
TRAINING = "training"
SAVING_CHECKPOINT = "saving_checkpoint"
RESUMING_VLLM = "resuming_vllm"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TrainingJob:
job_id: str
status: TrainingStatus
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
model_path: Optional[str] = None
checkpoint_path: Optional[str] = None
training_samples: int = 0
loss_history: List[float] = field(default_factory=list)
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'job_id': self.job_id,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'model_path': self.model_path,
'checkpoint_path': self.checkpoint_path,
'training_samples': self.training_samples,
'loss_history': self.loss_history,
'error': self.error,
}
class ApolloWorker:
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
self.config = self._load_config(config_path)
self.jobs: Dict[str, TrainingJob] = {}
self.vllm_paused = False
self.app = web.Application()
self._setup_routes()
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from file or use defaults."""
default_config = {
'host': '0.0.0.0',
'port': 8080,
'vllm_socket': '/tmp/vllm_control.sock',
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
'max_training_samples': 100,
'learning_rate': 1e-5,
'batch_size': 1,
}
if os.path.exists(config_path):
with open(config_path, 'r') as f:
user_config = json.load(f)
default_config.update(user_config)
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
return default_config
def _setup_routes(self):
"""Setup HTTP routes."""
self.app.router.add_post('/train', self.handle_train_request)
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
self.app.router.add_get('/health', self.handle_health_check)
async def handle_health_check(self, request: web.Request) -> web.Response:
"""Health check endpoint."""
return web.json_response({
'status': 'healthy',
'vllm_paused': self.vllm_paused,
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
})
async def handle_train_request(self, request: web.Request) -> web.Response:
"""Handle training request from poc-agent."""
try:
data = await request.json()
# Validate required fields
if 'training_data' not in data:
return web.json_response(
{'error': 'Missing training_data field'},
status=400
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
job = TrainingJob(
job_id=job_id,
status=TrainingStatus.PENDING,
created_at=datetime.now(),
model_path=self.config['model_path']
)
self.jobs[job_id] = job
# Start training in background
asyncio.create_task(self.execute_training(job, data))
return web.json_response({
'job_id': job_id,
'status': 'accepted',
'message': 'Training job started'
})
except Exception as e:
logger.error(f"Error handling train request: {e}")
return web.json_response(
{'error': str(e)},
status=500
)
async def handle_status_request(self, request: web.Request) -> web.Response:
"""Get training job status."""
job_id = request.match_info['job_id']
if job_id not in self.jobs:
return web.json_response(
{'error': 'Job not found'},
status=404
)
job = self.jobs[job_id]
return web.json_response(job.to_dict())
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
"""List available checkpoints."""
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoints = []
if checkpoint_dir.exists():
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
checkpoints.append({
'filename': checkpoint_file.name,
'path': str(checkpoint_file),
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
'size': checkpoint_file.stat().st_size
})
return web.json_response({'checkpoints': checkpoints})
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
"""Execute the training pipeline."""
try:
logger.info(f"Starting training job {job.job_id}")
job.started_at = datetime.now()
# Step 1: Pause vLLM
job.status = TrainingStatus.PAUSING_VLLM
logger.info("Pausing vLLM...")
await self.pause_vllm()
self.vllm_paused = True
# Step 2: Load model and prepare for training
job.status = TrainingStatus.TRAINING
logger.info("Loading model and preparing for training...")
# Load model (this would be the actual Qwen3.5-27B model)
# For now, we'll use a placeholder
model = await self.load_model_for_training()
# Step 3: Run APOLLO-Mini training
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
# Extract training samples
samples = training_data['samples']
job.training_samples = len(samples)
# Run training loop
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
job.loss_history = loss_history
# Step 4: Save checkpoint
job.status = TrainingStatus.SAVING_CHECKPOINT
logger.info("Saving checkpoint...")
checkpoint_path = await self.save_checkpoint(model, job)
job.checkpoint_path = checkpoint_path
# Step 5: Resume vLLM
job.status = TrainingStatus.RESUMING_VLLM
logger.info("Resuming vLLM...")
await self.resume_vllm()
self.vllm_paused = False
# Mark job as completed
job.status = TrainingStatus.COMPLETED
job.completed_at = datetime.now()
logger.info(f"Training job {job.job_id} completed successfully")
except Exception as e:
logger.error(f"Training job {job.job_id} failed: {e}")
job.status = TrainingStatus.FAILED
job.error = str(e)
job.completed_at = datetime.now()
# Try to resume vLLM if it was paused
if self.vllm_paused:
try:
await self.resume_vllm()
self.vllm_paused = False
except Exception as resume_error:
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
async def pause_vllm(self):
"""Pause vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/pause_generation",
json={"mode": "keep", "clear_cache": False},
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM paused")
except Exception as e:
logger.warning(f"Failed to pause vLLM: {e}")
async def resume_vllm(self):
"""Resume vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/resume_generation",
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM resumed")
except Exception as e:
logger.warning(f"Failed to resume vLLM: {e}")
async def load_model_for_training(self) -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
views (narrowing merged weights into separate q/k/v/z etc.), and
constructs the HF model around those views. No weight copying
all parameters share vLLM's GPU memory.
"""
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
model_path = self.config['model_path']
# Import vLLM weights via CUDA IPC
logger.info(f"Importing vLLM weights from {handle_path}")
handles = torch.load(handle_path, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
logger.info(f"Imported {len(vllm_params)} parameters")
# Map vLLM merged layout → HF separate layout (views, no copies)
from weight_mapping import load_hf_model_with_vllm_weights
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
logger.info("HF model constructed with vLLM weight views")
return model
async def run_apollo_training(self, model: nn.Module,
samples: List[Dict[str, str]],
config: Dict[str, Any]) -> List[float]:
"""Run Apollo-Mini training on conversation decision points."""
from apollo_mini import ApolloMini
from transformers import AutoTokenizer
lr = config.get('learning_rate', self.config['learning_rate'])
tokenizer = AutoTokenizer.from_pretrained(
self.config['model_path'], trust_remote_code=True)
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
optimizer = ApolloMini(groups, lr=lr)
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
context = sample.get('context', '')
continuation = sample.get('continuation', '')
# Tokenize
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
# Forward through context (no gradients)
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits # [1, cont_len, vocab]
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
logger.info(f"Training done: {len(samples)} examples, "
f"final loss={loss_history[-1]:.4f}")
return loss_history
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
"""Save model checkpoint in HuggingFace safetensors format."""
from safetensors.torch import save_file
import shutil
checkpoint_dir = Path(self.config['checkpoint_dir'])
date_str = datetime.now().strftime('%Y-%m-%d')
out_dir = checkpoint_dir / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights
tensors = {name: p.data.contiguous().cpu()
for name, p in model.named_parameters()}
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
# Copy config files
config_dir = Path(self.config['model_path'])
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Save training metadata
meta = {
'job_id': job.job_id,
'training_samples': job.training_samples,
'loss_history': job.loss_history,
'timestamp': datetime.now().isoformat(),
}
with open(out_dir / 'training-meta.json', 'w') as f:
json.dump(meta, f, indent=2)
# Update latest symlink
latest = checkpoint_dir / 'latest'
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
size_gb = save_path.stat().st_size / 1e9
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
return str(out_dir)
async def run(self):
"""Run the daemon."""
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, self.config['host'], self.config['port'])
await site.start()
logger.info("Apollo Worker is running")
# Keep running
while True:
await asyncio.sleep(3600) # Sleep for an hour
def main():
worker = ApolloWorker()
asyncio.run(worker.run())
if __name__ == '__main__':
main()