From 039473d31f49024c341f8d03e92a80112a3a4bdd Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Thu, 16 Apr 2026 00:51:58 -0400 Subject: [PATCH] training: persist Apollo optimizer state across /train calls Optimizer state (momentum, variance estimates) now persists between training sessions: - Saved to /tmp/apollo_optimizer_state.pt during checkpoint sync - Restored on next /train call if available - Preserves training continuity for incremental learning Previously each /train call started with fresh optimizer state, losing accumulated gradient history. Co-Authored-By: Proof of Concept --- training/DESIGN.md | 5 ++- training/apollo_plugin/train_router.py | 62 ++++++++++++++++++++------ 2 files changed, 51 insertions(+), 16 deletions(-) diff --git a/training/DESIGN.md b/training/DESIGN.md index 00ca499..5b7fe30 100644 --- a/training/DESIGN.md +++ b/training/DESIGN.md @@ -215,6 +215,7 @@ a few hundred MB. | File | Purpose | |------|---------| | `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. | +| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. | | `/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. | ### Moria (client) @@ -224,11 +225,11 @@ a few hundred MB. | `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. | | `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. | -### In-memory (not persisted) +### In-memory | State | Location | Notes | |-------|----------|-------| -| Apollo optimizer state | train_router._model | Created fresh each /train call. ~10GB for rank-256. Not persisted between requests. | +| Apollo optimizer | train_router._optimizer | ~10GB for rank-256. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. | | HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. | ## Hyperparameters diff --git a/training/apollo_plugin/train_router.py b/training/apollo_plugin/train_router.py index 6fa4883..4857162 100644 --- a/training/apollo_plugin/train_router.py +++ b/training/apollo_plugin/train_router.py @@ -39,6 +39,9 @@ class TrainResponse(BaseModel): _model: nn.Module | None = None _model_path: str | None = None _initialized: bool = False +_optimizer: Any = None # Persisted Apollo optimizer + +OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" def _load_training_model() -> nn.Module: @@ -134,18 +137,14 @@ async def handle_train(request: TrainRequest, raw_request: Request): ) -async def run_training( - model: nn.Module, - samples: list[dict[str, Any]], - config: dict[str, Any], -) -> list[float]: - """Run Apollo training on the given samples. - - Each sample has: - context_ids: token IDs for frozen context (no gradients) - continuation_ids: token IDs for the decision we're training on - """ +def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]): + """Get existing optimizer or create new one. Persists state between calls.""" + global _optimizer from .optimizer import Apollo + import os + + if _optimizer is not None: + return _optimizer # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] @@ -165,8 +164,8 @@ async def run_training( if not groups: raise ValueError("No trainable parameters found") - # Apollo settings from request config - optimizer = Apollo( + # Create optimizer + _optimizer = Apollo( groups, lr=config.get('lr', 1e-5), rank=config.get('rank', 256), @@ -179,9 +178,42 @@ async def run_training( norm_growth_limit=config.get('norm_growth_limit', 1.01), ) + # Restore state if exists + if os.path.exists(OPTIMIZER_STATE_PATH): + try: + state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) + _optimizer.load_state_dict(state) + logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}") + except Exception as e: + logger.warning(f"[apollo] Could not restore optimizer state: {e}") + logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " - f"state={optimizer.state_size_bytes()/1e6:.1f}MB") + f"state={_optimizer.state_size_bytes()/1e6:.1f}MB") + + return _optimizer + + +def _save_optimizer_state(): + """Save optimizer state for persistence between /train calls.""" + global _optimizer + if _optimizer is not None: + torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH) + logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}") + + +async def run_training( + model: nn.Module, + samples: list[dict[str, Any]], + config: dict[str, Any], +) -> list[float]: + """Run Apollo training on the given samples. + + Each sample has: + context_ids: token IDs for frozen context (no gradients) + continuation_ids: token IDs for the decision we're training on + """ + optimizer = _get_or_create_optimizer(model, config) loss_history = [] @@ -250,6 +282,8 @@ def schedule_checkpoint_sync(): if _model_path: from .checkpoint_sync import checkpoint_sync logger.info("[apollo] Starting checkpoint sync...") + # Save optimizer state alongside model weights + _save_optimizer_state() result = checkpoint_sync(_model_path) logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB") except Exception as e: