"""Training endpoint for vLLM - forwards to training subprocess via ZMQ. Patches vLLM's build_app() to add /train route. The actual training runs in a dedicated subprocess (training_worker.py) to avoid blocking the event loop and to keep training work isolated from vLLM internals. """ import asyncio import logging import os import subprocess import sys from datetime import datetime from pathlib import Path from typing import Any import zmq import zmq.asyncio from fastapi import APIRouter, FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel logger = logging.getLogger(__name__) router = APIRouter() DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" # Global state for subprocess management _worker_process: subprocess.Popen | None = None _zmq_context: zmq.asyncio.Context | None = None _zmq_socket: zmq.asyncio.Socket | None = None _initialized: bool = False class TrainRequest(BaseModel): training_data: dict[str, Any] # {"samples": [...], "config": {...}} class TrainResponse(BaseModel): job_id: str status: str training_samples: int loss_history: list[float] def _start_worker_subprocess(): """Start the training worker subprocess.""" global _worker_process if _worker_process is not None and _worker_process.poll() is None: return # Still running # Start worker as subprocess using script path worker_script = Path(__file__).parent / 'training_worker.py' _worker_process = subprocess.Popen( [sys.executable, str(worker_script)], env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR}, ) logger.info(f"Started training worker subprocess (pid={_worker_process.pid})") # Give it a moment to bind the socket import time time.sleep(0.5) def _ensure_initialized(): """Ensure subprocess is running and ZMQ socket is connected.""" global _zmq_context, _zmq_socket, _initialized if _initialized: return # Start worker if needed _start_worker_subprocess() # Create async ZMQ context and socket _zmq_context = zmq.asyncio.Context() _zmq_socket = _zmq_context.socket(zmq.REQ) _zmq_socket.connect(DEFAULT_ZMQ_ADDR) # Set timeout for recv _zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training _initialized = True logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}") async def _send_request(request: dict[str, Any]) -> dict[str, Any]: """Send request to worker and wait for response.""" _ensure_initialized() # ZMQ async send/recv await _zmq_socket.send_json(request) response = await _zmq_socket.recv_json() return response @router.post("/train") async def handle_train(request: TrainRequest): """Handle training request - forwards to training subprocess.""" try: _ensure_initialized() except Exception as e: return JSONResponse( content={"error": f"Training not available: {e}"}, status_code=503, ) try: training_data = request.training_data samples = training_data.get("samples", []) config = training_data.get("config", {}) if not samples: return JSONResponse( content={"error": "No training samples provided"}, status_code=400, ) job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" logger.info(f"Starting training job {job_id} with {len(samples)} samples") # Forward to worker response = await _send_request({ 'type': 'train', 'samples': samples, 'config': config, }) if 'error' in response: return JSONResponse( content={"error": response['error']}, status_code=500, ) logger.info( f"Training job {job_id} completed, " f"final loss: {response['loss_history'][-1]:.4f}" ) return JSONResponse(content={ "job_id": job_id, "status": response['status'], "training_samples": response['training_samples'], "loss_history": response['loss_history'], }) except zmq.Again: logger.error("Training request timed out") return JSONResponse( content={"error": "Training request timed out"}, status_code=504, ) except Exception as e: logger.exception(f"Training failed: {e}") return JSONResponse( content={"error": str(e)}, status_code=500, ) @router.post("/checkpoint") async def handle_checkpoint(): """Trigger checkpoint sync to disk.""" try: _ensure_initialized() except Exception as e: return JSONResponse( content={"error": f"Training not available: {e}"}, status_code=503, ) try: response = await _send_request({'type': 'checkpoint'}) if 'error' in response: return JSONResponse( content={"error": response['error']}, status_code=500, ) return JSONResponse(content=response) except Exception as e: logger.exception(f"Checkpoint failed: {e}") return JSONResponse( content={"error": str(e)}, status_code=500, ) @router.get("/train/status") async def handle_status(): """Get training worker status.""" try: _ensure_initialized() except Exception as e: return JSONResponse( content={ "status": "unavailable", "error": str(e), }, status_code=503, ) try: response = await _send_request({'type': 'status'}) return JSONResponse(content=response) except Exception as e: return JSONResponse( content={ "status": "error", "error": str(e), }, status_code=500, ) def attach_router(app: FastAPI): """Attach training router to FastAPI app.""" app.include_router(router) logger.info("Training router attached") def _patch_api_server(): """Patch vLLM's build_app to include our training router.""" from vllm.entrypoints.openai import api_server original_build_app = api_server.build_app def patched_build_app(*args, **kwargs): app = original_build_app(*args, **kwargs) attach_router(app) return app api_server.build_app = patched_build_app logger.info("API server patched for /train endpoint")