training: integrate /train into vLLM process (no separate daemon)

Remove standalone worker.py daemon. Training now runs inside vLLM:

- train_router.py: FastAPI router patched into vLLM's build_app()
- /train served on same port as /completions, /score
- Lazy-loads HF model with vLLM weight views on first request
- HOGWILD training: no pause, weights updated in-place

The previous architecture had a separate daemon on port 8080 that
communicated with vLLM via pause/resume endpoints. This was wrong -
training should run in-process, sharing GPU memory directly.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 00:48:05 -04:00
parent 2f08149fab
commit 7e7e9a4b69
6 changed files with 320 additions and 542 deletions

View file

@ -22,25 +22,29 @@ The training signal comes from two sources:
│ │
│ ┌──────────────────────────────────────────────┐ │
│ │ Model Weights (54GB, bf16) │ │
│ │ Shared via CUDA IPC │ │
│ │ Shared: vLLM inference + HF training │ │
│ └──────────────┬──────────────┬────────────────┘ │
│ │ │ │
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
│ │ vLLM (inference)│ │ Apollo (training) │ │
│ │ vLLM (inference)│ │ HF model (training) │ │
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
│ │ Serves requests │ │ Optimizer state ~10GB │ │
│ │ Never paused │ │ Activations ~10GB │ │
│ │ /completions │ │ Optimizer state ~10GB │ │
│ │ /score │ │ Views into vLLM weights │ │
│ │ /train ────────┼──┼─► Apollo optimizer │ │
│ └─────────────────┘ └────────────────────────┘ │
└─────────────────────────────────────────────────────┘
Moria B200
Single vLLM process serves everything
No separate daemon - /train is a vLLM route
Moria B200 (vLLM)
┌──────────────────┐ ┌──────────────────┐
│ Training signal │ HTTP │ Apollo worker │
│ agent │──────────>│ daemon │
│ │ │
│ Dream loop │ │ Checkpoint sync
│ (generates │ │ (mmap + diff,
│ scenarios) │ │ every 10 min)
│ Training signal │ HTTP │ /completions
│ agent │──────────>│ /score
│ │ │ /train
│ Dream loop │ │
│ (generates │ │ Checkpoint sync
│ scenarios) │ │ (10 min batched)
└──────────────────┘ └──────────────────┘
```
@ -220,34 +224,30 @@ a few hundred MB.
## Components
### Built ✓
- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256)
- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking)
- `optimizer.py` — Apollo optimizer (configurable rank, default 256)
- `train_router.py` — /train endpoint, runs in vLLM process
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
- `training_example.py` — tokenization with chat template
- `vllm_export_hook.py` — source patch for IPC handle export
- `checkpoint/` — Rust tool for mmap + diff checkpoint sync
- `export_hook.py` — vLLM plugin hook for IPC handle export
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
### To build
- **Dream loop → training bridge**: connect dream output to Apollo
- **Dream loop → training bridge**: connect dream output to /train
- **Training-signal agent**: flags moments in conversation logs
- **Instruction stripping**: remove scaffolding from training examples
- **Quality monitoring**: track model capability over time
- **HF model forward pass integration**: wire into apollo_worker
## Files
```
training/
DESIGN.md — this document
apollo_mini.py — Apollo optimizer
apollo_worker.py — HTTP training daemon
weight_mapping.py — vLLM ↔ HF weight views
training_example.py — tokenization helpers
export_weights.py — standalone weight export (unused)
vllm_export_hook.py — vLLM source patch for IPC export
start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch)
train.py — standalone training script (alternative)
checkpoint/
Cargo.toml — Rust checkpoint tool
src/main.rs — mmap + diff sync
DESIGN.md — this document
pyproject.toml — package config, vLLM plugin entry point
apollo_plugin/
__init__.py — plugin registration
export_hook.py — patches vLLM to export IPC handles
train_router.py — /train endpoint (FastAPI router)
optimizer.py — Apollo optimizer
weight_mapping.py — vLLM ↔ HF weight views
checkpoint_sync.py — mmap + diff sync to safetensors
steering.py — steering vector extraction (experimental)
```

View file

@ -1,8 +1,8 @@
"""Apollo training plugin for vLLM.
Enables continuous fine-tuning alongside live inference by:
1. Exporting CUDA IPC handles for weight sharing
2. Providing a training worker daemon (/train endpoint)
1. Exporting CUDA IPC handles for weight sharing (export_hook)
2. Adding /train endpoint to vLLM's HTTP server (train_router)
3. Block-level checkpoint sync to safetensors files
Install: pip install -e /path/to/training
@ -10,8 +10,10 @@ Then vLLM auto-loads via entry point.
"""
from .export_hook import _patch_model_runner
from .train_router import _patch_api_server
def register():
"""Called by vLLM's plugin loader on startup."""
_patch_model_runner()
_patch_api_server()

View file

@ -59,6 +59,10 @@ def _patch_model_runner():
result = original_load(self, *args, **kwargs)
try:
export_model_weights(self.model_runner.model)
# Set model path for training router
model_path = self.vllm_config.model_config.model
from .train_router import set_model_path
set_model_path(model_path)
except Exception as e:
print(f"[apollo] Failed to export weights: {e}")
return result

View file

@ -0,0 +1,282 @@
"""Training endpoint for vLLM - runs Apollo training in-process.
Patches vLLM's build_app() to add /train route. Training runs HOGWILD
style - no pause needed, weights updated in-place while inference continues.
"""
import logging
from datetime import datetime
from typing import Any
import torch
import torch.nn as nn
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter()
class TrainingSample(BaseModel):
context_ids: list[int]
continuation_ids: list[int]
class TrainRequest(BaseModel):
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
class TrainResponse(BaseModel):
job_id: str
status: str
training_samples: int
loss_history: list[float]
# Global reference to HF model with vLLM weight views
_model: nn.Module | None = None
_model_path: str | None = None
_initialized: bool = False
def _load_training_model() -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Uses CUDA IPC handles exported by export_hook to create an HF model
whose parameters share GPU memory with vLLM's model.
"""
from .weight_mapping import load_hf_model_with_vllm_weights
from .export_hook import HANDLE_PATH
handles = torch.load(HANDLE_PATH, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, _model_path)
model.train()
return model
def _ensure_initialized():
"""Lazy-initialize the training model on first /train request."""
global _model, _initialized
if _initialized:
return
if _model_path is None:
raise RuntimeError("Model path not set - export_hook may not have run")
logger.info("[apollo] Loading HF model with vLLM weight views...")
_model = _load_training_model()
_initialized = True
logger.info("[apollo] Training model ready")
def set_model_path(path: str):
"""Set model path for training. Called by export_hook after model load."""
global _model_path
_model_path = path
logger.info(f"[apollo] Model path set: {path}")
@router.post("/train")
async def handle_train(request: TrainRequest, raw_request: Request):
"""Handle training request - runs Apollo training on provided samples."""
global _model
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
training_data = request.training_data
samples = training_data.get("samples", [])
config = training_data.get("config", {})
if not samples:
return JSONResponse(
content={"error": "No training samples provided"},
status_code=400,
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"[apollo] Starting training job {job_id} with {len(samples)} samples")
# Run training
loss_history = await run_training(_model, samples, config)
logger.info(f"[apollo] Training job {job_id} completed, final loss: {loss_history[-1]:.4f}")
# Schedule checkpoint sync (batched, 10 min delay)
schedule_checkpoint_sync()
return JSONResponse(content={
"job_id": job_id,
"status": "completed",
"training_samples": len(samples),
"loss_history": loss_history,
})
except Exception as e:
logger.exception(f"[apollo] Training failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
async def run_training(
model: nn.Module,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
from .optimizer import Apollo
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 256:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
# Apollo settings from request config
optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', 256),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"[apollo] Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
return loss_history
# Checkpoint sync scheduling
_checkpoint_task = None
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
def schedule_checkpoint_sync():
"""Schedule checkpoint sync after delay (batched)."""
global _checkpoint_task
import asyncio
if _checkpoint_task is not None:
# Already scheduled
return
async def do_sync():
global _checkpoint_task
try:
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
if _model_path:
from .checkpoint_sync import checkpoint_sync
logger.info("[apollo] Starting checkpoint sync...")
result = checkpoint_sync(_model_path)
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
except Exception as e:
logger.error(f"[apollo] Checkpoint sync failed: {e}")
finally:
_checkpoint_task = None
_checkpoint_task = asyncio.create_task(do_sync())
logger.info(f"[apollo] Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS//60} min")
def attach_router(app: FastAPI):
"""Attach training router to FastAPI app."""
app.include_router(router)
logger.info("[apollo] Training router attached")
def _patch_api_server():
"""Patch vLLM's build_app to include our training router."""
from vllm.entrypoints.openai import api_server
original_build_app = api_server.build_app
def patched_build_app(*args, **kwargs):
app = original_build_app(*args, **kwargs)
attach_router(app)
return app
api_server.build_app = patched_build_app
logger.info("[apollo] API server patched for /train endpoint")

View file

@ -1,509 +0,0 @@
#!/usr/bin/env python3
"""
Apollo Mini Training Daemon
This daemon:
1. Listens over HTTPS for training requests from poc-agent
2. Pauses vLLM inference
3. Runs APOLLO-Mini training with torch.enable_grad()
4. Saves checkpoints and training metadata
5. Resumes vLLM inference
Communication protocol:
- POST /train: Start a training job
- GET /status/{job_id}: Check training status
- GET /checkpoints: List available checkpoints
"""
import asyncio
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, List
from enum import Enum
import torch
import torch.nn as nn
from aiohttp import web
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('apollo_worker')
class TrainingStatus(Enum):
PENDING = "pending"
PAUSING_VLLM = "pausing_vllm"
TRAINING = "training"
SAVING_CHECKPOINT = "saving_checkpoint"
RESUMING_VLLM = "resuming_vllm"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TrainingJob:
job_id: str
status: TrainingStatus
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
model_path: Optional[str] = None
checkpoint_path: Optional[str] = None
training_samples: int = 0
loss_history: List[float] = field(default_factory=list)
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'job_id': self.job_id,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'model_path': self.model_path,
'checkpoint_path': self.checkpoint_path,
'training_samples': self.training_samples,
'loss_history': self.loss_history,
'error': self.error,
}
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
class ApolloWorker:
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
self.config = self._load_config(config_path)
self.jobs: Dict[str, TrainingJob] = {}
self.vllm_paused = False
self.app = web.Application()
self._setup_routes()
self._checkpoint_timer: Optional[asyncio.Task] = None
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from file or use defaults."""
default_config = {
'host': '0.0.0.0',
'port': 8080,
'vllm_socket': '/tmp/vllm_control.sock',
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
'max_training_samples': 100,
'learning_rate': 1e-5,
'batch_size': 1,
}
if os.path.exists(config_path):
with open(config_path, 'r') as f:
user_config = json.load(f)
default_config.update(user_config)
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
return default_config
def _setup_routes(self):
"""Setup HTTP routes."""
self.app.router.add_post('/train', self.handle_train_request)
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
self.app.router.add_get('/health', self.handle_health_check)
async def handle_health_check(self, request: web.Request) -> web.Response:
"""Health check endpoint."""
return web.json_response({
'status': 'healthy',
'vllm_paused': self.vllm_paused,
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
})
async def handle_train_request(self, request: web.Request) -> web.Response:
"""Handle training request from poc-agent."""
try:
data = await request.json()
# Validate required fields
if 'training_data' not in data:
return web.json_response(
{'error': 'Missing training_data field'},
status=400
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
job = TrainingJob(
job_id=job_id,
status=TrainingStatus.PENDING,
created_at=datetime.now(),
model_path=self.config['model_path']
)
self.jobs[job_id] = job
# Start training in background
asyncio.create_task(self.execute_training(job, data))
return web.json_response({
'job_id': job_id,
'status': 'accepted',
'message': 'Training job started'
})
except Exception as e:
logger.error(f"Error handling train request: {e}")
return web.json_response(
{'error': str(e)},
status=500
)
async def handle_status_request(self, request: web.Request) -> web.Response:
"""Get training job status."""
job_id = request.match_info['job_id']
if job_id not in self.jobs:
return web.json_response(
{'error': 'Job not found'},
status=404
)
job = self.jobs[job_id]
return web.json_response(job.to_dict())
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
"""List available checkpoints."""
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoints = []
if checkpoint_dir.exists():
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
checkpoints.append({
'filename': checkpoint_file.name,
'path': str(checkpoint_file),
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
'size': checkpoint_file.stat().st_size
})
return web.json_response({'checkpoints': checkpoints})
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
"""Execute the training pipeline."""
try:
logger.info(f"Starting training job {job.job_id}")
job.started_at = datetime.now()
# Step 1: Pause vLLM
job.status = TrainingStatus.PAUSING_VLLM
logger.info("Pausing vLLM...")
await self.pause_vllm()
self.vllm_paused = True
# Step 2: Load model and prepare for training
job.status = TrainingStatus.TRAINING
logger.info("Loading model and preparing for training...")
# Load model (this would be the actual Qwen3.5-27B model)
# For now, we'll use a placeholder
model = await self.load_model_for_training()
# Step 3: Run APOLLO-Mini training
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
# Extract training samples
samples = training_data['samples']
job.training_samples = len(samples)
# Run training loop
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
job.loss_history = loss_history
# Step 4: Save checkpoint
job.status = TrainingStatus.SAVING_CHECKPOINT
logger.info("Saving checkpoint...")
checkpoint_path = await self.save_checkpoint(model, job)
job.checkpoint_path = checkpoint_path
# Step 5: Resume vLLM
job.status = TrainingStatus.RESUMING_VLLM
logger.info("Resuming vLLM...")
await self.resume_vllm()
self.vllm_paused = False
# Mark job as completed
job.status = TrainingStatus.COMPLETED
job.completed_at = datetime.now()
logger.info(f"Training job {job.job_id} completed successfully")
# Schedule checkpoint sync (batched — won't duplicate if timer pending)
self.schedule_checkpoint_sync()
except Exception as e:
logger.error(f"Training job {job.job_id} failed: {e}")
job.status = TrainingStatus.FAILED
job.error = str(e)
job.completed_at = datetime.now()
# Try to resume vLLM if it was paused
if self.vllm_paused:
try:
await self.resume_vllm()
self.vllm_paused = False
except Exception as resume_error:
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
async def pause_vllm(self):
"""Pause vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/pause_generation",
json={"mode": "keep", "clear_cache": False},
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM paused")
except Exception as e:
logger.warning(f"Failed to pause vLLM: {e}")
async def resume_vllm(self):
"""Resume vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/resume_generation",
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM resumed")
except Exception as e:
logger.warning(f"Failed to resume vLLM: {e}")
def schedule_checkpoint_sync(self):
"""Schedule a checkpoint sync in 10 minutes, if not already scheduled.
This batches multiple training runs into a single sync the timer
resets only when no timer is pending.
"""
if self._checkpoint_timer is not None:
logger.debug("Checkpoint sync already scheduled, skipping")
return
self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay())
logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes")
async def _checkpoint_sync_after_delay(self):
"""Wait then sync — the actual timer task."""
try:
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
await self._do_checkpoint_sync()
except asyncio.CancelledError:
logger.debug("Checkpoint sync cancelled")
finally:
self._checkpoint_timer = None
async def _do_checkpoint_sync(self):
"""Execute the checkpoint sync."""
try:
from apollo_plugin.checkpoint_sync import checkpoint_sync
logger.info("Starting checkpoint sync...")
result = checkpoint_sync(
self.config['model_path'],
self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'),
)
changed_mb = result['total_changed'] / 1e6
logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written")
except Exception as e:
logger.error(f"Checkpoint sync failed: {e}")
async def load_model_for_training(self) -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
views (narrowing merged weights into separate q/k/v/z etc.), and
constructs the HF model around those views. No weight copying
all parameters share vLLM's GPU memory.
"""
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
model_path = self.config['model_path']
# Import vLLM weights via CUDA IPC
logger.info(f"Importing vLLM weights from {handle_path}")
handles = torch.load(handle_path, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
logger.info(f"Imported {len(vllm_params)} parameters")
# Map vLLM merged layout → HF separate layout (views, no copies)
from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
logger.info("HF model constructed with vLLM weight views")
return model
async def run_apollo_training(self, model: nn.Module,
samples: List[Dict[str, Any]],
config: Dict[str, Any]) -> List[float]:
"""Run Apollo-Mini training on conversation decision points.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
from apollo_plugin.optimizer import Apollo
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
# Apollo settings from request config, falling back to server defaults
optimizer = Apollo(
groups,
lr=config.get('lr', self.config.get('learning_rate', 1e-5)),
rank=config.get('rank', 256),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'), # None = auto
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
rank = config.get('rank', 256)
lr = config.get('lr', self.config.get('learning_rate', 1e-5))
logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
# context_ids: frozen (forward only, no gradients)
# continuation_ids: the decision we're training on
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
# Forward through context (no gradients)
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits # [1, cont_len, vocab]
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
logger.info(f"Training done: {len(samples)} examples, "
f"final loss={loss_history[-1]:.4f}")
return loss_history
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
"""Save model checkpoint in HuggingFace safetensors format."""
from safetensors.torch import save_file
import shutil
checkpoint_dir = Path(self.config['checkpoint_dir'])
date_str = datetime.now().strftime('%Y-%m-%d')
out_dir = checkpoint_dir / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights
tensors = {name: p.data.contiguous().cpu()
for name, p in model.named_parameters()}
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
# Copy config files
config_dir = Path(self.config['model_path'])
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Save training metadata
meta = {
'job_id': job.job_id,
'training_samples': job.training_samples,
'loss_history': job.loss_history,
'timestamp': datetime.now().isoformat(),
}
with open(out_dir / 'training-meta.json', 'w') as f:
json.dump(meta, f, indent=2)
# Update latest symlink
latest = checkpoint_dir / 'latest'
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
size_gb = save_path.stat().st_size / 1e9
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
return str(out_dir)
async def run(self):
"""Run the daemon."""
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, self.config['host'], self.config['port'])
await site.start()
logger.info("Apollo Worker is running")
# Keep running
while True:
await asyncio.sleep(3600) # Sleep for an hour
def main():
worker = ApolloWorker()
asyncio.run(worker.run())
if __name__ == '__main__':
main()

View file

@ -20,7 +20,6 @@ dev = ["pytest"]
apollo = "apollo_plugin:register"
[project.scripts]
apollo-worker = "apollo_plugin.worker:main"
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
[tool.setuptools.packages.find]