Remove standalone worker.py daemon. Training now runs inside vLLM: - train_router.py: FastAPI router patched into vLLM's build_app() - /train served on same port as /completions, /score - Lazy-loads HF model with vLLM weight views on first request - HOGWILD training: no pause, weights updated in-place The previous architecture had a separate daemon on port 8080 that communicated with vLLM via pause/resume endpoints. This was wrong - training should run in-process, sharing GPU memory directly. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
282 lines
8.4 KiB
Python
282 lines
8.4 KiB
Python
"""Training endpoint for vLLM - runs Apollo training in-process.
|
|
|
|
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
|
|
style - no pause needed, weights updated in-place while inference continues.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from fastapi import APIRouter, FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class TrainingSample(BaseModel):
|
|
context_ids: list[int]
|
|
continuation_ids: list[int]
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
|
|
|
|
|
|
class TrainResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
training_samples: int
|
|
loss_history: list[float]
|
|
|
|
|
|
# Global reference to HF model with vLLM weight views
|
|
_model: nn.Module | None = None
|
|
_model_path: str | None = None
|
|
_initialized: bool = False
|
|
|
|
|
|
def _load_training_model() -> nn.Module:
|
|
"""Load HF model with weights pointing to vLLM's GPU memory.
|
|
|
|
Uses CUDA IPC handles exported by export_hook to create an HF model
|
|
whose parameters share GPU memory with vLLM's model.
|
|
"""
|
|
from .weight_mapping import load_hf_model_with_vllm_weights
|
|
from .export_hook import HANDLE_PATH
|
|
|
|
handles = torch.load(HANDLE_PATH, weights_only=False)
|
|
vllm_params = {}
|
|
for name, info in handles.items():
|
|
func, args = info['handle']
|
|
vllm_params[name] = func(*args)
|
|
|
|
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
|
|
model.train()
|
|
return model
|
|
|
|
|
|
def _ensure_initialized():
|
|
"""Lazy-initialize the training model on first /train request."""
|
|
global _model, _initialized
|
|
|
|
if _initialized:
|
|
return
|
|
|
|
if _model_path is None:
|
|
raise RuntimeError("Model path not set - export_hook may not have run")
|
|
|
|
logger.info("[apollo] Loading HF model with vLLM weight views...")
|
|
_model = _load_training_model()
|
|
_initialized = True
|
|
logger.info("[apollo] Training model ready")
|
|
|
|
|
|
def set_model_path(path: str):
|
|
"""Set model path for training. Called by export_hook after model load."""
|
|
global _model_path
|
|
_model_path = path
|
|
logger.info(f"[apollo] Model path set: {path}")
|
|
|
|
|
|
@router.post("/train")
|
|
async def handle_train(request: TrainRequest, raw_request: Request):
|
|
"""Handle training request - runs Apollo training on provided samples."""
|
|
global _model
|
|
|
|
try:
|
|
_ensure_initialized()
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
content={"error": f"Training not available: {e}"},
|
|
status_code=503,
|
|
)
|
|
|
|
try:
|
|
training_data = request.training_data
|
|
samples = training_data.get("samples", [])
|
|
config = training_data.get("config", {})
|
|
|
|
if not samples:
|
|
return JSONResponse(
|
|
content={"error": "No training samples provided"},
|
|
status_code=400,
|
|
)
|
|
|
|
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
|
|
|
|
# Run training
|
|
loss_history = await run_training(_model, samples, config)
|
|
|
|
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
|
|
|
|
# Schedule checkpoint sync (batched, 10 min delay)
|
|
schedule_checkpoint_sync()
|
|
|
|
return JSONResponse(content={
|
|
"job_id": job_id,
|
|
"status": "completed",
|
|
"training_samples": len(samples),
|
|
"loss_history": loss_history,
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.exception(f"[apollo] Training failed: {e}")
|
|
return JSONResponse(
|
|
content={"error": str(e)},
|
|
status_code=500,
|
|
)
|
|
|
|
|
|
async def run_training(
|
|
model: nn.Module,
|
|
samples: list[dict[str, Any]],
|
|
config: dict[str, Any],
|
|
) -> list[float]:
|
|
"""Run Apollo training on the given samples.
|
|
|
|
Each sample has:
|
|
context_ids: token IDs for frozen context (no gradients)
|
|
continuation_ids: token IDs for the decision we're training on
|
|
"""
|
|
from .optimizer import Apollo
|
|
|
|
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
|
apollo_params, standard_params = [], []
|
|
for p in model.parameters():
|
|
if p.requires_grad:
|
|
if p.ndim >= 2 and min(p.shape) >= 256:
|
|
apollo_params.append(p)
|
|
else:
|
|
standard_params.append(p)
|
|
|
|
groups = []
|
|
if apollo_params:
|
|
groups.append({'params': apollo_params})
|
|
if standard_params:
|
|
groups.append({'params': standard_params})
|
|
|
|
if not groups:
|
|
raise ValueError("No trainable parameters found")
|
|
|
|
# Apollo settings from request config
|
|
optimizer = Apollo(
|
|
groups,
|
|
lr=config.get('lr', 1e-5),
|
|
rank=config.get('rank', 256),
|
|
betas=tuple(config.get('betas', (0.9, 0.999))),
|
|
eps=config.get('eps', 1e-8),
|
|
weight_decay=config.get('weight_decay', 0.01),
|
|
warmup_steps=config.get('warmup_steps', 0),
|
|
scale=config.get('scale'),
|
|
proj_refresh=config.get('proj_refresh', 200),
|
|
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
|
)
|
|
|
|
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
|
f"{len(standard_params)} standard, "
|
|
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
|
|
|
loss_history = []
|
|
|
|
for i, sample in enumerate(samples):
|
|
ctx_ids = sample['context_ids']
|
|
cont_ids = sample['continuation_ids']
|
|
all_ids = ctx_ids + cont_ids
|
|
context_len = len(ctx_ids)
|
|
|
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Context-frozen forward pass
|
|
with torch.no_grad():
|
|
outputs = model(input_ids[:, :context_len], use_cache=True)
|
|
past_kv = outputs.past_key_values
|
|
|
|
# Decision tokens with gradients
|
|
with torch.enable_grad():
|
|
outputs = model(
|
|
input_ids[:, context_len:],
|
|
past_key_values=past_kv,
|
|
use_cache=False,
|
|
)
|
|
logits = outputs.logits
|
|
|
|
# Shift: predict next token from each position
|
|
shift_logits = logits[:, :-1].contiguous()
|
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
|
|
loss = nn.functional.cross_entropy(
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1),
|
|
)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loss_val = loss.item()
|
|
loss_history.append(loss_val)
|
|
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
|
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
|
|
|
return loss_history
|
|
|
|
|
|
# Checkpoint sync scheduling
|
|
_checkpoint_task = None
|
|
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
|
|
|
|
|
def schedule_checkpoint_sync():
|
|
"""Schedule checkpoint sync after delay (batched)."""
|
|
global _checkpoint_task
|
|
import asyncio
|
|
|
|
if _checkpoint_task is not None:
|
|
# Already scheduled
|
|
return
|
|
|
|
async def do_sync():
|
|
global _checkpoint_task
|
|
try:
|
|
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
|
if _model_path:
|
|
from .checkpoint_sync import checkpoint_sync
|
|
logger.info("[apollo] Starting checkpoint sync...")
|
|
result = checkpoint_sync(_model_path)
|
|
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
|
except Exception as e:
|
|
logger.error(f"[apollo] Checkpoint sync failed: {e}")
|
|
finally:
|
|
_checkpoint_task = None
|
|
|
|
_checkpoint_task = asyncio.create_task(do_sync())
|
|
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
|
|
|
|
|
|
def attach_router(app: FastAPI):
|
|
"""Attach training router to FastAPI app."""
|
|
app.include_router(router)
|
|
logger.info("[apollo] Training router attached")
|
|
|
|
|
|
def _patch_api_server():
|
|
"""Patch vLLM's build_app to include our training router."""
|
|
from vllm.entrypoints.openai import api_server
|
|
|
|
original_build_app = api_server.build_app
|
|
|
|
def patched_build_app(*args, **kwargs):
|
|
app = original_build_app(*args, **kwargs)
|
|
attach_router(app)
|
|
return app
|
|
|
|
api_server.build_app = patched_build_app
|
|
logger.info("[apollo] API server patched for /train endpoint")
|