- Add training_worker.py: long-lived subprocess that handles GPU training
work, owns HF model wrapper (views into vLLM GPU memory), Apollo
optimizer, and checkpoint sync
- train_router.py: now forwards /train requests via async ZMQ instead of
running training in-process. Adds /checkpoint and /train/status endpoints
- export_hook.py: store model_path in __metadata__ so training worker can
find it without cross-process communication
- This fixes two bugs:
1. Process boundary issue - model_path was set in worker process but
needed in API server process
2. Blocking event loop - training blocked vLLM's async event loop
Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
323 lines
11 KiB
Python
323 lines
11 KiB
Python
"""Training subprocess - handles Apollo training and checkpoint sync.
|
|
|
|
Long-lived process that:
|
|
1. Loads IPC handles from vLLM's exported weights
|
|
2. Creates HF model with views into vLLM's GPU memory
|
|
3. Handles training requests via ZMQ
|
|
4. Handles checkpoint sync requests
|
|
5. Persists Apollo optimizer state between calls
|
|
|
|
Communicates with the API server's /train endpoint via ZMQ REP socket.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# Handle running as script vs module
|
|
if __name__ == '__main__' and __package__ is None:
|
|
# Running as script - add parent to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
__package__ = 'apollo_plugin'
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import zmq
|
|
|
|
from .checkpoint_sync import checkpoint_sync
|
|
from .optimizer import Apollo
|
|
from .weight_mapping import load_hf_model_with_vllm_weights
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_RANK = 64
|
|
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
|
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
|
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
|
|
|
|
|
class TrainingWorker:
|
|
"""Long-lived training worker process."""
|
|
|
|
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
|
|
self.zmq_addr = zmq_addr
|
|
self.model: nn.Module | None = None
|
|
self.optimizer: Apollo | None = None
|
|
self.model_path: str | None = None
|
|
self._running = True
|
|
|
|
def _create_model_wrapper(self) -> nn.Module:
|
|
"""Create HF model wrapper with views into vLLM's GPU memory."""
|
|
if not os.path.exists(HANDLE_PATH):
|
|
raise FileNotFoundError(
|
|
f"Weight handles not found: {HANDLE_PATH}. "
|
|
"Is vLLM running with the export hook?"
|
|
)
|
|
|
|
handles = torch.load(HANDLE_PATH, weights_only=False)
|
|
|
|
# Extract metadata
|
|
metadata = handles.pop('__metadata__', {})
|
|
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
|
|
if not self.model_path:
|
|
raise ValueError(
|
|
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
|
|
)
|
|
|
|
# Reconstruct tensors from IPC handles
|
|
vllm_params = {}
|
|
for name, info in handles.items():
|
|
func, args = info['handle']
|
|
vllm_params[name] = func(*args)
|
|
|
|
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
|
|
model.train()
|
|
return model
|
|
|
|
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
|
|
"""Get existing optimizer or create new one."""
|
|
if self.optimizer is not None:
|
|
return self.optimizer
|
|
|
|
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
|
|
apollo_params, standard_params = [], []
|
|
for p in self.model.parameters():
|
|
if p.requires_grad:
|
|
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
|
apollo_params.append(p)
|
|
else:
|
|
standard_params.append(p)
|
|
|
|
groups = []
|
|
if apollo_params:
|
|
groups.append({'params': apollo_params})
|
|
if standard_params:
|
|
groups.append({'params': standard_params})
|
|
|
|
if not groups:
|
|
raise ValueError("No trainable parameters found")
|
|
|
|
self.optimizer = Apollo(
|
|
groups,
|
|
lr=config.get('lr', 1e-5),
|
|
rank=config.get('rank', DEFAULT_RANK),
|
|
betas=tuple(config.get('betas', (0.9, 0.999))),
|
|
eps=config.get('eps', 1e-8),
|
|
weight_decay=config.get('weight_decay', 0.01),
|
|
warmup_steps=config.get('warmup_steps', 0),
|
|
scale=config.get('scale'),
|
|
proj_refresh=config.get('proj_refresh', 200),
|
|
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
|
)
|
|
|
|
# Restore state if exists
|
|
if os.path.exists(OPTIMIZER_STATE_PATH):
|
|
try:
|
|
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
|
self.optimizer.load_state_dict(state)
|
|
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not restore optimizer state: {e}")
|
|
|
|
logger.info(
|
|
f"Optimizer: {len(apollo_params)} apollo params, "
|
|
f"{len(standard_params)} standard, "
|
|
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
|
|
)
|
|
|
|
return self.optimizer
|
|
|
|
def _save_optimizer_state(self):
|
|
"""Save optimizer state for persistence."""
|
|
if self.optimizer is not None:
|
|
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
|
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
|
|
|
def _run_training(
|
|
self,
|
|
samples: list[dict[str, Any]],
|
|
config: dict[str, Any],
|
|
) -> list[float]:
|
|
"""Run Apollo training on the given samples."""
|
|
optimizer = self._get_or_create_optimizer(config)
|
|
|
|
loss_history = []
|
|
|
|
for i, sample in enumerate(samples):
|
|
ctx_ids = sample['context_ids']
|
|
cont_ids = sample['continuation_ids']
|
|
all_ids = ctx_ids + cont_ids
|
|
context_len = len(ctx_ids)
|
|
|
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Context-frozen forward pass
|
|
with torch.no_grad():
|
|
outputs = self.model(input_ids[:, :context_len], use_cache=True)
|
|
past_kv = outputs.past_key_values
|
|
|
|
# Decision tokens with gradients
|
|
with torch.enable_grad():
|
|
outputs = self.model(
|
|
input_ids[:, context_len:],
|
|
past_key_values=past_kv,
|
|
use_cache=False,
|
|
)
|
|
logits = outputs.logits
|
|
|
|
# Shift: predict next token from each position
|
|
shift_logits = logits[:, :-1].contiguous()
|
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
|
|
loss = nn.functional.cross_entropy(
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1),
|
|
)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loss_val = loss.item()
|
|
loss_history.append(loss_val)
|
|
logger.info(
|
|
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
|
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
|
|
)
|
|
|
|
return loss_history
|
|
|
|
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
"""Handle a training request."""
|
|
samples = request.get('samples', [])
|
|
config = request.get('config', {})
|
|
|
|
if not samples:
|
|
return {'error': 'No training samples provided'}
|
|
|
|
try:
|
|
loss_history = self._run_training(samples, config)
|
|
return {
|
|
'status': 'completed',
|
|
'training_samples': len(samples),
|
|
'loss_history': loss_history,
|
|
}
|
|
except Exception as e:
|
|
logger.exception(f"Training failed: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
"""Handle a checkpoint sync request."""
|
|
if not self.model_path:
|
|
return {'error': 'Model path not set'}
|
|
|
|
try:
|
|
self._save_optimizer_state()
|
|
result = checkpoint_sync(self.model_path)
|
|
return {
|
|
'status': 'completed',
|
|
'total_changed': result['total_changed'],
|
|
'files_changed': result['files_changed'],
|
|
}
|
|
except Exception as e:
|
|
logger.exception(f"Checkpoint sync failed: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
"""Handle a status request."""
|
|
return {
|
|
'status': 'ready',
|
|
'model_loaded': self.model is not None,
|
|
'optimizer_loaded': self.optimizer is not None,
|
|
'model_path': self.model_path,
|
|
'optimizer_state_mb': (
|
|
self.optimizer.state_size_bytes() / 1e6
|
|
if self.optimizer else 0
|
|
),
|
|
}
|
|
|
|
def run(self):
|
|
"""Main loop - listen for requests and handle them."""
|
|
# Set up signal handlers
|
|
def handle_signal(signum, frame):
|
|
logger.info(f"Received signal {signum}, shutting down...")
|
|
self._running = False
|
|
|
|
signal.signal(signal.SIGTERM, handle_signal)
|
|
signal.signal(signal.SIGINT, handle_signal)
|
|
|
|
# Set up ZMQ socket first so API server can connect
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REP)
|
|
socket.bind(self.zmq_addr)
|
|
logger.info(f"Training worker listening on {self.zmq_addr}")
|
|
|
|
# Create HF model wrapper with views into vLLM's GPU memory
|
|
logger.info("Connecting to vLLM weights via IPC handles...")
|
|
try:
|
|
self.model = self._create_model_wrapper()
|
|
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to vLLM weights: {e}")
|
|
logger.info("Will retry on first training request")
|
|
|
|
# Set socket timeout so we can check _running flag
|
|
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
|
|
|
while self._running:
|
|
try:
|
|
message = socket.recv_json()
|
|
except zmq.Again:
|
|
# Timeout, check _running and continue
|
|
continue
|
|
|
|
request_type = message.get('type', 'train')
|
|
logger.info(f"Received {request_type} request")
|
|
|
|
# Ensure model is loaded
|
|
if self.model is None and request_type != 'status':
|
|
try:
|
|
self.model = self._create_model_wrapper()
|
|
except Exception as e:
|
|
socket.send_json({'error': f'Model not loaded: {e}'})
|
|
continue
|
|
|
|
# Dispatch request
|
|
if request_type == 'train':
|
|
response = self._handle_train(message)
|
|
elif request_type == 'checkpoint':
|
|
response = self._handle_checkpoint(message)
|
|
elif request_type == 'status':
|
|
response = self._handle_status(message)
|
|
else:
|
|
response = {'error': f'Unknown request type: {request_type}'}
|
|
|
|
socket.send_json(response)
|
|
|
|
# Cleanup
|
|
logger.info("Saving optimizer state before shutdown...")
|
|
self._save_optimizer_state()
|
|
socket.close()
|
|
context.term()
|
|
logger.info("Training worker shut down")
|
|
|
|
|
|
def main():
|
|
"""Entry point for running as a subprocess."""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
|
|
datefmt='%H:%M:%S',
|
|
)
|
|
|
|
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
|
|
worker = TrainingWorker(zmq_addr)
|
|
worker.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|