training: move to dedicated subprocess with ZMQ communication
- Add training_worker.py: long-lived subprocess that handles GPU training
work, owns HF model wrapper (views into vLLM GPU memory), Apollo
optimizer, and checkpoint sync
- train_router.py: now forwards /train requests via async ZMQ instead of
running training in-process. Adds /checkpoint and /train/status endpoints
- export_hook.py: store model_path in __metadata__ so training worker can
find it without cross-process communication
- This fixes two bugs:
1. Process boundary issue - model_path was set in worker process but
needed in API server process
2. Blocking event loop - training blocked vLLM's async event loop
Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
68a2df2185
commit
2c6a5c0f4a
6 changed files with 503 additions and 233 deletions
323
training/apollo_plugin/training_worker.py
Normal file
323
training/apollo_plugin/training_worker.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Training subprocess - handles Apollo training and checkpoint sync.
|
||||
|
||||
Long-lived process that:
|
||||
1. Loads IPC handles from vLLM's exported weights
|
||||
2. Creates HF model with views into vLLM's GPU memory
|
||||
3. Handles training requests via ZMQ
|
||||
4. Handles checkpoint sync requests
|
||||
5. Persists Apollo optimizer state between calls
|
||||
|
||||
Communicates with the API server's /train endpoint via ZMQ REP socket.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Handle running as script vs module
|
||||
if __name__ == '__main__' and __package__ is None:
|
||||
# Running as script - add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
__package__ = 'apollo_plugin'
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import zmq
|
||||
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
from .optimizer import Apollo
|
||||
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RANK = 64
|
||||
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||
|
||||
|
||||
class TrainingWorker:
|
||||
"""Long-lived training worker process."""
|
||||
|
||||
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
|
||||
self.zmq_addr = zmq_addr
|
||||
self.model: nn.Module | None = None
|
||||
self.optimizer: Apollo | None = None
|
||||
self.model_path: str | None = None
|
||||
self._running = True
|
||||
|
||||
def _create_model_wrapper(self) -> nn.Module:
|
||||
"""Create HF model wrapper with views into vLLM's GPU memory."""
|
||||
if not os.path.exists(HANDLE_PATH):
|
||||
raise FileNotFoundError(
|
||||
f"Weight handles not found: {HANDLE_PATH}. "
|
||||
"Is vLLM running with the export hook?"
|
||||
)
|
||||
|
||||
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||
|
||||
# Extract metadata
|
||||
metadata = handles.pop('__metadata__', {})
|
||||
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
|
||||
if not self.model_path:
|
||||
raise ValueError(
|
||||
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
|
||||
)
|
||||
|
||||
# Reconstruct tensors from IPC handles
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
|
||||
model.train()
|
||||
return model
|
||||
|
||||
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
|
||||
"""Get existing optimizer or create new one."""
|
||||
if self.optimizer is not None:
|
||||
return self.optimizer
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in self.model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
self.optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', DEFAULT_RANK),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'),
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
# Restore state if exists
|
||||
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||
try:
|
||||
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||
self.optimizer.load_state_dict(state)
|
||||
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not restore optimizer state: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def _save_optimizer_state(self):
|
||||
"""Save optimizer state for persistence."""
|
||||
if self.optimizer is not None:
|
||||
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||
|
||||
def _run_training(
|
||||
self,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples."""
|
||||
optimizer = self._get_or_create_optimizer(config)
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
outputs = self.model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = self.model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(
|
||||
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
|
||||
)
|
||||
|
||||
return loss_history
|
||||
|
||||
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a training request."""
|
||||
samples = request.get('samples', [])
|
||||
config = request.get('config', {})
|
||||
|
||||
if not samples:
|
||||
return {'error': 'No training samples provided'}
|
||||
|
||||
try:
|
||||
loss_history = self._run_training(samples, config)
|
||||
return {
|
||||
'status': 'completed',
|
||||
'training_samples': len(samples),
|
||||
'loss_history': loss_history,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Training failed: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a checkpoint sync request."""
|
||||
if not self.model_path:
|
||||
return {'error': 'Model path not set'}
|
||||
|
||||
try:
|
||||
self._save_optimizer_state()
|
||||
result = checkpoint_sync(self.model_path)
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total_changed': result['total_changed'],
|
||||
'files_changed': result['files_changed'],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Checkpoint sync failed: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a status request."""
|
||||
return {
|
||||
'status': 'ready',
|
||||
'model_loaded': self.model is not None,
|
||||
'optimizer_loaded': self.optimizer is not None,
|
||||
'model_path': self.model_path,
|
||||
'optimizer_state_mb': (
|
||||
self.optimizer.state_size_bytes() / 1e6
|
||||
if self.optimizer else 0
|
||||
),
|
||||
}
|
||||
|
||||
def run(self):
|
||||
"""Main loop - listen for requests and handle them."""
|
||||
# Set up signal handlers
|
||||
def handle_signal(signum, frame):
|
||||
logger.info(f"Received signal {signum}, shutting down...")
|
||||
self._running = False
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
|
||||
# Set up ZMQ socket first so API server can connect
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(self.zmq_addr)
|
||||
logger.info(f"Training worker listening on {self.zmq_addr}")
|
||||
|
||||
# Create HF model wrapper with views into vLLM's GPU memory
|
||||
logger.info("Connecting to vLLM weights via IPC handles...")
|
||||
try:
|
||||
self.model = self._create_model_wrapper()
|
||||
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to vLLM weights: {e}")
|
||||
logger.info("Will retry on first training request")
|
||||
|
||||
# Set socket timeout so we can check _running flag
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
message = socket.recv_json()
|
||||
except zmq.Again:
|
||||
# Timeout, check _running and continue
|
||||
continue
|
||||
|
||||
request_type = message.get('type', 'train')
|
||||
logger.info(f"Received {request_type} request")
|
||||
|
||||
# Ensure model is loaded
|
||||
if self.model is None and request_type != 'status':
|
||||
try:
|
||||
self.model = self._create_model_wrapper()
|
||||
except Exception as e:
|
||||
socket.send_json({'error': f'Model not loaded: {e}'})
|
||||
continue
|
||||
|
||||
# Dispatch request
|
||||
if request_type == 'train':
|
||||
response = self._handle_train(message)
|
||||
elif request_type == 'checkpoint':
|
||||
response = self._handle_checkpoint(message)
|
||||
elif request_type == 'status':
|
||||
response = self._handle_status(message)
|
||||
else:
|
||||
response = {'error': f'Unknown request type: {request_type}'}
|
||||
|
||||
socket.send_json(response)
|
||||
|
||||
# Cleanup
|
||||
logger.info("Saving optimizer state before shutdown...")
|
||||
self._save_optimizer_state()
|
||||
socket.close()
|
||||
context.term()
|
||||
logger.info("Training worker shut down")
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for running as a subprocess."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
)
|
||||
|
||||
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
|
||||
worker = TrainingWorker(zmq_addr)
|
||||
worker.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue