"""Training endpoint for vLLM - runs Apollo training in-process. Patches vLLM's build_app() to add /train route. Training runs HOGWILD style - no pause needed, weights updated in-place while inference continues. """ import logging from datetime import datetime from typing import Any import torch import torch.nn as nn from fastapi import APIRouter, FastAPI, Request from fastapi.responses import JSONResponse from pydantic import BaseModel logger = logging.getLogger(__name__) router = APIRouter() class TrainingSample(BaseModel): context_ids: list[int] continuation_ids: list[int] class TrainRequest(BaseModel): training_data: dict[str, Any] # {"samples": [...], "config": {...}} class TrainResponse(BaseModel): job_id: str status: str training_samples: int loss_history: list[float] # Global reference to HF model with vLLM weight views _model: nn.Module | None = None _model_path: str | None = None _initialized: bool = False _optimizer: Any = None # Persisted Apollo optimizer OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" def _load_training_model() -> nn.Module: """Load HF model with weights pointing to vLLM's GPU memory. Uses CUDA IPC handles exported by export_hook to create an HF model whose parameters share GPU memory with vLLM's model. """ from .weight_mapping import load_hf_model_with_vllm_weights from .export_hook import HANDLE_PATH handles = torch.load(HANDLE_PATH, weights_only=False) vllm_params = {} for name, info in handles.items(): func, args = info['handle'] vllm_params[name] = func(*args) model = load_hf_model_with_vllm_weights(vllm_params, _model_path) model.train() return model def _ensure_initialized(): """Lazy-initialize the training model on first /train request.""" global _model, _initialized if _initialized: return if _model_path is None: raise RuntimeError("Model path not set - export_hook may not have run") logger.info("[apollo] Loading HF model with vLLM weight views...") _model = _load_training_model() _initialized = True logger.info("[apollo] Training model ready") def set_model_path(path: str): """Set model path for training. Called by export_hook after model load.""" global _model_path _model_path = path logger.info(f"[apollo] Model path set: {path}") @router.post("/train") async def handle_train(request: TrainRequest, raw_request: Request): """Handle training request - runs Apollo training on provided samples.""" global _model try: _ensure_initialized() except Exception as e: return JSONResponse( content={"error": f"Training not available: {e}"}, status_code=503, ) try: training_data = request.training_data samples = training_data.get("samples", []) config = training_data.get("config", {}) if not samples: return JSONResponse( content={"error": "No training samples provided"}, status_code=400, ) job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples") # Run training loss_history = await run_training(_model, samples, config) logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}") # Schedule checkpoint sync (batched, 10 min delay) schedule_checkpoint_sync() return JSONResponse(content={ "job_id": job_id, "status": "completed", "training_samples": len(samples), "loss_history": loss_history, }) except Exception as e: logger.exception(f"[apollo] Training failed: {e}") return JSONResponse( content={"error": str(e)}, status_code=500, ) def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]): """Get existing optimizer or create new one. Persists state between calls.""" global _optimizer from .optimizer import Apollo import os if _optimizer is not None: return _optimizer # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] for p in model.parameters(): if p.requires_grad: if p.ndim >= 2 and min(p.shape) >= 256: apollo_params.append(p) else: standard_params.append(p) groups = [] if apollo_params: groups.append({'params': apollo_params}) if standard_params: groups.append({'params': standard_params}) if not groups: raise ValueError("No trainable parameters found") # Create optimizer _optimizer = Apollo( groups, lr=config.get('lr', 1e-5), rank=config.get('rank', 256), betas=tuple(config.get('betas', (0.9, 0.999))), eps=config.get('eps', 1e-8), weight_decay=config.get('weight_decay', 0.01), warmup_steps=config.get('warmup_steps', 0), scale=config.get('scale'), proj_refresh=config.get('proj_refresh', 200), norm_growth_limit=config.get('norm_growth_limit', 1.01), ) # Restore state if exists if os.path.exists(OPTIMIZER_STATE_PATH): try: state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) _optimizer.load_state_dict(state) logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}") except Exception as e: logger.warning(f"[apollo] Could not restore optimizer state: {e}") logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " f"state={_optimizer.state_size_bytes()/1e6:.1f}MB") return _optimizer def _save_optimizer_state(): """Save optimizer state for persistence between /train calls.""" global _optimizer if _optimizer is not None: torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH) logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}") async def run_training( model: nn.Module, samples: list[dict[str, Any]], config: dict[str, Any], ) -> list[float]: """Run Apollo training on the given samples. Each sample has: context_ids: token IDs for frozen context (no gradients) continuation_ids: token IDs for the decision we're training on """ optimizer = _get_or_create_optimizer(model, config) loss_history = [] for i, sample in enumerate(samples): ctx_ids = sample['context_ids'] cont_ids = sample['continuation_ids'] all_ids = ctx_ids + cont_ids context_len = len(ctx_ids) input_ids = torch.tensor([all_ids], device='cuda:0') optimizer.zero_grad() # Context-frozen forward pass with torch.no_grad(): outputs = model(input_ids[:, :context_len], use_cache=True) past_kv = outputs.past_key_values # Decision tokens with gradients with torch.enable_grad(): outputs = model( input_ids[:, context_len:], past_key_values=past_kv, use_cache=False, ) logits = outputs.logits # Shift: predict next token from each position shift_logits = logits[:, :-1].contiguous() shift_labels = input_ids[:, context_len + 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) loss.backward() optimizer.step() loss_val = loss.item() loss_history.append(loss_val) logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} " f"(ctx={context_len}, cont={len(cont_ids)} tokens)") return loss_history # Checkpoint sync scheduling _checkpoint_task = None CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes def schedule_checkpoint_sync(): """Schedule checkpoint sync after delay (batched).""" global _checkpoint_task import asyncio if _checkpoint_task is not None: # Already scheduled return async def do_sync(): global _checkpoint_task try: await asyncio.sleep(CHECKPOINT_DELAY_SECS) if _model_path: from .checkpoint_sync import checkpoint_sync logger.info("[apollo] Starting checkpoint sync...") # Save optimizer state alongside model weights _save_optimizer_state() result = checkpoint_sync(_model_path) logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB") except Exception as e: logger.error(f"[apollo] Checkpoint sync failed: {e}") finally: _checkpoint_task = None _checkpoint_task = asyncio.create_task(do_sync()) logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min") def attach_router(app: FastAPI): """Attach training router to FastAPI app.""" app.include_router(router) logger.info("[apollo] Training router attached") def _patch_api_server(): """Patch vLLM's build_app to include our training router.""" from vllm.entrypoints.openai import api_server original_build_app = api_server.build_app def patched_build_app(*args, **kwargs): app = original_build_app(*args, **kwargs) attach_router(app) return app api_server.build_app = patched_build_app logger.info("[apollo] API server patched for /train endpoint")