training: move to dedicated subprocess with ZMQ communication
- Add training_worker.py: long-lived subprocess that handles GPU training
work, owns HF model wrapper (views into vLLM GPU memory), Apollo
optimizer, and checkpoint sync
- train_router.py: now forwards /train requests via async ZMQ instead of
running training in-process. Adds /checkpoint and /train/status endpoints
- export_hook.py: store model_path in __metadata__ so training worker can
find it without cross-process communication
- This fixes two bugs:
1. Process boundary issue - model_path was set in worker process but
needed in API server process
2. Blocking event loop - training blocked vLLM's async event loop
Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
68a2df2185
commit
2c6a5c0f4a
6 changed files with 503 additions and 233 deletions
|
|
@ -1,16 +1,23 @@
|
|||
"""Training endpoint for vLLM - runs Apollo training in-process.
|
||||
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
|
||||
|
||||
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
|
||||
style - no pause needed, weights updated in-place while inference continues.
|
||||
Patches vLLM's build_app() to add /train route. The actual training runs
|
||||
in a dedicated subprocess (training_worker.py) to avoid blocking the
|
||||
event loop and to keep training work isolated from vLLM internals.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -18,10 +25,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||
|
||||
class TrainingSample(BaseModel):
|
||||
context_ids: list[int]
|
||||
continuation_ids: list[int]
|
||||
# Global state for subprocess management
|
||||
_worker_process: subprocess.Popen | None = None
|
||||
_zmq_context: zmq.asyncio.Context | None = None
|
||||
_zmq_socket: zmq.asyncio.Socket | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
|
|
@ -35,64 +45,61 @@ class TrainResponse(BaseModel):
|
|||
loss_history: list[float]
|
||||
|
||||
|
||||
# Global reference to HF model with vLLM weight views
|
||||
_model: nn.Module | None = None
|
||||
_model_path: str | None = None
|
||||
_initialized: bool = False
|
||||
_optimizer: Any = None # Persisted Apollo optimizer
|
||||
def _start_worker_subprocess():
|
||||
"""Start the training worker subprocess."""
|
||||
global _worker_process
|
||||
|
||||
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||
DEFAULT_RANK = 64
|
||||
if _worker_process is not None and _worker_process.poll() is None:
|
||||
return # Still running
|
||||
|
||||
# Start worker as subprocess using script path
|
||||
worker_script = Path(__file__).parent / 'training_worker.py'
|
||||
_worker_process = subprocess.Popen(
|
||||
[sys.executable, str(worker_script)],
|
||||
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
|
||||
)
|
||||
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
|
||||
|
||||
def _load_training_model() -> nn.Module:
|
||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
Uses CUDA IPC handles exported by export_hook to create an HF model
|
||||
whose parameters share GPU memory with vLLM's model.
|
||||
"""
|
||||
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||
from .export_hook import HANDLE_PATH
|
||||
|
||||
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
|
||||
model.train()
|
||||
return model
|
||||
# Give it a moment to bind the socket
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _ensure_initialized():
|
||||
"""Lazy-initialize the training model on first /train request."""
|
||||
global _model, _initialized
|
||||
"""Ensure subprocess is running and ZMQ socket is connected."""
|
||||
global _zmq_context, _zmq_socket, _initialized
|
||||
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
if _model_path is None:
|
||||
raise RuntimeError("Model path not set - export_hook may not have run")
|
||||
# Start worker if needed
|
||||
_start_worker_subprocess()
|
||||
|
||||
# Create async ZMQ context and socket
|
||||
_zmq_context = zmq.asyncio.Context()
|
||||
_zmq_socket = _zmq_context.socket(zmq.REQ)
|
||||
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
|
||||
|
||||
# Set timeout for recv
|
||||
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
|
||||
|
||||
logger.info("[apollo] Loading HF model with vLLM weight views...")
|
||||
_model = _load_training_model()
|
||||
_initialized = True
|
||||
logger.info("[apollo] Training model ready")
|
||||
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
|
||||
|
||||
|
||||
def set_model_path(path: str):
|
||||
"""Set model path for training. Called by export_hook after model load."""
|
||||
global _model_path
|
||||
_model_path = path
|
||||
logger.info(f"[apollo] Model path set: {path}")
|
||||
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send request to worker and wait for response."""
|
||||
_ensure_initialized()
|
||||
|
||||
# ZMQ async send/recv
|
||||
await _zmq_socket.send_json(request)
|
||||
response = await _zmq_socket.recv_json()
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/train")
|
||||
async def handle_train(request: TrainRequest, raw_request: Request):
|
||||
"""Handle training request - runs Apollo training on provided samples."""
|
||||
global _model
|
||||
|
||||
async def handle_train(request: TrainRequest):
|
||||
"""Handle training request - forwards to training subprocess."""
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
|
|
@ -113,193 +120,109 @@ async def handle_train(request: TrainRequest, raw_request: Request):
|
|||
)
|
||||
|
||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
|
||||
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
|
||||
|
||||
# Run training
|
||||
loss_history = await run_training(_model, samples, config)
|
||||
# Forward to worker
|
||||
response = await _send_request({
|
||||
'type': 'train',
|
||||
'samples': samples,
|
||||
'config': config,
|
||||
})
|
||||
|
||||
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
|
||||
if 'error' in response:
|
||||
return JSONResponse(
|
||||
content={"error": response['error']},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Schedule checkpoint sync (batched, 10 min delay)
|
||||
schedule_checkpoint_sync()
|
||||
logger.info(
|
||||
f"Training job {job_id} completed, "
|
||||
f"final loss: {response['loss_history'][-1]:.4f}"
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"job_id": job_id,
|
||||
"status": "completed",
|
||||
"training_samples": len(samples),
|
||||
"loss_history": loss_history,
|
||||
"status": response['status'],
|
||||
"training_samples": response['training_samples'],
|
||||
"loss_history": response['loss_history'],
|
||||
})
|
||||
|
||||
except zmq.Again:
|
||||
logger.error("Training request timed out")
|
||||
return JSONResponse(
|
||||
content={"error": "Training request timed out"},
|
||||
status_code=504,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"[apollo] Training failed: {e}")
|
||||
logger.exception(f"Training failed: {e}")
|
||||
return JSONResponse(
|
||||
content={"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
|
||||
"""Get existing optimizer or create new one. Persists state between calls."""
|
||||
global _optimizer
|
||||
from .optimizer import Apollo
|
||||
import os
|
||||
@router.post("/checkpoint")
|
||||
async def handle_checkpoint():
|
||||
"""Trigger checkpoint sync to disk."""
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={"error": f"Training not available: {e}"},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
if _optimizer is not None:
|
||||
return _optimizer
|
||||
try:
|
||||
response = await _send_request({'type': 'checkpoint'})
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
# Create optimizer
|
||||
_optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', DEFAULT_RANK),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'),
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
# Restore state if exists
|
||||
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||
try:
|
||||
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||
_optimizer.load_state_dict(state)
|
||||
logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[apollo] Could not restore optimizer state: {e}")
|
||||
|
||||
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={_optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
return _optimizer
|
||||
|
||||
|
||||
def _save_optimizer_state():
|
||||
"""Save optimizer state for persistence between /train calls."""
|
||||
global _optimizer
|
||||
if _optimizer is not None:
|
||||
torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||
logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||
|
||||
|
||||
async def run_training(
|
||||
model: nn.Module,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
optimizer = _get_or_create_optimizer(model, config)
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
if 'error' in response:
|
||||
return JSONResponse(
|
||||
content={"error": response['error']},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return JSONResponse(content=response)
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||
|
||||
return loss_history
|
||||
except Exception as e:
|
||||
logger.exception(f"Checkpoint failed: {e}")
|
||||
return JSONResponse(
|
||||
content={"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
# Checkpoint sync scheduling
|
||||
_checkpoint_task = None
|
||||
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||
@router.get("/train/status")
|
||||
async def handle_status():
|
||||
"""Get training worker status."""
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "unavailable",
|
||||
"error": str(e),
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await _send_request({'type': 'status'})
|
||||
return JSONResponse(content=response)
|
||||
|
||||
def schedule_checkpoint_sync():
|
||||
"""Schedule checkpoint sync after delay (batched)."""
|
||||
global _checkpoint_task
|
||||
import asyncio
|
||||
|
||||
if _checkpoint_task is not None:
|
||||
# Already scheduled
|
||||
return
|
||||
|
||||
async def do_sync():
|
||||
global _checkpoint_task
|
||||
try:
|
||||
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||
if _model_path:
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
logger.info("[apollo] Starting checkpoint sync...")
|
||||
# Save optimizer state alongside model weights
|
||||
_save_optimizer_state()
|
||||
result = checkpoint_sync(_model_path)
|
||||
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
||||
except Exception as e:
|
||||
logger.error(f"[apollo] Checkpoint sync failed: {e}")
|
||||
finally:
|
||||
_checkpoint_task = None
|
||||
|
||||
_checkpoint_task = asyncio.create_task(do_sync())
|
||||
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach training router to FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("[apollo] Training router attached")
|
||||
logger.info("Training router attached")
|
||||
|
||||
|
||||
def _patch_api_server():
|
||||
|
|
@ -314,4 +237,4 @@ def _patch_api_server():
|
|||
return app
|
||||
|
||||
api_server.build_app = patched_build_app
|
||||
logger.info("[apollo] API server patched for /train endpoint")
|
||||
logger.info("API server patched for /train endpoint")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue