training: integrate /train into vLLM process (no separate daemon)
Remove standalone worker.py daemon. Training now runs inside vLLM: - train_router.py: FastAPI router patched into vLLM's build_app() - /train served on same port as /completions, /score - Lazy-loads HF model with vLLM weight views on first request - HOGWILD training: no pause, weights updated in-place The previous architecture had a separate daemon on port 8080 that communicated with vLLM via pause/resume endpoints. This was wrong - training should run in-process, sharing GPU memory directly. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
2f08149fab
commit
7e7e9a4b69
6 changed files with 320 additions and 542 deletions
|
|
@ -22,25 +22,29 @@ The training signal comes from two sources:
|
||||||
│ │
|
│ │
|
||||||
│ ┌──────────────────────────────────────────────┐ │
|
│ ┌──────────────────────────────────────────────┐ │
|
||||||
│ │ Model Weights (54GB, bf16) │ │
|
│ │ Model Weights (54GB, bf16) │ │
|
||||||
│ │ Shared via CUDA IPC │ │
|
│ │ Shared: vLLM inference + HF training │ │
|
||||||
│ └──────────────┬──────────────┬────────────────┘ │
|
│ └──────────────┬──────────────┬────────────────┘ │
|
||||||
│ │ │ │
|
│ │ │ │
|
||||||
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
|
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
|
||||||
│ │ vLLM (inference)│ │ Apollo (training) │ │
|
│ │ vLLM (inference)│ │ HF model (training) │ │
|
||||||
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
|
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
|
||||||
│ │ Serves requests │ │ Optimizer state ~10GB │ │
|
│ │ /completions │ │ Optimizer state ~10GB │ │
|
||||||
│ │ Never paused │ │ Activations ~10GB │ │
|
│ │ /score │ │ Views into vLLM weights │ │
|
||||||
|
│ │ /train ────────┼──┼─► Apollo optimizer │ │
|
||||||
│ └─────────────────┘ └────────────────────────┘ │
|
│ └─────────────────┘ └────────────────────────┘ │
|
||||||
└─────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
Moria B200
|
Single vLLM process serves everything
|
||||||
|
No separate daemon - /train is a vLLM route
|
||||||
|
|
||||||
|
Moria B200 (vLLM)
|
||||||
┌──────────────────┐ ┌──────────────────┐
|
┌──────────────────┐ ┌──────────────────┐
|
||||||
│ Training signal │ HTTP │ Apollo worker │
|
│ Training signal │ HTTP │ /completions │
|
||||||
│ agent │──────────>│ daemon │
|
│ agent │──────────>│ /score │
|
||||||
│ │ │ │
|
│ │ │ /train │
|
||||||
│ Dream loop │ │ Checkpoint sync │
|
│ Dream loop │ │ │
|
||||||
│ (generates │ │ (mmap + diff, │
|
│ (generates │ │ Checkpoint sync │
|
||||||
│ scenarios) │ │ every 10 min) │
|
│ scenarios) │ │ (10 min batched) │
|
||||||
└──────────────────┘ └──────────────────┘
|
└──────────────────┘ └──────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -220,34 +224,30 @@ a few hundred MB.
|
||||||
## Components
|
## Components
|
||||||
|
|
||||||
### Built ✓
|
### Built ✓
|
||||||
- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256)
|
- `optimizer.py` — Apollo optimizer (configurable rank, default 256)
|
||||||
- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking)
|
- `train_router.py` — /train endpoint, runs in vLLM process
|
||||||
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
|
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
|
||||||
- `training_example.py` — tokenization with chat template
|
- `export_hook.py` — vLLM plugin hook for IPC handle export
|
||||||
- `vllm_export_hook.py` — source patch for IPC handle export
|
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
|
||||||
- `checkpoint/` — Rust tool for mmap + diff checkpoint sync
|
|
||||||
|
|
||||||
### To build
|
### To build
|
||||||
- **Dream loop → training bridge**: connect dream output to Apollo
|
- **Dream loop → training bridge**: connect dream output to /train
|
||||||
- **Training-signal agent**: flags moments in conversation logs
|
- **Training-signal agent**: flags moments in conversation logs
|
||||||
- **Instruction stripping**: remove scaffolding from training examples
|
- **Instruction stripping**: remove scaffolding from training examples
|
||||||
- **Quality monitoring**: track model capability over time
|
- **Quality monitoring**: track model capability over time
|
||||||
- **HF model forward pass integration**: wire into apollo_worker
|
|
||||||
|
|
||||||
## Files
|
## Files
|
||||||
|
|
||||||
```
|
```
|
||||||
training/
|
training/
|
||||||
DESIGN.md — this document
|
DESIGN.md — this document
|
||||||
apollo_mini.py — Apollo optimizer
|
pyproject.toml — package config, vLLM plugin entry point
|
||||||
apollo_worker.py — HTTP training daemon
|
apollo_plugin/
|
||||||
|
__init__.py — plugin registration
|
||||||
|
export_hook.py — patches vLLM to export IPC handles
|
||||||
|
train_router.py — /train endpoint (FastAPI router)
|
||||||
|
optimizer.py — Apollo optimizer
|
||||||
weight_mapping.py — vLLM ↔ HF weight views
|
weight_mapping.py — vLLM ↔ HF weight views
|
||||||
training_example.py — tokenization helpers
|
checkpoint_sync.py — mmap + diff sync to safetensors
|
||||||
export_weights.py — standalone weight export (unused)
|
steering.py — steering vector extraction (experimental)
|
||||||
vllm_export_hook.py — vLLM source patch for IPC export
|
|
||||||
start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch)
|
|
||||||
train.py — standalone training script (alternative)
|
|
||||||
checkpoint/
|
|
||||||
Cargo.toml — Rust checkpoint tool
|
|
||||||
src/main.rs — mmap + diff sync
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""Apollo training plugin for vLLM.
|
"""Apollo training plugin for vLLM.
|
||||||
|
|
||||||
Enables continuous fine-tuning alongside live inference by:
|
Enables continuous fine-tuning alongside live inference by:
|
||||||
1. Exporting CUDA IPC handles for weight sharing
|
1. Exporting CUDA IPC handles for weight sharing (export_hook)
|
||||||
2. Providing a training worker daemon (/train endpoint)
|
2. Adding /train endpoint to vLLM's HTTP server (train_router)
|
||||||
3. Block-level checkpoint sync to safetensors files
|
3. Block-level checkpoint sync to safetensors files
|
||||||
|
|
||||||
Install: pip install -e /path/to/training
|
Install: pip install -e /path/to/training
|
||||||
|
|
@ -10,8 +10,10 @@ Then vLLM auto-loads via entry point.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .export_hook import _patch_model_runner
|
from .export_hook import _patch_model_runner
|
||||||
|
from .train_router import _patch_api_server
|
||||||
|
|
||||||
|
|
||||||
def register():
|
def register():
|
||||||
"""Called by vLLM's plugin loader on startup."""
|
"""Called by vLLM's plugin loader on startup."""
|
||||||
_patch_model_runner()
|
_patch_model_runner()
|
||||||
|
_patch_api_server()
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,10 @@ def _patch_model_runner():
|
||||||
result = original_load(self, *args, **kwargs)
|
result = original_load(self, *args, **kwargs)
|
||||||
try:
|
try:
|
||||||
export_model_weights(self.model_runner.model)
|
export_model_weights(self.model_runner.model)
|
||||||
|
# Set model path for training router
|
||||||
|
model_path = self.vllm_config.model_config.model
|
||||||
|
from .train_router import set_model_path
|
||||||
|
set_model_path(model_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[apollo] Failed to export weights: {e}")
|
print(f"[apollo] Failed to export weights: {e}")
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
282
training/apollo_plugin/train_router.py
Normal file
282
training/apollo_plugin/train_router.py
Normal file
|
|
@ -0,0 +1,282 @@
|
||||||
|
"""Training endpoint for vLLM - runs Apollo training in-process.
|
||||||
|
|
||||||
|
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
|
||||||
|
style - no pause needed, weights updated in-place while inference continues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSample(BaseModel):
|
||||||
|
context_ids: list[int]
|
||||||
|
continuation_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainRequest(BaseModel):
|
||||||
|
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
training_samples: int
|
||||||
|
loss_history: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
# Global reference to HF model with vLLM weight views
|
||||||
|
_model: nn.Module | None = None
|
||||||
|
_model_path: str | None = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _load_training_model() -> nn.Module:
|
||||||
|
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||||
|
|
||||||
|
Uses CUDA IPC handles exported by export_hook to create an HF model
|
||||||
|
whose parameters share GPU memory with vLLM's model.
|
||||||
|
"""
|
||||||
|
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||||
|
from .export_hook import HANDLE_PATH
|
||||||
|
|
||||||
|
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||||
|
vllm_params = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
vllm_params[name] = func(*args)
|
||||||
|
|
||||||
|
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
|
||||||
|
model.train()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_initialized():
|
||||||
|
"""Lazy-initialize the training model on first /train request."""
|
||||||
|
global _model, _initialized
|
||||||
|
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _model_path is None:
|
||||||
|
raise RuntimeError("Model path not set - export_hook may not have run")
|
||||||
|
|
||||||
|
logger.info("[apollo] Loading HF model with vLLM weight views...")
|
||||||
|
_model = _load_training_model()
|
||||||
|
_initialized = True
|
||||||
|
logger.info("[apollo] Training model ready")
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_path(path: str):
|
||||||
|
"""Set model path for training. Called by export_hook after model load."""
|
||||||
|
global _model_path
|
||||||
|
_model_path = path
|
||||||
|
logger.info(f"[apollo] Model path set: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train")
|
||||||
|
async def handle_train(request: TrainRequest, raw_request: Request):
|
||||||
|
"""Handle training request - runs Apollo training on provided samples."""
|
||||||
|
global _model
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ensure_initialized()
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Training not available: {e}"},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_data = request.training_data
|
||||||
|
samples = training_data.get("samples", [])
|
||||||
|
config = training_data.get("config", {})
|
||||||
|
|
||||||
|
if not samples:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": "No training samples provided"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
|
||||||
|
|
||||||
|
# Run training
|
||||||
|
loss_history = await run_training(_model, samples, config)
|
||||||
|
|
||||||
|
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
|
||||||
|
|
||||||
|
# Schedule checkpoint sync (batched, 10 min delay)
|
||||||
|
schedule_checkpoint_sync()
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "completed",
|
||||||
|
"training_samples": len(samples),
|
||||||
|
"loss_history": loss_history,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"[apollo] Training failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_training(
|
||||||
|
model: nn.Module,
|
||||||
|
samples: list[dict[str, Any]],
|
||||||
|
config: dict[str, Any],
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run Apollo training on the given samples.
|
||||||
|
|
||||||
|
Each sample has:
|
||||||
|
context_ids: token IDs for frozen context (no gradients)
|
||||||
|
continuation_ids: token IDs for the decision we're training on
|
||||||
|
"""
|
||||||
|
from .optimizer import Apollo
|
||||||
|
|
||||||
|
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||||
|
apollo_params, standard_params = [], []
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
if p.ndim >= 2 and min(p.shape) >= 256:
|
||||||
|
apollo_params.append(p)
|
||||||
|
else:
|
||||||
|
standard_params.append(p)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
if apollo_params:
|
||||||
|
groups.append({'params': apollo_params})
|
||||||
|
if standard_params:
|
||||||
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
|
if not groups:
|
||||||
|
raise ValueError("No trainable parameters found")
|
||||||
|
|
||||||
|
# Apollo settings from request config
|
||||||
|
optimizer = Apollo(
|
||||||
|
groups,
|
||||||
|
lr=config.get('lr', 1e-5),
|
||||||
|
rank=config.get('rank', 256),
|
||||||
|
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||||
|
eps=config.get('eps', 1e-8),
|
||||||
|
weight_decay=config.get('weight_decay', 0.01),
|
||||||
|
warmup_steps=config.get('warmup_steps', 0),
|
||||||
|
scale=config.get('scale'),
|
||||||
|
proj_refresh=config.get('proj_refresh', 200),
|
||||||
|
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
||||||
|
f"{len(standard_params)} standard, "
|
||||||
|
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||||
|
|
||||||
|
loss_history = []
|
||||||
|
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
ctx_ids = sample['context_ids']
|
||||||
|
cont_ids = sample['continuation_ids']
|
||||||
|
all_ids = ctx_ids + cont_ids
|
||||||
|
context_len = len(ctx_ids)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Context-frozen forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||||
|
past_kv = outputs.past_key_values
|
||||||
|
|
||||||
|
# Decision tokens with gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
outputs = model(
|
||||||
|
input_ids[:, context_len:],
|
||||||
|
past_key_values=past_kv,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Shift: predict next token from each position
|
||||||
|
shift_logits = logits[:, :-1].contiguous()
|
||||||
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||||
|
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_val = loss.item()
|
||||||
|
loss_history.append(loss_val)
|
||||||
|
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||||
|
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||||
|
|
||||||
|
return loss_history
|
||||||
|
|
||||||
|
|
||||||
|
# Checkpoint sync scheduling
|
||||||
|
_checkpoint_task = None
|
||||||
|
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_checkpoint_sync():
|
||||||
|
"""Schedule checkpoint sync after delay (batched)."""
|
||||||
|
global _checkpoint_task
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if _checkpoint_task is not None:
|
||||||
|
# Already scheduled
|
||||||
|
return
|
||||||
|
|
||||||
|
async def do_sync():
|
||||||
|
global _checkpoint_task
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||||
|
if _model_path:
|
||||||
|
from .checkpoint_sync import checkpoint_sync
|
||||||
|
logger.info("[apollo] Starting checkpoint sync...")
|
||||||
|
result = checkpoint_sync(_model_path)
|
||||||
|
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[apollo] Checkpoint sync failed: {e}")
|
||||||
|
finally:
|
||||||
|
_checkpoint_task = None
|
||||||
|
|
||||||
|
_checkpoint_task = asyncio.create_task(do_sync())
|
||||||
|
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
"""Attach training router to FastAPI app."""
|
||||||
|
app.include_router(router)
|
||||||
|
logger.info("[apollo] Training router attached")
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_api_server():
|
||||||
|
"""Patch vLLM's build_app to include our training router."""
|
||||||
|
from vllm.entrypoints.openai import api_server
|
||||||
|
|
||||||
|
original_build_app = api_server.build_app
|
||||||
|
|
||||||
|
def patched_build_app(*args, **kwargs):
|
||||||
|
app = original_build_app(*args, **kwargs)
|
||||||
|
attach_router(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
api_server.build_app = patched_build_app
|
||||||
|
logger.info("[apollo] API server patched for /train endpoint")
|
||||||
|
|
@ -1,509 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Apollo Mini Training Daemon
|
|
||||||
|
|
||||||
This daemon:
|
|
||||||
1. Listens over HTTPS for training requests from poc-agent
|
|
||||||
2. Pauses vLLM inference
|
|
||||||
3. Runs APOLLO-Mini training with torch.enable_grad()
|
|
||||||
4. Saves checkpoints and training metadata
|
|
||||||
5. Resumes vLLM inference
|
|
||||||
|
|
||||||
Communication protocol:
|
|
||||||
- POST /train: Start a training job
|
|
||||||
- GET /status/{job_id}: Check training status
|
|
||||||
- GET /checkpoints: List available checkpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Dict, Any, List
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger('apollo_worker')
|
|
||||||
|
|
||||||
class TrainingStatus(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
PAUSING_VLLM = "pausing_vllm"
|
|
||||||
TRAINING = "training"
|
|
||||||
SAVING_CHECKPOINT = "saving_checkpoint"
|
|
||||||
RESUMING_VLLM = "resuming_vllm"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingJob:
|
|
||||||
job_id: str
|
|
||||||
status: TrainingStatus
|
|
||||||
created_at: datetime
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
model_path: Optional[str] = None
|
|
||||||
checkpoint_path: Optional[str] = None
|
|
||||||
training_samples: int = 0
|
|
||||||
loss_history: List[float] = field(default_factory=list)
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
'job_id': self.job_id,
|
|
||||||
'status': self.status.value,
|
|
||||||
'created_at': self.created_at.isoformat(),
|
|
||||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
|
||||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
|
||||||
'model_path': self.model_path,
|
|
||||||
'checkpoint_path': self.checkpoint_path,
|
|
||||||
'training_samples': self.training_samples,
|
|
||||||
'loss_history': self.loss_history,
|
|
||||||
'error': self.error,
|
|
||||||
}
|
|
||||||
|
|
||||||
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
|
||||||
|
|
||||||
|
|
||||||
class ApolloWorker:
|
|
||||||
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
|
||||||
self.config = self._load_config(config_path)
|
|
||||||
self.jobs: Dict[str, TrainingJob] = {}
|
|
||||||
self.vllm_paused = False
|
|
||||||
self.app = web.Application()
|
|
||||||
self._setup_routes()
|
|
||||||
self._checkpoint_timer: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
|
||||||
"""Load configuration from file or use defaults."""
|
|
||||||
default_config = {
|
|
||||||
'host': '0.0.0.0',
|
|
||||||
'port': 8080,
|
|
||||||
'vllm_socket': '/tmp/vllm_control.sock',
|
|
||||||
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
|
|
||||||
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
|
|
||||||
'max_training_samples': 100,
|
|
||||||
'learning_rate': 1e-5,
|
|
||||||
'batch_size': 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.exists(config_path):
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
user_config = json.load(f)
|
|
||||||
default_config.update(user_config)
|
|
||||||
|
|
||||||
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
|
|
||||||
return default_config
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
"""Setup HTTP routes."""
|
|
||||||
self.app.router.add_post('/train', self.handle_train_request)
|
|
||||||
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
|
|
||||||
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
|
|
||||||
self.app.router.add_get('/health', self.handle_health_check)
|
|
||||||
|
|
||||||
async def handle_health_check(self, request: web.Request) -> web.Response:
|
|
||||||
"""Health check endpoint."""
|
|
||||||
return web.json_response({
|
|
||||||
'status': 'healthy',
|
|
||||||
'vllm_paused': self.vllm_paused,
|
|
||||||
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
|
|
||||||
})
|
|
||||||
|
|
||||||
async def handle_train_request(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle training request from poc-agent."""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Validate required fields
|
|
||||||
if 'training_data' not in data:
|
|
||||||
return web.json_response(
|
|
||||||
{'error': 'Missing training_data field'},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
|
|
||||||
job = TrainingJob(
|
|
||||||
job_id=job_id,
|
|
||||||
status=TrainingStatus.PENDING,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
model_path=self.config['model_path']
|
|
||||||
)
|
|
||||||
self.jobs[job_id] = job
|
|
||||||
|
|
||||||
# Start training in background
|
|
||||||
asyncio.create_task(self.execute_training(job, data))
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'job_id': job_id,
|
|
||||||
'status': 'accepted',
|
|
||||||
'message': 'Training job started'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling train request: {e}")
|
|
||||||
return web.json_response(
|
|
||||||
{'error': str(e)},
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_status_request(self, request: web.Request) -> web.Response:
|
|
||||||
"""Get training job status."""
|
|
||||||
job_id = request.match_info['job_id']
|
|
||||||
|
|
||||||
if job_id not in self.jobs:
|
|
||||||
return web.json_response(
|
|
||||||
{'error': 'Job not found'},
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
job = self.jobs[job_id]
|
|
||||||
return web.json_response(job.to_dict())
|
|
||||||
|
|
||||||
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
|
|
||||||
"""List available checkpoints."""
|
|
||||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
|
||||||
checkpoints = []
|
|
||||||
|
|
||||||
if checkpoint_dir.exists():
|
|
||||||
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
|
|
||||||
checkpoints.append({
|
|
||||||
'filename': checkpoint_file.name,
|
|
||||||
'path': str(checkpoint_file),
|
|
||||||
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
|
|
||||||
'size': checkpoint_file.stat().st_size
|
|
||||||
})
|
|
||||||
|
|
||||||
return web.json_response({'checkpoints': checkpoints})
|
|
||||||
|
|
||||||
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
|
|
||||||
"""Execute the training pipeline."""
|
|
||||||
try:
|
|
||||||
logger.info(f"Starting training job {job.job_id}")
|
|
||||||
job.started_at = datetime.now()
|
|
||||||
|
|
||||||
# Step 1: Pause vLLM
|
|
||||||
job.status = TrainingStatus.PAUSING_VLLM
|
|
||||||
logger.info("Pausing vLLM...")
|
|
||||||
await self.pause_vllm()
|
|
||||||
self.vllm_paused = True
|
|
||||||
|
|
||||||
# Step 2: Load model and prepare for training
|
|
||||||
job.status = TrainingStatus.TRAINING
|
|
||||||
logger.info("Loading model and preparing for training...")
|
|
||||||
|
|
||||||
# Load model (this would be the actual Qwen3.5-27B model)
|
|
||||||
# For now, we'll use a placeholder
|
|
||||||
model = await self.load_model_for_training()
|
|
||||||
|
|
||||||
# Step 3: Run APOLLO-Mini training
|
|
||||||
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
|
|
||||||
|
|
||||||
# Extract training samples
|
|
||||||
samples = training_data['samples']
|
|
||||||
job.training_samples = len(samples)
|
|
||||||
|
|
||||||
# Run training loop
|
|
||||||
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
|
|
||||||
job.loss_history = loss_history
|
|
||||||
|
|
||||||
# Step 4: Save checkpoint
|
|
||||||
job.status = TrainingStatus.SAVING_CHECKPOINT
|
|
||||||
logger.info("Saving checkpoint...")
|
|
||||||
checkpoint_path = await self.save_checkpoint(model, job)
|
|
||||||
job.checkpoint_path = checkpoint_path
|
|
||||||
|
|
||||||
# Step 5: Resume vLLM
|
|
||||||
job.status = TrainingStatus.RESUMING_VLLM
|
|
||||||
logger.info("Resuming vLLM...")
|
|
||||||
await self.resume_vllm()
|
|
||||||
self.vllm_paused = False
|
|
||||||
|
|
||||||
# Mark job as completed
|
|
||||||
job.status = TrainingStatus.COMPLETED
|
|
||||||
job.completed_at = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"Training job {job.job_id} completed successfully")
|
|
||||||
|
|
||||||
# Schedule checkpoint sync (batched — won't duplicate if timer pending)
|
|
||||||
self.schedule_checkpoint_sync()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Training job {job.job_id} failed: {e}")
|
|
||||||
job.status = TrainingStatus.FAILED
|
|
||||||
job.error = str(e)
|
|
||||||
job.completed_at = datetime.now()
|
|
||||||
|
|
||||||
# Try to resume vLLM if it was paused
|
|
||||||
if self.vllm_paused:
|
|
||||||
try:
|
|
||||||
await self.resume_vllm()
|
|
||||||
self.vllm_paused = False
|
|
||||||
except Exception as resume_error:
|
|
||||||
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
|
|
||||||
|
|
||||||
async def pause_vllm(self):
|
|
||||||
"""Pause vLLM inference via HTTP API."""
|
|
||||||
import aiohttp as aio
|
|
||||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
|
||||||
try:
|
|
||||||
async with aio.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/pause_generation",
|
|
||||||
json={"mode": "keep", "clear_cache": False},
|
|
||||||
timeout=aio.ClientTimeout(total=10),
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
logger.info("vLLM paused")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to pause vLLM: {e}")
|
|
||||||
|
|
||||||
async def resume_vllm(self):
|
|
||||||
"""Resume vLLM inference via HTTP API."""
|
|
||||||
import aiohttp as aio
|
|
||||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
|
||||||
try:
|
|
||||||
async with aio.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/resume_generation",
|
|
||||||
timeout=aio.ClientTimeout(total=10),
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
logger.info("vLLM resumed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to resume vLLM: {e}")
|
|
||||||
|
|
||||||
def schedule_checkpoint_sync(self):
|
|
||||||
"""Schedule a checkpoint sync in 10 minutes, if not already scheduled.
|
|
||||||
|
|
||||||
This batches multiple training runs into a single sync — the timer
|
|
||||||
resets only when no timer is pending.
|
|
||||||
"""
|
|
||||||
if self._checkpoint_timer is not None:
|
|
||||||
logger.debug("Checkpoint sync already scheduled, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay())
|
|
||||||
logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes")
|
|
||||||
|
|
||||||
async def _checkpoint_sync_after_delay(self):
|
|
||||||
"""Wait then sync — the actual timer task."""
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
|
||||||
await self._do_checkpoint_sync()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug("Checkpoint sync cancelled")
|
|
||||||
finally:
|
|
||||||
self._checkpoint_timer = None
|
|
||||||
|
|
||||||
async def _do_checkpoint_sync(self):
|
|
||||||
"""Execute the checkpoint sync."""
|
|
||||||
try:
|
|
||||||
from apollo_plugin.checkpoint_sync import checkpoint_sync
|
|
||||||
logger.info("Starting checkpoint sync...")
|
|
||||||
result = checkpoint_sync(
|
|
||||||
self.config['model_path'],
|
|
||||||
self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'),
|
|
||||||
)
|
|
||||||
changed_mb = result['total_changed'] / 1e6
|
|
||||||
logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Checkpoint sync failed: {e}")
|
|
||||||
|
|
||||||
async def load_model_for_training(self) -> nn.Module:
|
|
||||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
|
||||||
|
|
||||||
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
|
|
||||||
views (narrowing merged weights into separate q/k/v/z etc.), and
|
|
||||||
constructs the HF model around those views. No weight copying —
|
|
||||||
all parameters share vLLM's GPU memory.
|
|
||||||
"""
|
|
||||||
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
|
|
||||||
model_path = self.config['model_path']
|
|
||||||
|
|
||||||
# Import vLLM weights via CUDA IPC
|
|
||||||
logger.info(f"Importing vLLM weights from {handle_path}")
|
|
||||||
handles = torch.load(handle_path, weights_only=False)
|
|
||||||
vllm_params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args = info['handle']
|
|
||||||
vllm_params[name] = func(*args)
|
|
||||||
logger.info(f"Imported {len(vllm_params)} parameters")
|
|
||||||
|
|
||||||
# Map vLLM merged layout → HF separate layout (views, no copies)
|
|
||||||
from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights
|
|
||||||
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
|
||||||
logger.info("HF model constructed with vLLM weight views")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def run_apollo_training(self, model: nn.Module,
|
|
||||||
samples: List[Dict[str, Any]],
|
|
||||||
config: Dict[str, Any]) -> List[float]:
|
|
||||||
"""Run Apollo-Mini training on conversation decision points.
|
|
||||||
|
|
||||||
Each sample has:
|
|
||||||
context_ids: token IDs for frozen context (no gradients)
|
|
||||||
continuation_ids: token IDs for the decision we're training on
|
|
||||||
"""
|
|
||||||
from apollo_plugin.optimizer import Apollo
|
|
||||||
|
|
||||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
|
||||||
apollo_params, standard_params = [], []
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({'params': apollo_params})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({'params': standard_params})
|
|
||||||
|
|
||||||
# Apollo settings from request config, falling back to server defaults
|
|
||||||
optimizer = Apollo(
|
|
||||||
groups,
|
|
||||||
lr=config.get('lr', self.config.get('learning_rate', 1e-5)),
|
|
||||||
rank=config.get('rank', 256),
|
|
||||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
|
||||||
eps=config.get('eps', 1e-8),
|
|
||||||
weight_decay=config.get('weight_decay', 0.01),
|
|
||||||
warmup_steps=config.get('warmup_steps', 0),
|
|
||||||
scale=config.get('scale'), # None = auto
|
|
||||||
proj_refresh=config.get('proj_refresh', 200),
|
|
||||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
|
||||||
)
|
|
||||||
rank = config.get('rank', 256)
|
|
||||||
lr = config.get('lr', self.config.get('learning_rate', 1e-5))
|
|
||||||
logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, "
|
|
||||||
f"{len(standard_params)} standard, "
|
|
||||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
|
||||||
|
|
||||||
loss_history = []
|
|
||||||
|
|
||||||
for i, sample in enumerate(samples):
|
|
||||||
# context_ids: frozen (forward only, no gradients)
|
|
||||||
# continuation_ids: the decision we're training on
|
|
||||||
ctx_ids = sample['context_ids']
|
|
||||||
cont_ids = sample['continuation_ids']
|
|
||||||
all_ids = ctx_ids + cont_ids
|
|
||||||
context_len = len(ctx_ids)
|
|
||||||
|
|
||||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Context-frozen forward pass
|
|
||||||
with torch.no_grad():
|
|
||||||
# Forward through context (no gradients)
|
|
||||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
|
||||||
past_kv = outputs.past_key_values
|
|
||||||
|
|
||||||
# Decision tokens with gradients
|
|
||||||
with torch.enable_grad():
|
|
||||||
outputs = model(
|
|
||||||
input_ids[:, context_len:],
|
|
||||||
past_key_values=past_kv,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
logits = outputs.logits # [1, cont_len, vocab]
|
|
||||||
|
|
||||||
# Shift: predict next token from each position
|
|
||||||
shift_logits = logits[:, :-1].contiguous()
|
|
||||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
||||||
|
|
||||||
loss = nn.functional.cross_entropy(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1),
|
|
||||||
)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
loss_val = loss.item()
|
|
||||||
loss_history.append(loss_val)
|
|
||||||
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
|
||||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
|
||||||
|
|
||||||
logger.info(f"Training done: {len(samples)} examples, "
|
|
||||||
f"final loss={loss_history[-1]:.4f}")
|
|
||||||
return loss_history
|
|
||||||
|
|
||||||
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
|
|
||||||
"""Save model checkpoint in HuggingFace safetensors format."""
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
|
||||||
date_str = datetime.now().strftime('%Y-%m-%d')
|
|
||||||
out_dir = checkpoint_dir / date_str
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save weights
|
|
||||||
tensors = {name: p.data.contiguous().cpu()
|
|
||||||
for name, p in model.named_parameters()}
|
|
||||||
save_path = out_dir / "model.safetensors"
|
|
||||||
save_file(tensors, str(save_path))
|
|
||||||
|
|
||||||
# Copy config files
|
|
||||||
config_dir = Path(self.config['model_path'])
|
|
||||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
|
||||||
'special_tokens_map.json']:
|
|
||||||
src = config_dir / f
|
|
||||||
if src.exists():
|
|
||||||
shutil.copy2(src, out_dir / f)
|
|
||||||
|
|
||||||
# Save training metadata
|
|
||||||
meta = {
|
|
||||||
'job_id': job.job_id,
|
|
||||||
'training_samples': job.training_samples,
|
|
||||||
'loss_history': job.loss_history,
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
with open(out_dir / 'training-meta.json', 'w') as f:
|
|
||||||
json.dump(meta, f, indent=2)
|
|
||||||
|
|
||||||
# Update latest symlink
|
|
||||||
latest = checkpoint_dir / 'latest'
|
|
||||||
if latest.is_symlink():
|
|
||||||
latest.unlink()
|
|
||||||
latest.symlink_to(date_str)
|
|
||||||
|
|
||||||
size_gb = save_path.stat().st_size / 1e9
|
|
||||||
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
|
|
||||||
return str(out_dir)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Run the daemon."""
|
|
||||||
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
|
|
||||||
runner = web.AppRunner(self.app)
|
|
||||||
await runner.setup()
|
|
||||||
site = web.TCPSite(runner, self.config['host'], self.config['port'])
|
|
||||||
await site.start()
|
|
||||||
logger.info("Apollo Worker is running")
|
|
||||||
|
|
||||||
# Keep running
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(3600) # Sleep for an hour
|
|
||||||
|
|
||||||
def main():
|
|
||||||
worker = ApolloWorker()
|
|
||||||
asyncio.run(worker.run())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -20,7 +20,6 @@ dev = ["pytest"]
|
||||||
apollo = "apollo_plugin:register"
|
apollo = "apollo_plugin:register"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
apollo-worker = "apollo_plugin.worker:main"
|
|
||||||
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
|
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue