#!/usr/bin/env python3 """ Apollo Mini Training Daemon This daemon: 1. Listens over HTTPS for training requests from poc-agent 2. Pauses vLLM inference 3. Runs APOLLO-Mini training with torch.enable_grad() 4. Saves checkpoints and training metadata 5. Resumes vLLM inference Communication protocol: - POST /train: Start a training job - GET /status/{job_id}: Check training status - GET /checkpoints: List available checkpoints """ import asyncio import json import logging import os import sys import time from dataclasses import dataclass, field, asdict from datetime import datetime from pathlib import Path from typing import Optional, Dict, Any, List from enum import Enum import torch import torch.nn as nn from aiohttp import web # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('apollo_worker') class TrainingStatus(Enum): PENDING = "pending" PAUSING_VLLM = "pausing_vllm" TRAINING = "training" SAVING_CHECKPOINT = "saving_checkpoint" RESUMING_VLLM = "resuming_vllm" COMPLETED = "completed" FAILED = "failed" @dataclass class TrainingJob: job_id: str status: TrainingStatus created_at: datetime started_at: Optional[datetime] = None completed_at: Optional[datetime] = None model_path: Optional[str] = None checkpoint_path: Optional[str] = None training_samples: int = 0 loss_history: List[float] = field(default_factory=list) error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { 'job_id': self.job_id, 'status': self.status.value, 'created_at': self.created_at.isoformat(), 'started_at': self.started_at.isoformat() if self.started_at else None, 'completed_at': self.completed_at.isoformat() if self.completed_at else None, 'model_path': self.model_path, 'checkpoint_path': self.checkpoint_path, 'training_samples': self.training_samples, 'loss_history': self.loss_history, 'error': self.error, } CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes class ApolloWorker: def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): self.config = self._load_config(config_path) self.jobs: Dict[str, TrainingJob] = {} self.vllm_paused = False self.app = web.Application() self._setup_routes() self._checkpoint_timer: Optional[asyncio.Task] = None def _load_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from file or use defaults.""" default_config = { 'host': '0.0.0.0', 'port': 8080, 'vllm_socket': '/tmp/vllm_control.sock', 'model_path': '/home/ubuntu/models/Qwen3.5-27B', 'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints', 'max_training_samples': 100, 'learning_rate': 1e-5, 'batch_size': 1, } if os.path.exists(config_path): with open(config_path, 'r') as f: user_config = json.load(f) default_config.update(user_config) Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) return default_config def _setup_routes(self): """Setup HTTP routes.""" self.app.router.add_post('/train', self.handle_train_request) self.app.router.add_get('/status/{job_id}', self.handle_status_request) self.app.router.add_get('/checkpoints', self.handle_list_checkpoints) self.app.router.add_get('/health', self.handle_health_check) async def handle_health_check(self, request: web.Request) -> web.Response: """Health check endpoint.""" return web.json_response({ 'status': 'healthy', 'vllm_paused': self.vllm_paused, 'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]]) }) async def handle_train_request(self, request: web.Request) -> web.Response: """Handle training request from poc-agent.""" try: data = await request.json() # Validate required fields if 'training_data' not in data: return web.json_response( {'error': 'Missing training_data field'}, status=400 ) job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}" job = TrainingJob( job_id=job_id, status=TrainingStatus.PENDING, created_at=datetime.now(), model_path=self.config['model_path'] ) self.jobs[job_id] = job # Start training in background asyncio.create_task(self.execute_training(job, data)) return web.json_response({ 'job_id': job_id, 'status': 'accepted', 'message': 'Training job started' }) except Exception as e: logger.error(f"Error handling train request: {e}") return web.json_response( {'error': str(e)}, status=500 ) async def handle_status_request(self, request: web.Request) -> web.Response: """Get training job status.""" job_id = request.match_info['job_id'] if job_id not in self.jobs: return web.json_response( {'error': 'Job not found'}, status=404 ) job = self.jobs[job_id] return web.json_response(job.to_dict()) async def handle_list_checkpoints(self, request: web.Request) -> web.Response: """List available checkpoints.""" checkpoint_dir = Path(self.config['checkpoint_dir']) checkpoints = [] if checkpoint_dir.exists(): for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True): checkpoints.append({ 'filename': checkpoint_file.name, 'path': str(checkpoint_file), 'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(), 'size': checkpoint_file.stat().st_size }) return web.json_response({'checkpoints': checkpoints}) async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]): """Execute the training pipeline.""" try: logger.info(f"Starting training job {job.job_id}") job.started_at = datetime.now() # Step 1: Pause vLLM job.status = TrainingStatus.PAUSING_VLLM logger.info("Pausing vLLM...") await self.pause_vllm() self.vllm_paused = True # Step 2: Load model and prepare for training job.status = TrainingStatus.TRAINING logger.info("Loading model and preparing for training...") # Load model (this would be the actual Qwen3.5-27B model) # For now, we'll use a placeholder model = await self.load_model_for_training() # Step 3: Run APOLLO-Mini training logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples") # Extract training samples samples = training_data['samples'] job.training_samples = len(samples) # Run training loop loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {})) job.loss_history = loss_history # Step 4: Save checkpoint job.status = TrainingStatus.SAVING_CHECKPOINT logger.info("Saving checkpoint...") checkpoint_path = await self.save_checkpoint(model, job) job.checkpoint_path = checkpoint_path # Step 5: Resume vLLM job.status = TrainingStatus.RESUMING_VLLM logger.info("Resuming vLLM...") await self.resume_vllm() self.vllm_paused = False # Mark job as completed job.status = TrainingStatus.COMPLETED job.completed_at = datetime.now() logger.info(f"Training job {job.job_id} completed successfully") # Schedule checkpoint sync (batched — won't duplicate if timer pending) self.schedule_checkpoint_sync() except Exception as e: logger.error(f"Training job {job.job_id} failed: {e}") job.status = TrainingStatus.FAILED job.error = str(e) job.completed_at = datetime.now() # Try to resume vLLM if it was paused if self.vllm_paused: try: await self.resume_vllm() self.vllm_paused = False except Exception as resume_error: logger.error(f"Failed to resume vLLM after training error: {resume_error}") async def pause_vllm(self): """Pause vLLM inference via HTTP API.""" import aiohttp as aio url = self.config.get('vllm_url', 'http://localhost:8000') try: async with aio.ClientSession() as session: async with session.post( f"{url}/pause_generation", json={"mode": "keep", "clear_cache": False}, timeout=aio.ClientTimeout(total=10), ) as resp: resp.raise_for_status() logger.info("vLLM paused") except Exception as e: logger.warning(f"Failed to pause vLLM: {e}") async def resume_vllm(self): """Resume vLLM inference via HTTP API.""" import aiohttp as aio url = self.config.get('vllm_url', 'http://localhost:8000') try: async with aio.ClientSession() as session: async with session.post( f"{url}/resume_generation", timeout=aio.ClientTimeout(total=10), ) as resp: resp.raise_for_status() logger.info("vLLM resumed") except Exception as e: logger.warning(f"Failed to resume vLLM: {e}") def schedule_checkpoint_sync(self): """Schedule a checkpoint sync in 10 minutes, if not already scheduled. This batches multiple training runs into a single sync — the timer resets only when no timer is pending. """ if self._checkpoint_timer is not None: logger.debug("Checkpoint sync already scheduled, skipping") return self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay()) logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes") async def _checkpoint_sync_after_delay(self): """Wait then sync — the actual timer task.""" try: await asyncio.sleep(CHECKPOINT_DELAY_SECS) await self._do_checkpoint_sync() except asyncio.CancelledError: logger.debug("Checkpoint sync cancelled") finally: self._checkpoint_timer = None async def _do_checkpoint_sync(self): """Execute the checkpoint sync.""" try: from apollo_plugin.checkpoint_sync import checkpoint_sync logger.info("Starting checkpoint sync...") result = checkpoint_sync( self.config['model_path'], self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'), ) changed_mb = result['total_changed'] / 1e6 logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written") except Exception as e: logger.error(f"Checkpoint sync failed: {e}") async def load_model_for_training(self) -> nn.Module: """Load HF model with weights pointing to vLLM's GPU memory. Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible views (narrowing merged weights into separate q/k/v/z etc.), and constructs the HF model around those views. No weight copying — all parameters share vLLM's GPU memory. """ handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt') model_path = self.config['model_path'] # Import vLLM weights via CUDA IPC logger.info(f"Importing vLLM weights from {handle_path}") handles = torch.load(handle_path, weights_only=False) vllm_params = {} for name, info in handles.items(): func, args = info['handle'] vllm_params[name] = func(*args) logger.info(f"Imported {len(vllm_params)} parameters") # Map vLLM merged layout → HF separate layout (views, no copies) from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights model = load_hf_model_with_vllm_weights(vllm_params, model_path) logger.info("HF model constructed with vLLM weight views") return model async def run_apollo_training(self, model: nn.Module, samples: List[Dict[str, Any]], config: Dict[str, Any]) -> List[float]: """Run Apollo-Mini training on conversation decision points. Each sample has: context_ids: token IDs for frozen context (no gradients) continuation_ids: token IDs for the decision we're training on """ from apollo_plugin.optimizer import Apollo # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] for p in model.parameters(): if p.requires_grad: if p.ndim >= 2 and min(p.shape) >= 2: apollo_params.append(p) else: standard_params.append(p) groups = [] if apollo_params: groups.append({'params': apollo_params}) if standard_params: groups.append({'params': standard_params}) # Apollo settings from request config, falling back to server defaults optimizer = Apollo( groups, lr=config.get('lr', self.config.get('learning_rate', 1e-5)), rank=config.get('rank', 256), betas=tuple(config.get('betas', (0.9, 0.999))), eps=config.get('eps', 1e-8), weight_decay=config.get('weight_decay', 0.01), warmup_steps=config.get('warmup_steps', 0), scale=config.get('scale'), # None = auto proj_refresh=config.get('proj_refresh', 200), norm_growth_limit=config.get('norm_growth_limit', 1.01), ) rank = config.get('rank', 256) lr = config.get('lr', self.config.get('learning_rate', 1e-5)) logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " f"state={optimizer.state_size_bytes()/1e6:.1f}MB") loss_history = [] for i, sample in enumerate(samples): # context_ids: frozen (forward only, no gradients) # continuation_ids: the decision we're training on ctx_ids = sample['context_ids'] cont_ids = sample['continuation_ids'] all_ids = ctx_ids + cont_ids context_len = len(ctx_ids) input_ids = torch.tensor([all_ids], device='cuda:0') optimizer.zero_grad() # Context-frozen forward pass with torch.no_grad(): # Forward through context (no gradients) outputs = model(input_ids[:, :context_len], use_cache=True) past_kv = outputs.past_key_values # Decision tokens with gradients with torch.enable_grad(): outputs = model( input_ids[:, context_len:], past_key_values=past_kv, use_cache=False, ) logits = outputs.logits # [1, cont_len, vocab] # Shift: predict next token from each position shift_logits = logits[:, :-1].contiguous() shift_labels = input_ids[:, context_len + 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) loss.backward() optimizer.step() loss_val = loss.item() loss_history.append(loss_val) logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " f"(ctx={context_len}, cont={len(cont_ids)} tokens)") logger.info(f"Training done: {len(samples)} examples, " f"final loss={loss_history[-1]:.4f}") return loss_history async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str: """Save model checkpoint in HuggingFace safetensors format.""" from safetensors.torch import save_file import shutil checkpoint_dir = Path(self.config['checkpoint_dir']) date_str = datetime.now().strftime('%Y-%m-%d') out_dir = checkpoint_dir / date_str out_dir.mkdir(parents=True, exist_ok=True) # Save weights tensors = {name: p.data.contiguous().cpu() for name, p in model.named_parameters()} save_path = out_dir / "model.safetensors" save_file(tensors, str(save_path)) # Copy config files config_dir = Path(self.config['model_path']) for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']: src = config_dir / f if src.exists(): shutil.copy2(src, out_dir / f) # Save training metadata meta = { 'job_id': job.job_id, 'training_samples': job.training_samples, 'loss_history': job.loss_history, 'timestamp': datetime.now().isoformat(), } with open(out_dir / 'training-meta.json', 'w') as f: json.dump(meta, f, indent=2) # Update latest symlink latest = checkpoint_dir / 'latest' if latest.is_symlink(): latest.unlink() latest.symlink_to(date_str) size_gb = save_path.stat().st_size / 1e9 logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)") return str(out_dir) async def run(self): """Run the daemon.""" logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}") runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, self.config['host'], self.config['port']) await site.start() logger.info("Apollo Worker is running") # Keep running while True: await asyncio.sleep(3600) # Sleep for an hour def main(): worker = ApolloWorker() asyncio.run(worker.run()) if __name__ == '__main__': main()