diff --git a/training/DESIGN.md b/training/DESIGN.md index f966fa4..bf6a774 100644 --- a/training/DESIGN.md +++ b/training/DESIGN.md @@ -22,25 +22,29 @@ The training signal comes from two sources: │ │ │ ┌──────────────────────────────────────────────┐ │ │ │ Model Weights (54GB, bf16) │ │ -│ │ Shared via CUDA IPC │ │ +│ │ Shared: vLLM inference + HF training │ │ │ └──────────────┬──────────────┬────────────────┘ │ │ │ │ │ │ ┌──────────────▼──┐ ┌───────▼────────────────┐ │ -│ │ vLLM (inference)│ │ Apollo (training) │ │ +│ │ vLLM (inference)│ │ HF model (training) │ │ │ │ KV cache ~60GB │ │ Gradients ~54GB │ │ -│ │ Serves requests │ │ Optimizer state ~10GB │ │ -│ │ Never paused │ │ Activations ~10GB │ │ +│ │ /completions │ │ Optimizer state ~10GB │ │ +│ │ /score │ │ Views into vLLM weights │ │ +│ │ /train ────────┼──┼─► Apollo optimizer │ │ │ └─────────────────┘ └────────────────────────┘ │ └─────────────────────────────────────────────────────┘ -Moria B200 + Single vLLM process serves everything + No separate daemon - /train is a vLLM route + +Moria B200 (vLLM) ┌──────────────────┐ ┌──────────────────┐ -│ Training signal │ HTTP │ Apollo worker │ -│ agent │──────────>│ daemon │ -│ │ │ │ -│ Dream loop │ │ Checkpoint sync │ -│ (generates │ │ (mmap + diff, │ -│ scenarios) │ │ every 10 min) │ +│ Training signal │ HTTP │ /completions │ +│ agent │──────────>│ /score │ +│ │ │ /train │ +│ Dream loop │ │ │ +│ (generates │ │ Checkpoint sync │ +│ scenarios) │ │ (10 min batched) │ └──────────────────┘ └──────────────────┘ ``` @@ -220,34 +224,30 @@ a few hundred MB. ## Components ### Built ✓ -- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256) -- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking) +- `optimizer.py` — Apollo optimizer (configurable rank, default 256) +- `train_router.py` — /train endpoint, runs in vLLM process - `weight_mapping.py` — vLLM merged → HF separate views (validated) -- `training_example.py` — tokenization with chat template -- `vllm_export_hook.py` — source patch for IPC handle export -- `checkpoint/` — Rust tool for mmap + diff checkpoint sync +- `export_hook.py` — vLLM plugin hook for IPC handle export +- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python) ### To build -- **Dream loop → training bridge**: connect dream output to Apollo +- **Dream loop → training bridge**: connect dream output to /train - **Training-signal agent**: flags moments in conversation logs - **Instruction stripping**: remove scaffolding from training examples - **Quality monitoring**: track model capability over time -- **HF model forward pass integration**: wire into apollo_worker ## Files ``` training/ - DESIGN.md — this document - apollo_mini.py — Apollo optimizer - apollo_worker.py — HTTP training daemon - weight_mapping.py — vLLM ↔ HF weight views - training_example.py — tokenization helpers - export_weights.py — standalone weight export (unused) - vllm_export_hook.py — vLLM source patch for IPC export - start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch) - train.py — standalone training script (alternative) - checkpoint/ - Cargo.toml — Rust checkpoint tool - src/main.rs — mmap + diff sync + DESIGN.md — this document + pyproject.toml — package config, vLLM plugin entry point + apollo_plugin/ + __init__.py — plugin registration + export_hook.py — patches vLLM to export IPC handles + train_router.py — /train endpoint (FastAPI router) + optimizer.py — Apollo optimizer + weight_mapping.py — vLLM ↔ HF weight views + checkpoint_sync.py — mmap + diff sync to safetensors + steering.py — steering vector extraction (experimental) ``` diff --git a/training/apollo_plugin/__init__.py b/training/apollo_plugin/__init__.py index bfbecd0..b2e121e 100644 --- a/training/apollo_plugin/__init__.py +++ b/training/apollo_plugin/__init__.py @@ -1,8 +1,8 @@ """Apollo training plugin for vLLM. Enables continuous fine-tuning alongside live inference by: -1. Exporting CUDA IPC handles for weight sharing -2. Providing a training worker daemon (/train endpoint) +1. Exporting CUDA IPC handles for weight sharing (export_hook) +2. Adding /train endpoint to vLLM's HTTP server (train_router) 3. Block-level checkpoint sync to safetensors files Install: pip install -e /path/to/training @@ -10,8 +10,10 @@ Then vLLM auto-loads via entry point. """ from .export_hook import _patch_model_runner +from .train_router import _patch_api_server def register(): """Called by vLLM's plugin loader on startup.""" _patch_model_runner() + _patch_api_server() diff --git a/training/apollo_plugin/export_hook.py b/training/apollo_plugin/export_hook.py index 4853930..821163b 100644 --- a/training/apollo_plugin/export_hook.py +++ b/training/apollo_plugin/export_hook.py @@ -59,6 +59,10 @@ def _patch_model_runner(): result = original_load(self, *args, **kwargs) try: export_model_weights(self.model_runner.model) + # Set model path for training router + model_path = self.vllm_config.model_config.model + from .train_router import set_model_path + set_model_path(model_path) except Exception as e: print(f"[apollo] Failed to export weights: {e}") return result diff --git a/training/apollo_plugin/train_router.py b/training/apollo_plugin/train_router.py new file mode 100644 index 0000000..6fa4883 --- /dev/null +++ b/training/apollo_plugin/train_router.py @@ -0,0 +1,282 @@ +"""Training endpoint for vLLM - runs Apollo training in-process. + +Patches vLLM's build_app() to add /train route. Training runs HOGWILD +style - no pause needed, weights updated in-place while inference continues. +""" + +import logging +from datetime import datetime +from typing import Any + +import torch +import torch.nn as nn +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +class TrainingSample(BaseModel): + context_ids: list[int] + continuation_ids: list[int] + + +class TrainRequest(BaseModel): + training_data: dict[str, Any] # {"samples": [...], "config": {...}} + + +class TrainResponse(BaseModel): + job_id: str + status: str + training_samples: int + loss_history: list[float] + + +# Global reference to HF model with vLLM weight views +_model: nn.Module | None = None +_model_path: str | None = None +_initialized: bool = False + + +def _load_training_model() -> nn.Module: + """Load HF model with weights pointing to vLLM's GPU memory. + + Uses CUDA IPC handles exported by export_hook to create an HF model + whose parameters share GPU memory with vLLM's model. + """ + from .weight_mapping import load_hf_model_with_vllm_weights + from .export_hook import HANDLE_PATH + + handles = torch.load(HANDLE_PATH, weights_only=False) + vllm_params = {} + for name, info in handles.items(): + func, args = info['handle'] + vllm_params[name] = func(*args) + + model = load_hf_model_with_vllm_weights(vllm_params, _model_path) + model.train() + return model + + +def _ensure_initialized(): + """Lazy-initialize the training model on first /train request.""" + global _model, _initialized + + if _initialized: + return + + if _model_path is None: + raise RuntimeError("Model path not set - export_hook may not have run") + + logger.info("[apollo] Loading HF model with vLLM weight views...") + _model = _load_training_model() + _initialized = True + logger.info("[apollo] Training model ready") + + +def set_model_path(path: str): + """Set model path for training. Called by export_hook after model load.""" + global _model_path + _model_path = path + logger.info(f"[apollo] Model path set: {path}") + + +@router.post("/train") +async def handle_train(request: TrainRequest, raw_request: Request): + """Handle training request - runs Apollo training on provided samples.""" + global _model + + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={"error": f"Training not available: {e}"}, + status_code=503, + ) + + try: + training_data = request.training_data + samples = training_data.get("samples", []) + config = training_data.get("config", {}) + + if not samples: + return JSONResponse( + content={"error": "No training samples provided"}, + status_code=400, + ) + + job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples") + + # Run training + loss_history = await run_training(_model, samples, config) + + logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}") + + # Schedule checkpoint sync (batched, 10 min delay) + schedule_checkpoint_sync() + + return JSONResponse(content={ + "job_id": job_id, + "status": "completed", + "training_samples": len(samples), + "loss_history": loss_history, + }) + + except Exception as e: + logger.exception(f"[apollo] Training failed: {e}") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + +async def run_training( + model: nn.Module, + samples: list[dict[str, Any]], + config: dict[str, Any], +) -> list[float]: + """Run Apollo training on the given samples. + + Each sample has: + context_ids: token IDs for frozen context (no gradients) + continuation_ids: token IDs for the decision we're training on + """ + from .optimizer import Apollo + + # Build parameter groups (Apollo for 2D+, standard for small/1D) + apollo_params, standard_params = [], [] + for p in model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= 256: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + if not groups: + raise ValueError("No trainable parameters found") + + # Apollo settings from request config + optimizer = Apollo( + groups, + lr=config.get('lr', 1e-5), + rank=config.get('rank', 256), + betas=tuple(config.get('betas', (0.9, 0.999))), + eps=config.get('eps', 1e-8), + weight_decay=config.get('weight_decay', 0.01), + warmup_steps=config.get('warmup_steps', 0), + scale=config.get('scale'), + proj_refresh=config.get('proj_refresh', 200), + norm_growth_limit=config.get('norm_growth_limit', 1.01), + ) + + logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, " + f"{len(standard_params)} standard, " + f"state={optimizer.state_size_bytes()/1e6:.1f}MB") + + loss_history = [] + + for i, sample in enumerate(samples): + ctx_ids = sample['context_ids'] + cont_ids = sample['continuation_ids'] + all_ids = ctx_ids + cont_ids + context_len = len(ctx_ids) + + input_ids = torch.tensor([all_ids], device='cuda:0') + + optimizer.zero_grad() + + # Context-frozen forward pass + with torch.no_grad(): + outputs = model(input_ids[:, :context_len], use_cache=True) + past_kv = outputs.past_key_values + + # Decision tokens with gradients + with torch.enable_grad(): + outputs = model( + input_ids[:, context_len:], + past_key_values=past_kv, + use_cache=False, + ) + logits = outputs.logits + + # Shift: predict next token from each position + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + + loss = nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + loss.backward() + optimizer.step() + + loss_val = loss.item() + loss_history.append(loss_val) + logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} " + f"(ctx={context_len}, cont={len(cont_ids)} tokens)") + + return loss_history + + +# Checkpoint sync scheduling +_checkpoint_task = None +CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes + + +def schedule_checkpoint_sync(): + """Schedule checkpoint sync after delay (batched).""" + global _checkpoint_task + import asyncio + + if _checkpoint_task is not None: + # Already scheduled + return + + async def do_sync(): + global _checkpoint_task + try: + await asyncio.sleep(CHECKPOINT_DELAY_SECS) + if _model_path: + from .checkpoint_sync import checkpoint_sync + logger.info("[apollo] Starting checkpoint sync...") + result = checkpoint_sync(_model_path) + logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB") + except Exception as e: + logger.error(f"[apollo] Checkpoint sync failed: {e}") + finally: + _checkpoint_task = None + + _checkpoint_task = asyncio.create_task(do_sync()) + logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min") + + +def attach_router(app: FastAPI): + """Attach training router to FastAPI app.""" + app.include_router(router) + logger.info("[apollo] Training router attached") + + +def _patch_api_server(): + """Patch vLLM's build_app to include our training router.""" + from vllm.entrypoints.openai import api_server + + original_build_app = api_server.build_app + + def patched_build_app(*args, **kwargs): + app = original_build_app(*args, **kwargs) + attach_router(app) + return app + + api_server.build_app = patched_build_app + logger.info("[apollo] API server patched for /train endpoint") diff --git a/training/apollo_plugin/worker.py b/training/apollo_plugin/worker.py deleted file mode 100755 index d180c13..0000000 --- a/training/apollo_plugin/worker.py +++ /dev/null @@ -1,509 +0,0 @@ -#!/usr/bin/env python3 -""" -Apollo Mini Training Daemon - -This daemon: -1. Listens over HTTPS for training requests from poc-agent -2. Pauses vLLM inference -3. Runs APOLLO-Mini training with torch.enable_grad() -4. Saves checkpoints and training metadata -5. Resumes vLLM inference - -Communication protocol: -- POST /train: Start a training job -- GET /status/{job_id}: Check training status -- GET /checkpoints: List available checkpoints -""" - -import asyncio -import json -import logging -import os -import sys -import time -from dataclasses import dataclass, field, asdict -from datetime import datetime -from pathlib import Path -from typing import Optional, Dict, Any, List -from enum import Enum - -import torch -import torch.nn as nn -from aiohttp import web - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('apollo_worker') - -class TrainingStatus(Enum): - PENDING = "pending" - PAUSING_VLLM = "pausing_vllm" - TRAINING = "training" - SAVING_CHECKPOINT = "saving_checkpoint" - RESUMING_VLLM = "resuming_vllm" - COMPLETED = "completed" - FAILED = "failed" - -@dataclass -class TrainingJob: - job_id: str - status: TrainingStatus - created_at: datetime - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - model_path: Optional[str] = None - checkpoint_path: Optional[str] = None - training_samples: int = 0 - loss_history: List[float] = field(default_factory=list) - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return { - 'job_id': self.job_id, - 'status': self.status.value, - 'created_at': self.created_at.isoformat(), - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'model_path': self.model_path, - 'checkpoint_path': self.checkpoint_path, - 'training_samples': self.training_samples, - 'loss_history': self.loss_history, - 'error': self.error, - } - -CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes - - -class ApolloWorker: - def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): - self.config = self._load_config(config_path) - self.jobs: Dict[str, TrainingJob] = {} - self.vllm_paused = False - self.app = web.Application() - self._setup_routes() - self._checkpoint_timer: Optional[asyncio.Task] = None - - def _load_config(self, config_path: str) -> Dict[str, Any]: - """Load configuration from file or use defaults.""" - default_config = { - 'host': '0.0.0.0', - 'port': 8080, - 'vllm_socket': '/tmp/vllm_control.sock', - 'model_path': '/home/ubuntu/models/Qwen3.5-27B', - 'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints', - 'max_training_samples': 100, - 'learning_rate': 1e-5, - 'batch_size': 1, - } - - if os.path.exists(config_path): - with open(config_path, 'r') as f: - user_config = json.load(f) - default_config.update(user_config) - - Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) - return default_config - - def _setup_routes(self): - """Setup HTTP routes.""" - self.app.router.add_post('/train', self.handle_train_request) - self.app.router.add_get('/status/{job_id}', self.handle_status_request) - self.app.router.add_get('/checkpoints', self.handle_list_checkpoints) - self.app.router.add_get('/health', self.handle_health_check) - - async def handle_health_check(self, request: web.Request) -> web.Response: - """Health check endpoint.""" - return web.json_response({ - 'status': 'healthy', - 'vllm_paused': self.vllm_paused, - 'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]]) - }) - - async def handle_train_request(self, request: web.Request) -> web.Response: - """Handle training request from poc-agent.""" - try: - data = await request.json() - - # Validate required fields - if 'training_data' not in data: - return web.json_response( - {'error': 'Missing training_data field'}, - status=400 - ) - - job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}" - job = TrainingJob( - job_id=job_id, - status=TrainingStatus.PENDING, - created_at=datetime.now(), - model_path=self.config['model_path'] - ) - self.jobs[job_id] = job - - # Start training in background - asyncio.create_task(self.execute_training(job, data)) - - return web.json_response({ - 'job_id': job_id, - 'status': 'accepted', - 'message': 'Training job started' - }) - - except Exception as e: - logger.error(f"Error handling train request: {e}") - return web.json_response( - {'error': str(e)}, - status=500 - ) - - async def handle_status_request(self, request: web.Request) -> web.Response: - """Get training job status.""" - job_id = request.match_info['job_id'] - - if job_id not in self.jobs: - return web.json_response( - {'error': 'Job not found'}, - status=404 - ) - - job = self.jobs[job_id] - return web.json_response(job.to_dict()) - - async def handle_list_checkpoints(self, request: web.Request) -> web.Response: - """List available checkpoints.""" - checkpoint_dir = Path(self.config['checkpoint_dir']) - checkpoints = [] - - if checkpoint_dir.exists(): - for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True): - checkpoints.append({ - 'filename': checkpoint_file.name, - 'path': str(checkpoint_file), - 'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(), - 'size': checkpoint_file.stat().st_size - }) - - return web.json_response({'checkpoints': checkpoints}) - - async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]): - """Execute the training pipeline.""" - try: - logger.info(f"Starting training job {job.job_id}") - job.started_at = datetime.now() - - # Step 1: Pause vLLM - job.status = TrainingStatus.PAUSING_VLLM - logger.info("Pausing vLLM...") - await self.pause_vllm() - self.vllm_paused = True - - # Step 2: Load model and prepare for training - job.status = TrainingStatus.TRAINING - logger.info("Loading model and preparing for training...") - - # Load model (this would be the actual Qwen3.5-27B model) - # For now, we'll use a placeholder - model = await self.load_model_for_training() - - # Step 3: Run APOLLO-Mini training - logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples") - - # Extract training samples - samples = training_data['samples'] - job.training_samples = len(samples) - - # Run training loop - loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {})) - job.loss_history = loss_history - - # Step 4: Save checkpoint - job.status = TrainingStatus.SAVING_CHECKPOINT - logger.info("Saving checkpoint...") - checkpoint_path = await self.save_checkpoint(model, job) - job.checkpoint_path = checkpoint_path - - # Step 5: Resume vLLM - job.status = TrainingStatus.RESUMING_VLLM - logger.info("Resuming vLLM...") - await self.resume_vllm() - self.vllm_paused = False - - # Mark job as completed - job.status = TrainingStatus.COMPLETED - job.completed_at = datetime.now() - - logger.info(f"Training job {job.job_id} completed successfully") - - # Schedule checkpoint sync (batched — won't duplicate if timer pending) - self.schedule_checkpoint_sync() - - except Exception as e: - logger.error(f"Training job {job.job_id} failed: {e}") - job.status = TrainingStatus.FAILED - job.error = str(e) - job.completed_at = datetime.now() - - # Try to resume vLLM if it was paused - if self.vllm_paused: - try: - await self.resume_vllm() - self.vllm_paused = False - except Exception as resume_error: - logger.error(f"Failed to resume vLLM after training error: {resume_error}") - - async def pause_vllm(self): - """Pause vLLM inference via HTTP API.""" - import aiohttp as aio - url = self.config.get('vllm_url', 'http://localhost:8000') - try: - async with aio.ClientSession() as session: - async with session.post( - f"{url}/pause_generation", - json={"mode": "keep", "clear_cache": False}, - timeout=aio.ClientTimeout(total=10), - ) as resp: - resp.raise_for_status() - logger.info("vLLM paused") - except Exception as e: - logger.warning(f"Failed to pause vLLM: {e}") - - async def resume_vllm(self): - """Resume vLLM inference via HTTP API.""" - import aiohttp as aio - url = self.config.get('vllm_url', 'http://localhost:8000') - try: - async with aio.ClientSession() as session: - async with session.post( - f"{url}/resume_generation", - timeout=aio.ClientTimeout(total=10), - ) as resp: - resp.raise_for_status() - logger.info("vLLM resumed") - except Exception as e: - logger.warning(f"Failed to resume vLLM: {e}") - - def schedule_checkpoint_sync(self): - """Schedule a checkpoint sync in 10 minutes, if not already scheduled. - - This batches multiple training runs into a single sync — the timer - resets only when no timer is pending. - """ - if self._checkpoint_timer is not None: - logger.debug("Checkpoint sync already scheduled, skipping") - return - - self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay()) - logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes") - - async def _checkpoint_sync_after_delay(self): - """Wait then sync — the actual timer task.""" - try: - await asyncio.sleep(CHECKPOINT_DELAY_SECS) - await self._do_checkpoint_sync() - except asyncio.CancelledError: - logger.debug("Checkpoint sync cancelled") - finally: - self._checkpoint_timer = None - - async def _do_checkpoint_sync(self): - """Execute the checkpoint sync.""" - try: - from apollo_plugin.checkpoint_sync import checkpoint_sync - logger.info("Starting checkpoint sync...") - result = checkpoint_sync( - self.config['model_path'], - self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'), - ) - changed_mb = result['total_changed'] / 1e6 - logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written") - except Exception as e: - logger.error(f"Checkpoint sync failed: {e}") - - async def load_model_for_training(self) -> nn.Module: - """Load HF model with weights pointing to vLLM's GPU memory. - - Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible - views (narrowing merged weights into separate q/k/v/z etc.), and - constructs the HF model around those views. No weight copying — - all parameters share vLLM's GPU memory. - """ - handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt') - model_path = self.config['model_path'] - - # Import vLLM weights via CUDA IPC - logger.info(f"Importing vLLM weights from {handle_path}") - handles = torch.load(handle_path, weights_only=False) - vllm_params = {} - for name, info in handles.items(): - func, args = info['handle'] - vllm_params[name] = func(*args) - logger.info(f"Imported {len(vllm_params)} parameters") - - # Map vLLM merged layout → HF separate layout (views, no copies) - from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights - model = load_hf_model_with_vllm_weights(vllm_params, model_path) - logger.info("HF model constructed with vLLM weight views") - - return model - - async def run_apollo_training(self, model: nn.Module, - samples: List[Dict[str, Any]], - config: Dict[str, Any]) -> List[float]: - """Run Apollo-Mini training on conversation decision points. - - Each sample has: - context_ids: token IDs for frozen context (no gradients) - continuation_ids: token IDs for the decision we're training on - """ - from apollo_plugin.optimizer import Apollo - - # Build parameter groups (Apollo for 2D+, standard for small/1D) - apollo_params, standard_params = [], [] - for p in model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= 2: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - # Apollo settings from request config, falling back to server defaults - optimizer = Apollo( - groups, - lr=config.get('lr', self.config.get('learning_rate', 1e-5)), - rank=config.get('rank', 256), - betas=tuple(config.get('betas', (0.9, 0.999))), - eps=config.get('eps', 1e-8), - weight_decay=config.get('weight_decay', 0.01), - warmup_steps=config.get('warmup_steps', 0), - scale=config.get('scale'), # None = auto - proj_refresh=config.get('proj_refresh', 200), - norm_growth_limit=config.get('norm_growth_limit', 1.01), - ) - rank = config.get('rank', 256) - lr = config.get('lr', self.config.get('learning_rate', 1e-5)) - logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, " - f"{len(standard_params)} standard, " - f"state={optimizer.state_size_bytes()/1e6:.1f}MB") - - loss_history = [] - - for i, sample in enumerate(samples): - # context_ids: frozen (forward only, no gradients) - # continuation_ids: the decision we're training on - ctx_ids = sample['context_ids'] - cont_ids = sample['continuation_ids'] - all_ids = ctx_ids + cont_ids - context_len = len(ctx_ids) - - input_ids = torch.tensor([all_ids], device='cuda:0') - - optimizer.zero_grad() - - # Context-frozen forward pass - with torch.no_grad(): - # Forward through context (no gradients) - outputs = model(input_ids[:, :context_len], use_cache=True) - past_kv = outputs.past_key_values - - # Decision tokens with gradients - with torch.enable_grad(): - outputs = model( - input_ids[:, context_len:], - past_key_values=past_kv, - use_cache=False, - ) - logits = outputs.logits # [1, cont_len, vocab] - - # Shift: predict next token from each position - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - - loss = nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) - - loss.backward() - optimizer.step() - - loss_val = loss.item() - loss_history.append(loss_val) - logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " - f"(ctx={context_len}, cont={len(cont_ids)} tokens)") - - logger.info(f"Training done: {len(samples)} examples, " - f"final loss={loss_history[-1]:.4f}") - return loss_history - - async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str: - """Save model checkpoint in HuggingFace safetensors format.""" - from safetensors.torch import save_file - import shutil - - checkpoint_dir = Path(self.config['checkpoint_dir']) - date_str = datetime.now().strftime('%Y-%m-%d') - out_dir = checkpoint_dir / date_str - out_dir.mkdir(parents=True, exist_ok=True) - - # Save weights - tensors = {name: p.data.contiguous().cpu() - for name, p in model.named_parameters()} - save_path = out_dir / "model.safetensors" - save_file(tensors, str(save_path)) - - # Copy config files - config_dir = Path(self.config['model_path']) - for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', - 'special_tokens_map.json']: - src = config_dir / f - if src.exists(): - shutil.copy2(src, out_dir / f) - - # Save training metadata - meta = { - 'job_id': job.job_id, - 'training_samples': job.training_samples, - 'loss_history': job.loss_history, - 'timestamp': datetime.now().isoformat(), - } - with open(out_dir / 'training-meta.json', 'w') as f: - json.dump(meta, f, indent=2) - - # Update latest symlink - latest = checkpoint_dir / 'latest' - if latest.is_symlink(): - latest.unlink() - latest.symlink_to(date_str) - - size_gb = save_path.stat().st_size / 1e9 - logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)") - return str(out_dir) - - async def run(self): - """Run the daemon.""" - logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}") - runner = web.AppRunner(self.app) - await runner.setup() - site = web.TCPSite(runner, self.config['host'], self.config['port']) - await site.start() - logger.info("Apollo Worker is running") - - # Keep running - while True: - await asyncio.sleep(3600) # Sleep for an hour - -def main(): - worker = ApolloWorker() - asyncio.run(worker.run()) - -if __name__ == '__main__': - main() diff --git a/training/pyproject.toml b/training/pyproject.toml index 37ca129..cd6e1cc 100644 --- a/training/pyproject.toml +++ b/training/pyproject.toml @@ -20,7 +20,6 @@ dev = ["pytest"] apollo = "apollo_plugin:register" [project.scripts] -apollo-worker = "apollo_plugin.worker:main" apollo-checkpoint = "apollo_plugin.checkpoint_sync:main" [tool.setuptools.packages.find]