diff --git a/training/DESIGN.md b/training/DESIGN.md index 556525f..2df4e6d 100644 --- a/training/DESIGN.md +++ b/training/DESIGN.md @@ -26,25 +26,37 @@ The training signal comes from two sources: │ └──────────────┬──────────────┬────────────────┘ │ │ │ │ │ │ ┌──────────────▼──┐ ┌───────▼────────────────┐ │ -│ │ vLLM (inference)│ │ HF model (training) │ │ -│ │ KV cache ~60GB │ │ Gradients ~54GB │ │ -│ │ /completions │ │ Optimizer state ~10GB │ │ -│ │ /score │ │ Views into vLLM weights │ │ -│ │ /train ────────┼──┼─► Apollo optimizer │ │ -│ └─────────────────┘ └────────────────────────┘ │ +│ │ vLLM (inference)│ │ Training subprocess │ │ +│ │ KV cache ~60GB │ │ HF model wrapper │ │ +│ │ /completions │ │ Apollo optimizer ~2.5GB │ │ +│ │ /score │ │ Checkpoint sync │ │ +│ └────────┬────────┘ └───────────▲─────────────┘ │ +│ │ │ │ +│ │ ZMQ IPC │ │ +│ └───────────────────────┘ │ └─────────────────────────────────────────────────────┘ - Single vLLM process serves everything - No separate daemon - /train is a vLLM route +Process Architecture: +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │ +│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │ +│ │ │ │ │ │ +│ export_hook.py │ │ /completions │ │ HF model views │ +│ exports IPC │ │ /score │ │ Apollo optimizer│ +│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ + └──── IPC handles file ──────────────────┘ + /tmp/vllm_weight_handles.pt Moria B200 (vLLM) ┌──────────────────┐ ┌──────────────────┐ │ Training signal │ HTTP │ /completions │ │ agent │──────────>│ /score │ │ │ │ /train │ -│ Dream loop │ │ │ -│ (generates │ │ Checkpoint sync │ -│ scenarios) │ │ (10 min batched) │ +│ Dream loop │ │ /checkpoint │ +│ (generates │ │ /train/status │ +│ scenarios) │ │ │ └──────────────────┘ └──────────────────┘ ``` @@ -213,8 +225,9 @@ a few hundred MB. | File | Purpose | |------|---------| -| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. | -| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. | +| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). | +| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. | +| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. | | `/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. | ### Moria (client) @@ -224,12 +237,13 @@ a few hundred MB. | `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. | | `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. | -### In-memory +### In-memory (training_worker subprocess) | State | Location | Notes | |-------|----------|-------| -| Apollo optimizer | train_router._optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. | -| HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. | +| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. | +| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. | +| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. | ## Hyperparameters @@ -248,7 +262,8 @@ a few hundred MB. ### Built ✓ - `optimizer.py` — Apollo optimizer (configurable rank) -- `train_router.py` — /train endpoint, runs in vLLM process +- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ +- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync) - `weight_mapping.py` — vLLM merged → HF separate views (validated) - `export_hook.py` — vLLM plugin hook for IPC handle export - `checkpoint_sync.py` — mmap + diff checkpoint sync (Python) @@ -267,8 +282,9 @@ training/ pyproject.toml — package config, vLLM plugin entry point apollo_plugin/ __init__.py — plugin registration - export_hook.py — patches vLLM to export IPC handles - train_router.py — /train endpoint (FastAPI router) + export_hook.py — patches vLLM worker to export IPC handles + train_router.py — /train endpoint, forwards to worker via ZMQ + training_worker.py — training subprocess (HF model, Apollo, checkpoint) optimizer.py — Apollo optimizer weight_mapping.py — vLLM ↔ HF weight views checkpoint_sync.py — mmap + diff sync to safetensors diff --git a/training/apollo_plugin/checkpoint_sync.py b/training/apollo_plugin/checkpoint_sync.py index eff93cc..c2d7b2f 100644 --- a/training/apollo_plugin/checkpoint_sync.py +++ b/training/apollo_plugin/checkpoint_sync.py @@ -260,6 +260,9 @@ def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]: """ handles = torch.load(handles_path, weights_only=False) + # Skip metadata entry + handles.pop('__metadata__', None) + weights = {} for name, info in handles.items(): func, args = info['handle'] diff --git a/training/apollo_plugin/export_hook.py b/training/apollo_plugin/export_hook.py index 821163b..e0ff6fc 100644 --- a/training/apollo_plugin/export_hook.py +++ b/training/apollo_plugin/export_hook.py @@ -20,7 +20,7 @@ from pathlib import Path HANDLE_PATH = "/tmp/vllm_weight_handles.pt" -def export_model_weights(model): +def export_model_weights(model, model_path: str | None = None): """Export CUDA IPC handles for all model parameters.""" from torch.multiprocessing.reductions import reduce_tensor @@ -38,6 +38,12 @@ def export_model_weights(model): } total_bytes += param.nelement() * param.element_size() + # Include metadata for training worker + handles['__metadata__'] = { + 'model_path': model_path, + 'num_params': len(handles), + } + torch.save(handles, HANDLE_PATH) print(f"[apollo] Exported {len(handles)} weight handles " f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}") @@ -58,11 +64,8 @@ def _patch_model_runner(): def patched_load(self, *args, **kwargs): result = original_load(self, *args, **kwargs) try: - export_model_weights(self.model_runner.model) - # Set model path for training router model_path = self.vllm_config.model_config.model - from .train_router import set_model_path - set_model_path(model_path) + export_model_weights(self.model_runner.model, model_path) except Exception as e: print(f"[apollo] Failed to export weights: {e}") return result diff --git a/training/apollo_plugin/train_router.py b/training/apollo_plugin/train_router.py index 3a35119..d6f90b4 100644 --- a/training/apollo_plugin/train_router.py +++ b/training/apollo_plugin/train_router.py @@ -1,16 +1,23 @@ -"""Training endpoint for vLLM - runs Apollo training in-process. +"""Training endpoint for vLLM - forwards to training subprocess via ZMQ. -Patches vLLM's build_app() to add /train route. Training runs HOGWILD -style - no pause needed, weights updated in-place while inference continues. +Patches vLLM's build_app() to add /train route. The actual training runs +in a dedicated subprocess (training_worker.py) to avoid blocking the +event loop and to keep training work isolated from vLLM internals. """ +import asyncio import logging +import os +import subprocess +import sys from datetime import datetime +from pathlib import Path from typing import Any -import torch -import torch.nn as nn -from fastapi import APIRouter, FastAPI, Request +import zmq +import zmq.asyncio + +from fastapi import APIRouter, FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel @@ -18,10 +25,13 @@ logger = logging.getLogger(__name__) router = APIRouter() +DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" -class TrainingSample(BaseModel): - context_ids: list[int] - continuation_ids: list[int] +# Global state for subprocess management +_worker_process: subprocess.Popen | None = None +_zmq_context: zmq.asyncio.Context | None = None +_zmq_socket: zmq.asyncio.Socket | None = None +_initialized: bool = False class TrainRequest(BaseModel): @@ -35,64 +45,61 @@ class TrainResponse(BaseModel): loss_history: list[float] -# Global reference to HF model with vLLM weight views -_model: nn.Module | None = None -_model_path: str | None = None -_initialized: bool = False -_optimizer: Any = None # Persisted Apollo optimizer +def _start_worker_subprocess(): + """Start the training worker subprocess.""" + global _worker_process -OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" -DEFAULT_RANK = 64 + if _worker_process is not None and _worker_process.poll() is None: + return # Still running + # Start worker as subprocess using script path + worker_script = Path(__file__).parent / 'training_worker.py' + _worker_process = subprocess.Popen( + [sys.executable, str(worker_script)], + env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR}, + ) + logger.info(f"Started training worker subprocess (pid={_worker_process.pid})") -def _load_training_model() -> nn.Module: - """Load HF model with weights pointing to vLLM's GPU memory. - - Uses CUDA IPC handles exported by export_hook to create an HF model - whose parameters share GPU memory with vLLM's model. - """ - from .weight_mapping import load_hf_model_with_vllm_weights - from .export_hook import HANDLE_PATH - - handles = torch.load(HANDLE_PATH, weights_only=False) - vllm_params = {} - for name, info in handles.items(): - func, args = info['handle'] - vllm_params[name] = func(*args) - - model = load_hf_model_with_vllm_weights(vllm_params, _model_path) - model.train() - return model + # Give it a moment to bind the socket + import time + time.sleep(0.5) def _ensure_initialized(): - """Lazy-initialize the training model on first /train request.""" - global _model, _initialized + """Ensure subprocess is running and ZMQ socket is connected.""" + global _zmq_context, _zmq_socket, _initialized if _initialized: return - if _model_path is None: - raise RuntimeError("Model path not set - export_hook may not have run") + # Start worker if needed + _start_worker_subprocess() + + # Create async ZMQ context and socket + _zmq_context = zmq.asyncio.Context() + _zmq_socket = _zmq_context.socket(zmq.REQ) + _zmq_socket.connect(DEFAULT_ZMQ_ADDR) + + # Set timeout for recv + _zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training - logger.info("[apollo] Loading HF model with vLLM weight views...") - _model = _load_training_model() _initialized = True - logger.info("[apollo] Training model ready") + logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}") -def set_model_path(path: str): - """Set model path for training. Called by export_hook after model load.""" - global _model_path - _model_path = path - logger.info(f"[apollo] Model path set: {path}") +async def _send_request(request: dict[str, Any]) -> dict[str, Any]: + """Send request to worker and wait for response.""" + _ensure_initialized() + + # ZMQ async send/recv + await _zmq_socket.send_json(request) + response = await _zmq_socket.recv_json() + return response @router.post("/train") -async def handle_train(request: TrainRequest, raw_request: Request): - """Handle training request - runs Apollo training on provided samples.""" - global _model - +async def handle_train(request: TrainRequest): + """Handle training request - forwards to training subprocess.""" try: _ensure_initialized() except Exception as e: @@ -113,193 +120,109 @@ async def handle_train(request: TrainRequest, raw_request: Request): ) job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples") + logger.info(f"Starting training job {job_id} with {len(samples)} samples") - # Run training - loss_history = await run_training(_model, samples, config) + # Forward to worker + response = await _send_request({ + 'type': 'train', + 'samples': samples, + 'config': config, + }) - logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}") + if 'error' in response: + return JSONResponse( + content={"error": response['error']}, + status_code=500, + ) - # Schedule checkpoint sync (batched, 10 min delay) - schedule_checkpoint_sync() + logger.info( + f"Training job {job_id} completed, " + f"final loss: {response['loss_history'][-1]:.4f}" + ) return JSONResponse(content={ "job_id": job_id, - "status": "completed", - "training_samples": len(samples), - "loss_history": loss_history, + "status": response['status'], + "training_samples": response['training_samples'], + "loss_history": response['loss_history'], }) + except zmq.Again: + logger.error("Training request timed out") + return JSONResponse( + content={"error": "Training request timed out"}, + status_code=504, + ) except Exception as e: - logger.exception(f"[apollo] Training failed: {e}") + logger.exception(f"Training failed: {e}") return JSONResponse( content={"error": str(e)}, status_code=500, ) -def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]): - """Get existing optimizer or create new one. Persists state between calls.""" - global _optimizer - from .optimizer import Apollo - import os +@router.post("/checkpoint") +async def handle_checkpoint(): + """Trigger checkpoint sync to disk.""" + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={"error": f"Training not available: {e}"}, + status_code=503, + ) - if _optimizer is not None: - return _optimizer + try: + response = await _send_request({'type': 'checkpoint'}) - # Build parameter groups (Apollo for 2D+, standard for small/1D) - apollo_params, standard_params = [], [] - for p in model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - if not groups: - raise ValueError("No trainable parameters found") - - # Create optimizer - _optimizer = Apollo( - groups, - lr=config.get('lr', 1e-5), - rank=config.get('rank', DEFAULT_RANK), - betas=tuple(config.get('betas', (0.9, 0.999))), - eps=config.get('eps', 1e-8), - weight_decay=config.get('weight_decay', 0.01), - warmup_steps=config.get('warmup_steps', 0), - scale=config.get('scale'), - proj_refresh=config.get('proj_refresh', 200), - norm_growth_limit=config.get('norm_growth_limit', 1.01), - ) - - # Restore state if exists - if os.path.exists(OPTIMIZER_STATE_PATH): - try: - state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) - _optimizer.load_state_dict(state) - logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}") - except Exception as e: - logger.warning(f"[apollo] Could not restore optimizer state: {e}") - - logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, " - f"{len(standard_params)} standard, " - f"state={_optimizer.state_size_bytes()/1e6:.1f}MB") - - return _optimizer - - -def _save_optimizer_state(): - """Save optimizer state for persistence between /train calls.""" - global _optimizer - if _optimizer is not None: - torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH) - logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}") - - -async def run_training( - model: nn.Module, - samples: list[dict[str, Any]], - config: dict[str, Any], -) -> list[float]: - """Run Apollo training on the given samples. - - Each sample has: - context_ids: token IDs for frozen context (no gradients) - continuation_ids: token IDs for the decision we're training on - """ - optimizer = _get_or_create_optimizer(model, config) - - loss_history = [] - - for i, sample in enumerate(samples): - ctx_ids = sample['context_ids'] - cont_ids = sample['continuation_ids'] - all_ids = ctx_ids + cont_ids - context_len = len(ctx_ids) - - input_ids = torch.tensor([all_ids], device='cuda:0') - - optimizer.zero_grad() - - # Context-frozen forward pass - with torch.no_grad(): - outputs = model(input_ids[:, :context_len], use_cache=True) - past_kv = outputs.past_key_values - - # Decision tokens with gradients - with torch.enable_grad(): - outputs = model( - input_ids[:, context_len:], - past_key_values=past_kv, - use_cache=False, - ) - logits = outputs.logits - - # Shift: predict next token from each position - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - - loss = nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), + if 'error' in response: + return JSONResponse( + content={"error": response['error']}, + status_code=500, ) - loss.backward() - optimizer.step() + return JSONResponse(content=response) - loss_val = loss.item() - loss_history.append(loss_val) - logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} " - f"(ctx={context_len}, cont={len(cont_ids)} tokens)") - - return loss_history + except Exception as e: + logger.exception(f"Checkpoint failed: {e}") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) -# Checkpoint sync scheduling -_checkpoint_task = None -CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes +@router.get("/train/status") +async def handle_status(): + """Get training worker status.""" + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={ + "status": "unavailable", + "error": str(e), + }, + status_code=503, + ) + try: + response = await _send_request({'type': 'status'}) + return JSONResponse(content=response) -def schedule_checkpoint_sync(): - """Schedule checkpoint sync after delay (batched).""" - global _checkpoint_task - import asyncio - - if _checkpoint_task is not None: - # Already scheduled - return - - async def do_sync(): - global _checkpoint_task - try: - await asyncio.sleep(CHECKPOINT_DELAY_SECS) - if _model_path: - from .checkpoint_sync import checkpoint_sync - logger.info("[apollo] Starting checkpoint sync...") - # Save optimizer state alongside model weights - _save_optimizer_state() - result = checkpoint_sync(_model_path) - logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB") - except Exception as e: - logger.error(f"[apollo] Checkpoint sync failed: {e}") - finally: - _checkpoint_task = None - - _checkpoint_task = asyncio.create_task(do_sync()) - logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min") + except Exception as e: + return JSONResponse( + content={ + "status": "error", + "error": str(e), + }, + status_code=500, + ) def attach_router(app: FastAPI): """Attach training router to FastAPI app.""" app.include_router(router) - logger.info("[apollo] Training router attached") + logger.info("Training router attached") def _patch_api_server(): @@ -314,4 +237,4 @@ def _patch_api_server(): return app api_server.build_app = patched_build_app - logger.info("[apollo] API server patched for /train endpoint") + logger.info("API server patched for /train endpoint") diff --git a/training/apollo_plugin/training_worker.py b/training/apollo_plugin/training_worker.py new file mode 100644 index 0000000..f8b8c23 --- /dev/null +++ b/training/apollo_plugin/training_worker.py @@ -0,0 +1,323 @@ +"""Training subprocess - handles Apollo training and checkpoint sync. + +Long-lived process that: +1. Loads IPC handles from vLLM's exported weights +2. Creates HF model with views into vLLM's GPU memory +3. Handles training requests via ZMQ +4. Handles checkpoint sync requests +5. Persists Apollo optimizer state between calls + +Communicates with the API server's /train endpoint via ZMQ REP socket. +""" + +import logging +import os +import signal +import sys +from pathlib import Path +from typing import Any + +# Handle running as script vs module +if __name__ == '__main__' and __package__ is None: + # Running as script - add parent to path for imports + sys.path.insert(0, str(Path(__file__).parent.parent)) + __package__ = 'apollo_plugin' + +import torch +import torch.nn as nn +import zmq + +from .checkpoint_sync import checkpoint_sync +from .optimizer import Apollo +from .weight_mapping import load_hf_model_with_vllm_weights + +logger = logging.getLogger(__name__) + +DEFAULT_RANK = 64 +DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" +HANDLE_PATH = "/tmp/vllm_weight_handles.pt" +OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" + + +class TrainingWorker: + """Long-lived training worker process.""" + + def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR): + self.zmq_addr = zmq_addr + self.model: nn.Module | None = None + self.optimizer: Apollo | None = None + self.model_path: str | None = None + self._running = True + + def _create_model_wrapper(self) -> nn.Module: + """Create HF model wrapper with views into vLLM's GPU memory.""" + if not os.path.exists(HANDLE_PATH): + raise FileNotFoundError( + f"Weight handles not found: {HANDLE_PATH}. " + "Is vLLM running with the export hook?" + ) + + handles = torch.load(HANDLE_PATH, weights_only=False) + + # Extract metadata + metadata = handles.pop('__metadata__', {}) + self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH') + if not self.model_path: + raise ValueError( + "Model path not found in handles metadata or APOLLO_MODEL_PATH env var" + ) + + # Reconstruct tensors from IPC handles + vllm_params = {} + for name, info in handles.items(): + func, args = info['handle'] + vllm_params[name] = func(*args) + + model = load_hf_model_with_vllm_weights(vllm_params, self.model_path) + model.train() + return model + + def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo: + """Get existing optimizer or create new one.""" + if self.optimizer is not None: + return self.optimizer + + # Build parameter groups (Apollo for 2D+, standard Adam for small/1D) + apollo_params, standard_params = [], [] + for p in self.model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + if not groups: + raise ValueError("No trainable parameters found") + + self.optimizer = Apollo( + groups, + lr=config.get('lr', 1e-5), + rank=config.get('rank', DEFAULT_RANK), + betas=tuple(config.get('betas', (0.9, 0.999))), + eps=config.get('eps', 1e-8), + weight_decay=config.get('weight_decay', 0.01), + warmup_steps=config.get('warmup_steps', 0), + scale=config.get('scale'), + proj_refresh=config.get('proj_refresh', 200), + norm_growth_limit=config.get('norm_growth_limit', 1.01), + ) + + # Restore state if exists + if os.path.exists(OPTIMIZER_STATE_PATH): + try: + state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) + self.optimizer.load_state_dict(state) + logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}") + except Exception as e: + logger.warning(f"Could not restore optimizer state: {e}") + + logger.info( + f"Optimizer: {len(apollo_params)} apollo params, " + f"{len(standard_params)} standard, " + f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB" + ) + + return self.optimizer + + def _save_optimizer_state(self): + """Save optimizer state for persistence.""" + if self.optimizer is not None: + torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH) + logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}") + + def _run_training( + self, + samples: list[dict[str, Any]], + config: dict[str, Any], + ) -> list[float]: + """Run Apollo training on the given samples.""" + optimizer = self._get_or_create_optimizer(config) + + loss_history = [] + + for i, sample in enumerate(samples): + ctx_ids = sample['context_ids'] + cont_ids = sample['continuation_ids'] + all_ids = ctx_ids + cont_ids + context_len = len(ctx_ids) + + input_ids = torch.tensor([all_ids], device='cuda:0') + + optimizer.zero_grad() + + # Context-frozen forward pass + with torch.no_grad(): + outputs = self.model(input_ids[:, :context_len], use_cache=True) + past_kv = outputs.past_key_values + + # Decision tokens with gradients + with torch.enable_grad(): + outputs = self.model( + input_ids[:, context_len:], + past_key_values=past_kv, + use_cache=False, + ) + logits = outputs.logits + + # Shift: predict next token from each position + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + + loss = nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + loss.backward() + optimizer.step() + + loss_val = loss.item() + loss_history.append(loss_val) + logger.info( + f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " + f"(ctx={context_len}, cont={len(cont_ids)} tokens)" + ) + + return loss_history + + def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a training request.""" + samples = request.get('samples', []) + config = request.get('config', {}) + + if not samples: + return {'error': 'No training samples provided'} + + try: + loss_history = self._run_training(samples, config) + return { + 'status': 'completed', + 'training_samples': len(samples), + 'loss_history': loss_history, + } + except Exception as e: + logger.exception(f"Training failed: {e}") + return {'error': str(e)} + + def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a checkpoint sync request.""" + if not self.model_path: + return {'error': 'Model path not set'} + + try: + self._save_optimizer_state() + result = checkpoint_sync(self.model_path) + return { + 'status': 'completed', + 'total_changed': result['total_changed'], + 'files_changed': result['files_changed'], + } + except Exception as e: + logger.exception(f"Checkpoint sync failed: {e}") + return {'error': str(e)} + + def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a status request.""" + return { + 'status': 'ready', + 'model_loaded': self.model is not None, + 'optimizer_loaded': self.optimizer is not None, + 'model_path': self.model_path, + 'optimizer_state_mb': ( + self.optimizer.state_size_bytes() / 1e6 + if self.optimizer else 0 + ), + } + + def run(self): + """Main loop - listen for requests and handle them.""" + # Set up signal handlers + def handle_signal(signum, frame): + logger.info(f"Received signal {signum}, shutting down...") + self._running = False + + signal.signal(signal.SIGTERM, handle_signal) + signal.signal(signal.SIGINT, handle_signal) + + # Set up ZMQ socket first so API server can connect + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind(self.zmq_addr) + logger.info(f"Training worker listening on {self.zmq_addr}") + + # Create HF model wrapper with views into vLLM's GPU memory + logger.info("Connecting to vLLM weights via IPC handles...") + try: + self.model = self._create_model_wrapper() + logger.info("HF model wrapper ready (views into vLLM GPU memory)") + except Exception as e: + logger.error(f"Failed to connect to vLLM weights: {e}") + logger.info("Will retry on first training request") + + # Set socket timeout so we can check _running flag + socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout + + while self._running: + try: + message = socket.recv_json() + except zmq.Again: + # Timeout, check _running and continue + continue + + request_type = message.get('type', 'train') + logger.info(f"Received {request_type} request") + + # Ensure model is loaded + if self.model is None and request_type != 'status': + try: + self.model = self._create_model_wrapper() + except Exception as e: + socket.send_json({'error': f'Model not loaded: {e}'}) + continue + + # Dispatch request + if request_type == 'train': + response = self._handle_train(message) + elif request_type == 'checkpoint': + response = self._handle_checkpoint(message) + elif request_type == 'status': + response = self._handle_status(message) + else: + response = {'error': f'Unknown request type: {request_type}'} + + socket.send_json(response) + + # Cleanup + logger.info("Saving optimizer state before shutdown...") + self._save_optimizer_state() + socket.close() + context.term() + logger.info("Training worker shut down") + + +def main(): + """Entry point for running as a subprocess.""" + logging.basicConfig( + level=logging.INFO, + format='[apollo-worker] %(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S', + ) + + zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR) + worker = TrainingWorker(zmq_addr) + worker.run() + + +if __name__ == '__main__': + main() diff --git a/training/pyproject.toml b/training/pyproject.toml index cd6e1cc..7cf0581 100644 --- a/training/pyproject.toml +++ b/training/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "torch", "aiohttp", "safetensors", + "pyzmq", ] [project.optional-dependencies] @@ -21,6 +22,7 @@ apollo = "apollo_plugin:register" [project.scripts] apollo-checkpoint = "apollo_plugin.checkpoint_sync:main" +apollo-worker = "apollo_plugin.training_worker:main" [tool.setuptools.packages.find] where = ["."]