training: persist Apollo optimizer state across /train calls

Optimizer state (momentum, variance estimates) now persists between
training sessions:

- Saved to /tmp/apollo_optimizer_state.pt during checkpoint sync
- Restored on next /train call if available
- Preserves training continuity for incremental learning

Previously each /train call started with fresh optimizer state,
losing accumulated gradient history.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 00:51:58 -04:00
parent 78fa4b639f
commit 039473d31f
2 changed files with 51 additions and 16 deletions

View file

@ -215,6 +215,7 @@ a few hundred MB.
| File | Purpose |
|------|---------|
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. |
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. |
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
### Moria (client)
@ -224,11 +225,11 @@ a few hundred MB.
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
### In-memory (not persisted)
### In-memory
| State | Location | Notes |
|-------|----------|-------|
| Apollo optimizer state | train_router._model | Created fresh each /train call. ~10GB for rank-256. Not persisted between requests. |
| Apollo optimizer | train_router._optimizer | ~10GB for rank-256. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. |
| HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. |
## Hyperparameters

View file

@ -39,6 +39,9 @@ class TrainResponse(BaseModel):
_model: nn.Module | None = None
_model_path: str | None = None
_initialized: bool = False
_optimizer: Any = None # Persisted Apollo optimizer
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
def _load_training_model() -> nn.Module:
@ -134,18 +137,14 @@ async def handle_train(request: TrainRequest, raw_request: Request):
)
async def run_training(
model: nn.Module,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
"""Get existing optimizer or create new one. Persists state between calls."""
global _optimizer
from .optimizer import Apollo
import os
if _optimizer is not None:
return _optimizer
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
@ -165,8 +164,8 @@ async def run_training(
if not groups:
raise ValueError("No trainable parameters found")
# Apollo settings from request config
optimizer = Apollo(
# Create optimizer
_optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', 256),
@ -179,9 +178,42 @@ async def run_training(
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
# Restore state if exists
if os.path.exists(OPTIMIZER_STATE_PATH):
try:
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
_optimizer.load_state_dict(state)
logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}")
except Exception as e:
logger.warning(f"[apollo] Could not restore optimizer state: {e}")
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
f"state={_optimizer.state_size_bytes()/1e6:.1f}MB")
return _optimizer
def _save_optimizer_state():
"""Save optimizer state for persistence between /train calls."""
global _optimizer
if _optimizer is not None:
torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH)
logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}")
async def run_training(
model: nn.Module,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
optimizer = _get_or_create_optimizer(model, config)
loss_history = []
@ -250,6 +282,8 @@ def schedule_checkpoint_sync():
if _model_path:
from .checkpoint_sync import checkpoint_sync
logger.info("[apollo] Starting checkpoint sync...")
# Save optimizer state alongside model weights
_save_optimizer_state()
result = checkpoint_sync(_model_path)
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
except Exception as e: