training: integrate /train into vLLM process (no separate daemon)

Remove standalone worker.py daemon. Training now runs inside vLLM:

- train_router.py: FastAPI router patched into vLLM's build_app()
- /train served on same port as /completions, /score
- Lazy-loads HF model with vLLM weight views on first request
- HOGWILD training: no pause, weights updated in-place

The previous architecture had a separate daemon on port 8080 that
communicated with vLLM via pause/resume endpoints. This was wrong -
training should run in-process, sharing GPU memory directly.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 00:48:05 -04:00
parent 2f08149fab
commit 7e7e9a4b69
6 changed files with 320 additions and 542 deletions

View file

@ -0,0 +1,282 @@
"""Training endpoint for vLLM - runs Apollo training in-process.
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
style - no pause needed, weights updated in-place while inference continues.
"""
import logging
from datetime import datetime
from typing import Any
import torch
import torch.nn as nn
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter()
class TrainingSample(BaseModel):
context_ids: list[int]
continuation_ids: list[int]
class TrainRequest(BaseModel):
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
class TrainResponse(BaseModel):
job_id: str
status: str
training_samples: int
loss_history: list[float]
# Global reference to HF model with vLLM weight views
_model: nn.Module | None = None
_model_path: str | None = None
_initialized: bool = False
def _load_training_model() -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Uses CUDA IPC handles exported by export_hook to create an HF model
whose parameters share GPU memory with vLLM's model.
"""
from .weight_mapping import load_hf_model_with_vllm_weights
from .export_hook import HANDLE_PATH
handles = torch.load(HANDLE_PATH, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
model.train()
return model
def _ensure_initialized():
"""Lazy-initialize the training model on first /train request."""
global _model, _initialized
if _initialized:
return
if _model_path is None:
raise RuntimeError("Model path not set - export_hook may not have run")
logger.info("[apollo] Loading HF model with vLLM weight views...")
_model = _load_training_model()
_initialized = True
logger.info("[apollo] Training model ready")
def set_model_path(path: str):
"""Set model path for training. Called by export_hook after model load."""
global _model_path
_model_path = path
logger.info(f"[apollo] Model path set: {path}")
@router.post("/train")
async def handle_train(request: TrainRequest, raw_request: Request):
"""Handle training request - runs Apollo training on provided samples."""
global _model
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
training_data = request.training_data
samples = training_data.get("samples", [])
config = training_data.get("config", {})
if not samples:
return JSONResponse(
content={"error": "No training samples provided"},
status_code=400,
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
# Run training
loss_history = await run_training(_model, samples, config)
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
# Schedule checkpoint sync (batched, 10 min delay)
schedule_checkpoint_sync()
return JSONResponse(content={
"job_id": job_id,
"status": "completed",
"training_samples": len(samples),
"loss_history": loss_history,
})
except Exception as e:
logger.exception(f"[apollo] Training failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
async def run_training(
model: nn.Module,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
from .optimizer import Apollo
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 256:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
# Apollo settings from request config
optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', 256),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
return loss_history
# Checkpoint sync scheduling
_checkpoint_task = None
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
def schedule_checkpoint_sync():
"""Schedule checkpoint sync after delay (batched)."""
global _checkpoint_task
import asyncio
if _checkpoint_task is not None:
# Already scheduled
return
async def do_sync():
global _checkpoint_task
try:
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
if _model_path:
from .checkpoint_sync import checkpoint_sync
logger.info("[apollo] Starting checkpoint sync...")
result = checkpoint_sync(_model_path)
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
except Exception as e:
logger.error(f"[apollo] Checkpoint sync failed: {e}")
finally:
_checkpoint_task = None
_checkpoint_task = asyncio.create_task(do_sync())
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
def attach_router(app: FastAPI):
"""Attach training router to FastAPI app."""
app.include_router(router)
logger.info("[apollo] Training router attached")
def _patch_api_server():
"""Patch vLLM's build_app to include our training router."""
from vllm.entrypoints.openai import api_server
original_build_app = api_server.build_app
def patched_build_app(*args, **kwargs):
app = original_build_app(*args, **kwargs)
attach_router(app)
return app
api_server.build_app = patched_build_app
logger.info("[apollo] API server patched for /train endpoint")