#!/usr/bin/env python3 """ Apollo Mini Training Daemon This daemon: 1. Listens over HTTPS for training requests from poc-agent 2. Pauses vLLM inference 3. Runs APOLLO-Mini training with torch.enable_grad() 4. Saves checkpoints and training metadata 5. Resumes vLLM inference Communication protocol: - POST /train: Start a training job - GET /status/{job_id}: Check training status - GET /checkpoints: List available checkpoints """ import asyncio import json import logging import os import sys import time from dataclasses import dataclass, field, asdict from datetime import datetime from pathlib import Path from typing import Optional, Dict, Any, List from enum import Enum import torch import torch.nn as nn from aiohttp import web # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('apollo_worker') class TrainingStatus(Enum): PENDING = "pending" PAUSING_VLLM = "pausing_vllm" TRAINING = "training" SAVING_CHECKPOINT = "saving_checkpoint" RESUMING_VLLM = "resuming_vllm" COMPLETED = "completed" FAILED = "failed" @dataclass class TrainingJob: job_id: str status: TrainingStatus created_at: datetime started_at: Optional[datetime] = None completed_at: Optional[datetime] = None model_path: Optional[str] = None checkpoint_path: Optional[str] = None training_samples: int = 0 loss_history: List[float] = field(default_factory=list) error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { 'job_id': self.job_id, 'status': self.status.value, 'created_at': self.created_at.isoformat(), 'started_at': self.started_at.isoformat() if self.started_at else None, 'completed_at': self.completed_at.isoformat() if self.completed_at else None, 'model_path': self.model_path, 'checkpoint_path': self.checkpoint_path, 'training_samples': self.training_samples, 'loss_history': self.loss_history, 'error': self.error, } class ApolloWorker: def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): self.config = self._load_config(config_path) self.jobs: Dict[str, TrainingJob] = {} self.vllm_paused = False self.app = web.Application() self._setup_routes() def _load_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from file or use defaults.""" default_config = { 'host': '0.0.0.0', 'port': 8080, 'vllm_socket': '/tmp/vllm_control.sock', 'model_path': '/home/ubuntu/models/Qwen3.5-27B', 'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints', 'max_training_samples': 100, 'learning_rate': 1e-5, 'batch_size': 1, } if os.path.exists(config_path): with open(config_path, 'r') as f: user_config = json.load(f) default_config.update(user_config) Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) return default_config def _setup_routes(self): """Setup HTTP routes.""" self.app.router.add_post('/train', self.handle_train_request) self.app.router.add_get('/status/{job_id}', self.handle_status_request) self.app.router.add_get('/checkpoints', self.handle_list_checkpoints) self.app.router.add_get('/health', self.handle_health_check) async def handle_health_check(self, request: web.Request) -> web.Response: """Health check endpoint.""" return web.json_response({ 'status': 'healthy', 'vllm_paused': self.vllm_paused, 'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]]) }) async def handle_train_request(self, request: web.Request) -> web.Response: """Handle training request from poc-agent.""" try: data = await request.json() # Validate required fields if 'training_data' not in data: return web.json_response( {'error': 'Missing training_data field'}, status=400 ) job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}" job = TrainingJob( job_id=job_id, status=TrainingStatus.PENDING, created_at=datetime.now(), model_path=self.config['model_path'] ) self.jobs[job_id] = job # Start training in background asyncio.create_task(self.execute_training(job, data)) return web.json_response({ 'job_id': job_id, 'status': 'accepted', 'message': 'Training job started' }) except Exception as e: logger.error(f"Error handling train request: {e}") return web.json_response( {'error': str(e)}, status=500 ) async def handle_status_request(self, request: web.Request) -> web.Response: """Get training job status.""" job_id = request.match_info['job_id'] if job_id not in self.jobs: return web.json_response( {'error': 'Job not found'}, status=404 ) job = self.jobs[job_id] return web.json_response(job.to_dict()) async def handle_list_checkpoints(self, request: web.Request) -> web.Response: """List available checkpoints.""" checkpoint_dir = Path(self.config['checkpoint_dir']) checkpoints = [] if checkpoint_dir.exists(): for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True): checkpoints.append({ 'filename': checkpoint_file.name, 'path': str(checkpoint_file), 'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(), 'size': checkpoint_file.stat().st_size }) return web.json_response({'checkpoints': checkpoints}) async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]): """Execute the training pipeline.""" try: logger.info(f"Starting training job {job.job_id}") job.started_at = datetime.now() # Step 1: Pause vLLM job.status = TrainingStatus.PAUSING_VLLM logger.info("Pausing vLLM...") await self.pause_vllm() self.vllm_paused = True # Step 2: Load model and prepare for training job.status = TrainingStatus.TRAINING logger.info("Loading model and preparing for training...") # Load model (this would be the actual Qwen3.5-27B model) # For now, we'll use a placeholder model = await self.load_model_for_training() # Step 3: Run APOLLO-Mini training logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples") # Extract training samples samples = training_data['samples'] job.training_samples = len(samples) # Run training loop loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {})) job.loss_history = loss_history # Step 4: Save checkpoint job.status = TrainingStatus.SAVING_CHECKPOINT logger.info("Saving checkpoint...") checkpoint_path = await self.save_checkpoint(model, job) job.checkpoint_path = checkpoint_path # Step 5: Resume vLLM job.status = TrainingStatus.RESUMING_VLLM logger.info("Resuming vLLM...") await self.resume_vllm() self.vllm_paused = False # Mark job as completed job.status = TrainingStatus.COMPLETED job.completed_at = datetime.now() logger.info(f"Training job {job.job_id} completed successfully") except Exception as e: logger.error(f"Training job {job.job_id} failed: {e}") job.status = TrainingStatus.FAILED job.error = str(e) job.completed_at = datetime.now() # Try to resume vLLM if it was paused if self.vllm_paused: try: await self.resume_vllm() self.vllm_paused = False except Exception as resume_error: logger.error(f"Failed to resume vLLM after training error: {resume_error}") async def pause_vllm(self): """Pause vLLM inference via HTTP API.""" import aiohttp as aio url = self.config.get('vllm_url', 'http://localhost:8000') try: async with aio.ClientSession() as session: async with session.post( f"{url}/pause_generation", json={"mode": "keep", "clear_cache": False}, timeout=aio.ClientTimeout(total=10), ) as resp: resp.raise_for_status() logger.info("vLLM paused") except Exception as e: logger.warning(f"Failed to pause vLLM: {e}") async def resume_vllm(self): """Resume vLLM inference via HTTP API.""" import aiohttp as aio url = self.config.get('vllm_url', 'http://localhost:8000') try: async with aio.ClientSession() as session: async with session.post( f"{url}/resume_generation", timeout=aio.ClientTimeout(total=10), ) as resp: resp.raise_for_status() logger.info("vLLM resumed") except Exception as e: logger.warning(f"Failed to resume vLLM: {e}") async def load_model_for_training(self) -> nn.Module: """Load HF model with weights pointing to vLLM's GPU memory. Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible views (narrowing merged weights into separate q/k/v/z etc.), and constructs the HF model around those views. No weight copying — all parameters share vLLM's GPU memory. """ handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt') model_path = self.config['model_path'] # Import vLLM weights via CUDA IPC logger.info(f"Importing vLLM weights from {handle_path}") handles = torch.load(handle_path, weights_only=False) vllm_params = {} for name, info in handles.items(): func, args = info['handle'] vllm_params[name] = func(*args) logger.info(f"Imported {len(vllm_params)} parameters") # Map vLLM merged layout → HF separate layout (views, no copies) from weight_mapping import load_hf_model_with_vllm_weights model = load_hf_model_with_vllm_weights(vllm_params, model_path) logger.info("HF model constructed with vLLM weight views") return model async def run_apollo_training(self, model: nn.Module, samples: List[Dict[str, str]], config: Dict[str, Any]) -> List[float]: """Run Apollo-Mini training on conversation decision points.""" from apollo_mini import ApolloMini from transformers import AutoTokenizer lr = config.get('learning_rate', self.config['learning_rate']) tokenizer = AutoTokenizer.from_pretrained( self.config['model_path'], trust_remote_code=True) # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] for p in model.parameters(): if p.requires_grad: if p.ndim >= 2 and min(p.shape) >= 2: apollo_params.append(p) else: standard_params.append(p) groups = [] if apollo_params: groups.append({'params': apollo_params}) if standard_params: groups.append({'params': standard_params}) optimizer = ApolloMini(groups, lr=lr) logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " f"state={optimizer.state_size_bytes()/1e6:.1f}MB") loss_history = [] for i, sample in enumerate(samples): context = sample.get('context', '') continuation = sample.get('continuation', '') # Tokenize ctx_ids = tokenizer.encode(context, add_special_tokens=True) cont_ids = tokenizer.encode(continuation, add_special_tokens=False) all_ids = ctx_ids + cont_ids context_len = len(ctx_ids) input_ids = torch.tensor([all_ids], device='cuda:0') optimizer.zero_grad() # Context-frozen forward pass with torch.no_grad(): # Forward through context (no gradients) outputs = model(input_ids[:, :context_len], use_cache=True) past_kv = outputs.past_key_values # Decision tokens with gradients with torch.enable_grad(): outputs = model( input_ids[:, context_len:], past_key_values=past_kv, use_cache=False, ) logits = outputs.logits # [1, cont_len, vocab] # Shift: predict next token from each position shift_logits = logits[:, :-1].contiguous() shift_labels = input_ids[:, context_len + 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) loss.backward() optimizer.step() loss_val = loss.item() loss_history.append(loss_val) logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " f"(ctx={context_len}, cont={len(cont_ids)} tokens)") logger.info(f"Training done: {len(samples)} examples, " f"final loss={loss_history[-1]:.4f}") return loss_history async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str: """Save model checkpoint in HuggingFace safetensors format.""" from safetensors.torch import save_file import shutil checkpoint_dir = Path(self.config['checkpoint_dir']) date_str = datetime.now().strftime('%Y-%m-%d') out_dir = checkpoint_dir / date_str out_dir.mkdir(parents=True, exist_ok=True) # Save weights tensors = {name: p.data.contiguous().cpu() for name, p in model.named_parameters()} save_path = out_dir / "model.safetensors" save_file(tensors, str(save_path)) # Copy config files config_dir = Path(self.config['model_path']) for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']: src = config_dir / f if src.exists(): shutil.copy2(src, out_dir / f) # Save training metadata meta = { 'job_id': job.job_id, 'training_samples': job.training_samples, 'loss_history': job.loss_history, 'timestamp': datetime.now().isoformat(), } with open(out_dir / 'training-meta.json', 'w') as f: json.dump(meta, f, indent=2) # Update latest symlink latest = checkpoint_dir / 'latest' if latest.is_symlink(): latest.unlink() latest.symlink_to(date_str) size_gb = save_path.stat().st_size / 1e9 logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)") return str(out_dir) async def run(self): """Run the daemon.""" logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}") runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, self.config['host'], self.config['port']) await site.start() logger.info("Apollo Worker is running") # Keep running while True: await asyncio.sleep(3600) # Sleep for an hour def main(): worker = ApolloWorker() asyncio.run(worker.run()) if __name__ == '__main__': main()