consciousness/training/apollo_plugin/train_router.py

241 lines
6.5 KiB
Python
Raw Permalink Normal View History

"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
Patches vLLM's build_app() to add /train route. The actual training runs
in a dedicated subprocess (training_worker.py) to avoid blocking the
event loop and to keep training work isolated from vLLM internals.
"""
import asyncio
import logging
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
import zmq
import zmq.asyncio
from fastapi import APIRouter, FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter()
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
# Global state for subprocess management
_worker_process: subprocess.Popen | None = None
_zmq_context: zmq.asyncio.Context | None = None
_zmq_socket: zmq.asyncio.Socket | None = None
_initialized: bool = False
class TrainRequest(BaseModel):
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
class TrainResponse(BaseModel):
job_id: str
status: str
training_samples: int
loss_history: list[float]
def _start_worker_subprocess():
"""Start the training worker subprocess."""
global _worker_process
if _worker_process is not None and _worker_process.poll() is None:
return # Still running
# Start worker as subprocess using script path
worker_script = Path(__file__).parent / 'training_worker.py'
_worker_process = subprocess.Popen(
[sys.executable, str(worker_script)],
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
)
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
# Give it a moment to bind the socket
import time
time.sleep(0.5)
def _ensure_initialized():
"""Ensure subprocess is running and ZMQ socket is connected."""
global _zmq_context, _zmq_socket, _initialized
if _initialized:
return
# Start worker if needed
_start_worker_subprocess()
# Create async ZMQ context and socket
_zmq_context = zmq.asyncio.Context()
_zmq_socket = _zmq_context.socket(zmq.REQ)
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
# Set timeout for recv
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
_initialized = True
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
"""Send request to worker and wait for response."""
_ensure_initialized()
# ZMQ async send/recv
await _zmq_socket.send_json(request)
response = await _zmq_socket.recv_json()
return response
@router.post("/train")
async def handle_train(request: TrainRequest):
"""Handle training request - forwards to training subprocess."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
training_data = request.training_data
samples = training_data.get("samples", [])
config = training_data.get("config", {})
if not samples:
return JSONResponse(
content={"error": "No training samples provided"},
status_code=400,
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
# Forward to worker
response = await _send_request({
'type': 'train',
'samples': samples,
'config': config,
})
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
logger.info(
f"Training job {job_id} completed, "
f"final loss: {response['loss_history'][-1]:.4f}"
)
return JSONResponse(content={
"job_id": job_id,
"status": response['status'],
"training_samples": response['training_samples'],
"loss_history": response['loss_history'],
})
except zmq.Again:
logger.error("Training request timed out")
return JSONResponse(
content={"error": "Training request timed out"},
status_code=504,
)
except Exception as e:
logger.exception(f"Training failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
@router.post("/checkpoint")
async def handle_checkpoint():
"""Trigger checkpoint sync to disk."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
response = await _send_request({'type': 'checkpoint'})
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
return JSONResponse(content=response)
except Exception as e:
logger.exception(f"Checkpoint failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
@router.get("/train/status")
async def handle_status():
"""Get training worker status."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={
"status": "unavailable",
"error": str(e),
},
status_code=503,
)
try:
response = await _send_request({'type': 'status'})
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(
content={
"status": "error",
"error": str(e),
},
status_code=500,
)
def attach_router(app: FastAPI):
"""Attach training router to FastAPI app."""
app.include_router(router)
logger.info("Training router attached")
def _patch_api_server():
"""Patch vLLM's build_app to include our training router."""
from vllm.entrypoints.openai import api_server
original_build_app = api_server.build_app
def patched_build_app(*args, **kwargs):
app = original_build_app(*args, **kwargs)
attach_router(app)
return app
api_server.build_app = patched_build_app
logger.info("API server patched for /train endpoint")