training: move to dedicated subprocess with ZMQ communication

- Add training_worker.py: long-lived subprocess that handles GPU training
  work, owns HF model wrapper (views into vLLM GPU memory), Apollo
  optimizer, and checkpoint sync

- train_router.py: now forwards /train requests via async ZMQ instead of
  running training in-process. Adds /checkpoint and /train/status endpoints

- export_hook.py: store model_path in __metadata__ so training worker can
  find it without cross-process communication

- This fixes two bugs:
  1. Process boundary issue - model_path was set in worker process but
     needed in API server process
  2. Blocking event loop - training blocked vLLM's async event loop

Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
ProofOfConcept 2026-04-16 02:01:59 -04:00 committed by Kent Overstreet
parent 68a2df2185
commit 2c6a5c0f4a
6 changed files with 503 additions and 233 deletions

View file

@ -26,25 +26,37 @@ The training signal comes from two sources:
│ └──────────────┬──────────────┬────────────────┘ │
│ │ │ │
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
│ │ vLLM (inference)│ │ HF model (training) │ │
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
│ │ /completions │ │ Optimizer state ~10GB │ │
│ │ /score │ │ Views into vLLM weights │ │
│ │ /train ────────┼──┼─► Apollo optimizer │ │
│ └─────────────────┘ └────────────────────────┘ │
│ │ vLLM (inference)│ │ Training subprocess │ │
│ │ KV cache ~60GB │ │ HF model wrapper │ │
│ │ /completions │ │ Apollo optimizer ~2.5GB │ │
│ │ /score │ │ Checkpoint sync │ │
│ └────────┬────────┘ └───────────▲─────────────┘ │
│ │ │ │
│ │ ZMQ IPC │ │
│ └───────────────────────┘ │
└─────────────────────────────────────────────────────┘
Single vLLM process serves everything
No separate daemon - /train is a vLLM route
Process Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │
│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │
│ │ │ │ │ │
│ export_hook.py │ │ /completions │ │ HF model views │
│ exports IPC │ │ /score │ │ Apollo optimizer│
│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
└──── IPC handles file ──────────────────┘
/tmp/vllm_weight_handles.pt
Moria B200 (vLLM)
┌──────────────────┐ ┌──────────────────┐
│ Training signal │ HTTP │ /completions │
│ agent │──────────>│ /score │
│ │ │ /train │
│ Dream loop │ │ │
│ (generates │ │ Checkpoint sync │
│ scenarios) │ │ (10 min batched) │
│ Dream loop │ │ /checkpoint
│ (generates │ │ /train/status
│ scenarios) │ │
└──────────────────┘ └──────────────────┘
```
@ -213,8 +225,9 @@ a few hundred MB.
| File | Purpose |
|------|---------|
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. |
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. |
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). |
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. |
| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. |
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
### Moria (client)
@ -224,12 +237,13 @@ a few hundred MB.
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
### In-memory
### In-memory (training_worker subprocess)
| State | Location | Notes |
|-------|----------|-------|
| Apollo optimizer | train_router._optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. |
| HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. |
| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. |
| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. |
| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. |
## Hyperparameters
@ -248,7 +262,8 @@ a few hundred MB.
### Built ✓
- `optimizer.py` — Apollo optimizer (configurable rank)
- `train_router.py` — /train endpoint, runs in vLLM process
- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ
- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync)
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
- `export_hook.py` — vLLM plugin hook for IPC handle export
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
@ -267,8 +282,9 @@ training/
pyproject.toml — package config, vLLM plugin entry point
apollo_plugin/
__init__.py — plugin registration
export_hook.py — patches vLLM to export IPC handles
train_router.py — /train endpoint (FastAPI router)
export_hook.py — patches vLLM worker to export IPC handles
train_router.py — /train endpoint, forwards to worker via ZMQ
training_worker.py — training subprocess (HF model, Apollo, checkpoint)
optimizer.py — Apollo optimizer
weight_mapping.py — vLLM ↔ HF weight views
checkpoint_sync.py — mmap + diff sync to safetensors

View file

@ -260,6 +260,9 @@ def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
"""
handles = torch.load(handles_path, weights_only=False)
# Skip metadata entry
handles.pop('__metadata__', None)
weights = {}
for name, info in handles.items():
func, args = info['handle']

View file

@ -20,7 +20,7 @@ from pathlib import Path
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
def export_model_weights(model):
def export_model_weights(model, model_path: str | None = None):
"""Export CUDA IPC handles for all model parameters."""
from torch.multiprocessing.reductions import reduce_tensor
@ -38,6 +38,12 @@ def export_model_weights(model):
}
total_bytes += param.nelement() * param.element_size()
# Include metadata for training worker
handles['__metadata__'] = {
'model_path': model_path,
'num_params': len(handles),
}
torch.save(handles, HANDLE_PATH)
print(f"[apollo] Exported {len(handles)} weight handles "
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
@ -58,11 +64,8 @@ def _patch_model_runner():
def patched_load(self, *args, **kwargs):
result = original_load(self, *args, **kwargs)
try:
export_model_weights(self.model_runner.model)
# Set model path for training router
model_path = self.vllm_config.model_config.model
from .train_router import set_model_path
set_model_path(model_path)
export_model_weights(self.model_runner.model, model_path)
except Exception as e:
print(f"[apollo] Failed to export weights: {e}")
return result

View file

@ -1,16 +1,23 @@
"""Training endpoint for vLLM - runs Apollo training in-process.
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
style - no pause needed, weights updated in-place while inference continues.
Patches vLLM's build_app() to add /train route. The actual training runs
in a dedicated subprocess (training_worker.py) to avoid blocking the
event loop and to keep training work isolated from vLLM internals.
"""
import asyncio
import logging
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
from fastapi import APIRouter, FastAPI, Request
import zmq
import zmq.asyncio
from fastapi import APIRouter, FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
@ -18,10 +25,13 @@ logger = logging.getLogger(__name__)
router = APIRouter()
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
class TrainingSample(BaseModel):
context_ids: list[int]
continuation_ids: list[int]
# Global state for subprocess management
_worker_process: subprocess.Popen | None = None
_zmq_context: zmq.asyncio.Context | None = None
_zmq_socket: zmq.asyncio.Socket | None = None
_initialized: bool = False
class TrainRequest(BaseModel):
@ -35,64 +45,61 @@ class TrainResponse(BaseModel):
loss_history: list[float]
# Global reference to HF model with vLLM weight views
_model: nn.Module | None = None
_model_path: str | None = None
_initialized: bool = False
_optimizer: Any = None # Persisted Apollo optimizer
def _start_worker_subprocess():
"""Start the training worker subprocess."""
global _worker_process
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
DEFAULT_RANK = 64
if _worker_process is not None and _worker_process.poll() is None:
return # Still running
# Start worker as subprocess using script path
worker_script = Path(__file__).parent / 'training_worker.py'
_worker_process = subprocess.Popen(
[sys.executable, str(worker_script)],
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
)
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
def _load_training_model() -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Uses CUDA IPC handles exported by export_hook to create an HF model
whose parameters share GPU memory with vLLM's model.
"""
from .weight_mapping import load_hf_model_with_vllm_weights
from .export_hook import HANDLE_PATH
handles = torch.load(HANDLE_PATH, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
model.train()
return model
# Give it a moment to bind the socket
import time
time.sleep(0.5)
def _ensure_initialized():
"""Lazy-initialize the training model on first /train request."""
global _model, _initialized
"""Ensure subprocess is running and ZMQ socket is connected."""
global _zmq_context, _zmq_socket, _initialized
if _initialized:
return
if _model_path is None:
raise RuntimeError("Model path not set - export_hook may not have run")
# Start worker if needed
_start_worker_subprocess()
# Create async ZMQ context and socket
_zmq_context = zmq.asyncio.Context()
_zmq_socket = _zmq_context.socket(zmq.REQ)
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
# Set timeout for recv
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
logger.info("[apollo] Loading HF model with vLLM weight views...")
_model = _load_training_model()
_initialized = True
logger.info("[apollo] Training model ready")
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
def set_model_path(path: str):
"""Set model path for training. Called by export_hook after model load."""
global _model_path
_model_path = path
logger.info(f"[apollo] Model path set: {path}")
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
"""Send request to worker and wait for response."""
_ensure_initialized()
# ZMQ async send/recv
await _zmq_socket.send_json(request)
response = await _zmq_socket.recv_json()
return response
@router.post("/train")
async def handle_train(request: TrainRequest, raw_request: Request):
"""Handle training request - runs Apollo training on provided samples."""
global _model
async def handle_train(request: TrainRequest):
"""Handle training request - forwards to training subprocess."""
try:
_ensure_initialized()
except Exception as e:
@ -113,193 +120,109 @@ async def handle_train(request: TrainRequest, raw_request: Request):
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
# Run training
loss_history = await run_training(_model, samples, config)
# Forward to worker
response = await _send_request({
'type': 'train',
'samples': samples,
'config': config,
})
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
# Schedule checkpoint sync (batched, 10 min delay)
schedule_checkpoint_sync()
logger.info(
f"Training job {job_id} completed, "
f"final loss: {response['loss_history'][-1]:.4f}"
)
return JSONResponse(content={
"job_id": job_id,
"status": "completed",
"training_samples": len(samples),
"loss_history": loss_history,
"status": response['status'],
"training_samples": response['training_samples'],
"loss_history": response['loss_history'],
})
except zmq.Again:
logger.error("Training request timed out")
return JSONResponse(
content={"error": "Training request timed out"},
status_code=504,
)
except Exception as e:
logger.exception(f"[apollo] Training failed: {e}")
logger.exception(f"Training failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
"""Get existing optimizer or create new one. Persists state between calls."""
global _optimizer
from .optimizer import Apollo
import os
@router.post("/checkpoint")
async def handle_checkpoint():
"""Trigger checkpoint sync to disk."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
if _optimizer is not None:
return _optimizer
try:
response = await _send_request({'type': 'checkpoint'})
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
# Create optimizer
_optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', DEFAULT_RANK),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
# Restore state if exists
if os.path.exists(OPTIMIZER_STATE_PATH):
try:
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
_optimizer.load_state_dict(state)
logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}")
except Exception as e:
logger.warning(f"[apollo] Could not restore optimizer state: {e}")
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={_optimizer.state_size_bytes()/1e6:.1f}MB")
return _optimizer
def _save_optimizer_state():
"""Save optimizer state for persistence between /train calls."""
global _optimizer
if _optimizer is not None:
torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH)
logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}")
async def run_training(
model: nn.Module,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
optimizer = _get_or_create_optimizer(model, config)
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
loss.backward()
optimizer.step()
return JSONResponse(content=response)
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
return loss_history
except Exception as e:
logger.exception(f"Checkpoint failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
# Checkpoint sync scheduling
_checkpoint_task = None
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
@router.get("/train/status")
async def handle_status():
"""Get training worker status."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={
"status": "unavailable",
"error": str(e),
},
status_code=503,
)
try:
response = await _send_request({'type': 'status'})
return JSONResponse(content=response)
def schedule_checkpoint_sync():
"""Schedule checkpoint sync after delay (batched)."""
global _checkpoint_task
import asyncio
if _checkpoint_task is not None:
# Already scheduled
return
async def do_sync():
global _checkpoint_task
try:
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
if _model_path:
from .checkpoint_sync import checkpoint_sync
logger.info("[apollo] Starting checkpoint sync...")
# Save optimizer state alongside model weights
_save_optimizer_state()
result = checkpoint_sync(_model_path)
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
except Exception as e:
logger.error(f"[apollo] Checkpoint sync failed: {e}")
finally:
_checkpoint_task = None
_checkpoint_task = asyncio.create_task(do_sync())
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
except Exception as e:
return JSONResponse(
content={
"status": "error",
"error": str(e),
},
status_code=500,
)
def attach_router(app: FastAPI):
"""Attach training router to FastAPI app."""
app.include_router(router)
logger.info("[apollo] Training router attached")
logger.info("Training router attached")
def _patch_api_server():
@ -314,4 +237,4 @@ def _patch_api_server():
return app
api_server.build_app = patched_build_app
logger.info("[apollo] API server patched for /train endpoint")
logger.info("API server patched for /train endpoint")

View file

@ -0,0 +1,323 @@
"""Training subprocess - handles Apollo training and checkpoint sync.
Long-lived process that:
1. Loads IPC handles from vLLM's exported weights
2. Creates HF model with views into vLLM's GPU memory
3. Handles training requests via ZMQ
4. Handles checkpoint sync requests
5. Persists Apollo optimizer state between calls
Communicates with the API server's /train endpoint via ZMQ REP socket.
"""
import logging
import os
import signal
import sys
from pathlib import Path
from typing import Any
# Handle running as script vs module
if __name__ == '__main__' and __package__ is None:
# Running as script - add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
__package__ = 'apollo_plugin'
import torch
import torch.nn as nn
import zmq
from .checkpoint_sync import checkpoint_sync
from .optimizer import Apollo
from .weight_mapping import load_hf_model_with_vllm_weights
logger = logging.getLogger(__name__)
DEFAULT_RANK = 64
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
class TrainingWorker:
"""Long-lived training worker process."""
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
self.zmq_addr = zmq_addr
self.model: nn.Module | None = None
self.optimizer: Apollo | None = None
self.model_path: str | None = None
self._running = True
def _create_model_wrapper(self) -> nn.Module:
"""Create HF model wrapper with views into vLLM's GPU memory."""
if not os.path.exists(HANDLE_PATH):
raise FileNotFoundError(
f"Weight handles not found: {HANDLE_PATH}. "
"Is vLLM running with the export hook?"
)
handles = torch.load(HANDLE_PATH, weights_only=False)
# Extract metadata
metadata = handles.pop('__metadata__', {})
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
if not self.model_path:
raise ValueError(
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
)
# Reconstruct tensors from IPC handles
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
model.train()
return model
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
"""Get existing optimizer or create new one."""
if self.optimizer is not None:
return self.optimizer
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
apollo_params, standard_params = [], []
for p in self.model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
self.optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', DEFAULT_RANK),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
# Restore state if exists
if os.path.exists(OPTIMIZER_STATE_PATH):
try:
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
self.optimizer.load_state_dict(state)
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
except Exception as e:
logger.warning(f"Could not restore optimizer state: {e}")
logger.info(
f"Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
)
return self.optimizer
def _save_optimizer_state(self):
"""Save optimizer state for persistence."""
if self.optimizer is not None:
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
def _run_training(
self,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples."""
optimizer = self._get_or_create_optimizer(config)
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = self.model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = self.model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
)
return loss_history
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a training request."""
samples = request.get('samples', [])
config = request.get('config', {})
if not samples:
return {'error': 'No training samples provided'}
try:
loss_history = self._run_training(samples, config)
return {
'status': 'completed',
'training_samples': len(samples),
'loss_history': loss_history,
}
except Exception as e:
logger.exception(f"Training failed: {e}")
return {'error': str(e)}
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a checkpoint sync request."""
if not self.model_path:
return {'error': 'Model path not set'}
try:
self._save_optimizer_state()
result = checkpoint_sync(self.model_path)
return {
'status': 'completed',
'total_changed': result['total_changed'],
'files_changed': result['files_changed'],
}
except Exception as e:
logger.exception(f"Checkpoint sync failed: {e}")
return {'error': str(e)}
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a status request."""
return {
'status': 'ready',
'model_loaded': self.model is not None,
'optimizer_loaded': self.optimizer is not None,
'model_path': self.model_path,
'optimizer_state_mb': (
self.optimizer.state_size_bytes() / 1e6
if self.optimizer else 0
),
}
def run(self):
"""Main loop - listen for requests and handle them."""
# Set up signal handlers
def handle_signal(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
self._running = False
signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGINT, handle_signal)
# Set up ZMQ socket first so API server can connect
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(self.zmq_addr)
logger.info(f"Training worker listening on {self.zmq_addr}")
# Create HF model wrapper with views into vLLM's GPU memory
logger.info("Connecting to vLLM weights via IPC handles...")
try:
self.model = self._create_model_wrapper()
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
except Exception as e:
logger.error(f"Failed to connect to vLLM weights: {e}")
logger.info("Will retry on first training request")
# Set socket timeout so we can check _running flag
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
while self._running:
try:
message = socket.recv_json()
except zmq.Again:
# Timeout, check _running and continue
continue
request_type = message.get('type', 'train')
logger.info(f"Received {request_type} request")
# Ensure model is loaded
if self.model is None and request_type != 'status':
try:
self.model = self._create_model_wrapper()
except Exception as e:
socket.send_json({'error': f'Model not loaded: {e}'})
continue
# Dispatch request
if request_type == 'train':
response = self._handle_train(message)
elif request_type == 'checkpoint':
response = self._handle_checkpoint(message)
elif request_type == 'status':
response = self._handle_status(message)
else:
response = {'error': f'Unknown request type: {request_type}'}
socket.send_json(response)
# Cleanup
logger.info("Saving optimizer state before shutdown...")
self._save_optimizer_state()
socket.close()
context.term()
logger.info("Training worker shut down")
def main():
"""Entry point for running as a subprocess."""
logging.basicConfig(
level=logging.INFO,
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
)
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
worker = TrainingWorker(zmq_addr)
worker.run()
if __name__ == '__main__':
main()

View file

@ -11,6 +11,7 @@ dependencies = [
"torch",
"aiohttp",
"safetensors",
"pyzmq",
]
[project.optional-dependencies]
@ -21,6 +22,7 @@ apollo = "apollo_plugin:register"
[project.scripts]
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
apollo-worker = "apollo_plugin.training_worker:main"
[tool.setuptools.packages.find]
where = ["."]