training: integrate /train into vLLM process (no separate daemon)
Remove standalone worker.py daemon. Training now runs inside vLLM: - train_router.py: FastAPI router patched into vLLM's build_app() - /train served on same port as /completions, /score - Lazy-loads HF model with vLLM weight views on first request - HOGWILD training: no pause, weights updated in-place The previous architecture had a separate daemon on port 8080 that communicated with vLLM via pause/resume endpoints. This was wrong - training should run in-process, sharing GPU memory directly. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
2f08149fab
commit
7e7e9a4b69
6 changed files with 320 additions and 542 deletions
282
training/apollo_plugin/train_router.py
Normal file
282
training/apollo_plugin/train_router.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
"""Training endpoint for vLLM - runs Apollo training in-process.
|
||||
|
||||
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
|
||||
style - no pause needed, weights updated in-place while inference continues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TrainingSample(BaseModel):
|
||||
context_ids: list[int]
|
||||
continuation_ids: list[int]
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
|
||||
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
training_samples: int
|
||||
loss_history: list[float]
|
||||
|
||||
|
||||
# Global reference to HF model with vLLM weight views
|
||||
_model: nn.Module | None = None
|
||||
_model_path: str | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
def _load_training_model() -> nn.Module:
|
||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
Uses CUDA IPC handles exported by export_hook to create an HF model
|
||||
whose parameters share GPU memory with vLLM's model.
|
||||
"""
|
||||
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||
from .export_hook import HANDLE_PATH
|
||||
|
||||
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
|
||||
model.train()
|
||||
return model
|
||||
|
||||
|
||||
def _ensure_initialized():
|
||||
"""Lazy-initialize the training model on first /train request."""
|
||||
global _model, _initialized
|
||||
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
if _model_path is None:
|
||||
raise RuntimeError("Model path not set - export_hook may not have run")
|
||||
|
||||
logger.info("[apollo] Loading HF model with vLLM weight views...")
|
||||
_model = _load_training_model()
|
||||
_initialized = True
|
||||
logger.info("[apollo] Training model ready")
|
||||
|
||||
|
||||
def set_model_path(path: str):
|
||||
"""Set model path for training. Called by export_hook after model load."""
|
||||
global _model_path
|
||||
_model_path = path
|
||||
logger.info(f"[apollo] Model path set: {path}")
|
||||
|
||||
|
||||
@router.post("/train")
|
||||
async def handle_train(request: TrainRequest, raw_request: Request):
|
||||
"""Handle training request - runs Apollo training on provided samples."""
|
||||
global _model
|
||||
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={"error": f"Training not available: {e}"},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
try:
|
||||
training_data = request.training_data
|
||||
samples = training_data.get("samples", [])
|
||||
config = training_data.get("config", {})
|
||||
|
||||
if not samples:
|
||||
return JSONResponse(
|
||||
content={"error": "No training samples provided"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
|
||||
|
||||
# Run training
|
||||
loss_history = await run_training(_model, samples, config)
|
||||
|
||||
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
|
||||
|
||||
# Schedule checkpoint sync (batched, 10 min delay)
|
||||
schedule_checkpoint_sync()
|
||||
|
||||
return JSONResponse(content={
|
||||
"job_id": job_id,
|
||||
"status": "completed",
|
||||
"training_samples": len(samples),
|
||||
"loss_history": loss_history,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[apollo] Training failed: {e}")
|
||||
return JSONResponse(
|
||||
content={"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
async def run_training(
|
||||
model: nn.Module,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
from .optimizer import Apollo
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= 256:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
# Apollo settings from request config
|
||||
optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', 256),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'),
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||
|
||||
return loss_history
|
||||
|
||||
|
||||
# Checkpoint sync scheduling
|
||||
_checkpoint_task = None
|
||||
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
def schedule_checkpoint_sync():
|
||||
"""Schedule checkpoint sync after delay (batched)."""
|
||||
global _checkpoint_task
|
||||
import asyncio
|
||||
|
||||
if _checkpoint_task is not None:
|
||||
# Already scheduled
|
||||
return
|
||||
|
||||
async def do_sync():
|
||||
global _checkpoint_task
|
||||
try:
|
||||
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||
if _model_path:
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
logger.info("[apollo] Starting checkpoint sync...")
|
||||
result = checkpoint_sync(_model_path)
|
||||
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
||||
except Exception as e:
|
||||
logger.error(f"[apollo] Checkpoint sync failed: {e}")
|
||||
finally:
|
||||
_checkpoint_task = None
|
||||
|
||||
_checkpoint_task = asyncio.create_task(do_sync())
|
||||
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach training router to FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("[apollo] Training router attached")
|
||||
|
||||
|
||||
def _patch_api_server():
|
||||
"""Patch vLLM's build_app to include our training router."""
|
||||
from vllm.entrypoints.openai import api_server
|
||||
|
||||
original_build_app = api_server.build_app
|
||||
|
||||
def patched_build_app(*args, **kwargs):
|
||||
app = original_build_app(*args, **kwargs)
|
||||
attach_router(app)
|
||||
return app
|
||||
|
||||
api_server.build_app = patched_build_app
|
||||
logger.info("[apollo] API server patched for /train endpoint")
|
||||
Loading…
Add table
Add a link
Reference in a new issue