2026-04-16 02:01:59 -04:00
|
|
|
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
Patches vLLM's build_app() to add /train route. The actual training runs
|
|
|
|
|
in a dedicated subprocess (training_worker.py) to avoid blocking the
|
|
|
|
|
event loop and to keep training work isolated from vLLM internals.
|
2026-04-16 00:48:05 -04:00
|
|
|
"""
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
import asyncio
|
2026-04-16 00:48:05 -04:00
|
|
|
import logging
|
2026-04-16 02:01:59 -04:00
|
|
|
import os
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
2026-04-16 00:48:05 -04:00
|
|
|
from datetime import datetime
|
2026-04-16 02:01:59 -04:00
|
|
|
from pathlib import Path
|
2026-04-16 00:48:05 -04:00
|
|
|
from typing import Any
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
import zmq
|
|
|
|
|
import zmq.asyncio
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, FastAPI
|
2026-04-16 00:48:05 -04:00
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# Global state for subprocess management
|
|
|
|
|
_worker_process: subprocess.Popen | None = None
|
|
|
|
|
_zmq_context: zmq.asyncio.Context | None = None
|
|
|
|
|
_zmq_socket: zmq.asyncio.Socket | None = None
|
|
|
|
|
_initialized: bool = False
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
|
|
|
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainResponse(BaseModel):
|
|
|
|
|
job_id: str
|
|
|
|
|
status: str
|
|
|
|
|
training_samples: int
|
|
|
|
|
loss_history: list[float]
|
|
|
|
|
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
def _start_worker_subprocess():
|
|
|
|
|
"""Start the training worker subprocess."""
|
|
|
|
|
global _worker_process
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
if _worker_process is not None and _worker_process.poll() is None:
|
|
|
|
|
return # Still running
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# Start worker as subprocess using script path
|
|
|
|
|
worker_script = Path(__file__).parent / 'training_worker.py'
|
|
|
|
|
_worker_process = subprocess.Popen(
|
|
|
|
|
[sys.executable, str(worker_script)],
|
|
|
|
|
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# Give it a moment to bind the socket
|
|
|
|
|
import time
|
|
|
|
|
time.sleep(0.5)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_initialized():
|
2026-04-16 02:01:59 -04:00
|
|
|
"""Ensure subprocess is running and ZMQ socket is connected."""
|
|
|
|
|
global _zmq_context, _zmq_socket, _initialized
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
if _initialized:
|
|
|
|
|
return
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# Start worker if needed
|
|
|
|
|
_start_worker_subprocess()
|
|
|
|
|
|
|
|
|
|
# Create async ZMQ context and socket
|
|
|
|
|
_zmq_context = zmq.asyncio.Context()
|
|
|
|
|
_zmq_socket = _zmq_context.socket(zmq.REQ)
|
|
|
|
|
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
|
|
|
|
|
|
|
|
|
|
# Set timeout for recv
|
|
|
|
|
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
_initialized = True
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
|
|
|
|
|
"""Send request to worker and wait for response."""
|
|
|
|
|
_ensure_initialized()
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# ZMQ async send/recv
|
|
|
|
|
await _zmq_socket.send_json(request)
|
|
|
|
|
response = await _zmq_socket.recv_json()
|
|
|
|
|
return response
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
@router.post("/train")
|
|
|
|
|
async def handle_train(request: TrainRequest):
|
|
|
|
|
"""Handle training request - forwards to training subprocess."""
|
2026-04-16 00:48:05 -04:00
|
|
|
try:
|
|
|
|
|
_ensure_initialized()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": f"Training not available: {e}"},
|
|
|
|
|
status_code=503,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
training_data = request.training_data
|
|
|
|
|
samples = training_data.get("samples", [])
|
|
|
|
|
config = training_data.get("config", {})
|
|
|
|
|
|
|
|
|
|
if not samples:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": "No training samples provided"},
|
|
|
|
|
status_code=400,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
# Forward to worker
|
|
|
|
|
response = await _send_request({
|
|
|
|
|
'type': 'train',
|
|
|
|
|
'samples': samples,
|
|
|
|
|
'config': config,
|
|
|
|
|
})
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
if 'error' in response:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": response['error']},
|
|
|
|
|
status_code=500,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.info(
|
|
|
|
|
f"Training job {job_id} completed, "
|
|
|
|
|
f"final loss: {response['loss_history'][-1]:.4f}"
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"job_id": job_id,
|
2026-04-16 02:01:59 -04:00
|
|
|
"status": response['status'],
|
|
|
|
|
"training_samples": response['training_samples'],
|
|
|
|
|
"loss_history": response['loss_history'],
|
2026-04-16 00:48:05 -04:00
|
|
|
})
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
except zmq.Again:
|
|
|
|
|
logger.error("Training request timed out")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": "Training request timed out"},
|
|
|
|
|
status_code=504,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
except Exception as e:
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.exception(f"Training failed: {e}")
|
2026-04-16 00:48:05 -04:00
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": str(e)},
|
|
|
|
|
status_code=500,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
@router.post("/checkpoint")
|
|
|
|
|
async def handle_checkpoint():
|
|
|
|
|
"""Trigger checkpoint sync to disk."""
|
|
|
|
|
try:
|
|
|
|
|
_ensure_initialized()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": f"Training not available: {e}"},
|
|
|
|
|
status_code=503,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
try:
|
|
|
|
|
response = await _send_request({'type': 'checkpoint'})
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
if 'error' in response:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": response['error']},
|
|
|
|
|
status_code=500,
|
2026-04-16 00:48:05 -04:00
|
|
|
)
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
return JSONResponse(content=response)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Checkpoint failed: {e}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"error": str(e)},
|
|
|
|
|
status_code=500,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
@router.get("/train/status")
|
|
|
|
|
async def handle_status():
|
|
|
|
|
"""Get training worker status."""
|
|
|
|
|
try:
|
|
|
|
|
_ensure_initialized()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={
|
|
|
|
|
"status": "unavailable",
|
|
|
|
|
"error": str(e),
|
|
|
|
|
},
|
|
|
|
|
status_code=503,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
try:
|
|
|
|
|
response = await _send_request({'type': 'status'})
|
|
|
|
|
return JSONResponse(content=response)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
2026-04-16 02:01:59 -04:00
|
|
|
except Exception as e:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={
|
|
|
|
|
"status": "error",
|
|
|
|
|
"error": str(e),
|
|
|
|
|
},
|
|
|
|
|
status_code=500,
|
|
|
|
|
)
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def attach_router(app: FastAPI):
|
|
|
|
|
"""Attach training router to FastAPI app."""
|
|
|
|
|
app.include_router(router)
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.info("Training router attached")
|
2026-04-16 00:48:05 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _patch_api_server():
|
|
|
|
|
"""Patch vLLM's build_app to include our training router."""
|
|
|
|
|
from vllm.entrypoints.openai import api_server
|
|
|
|
|
|
|
|
|
|
original_build_app = api_server.build_app
|
|
|
|
|
|
|
|
|
|
def patched_build_app(*args, **kwargs):
|
|
|
|
|
app = original_build_app(*args, **kwargs)
|
|
|
|
|
attach_router(app)
|
|
|
|
|
return app
|
|
|
|
|
|
|
|
|
|
api_server.build_app = patched_build_app
|
2026-04-16 02:01:59 -04:00
|
|
|
logger.info("API server patched for /train endpoint")
|