consciousness/training/apollo_plugin/training_worker.py

324 lines
11 KiB
Python
Raw Permalink Normal View History

"""Training subprocess - handles Apollo training and checkpoint sync.
Long-lived process that:
1. Loads IPC handles from vLLM's exported weights
2. Creates HF model with views into vLLM's GPU memory
3. Handles training requests via ZMQ
4. Handles checkpoint sync requests
5. Persists Apollo optimizer state between calls
Communicates with the API server's /train endpoint via ZMQ REP socket.
"""
import logging
import os
import signal
import sys
from pathlib import Path
from typing import Any
# Handle running as script vs module
if __name__ == '__main__' and __package__ is None:
# Running as script - add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
__package__ = 'apollo_plugin'
import torch
import torch.nn as nn
import zmq
from .checkpoint_sync import checkpoint_sync
from .optimizer import Apollo
from .weight_mapping import load_hf_model_with_vllm_weights
logger = logging.getLogger(__name__)
DEFAULT_RANK = 64
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
class TrainingWorker:
"""Long-lived training worker process."""
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
self.zmq_addr = zmq_addr
self.model: nn.Module | None = None
self.optimizer: Apollo | None = None
self.model_path: str | None = None
self._running = True
def _create_model_wrapper(self) -> nn.Module:
"""Create HF model wrapper with views into vLLM's GPU memory."""
if not os.path.exists(HANDLE_PATH):
raise FileNotFoundError(
f"Weight handles not found: {HANDLE_PATH}. "
"Is vLLM running with the export hook?"
)
handles = torch.load(HANDLE_PATH, weights_only=False)
# Extract metadata
metadata = handles.pop('__metadata__', {})
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
if not self.model_path:
raise ValueError(
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
)
# Reconstruct tensors from IPC handles
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
model.train()
return model
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
"""Get existing optimizer or create new one."""
if self.optimizer is not None:
return self.optimizer
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
apollo_params, standard_params = [], []
for p in self.model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
self.optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', DEFAULT_RANK),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
# Restore state if exists
if os.path.exists(OPTIMIZER_STATE_PATH):
try:
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
self.optimizer.load_state_dict(state)
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
except Exception as e:
logger.warning(f"Could not restore optimizer state: {e}")
logger.info(
f"Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
)
return self.optimizer
def _save_optimizer_state(self):
"""Save optimizer state for persistence."""
if self.optimizer is not None:
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
def _run_training(
self,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples."""
optimizer = self._get_or_create_optimizer(config)
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = self.model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = self.model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
)
return loss_history
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a training request."""
samples = request.get('samples', [])
config = request.get('config', {})
if not samples:
return {'error': 'No training samples provided'}
try:
loss_history = self._run_training(samples, config)
return {
'status': 'completed',
'training_samples': len(samples),
'loss_history': loss_history,
}
except Exception as e:
logger.exception(f"Training failed: {e}")
return {'error': str(e)}
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a checkpoint sync request."""
if not self.model_path:
return {'error': 'Model path not set'}
try:
self._save_optimizer_state()
result = checkpoint_sync(self.model_path)
return {
'status': 'completed',
'total_changed': result['total_changed'],
'files_changed': result['files_changed'],
}
except Exception as e:
logger.exception(f"Checkpoint sync failed: {e}")
return {'error': str(e)}
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a status request."""
return {
'status': 'ready',
'model_loaded': self.model is not None,
'optimizer_loaded': self.optimizer is not None,
'model_path': self.model_path,
'optimizer_state_mb': (
self.optimizer.state_size_bytes() / 1e6
if self.optimizer else 0
),
}
def run(self):
"""Main loop - listen for requests and handle them."""
# Set up signal handlers
def handle_signal(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
self._running = False
signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGINT, handle_signal)
# Set up ZMQ socket first so API server can connect
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(self.zmq_addr)
logger.info(f"Training worker listening on {self.zmq_addr}")
# Create HF model wrapper with views into vLLM's GPU memory
logger.info("Connecting to vLLM weights via IPC handles...")
try:
self.model = self._create_model_wrapper()
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
except Exception as e:
logger.error(f"Failed to connect to vLLM weights: {e}")
logger.info("Will retry on first training request")
# Set socket timeout so we can check _running flag
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
while self._running:
try:
message = socket.recv_json()
except zmq.Again:
# Timeout, check _running and continue
continue
request_type = message.get('type', 'train')
logger.info(f"Received {request_type} request")
# Ensure model is loaded
if self.model is None and request_type != 'status':
try:
self.model = self._create_model_wrapper()
except Exception as e:
socket.send_json({'error': f'Model not loaded: {e}'})
continue
# Dispatch request
if request_type == 'train':
response = self._handle_train(message)
elif request_type == 'checkpoint':
response = self._handle_checkpoint(message)
elif request_type == 'status':
response = self._handle_status(message)
else:
response = {'error': f'Unknown request type: {request_type}'}
socket.send_json(response)
# Cleanup
logger.info("Saving optimizer state before shutdown...")
self._save_optimizer_state()
socket.close()
context.term()
logger.info("Training worker shut down")
def main():
"""Entry point for running as a subprocess."""
logging.basicConfig(
level=logging.INFO,
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
)
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
worker = TrainingWorker(zmq_addr)
worker.run()
if __name__ == '__main__':
main()