training: move to dedicated subprocess with ZMQ communication
- Add training_worker.py: long-lived subprocess that handles GPU training
work, owns HF model wrapper (views into vLLM GPU memory), Apollo
optimizer, and checkpoint sync
- train_router.py: now forwards /train requests via async ZMQ instead of
running training in-process. Adds /checkpoint and /train/status endpoints
- export_hook.py: store model_path in __metadata__ so training worker can
find it without cross-process communication
- This fixes two bugs:
1. Process boundary issue - model_path was set in worker process but
needed in API server process
2. Blocking event loop - training blocked vLLM's async event loop
Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
68a2df2185
commit
2c6a5c0f4a
6 changed files with 503 additions and 233 deletions
|
|
@ -26,25 +26,37 @@ The training signal comes from two sources:
|
|||
│ └──────────────┬──────────────┬────────────────┘ │
|
||||
│ │ │ │
|
||||
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
|
||||
│ │ vLLM (inference)│ │ HF model (training) │ │
|
||||
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
|
||||
│ │ /completions │ │ Optimizer state ~10GB │ │
|
||||
│ │ /score │ │ Views into vLLM weights │ │
|
||||
│ │ /train ────────┼──┼─► Apollo optimizer │ │
|
||||
│ └─────────────────┘ └────────────────────────┘ │
|
||||
│ │ vLLM (inference)│ │ Training subprocess │ │
|
||||
│ │ KV cache ~60GB │ │ HF model wrapper │ │
|
||||
│ │ /completions │ │ Apollo optimizer ~2.5GB │ │
|
||||
│ │ /score │ │ Checkpoint sync │ │
|
||||
│ └────────┬────────┘ └───────────▲─────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ ZMQ IPC │ │
|
||||
│ └───────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
|
||||
Single vLLM process serves everything
|
||||
No separate daemon - /train is a vLLM route
|
||||
Process Architecture:
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │
|
||||
│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │
|
||||
│ │ │ │ │ │
|
||||
│ export_hook.py │ │ /completions │ │ HF model views │
|
||||
│ exports IPC │ │ /score │ │ Apollo optimizer│
|
||||
│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │
|
||||
└──── IPC handles file ──────────────────┘
|
||||
/tmp/vllm_weight_handles.pt
|
||||
|
||||
Moria B200 (vLLM)
|
||||
┌──────────────────┐ ┌──────────────────┐
|
||||
│ Training signal │ HTTP │ /completions │
|
||||
│ agent │──────────>│ /score │
|
||||
│ │ │ /train │
|
||||
│ Dream loop │ │ │
|
||||
│ (generates │ │ Checkpoint sync │
|
||||
│ scenarios) │ │ (10 min batched) │
|
||||
│ Dream loop │ │ /checkpoint │
|
||||
│ (generates │ │ /train/status │
|
||||
│ scenarios) │ │ │
|
||||
└──────────────────┘ └──────────────────┘
|
||||
```
|
||||
|
||||
|
|
@ -213,8 +225,9 @@ a few hundred MB.
|
|||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. |
|
||||
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. |
|
||||
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). |
|
||||
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. |
|
||||
| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. |
|
||||
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
|
||||
|
||||
### Moria (client)
|
||||
|
|
@ -224,12 +237,13 @@ a few hundred MB.
|
|||
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
|
||||
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
|
||||
|
||||
### In-memory
|
||||
### In-memory (training_worker subprocess)
|
||||
|
||||
| State | Location | Notes |
|
||||
|-------|----------|-------|
|
||||
| Apollo optimizer | train_router._optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. |
|
||||
| HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. |
|
||||
| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. |
|
||||
| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. |
|
||||
| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. |
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
|
|
@ -248,7 +262,8 @@ a few hundred MB.
|
|||
|
||||
### Built ✓
|
||||
- `optimizer.py` — Apollo optimizer (configurable rank)
|
||||
- `train_router.py` — /train endpoint, runs in vLLM process
|
||||
- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ
|
||||
- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync)
|
||||
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
|
||||
- `export_hook.py` — vLLM plugin hook for IPC handle export
|
||||
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
|
||||
|
|
@ -267,8 +282,9 @@ training/
|
|||
pyproject.toml — package config, vLLM plugin entry point
|
||||
apollo_plugin/
|
||||
__init__.py — plugin registration
|
||||
export_hook.py — patches vLLM to export IPC handles
|
||||
train_router.py — /train endpoint (FastAPI router)
|
||||
export_hook.py — patches vLLM worker to export IPC handles
|
||||
train_router.py — /train endpoint, forwards to worker via ZMQ
|
||||
training_worker.py — training subprocess (HF model, Apollo, checkpoint)
|
||||
optimizer.py — Apollo optimizer
|
||||
weight_mapping.py — vLLM ↔ HF weight views
|
||||
checkpoint_sync.py — mmap + diff sync to safetensors
|
||||
|
|
|
|||
|
|
@ -260,6 +260,9 @@ def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
|
|||
"""
|
||||
handles = torch.load(handles_path, weights_only=False)
|
||||
|
||||
# Skip metadata entry
|
||||
handles.pop('__metadata__', None)
|
||||
|
||||
weights = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from pathlib import Path
|
|||
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
|
||||
|
||||
def export_model_weights(model):
|
||||
def export_model_weights(model, model_path: str | None = None):
|
||||
"""Export CUDA IPC handles for all model parameters."""
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
|
|
@ -38,6 +38,12 @@ def export_model_weights(model):
|
|||
}
|
||||
total_bytes += param.nelement() * param.element_size()
|
||||
|
||||
# Include metadata for training worker
|
||||
handles['__metadata__'] = {
|
||||
'model_path': model_path,
|
||||
'num_params': len(handles),
|
||||
}
|
||||
|
||||
torch.save(handles, HANDLE_PATH)
|
||||
print(f"[apollo] Exported {len(handles)} weight handles "
|
||||
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
|
||||
|
|
@ -58,11 +64,8 @@ def _patch_model_runner():
|
|||
def patched_load(self, *args, **kwargs):
|
||||
result = original_load(self, *args, **kwargs)
|
||||
try:
|
||||
export_model_weights(self.model_runner.model)
|
||||
# Set model path for training router
|
||||
model_path = self.vllm_config.model_config.model
|
||||
from .train_router import set_model_path
|
||||
set_model_path(model_path)
|
||||
export_model_weights(self.model_runner.model, model_path)
|
||||
except Exception as e:
|
||||
print(f"[apollo] Failed to export weights: {e}")
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
"""Training endpoint for vLLM - runs Apollo training in-process.
|
||||
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
|
||||
|
||||
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
|
||||
style - no pause needed, weights updated in-place while inference continues.
|
||||
Patches vLLM's build_app() to add /train route. The actual training runs
|
||||
in a dedicated subprocess (training_worker.py) to avoid blocking the
|
||||
event loop and to keep training work isolated from vLLM internals.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -18,10 +25,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||
|
||||
class TrainingSample(BaseModel):
|
||||
context_ids: list[int]
|
||||
continuation_ids: list[int]
|
||||
# Global state for subprocess management
|
||||
_worker_process: subprocess.Popen | None = None
|
||||
_zmq_context: zmq.asyncio.Context | None = None
|
||||
_zmq_socket: zmq.asyncio.Socket | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
|
|
@ -35,64 +45,61 @@ class TrainResponse(BaseModel):
|
|||
loss_history: list[float]
|
||||
|
||||
|
||||
# Global reference to HF model with vLLM weight views
|
||||
_model: nn.Module | None = None
|
||||
_model_path: str | None = None
|
||||
_initialized: bool = False
|
||||
_optimizer: Any = None # Persisted Apollo optimizer
|
||||
def _start_worker_subprocess():
|
||||
"""Start the training worker subprocess."""
|
||||
global _worker_process
|
||||
|
||||
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||
DEFAULT_RANK = 64
|
||||
if _worker_process is not None and _worker_process.poll() is None:
|
||||
return # Still running
|
||||
|
||||
# Start worker as subprocess using script path
|
||||
worker_script = Path(__file__).parent / 'training_worker.py'
|
||||
_worker_process = subprocess.Popen(
|
||||
[sys.executable, str(worker_script)],
|
||||
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
|
||||
)
|
||||
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
|
||||
|
||||
def _load_training_model() -> nn.Module:
|
||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
Uses CUDA IPC handles exported by export_hook to create an HF model
|
||||
whose parameters share GPU memory with vLLM's model.
|
||||
"""
|
||||
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||
from .export_hook import HANDLE_PATH
|
||||
|
||||
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
|
||||
model.train()
|
||||
return model
|
||||
# Give it a moment to bind the socket
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _ensure_initialized():
|
||||
"""Lazy-initialize the training model on first /train request."""
|
||||
global _model, _initialized
|
||||
"""Ensure subprocess is running and ZMQ socket is connected."""
|
||||
global _zmq_context, _zmq_socket, _initialized
|
||||
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
if _model_path is None:
|
||||
raise RuntimeError("Model path not set - export_hook may not have run")
|
||||
# Start worker if needed
|
||||
_start_worker_subprocess()
|
||||
|
||||
# Create async ZMQ context and socket
|
||||
_zmq_context = zmq.asyncio.Context()
|
||||
_zmq_socket = _zmq_context.socket(zmq.REQ)
|
||||
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
|
||||
|
||||
# Set timeout for recv
|
||||
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
|
||||
|
||||
logger.info("[apollo] Loading HF model with vLLM weight views...")
|
||||
_model = _load_training_model()
|
||||
_initialized = True
|
||||
logger.info("[apollo] Training model ready")
|
||||
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
|
||||
|
||||
|
||||
def set_model_path(path: str):
|
||||
"""Set model path for training. Called by export_hook after model load."""
|
||||
global _model_path
|
||||
_model_path = path
|
||||
logger.info(f"[apollo] Model path set: {path}")
|
||||
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send request to worker and wait for response."""
|
||||
_ensure_initialized()
|
||||
|
||||
# ZMQ async send/recv
|
||||
await _zmq_socket.send_json(request)
|
||||
response = await _zmq_socket.recv_json()
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/train")
|
||||
async def handle_train(request: TrainRequest, raw_request: Request):
|
||||
"""Handle training request - runs Apollo training on provided samples."""
|
||||
global _model
|
||||
|
||||
async def handle_train(request: TrainRequest):
|
||||
"""Handle training request - forwards to training subprocess."""
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
|
|
@ -113,193 +120,109 @@ async def handle_train(request: TrainRequest, raw_request: Request):
|
|||
)
|
||||
|
||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
|
||||
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
|
||||
|
||||
# Run training
|
||||
loss_history = await run_training(_model, samples, config)
|
||||
# Forward to worker
|
||||
response = await _send_request({
|
||||
'type': 'train',
|
||||
'samples': samples,
|
||||
'config': config,
|
||||
})
|
||||
|
||||
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
|
||||
if 'error' in response:
|
||||
return JSONResponse(
|
||||
content={"error": response['error']},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Schedule checkpoint sync (batched, 10 min delay)
|
||||
schedule_checkpoint_sync()
|
||||
logger.info(
|
||||
f"Training job {job_id} completed, "
|
||||
f"final loss: {response['loss_history'][-1]:.4f}"
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"job_id": job_id,
|
||||
"status": "completed",
|
||||
"training_samples": len(samples),
|
||||
"loss_history": loss_history,
|
||||
"status": response['status'],
|
||||
"training_samples": response['training_samples'],
|
||||
"loss_history": response['loss_history'],
|
||||
})
|
||||
|
||||
except zmq.Again:
|
||||
logger.error("Training request timed out")
|
||||
return JSONResponse(
|
||||
content={"error": "Training request timed out"},
|
||||
status_code=504,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"[apollo] Training failed: {e}")
|
||||
logger.exception(f"Training failed: {e}")
|
||||
return JSONResponse(
|
||||
content={"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
|
||||
"""Get existing optimizer or create new one. Persists state between calls."""
|
||||
global _optimizer
|
||||
from .optimizer import Apollo
|
||||
import os
|
||||
|
||||
if _optimizer is not None:
|
||||
return _optimizer
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
# Create optimizer
|
||||
_optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', DEFAULT_RANK),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'),
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
# Restore state if exists
|
||||
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||
@router.post("/checkpoint")
|
||||
async def handle_checkpoint():
|
||||
"""Trigger checkpoint sync to disk."""
|
||||
try:
|
||||
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||
_optimizer.load_state_dict(state)
|
||||
logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
logger.warning(f"[apollo] Could not restore optimizer state: {e}")
|
||||
|
||||
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={_optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
return _optimizer
|
||||
|
||||
|
||||
def _save_optimizer_state():
|
||||
"""Save optimizer state for persistence between /train calls."""
|
||||
global _optimizer
|
||||
if _optimizer is not None:
|
||||
torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||
logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||
|
||||
|
||||
async def run_training(
|
||||
model: nn.Module,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
optimizer = _get_or_create_optimizer(model, config)
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
return JSONResponse(
|
||||
content={"error": f"Training not available: {e}"},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||
|
||||
return loss_history
|
||||
|
||||
|
||||
# Checkpoint sync scheduling
|
||||
_checkpoint_task = None
|
||||
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
def schedule_checkpoint_sync():
|
||||
"""Schedule checkpoint sync after delay (batched)."""
|
||||
global _checkpoint_task
|
||||
import asyncio
|
||||
|
||||
if _checkpoint_task is not None:
|
||||
# Already scheduled
|
||||
return
|
||||
|
||||
async def do_sync():
|
||||
global _checkpoint_task
|
||||
try:
|
||||
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||
if _model_path:
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
logger.info("[apollo] Starting checkpoint sync...")
|
||||
# Save optimizer state alongside model weights
|
||||
_save_optimizer_state()
|
||||
result = checkpoint_sync(_model_path)
|
||||
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
||||
except Exception as e:
|
||||
logger.error(f"[apollo] Checkpoint sync failed: {e}")
|
||||
finally:
|
||||
_checkpoint_task = None
|
||||
response = await _send_request({'type': 'checkpoint'})
|
||||
|
||||
_checkpoint_task = asyncio.create_task(do_sync())
|
||||
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
|
||||
if 'error' in response:
|
||||
return JSONResponse(
|
||||
content={"error": response['error']},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(content=response)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Checkpoint failed: {e}")
|
||||
return JSONResponse(
|
||||
content={"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/train/status")
|
||||
async def handle_status():
|
||||
"""Get training worker status."""
|
||||
try:
|
||||
_ensure_initialized()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "unavailable",
|
||||
"error": str(e),
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await _send_request({'type': 'status'})
|
||||
return JSONResponse(content=response)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach training router to FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("[apollo] Training router attached")
|
||||
logger.info("Training router attached")
|
||||
|
||||
|
||||
def _patch_api_server():
|
||||
|
|
@ -314,4 +237,4 @@ def _patch_api_server():
|
|||
return app
|
||||
|
||||
api_server.build_app = patched_build_app
|
||||
logger.info("[apollo] API server patched for /train endpoint")
|
||||
logger.info("API server patched for /train endpoint")
|
||||
|
|
|
|||
323
training/apollo_plugin/training_worker.py
Normal file
323
training/apollo_plugin/training_worker.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Training subprocess - handles Apollo training and checkpoint sync.
|
||||
|
||||
Long-lived process that:
|
||||
1. Loads IPC handles from vLLM's exported weights
|
||||
2. Creates HF model with views into vLLM's GPU memory
|
||||
3. Handles training requests via ZMQ
|
||||
4. Handles checkpoint sync requests
|
||||
5. Persists Apollo optimizer state between calls
|
||||
|
||||
Communicates with the API server's /train endpoint via ZMQ REP socket.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Handle running as script vs module
|
||||
if __name__ == '__main__' and __package__ is None:
|
||||
# Running as script - add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
__package__ = 'apollo_plugin'
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import zmq
|
||||
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
from .optimizer import Apollo
|
||||
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RANK = 64
|
||||
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||
|
||||
|
||||
class TrainingWorker:
|
||||
"""Long-lived training worker process."""
|
||||
|
||||
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
|
||||
self.zmq_addr = zmq_addr
|
||||
self.model: nn.Module | None = None
|
||||
self.optimizer: Apollo | None = None
|
||||
self.model_path: str | None = None
|
||||
self._running = True
|
||||
|
||||
def _create_model_wrapper(self) -> nn.Module:
|
||||
"""Create HF model wrapper with views into vLLM's GPU memory."""
|
||||
if not os.path.exists(HANDLE_PATH):
|
||||
raise FileNotFoundError(
|
||||
f"Weight handles not found: {HANDLE_PATH}. "
|
||||
"Is vLLM running with the export hook?"
|
||||
)
|
||||
|
||||
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||
|
||||
# Extract metadata
|
||||
metadata = handles.pop('__metadata__', {})
|
||||
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
|
||||
if not self.model_path:
|
||||
raise ValueError(
|
||||
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
|
||||
)
|
||||
|
||||
# Reconstruct tensors from IPC handles
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
|
||||
model.train()
|
||||
return model
|
||||
|
||||
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
|
||||
"""Get existing optimizer or create new one."""
|
||||
if self.optimizer is not None:
|
||||
return self.optimizer
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in self.model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
self.optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', DEFAULT_RANK),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'),
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
# Restore state if exists
|
||||
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||
try:
|
||||
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||
self.optimizer.load_state_dict(state)
|
||||
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not restore optimizer state: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def _save_optimizer_state(self):
|
||||
"""Save optimizer state for persistence."""
|
||||
if self.optimizer is not None:
|
||||
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||
|
||||
def _run_training(
|
||||
self,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples."""
|
||||
optimizer = self._get_or_create_optimizer(config)
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
outputs = self.model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = self.model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(
|
||||
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
|
||||
)
|
||||
|
||||
return loss_history
|
||||
|
||||
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a training request."""
|
||||
samples = request.get('samples', [])
|
||||
config = request.get('config', {})
|
||||
|
||||
if not samples:
|
||||
return {'error': 'No training samples provided'}
|
||||
|
||||
try:
|
||||
loss_history = self._run_training(samples, config)
|
||||
return {
|
||||
'status': 'completed',
|
||||
'training_samples': len(samples),
|
||||
'loss_history': loss_history,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Training failed: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a checkpoint sync request."""
|
||||
if not self.model_path:
|
||||
return {'error': 'Model path not set'}
|
||||
|
||||
try:
|
||||
self._save_optimizer_state()
|
||||
result = checkpoint_sync(self.model_path)
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total_changed': result['total_changed'],
|
||||
'files_changed': result['files_changed'],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Checkpoint sync failed: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a status request."""
|
||||
return {
|
||||
'status': 'ready',
|
||||
'model_loaded': self.model is not None,
|
||||
'optimizer_loaded': self.optimizer is not None,
|
||||
'model_path': self.model_path,
|
||||
'optimizer_state_mb': (
|
||||
self.optimizer.state_size_bytes() / 1e6
|
||||
if self.optimizer else 0
|
||||
),
|
||||
}
|
||||
|
||||
def run(self):
|
||||
"""Main loop - listen for requests and handle them."""
|
||||
# Set up signal handlers
|
||||
def handle_signal(signum, frame):
|
||||
logger.info(f"Received signal {signum}, shutting down...")
|
||||
self._running = False
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
|
||||
# Set up ZMQ socket first so API server can connect
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(self.zmq_addr)
|
||||
logger.info(f"Training worker listening on {self.zmq_addr}")
|
||||
|
||||
# Create HF model wrapper with views into vLLM's GPU memory
|
||||
logger.info("Connecting to vLLM weights via IPC handles...")
|
||||
try:
|
||||
self.model = self._create_model_wrapper()
|
||||
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to vLLM weights: {e}")
|
||||
logger.info("Will retry on first training request")
|
||||
|
||||
# Set socket timeout so we can check _running flag
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
message = socket.recv_json()
|
||||
except zmq.Again:
|
||||
# Timeout, check _running and continue
|
||||
continue
|
||||
|
||||
request_type = message.get('type', 'train')
|
||||
logger.info(f"Received {request_type} request")
|
||||
|
||||
# Ensure model is loaded
|
||||
if self.model is None and request_type != 'status':
|
||||
try:
|
||||
self.model = self._create_model_wrapper()
|
||||
except Exception as e:
|
||||
socket.send_json({'error': f'Model not loaded: {e}'})
|
||||
continue
|
||||
|
||||
# Dispatch request
|
||||
if request_type == 'train':
|
||||
response = self._handle_train(message)
|
||||
elif request_type == 'checkpoint':
|
||||
response = self._handle_checkpoint(message)
|
||||
elif request_type == 'status':
|
||||
response = self._handle_status(message)
|
||||
else:
|
||||
response = {'error': f'Unknown request type: {request_type}'}
|
||||
|
||||
socket.send_json(response)
|
||||
|
||||
# Cleanup
|
||||
logger.info("Saving optimizer state before shutdown...")
|
||||
self._save_optimizer_state()
|
||||
socket.close()
|
||||
context.term()
|
||||
logger.info("Training worker shut down")
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for running as a subprocess."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
)
|
||||
|
||||
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
|
||||
worker = TrainingWorker(zmq_addr)
|
||||
worker.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
|||
"torch",
|
||||
"aiohttp",
|
||||
"safetensors",
|
||||
"pyzmq",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -21,6 +22,7 @@ apollo = "apollo_plugin:register"
|
|||
|
||||
[project.scripts]
|
||||
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
|
||||
apollo-worker = "apollo_plugin.training_worker:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue