training: persist Apollo optimizer state across /train calls
Optimizer state (momentum, variance estimates) now persists between training sessions: - Saved to /tmp/apollo_optimizer_state.pt during checkpoint sync - Restored on next /train call if available - Preserves training continuity for incremental learning Previously each /train call started with fresh optimizer state, losing accumulated gradient history. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
78fa4b639f
commit
039473d31f
2 changed files with 51 additions and 16 deletions
|
|
@ -215,6 +215,7 @@ a few hundred MB.
|
|||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by train_router to construct HF model with vLLM weight views. |
|
||||
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync, restored on next /train call. Preserves training continuity across sessions. |
|
||||
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
|
||||
|
||||
### Moria (client)
|
||||
|
|
@ -224,11 +225,11 @@ a few hundred MB.
|
|||
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
|
||||
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
|
||||
|
||||
### In-memory (not persisted)
|
||||
### In-memory
|
||||
|
||||
| State | Location | Notes |
|
||||
|-------|----------|-------|
|
||||
| Apollo optimizer state | train_router._model | Created fresh each /train call. ~10GB for rank-256. Not persisted between requests. |
|
||||
| Apollo optimizer | train_router._optimizer | ~10GB for rank-256. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync. |
|
||||
| HF model with vLLM views | train_router._model | Lazy-loaded on first /train. Parameters point to vLLM's GPU memory. |
|
||||
|
||||
## Hyperparameters
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ class TrainResponse(BaseModel):
|
|||
_model: nn.Module | None = None
|
||||
_model_path: str | None = None
|
||||
_initialized: bool = False
|
||||
_optimizer: Any = None # Persisted Apollo optimizer
|
||||
|
||||
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||
|
||||
|
||||
def _load_training_model() -> nn.Module:
|
||||
|
|
@ -134,18 +137,14 @@ async def handle_train(request: TrainRequest, raw_request: Request):
|
|||
)
|
||||
|
||||
|
||||
async def run_training(
|
||||
model: nn.Module,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
|
||||
"""Get existing optimizer or create new one. Persists state between calls."""
|
||||
global _optimizer
|
||||
from .optimizer import Apollo
|
||||
import os
|
||||
|
||||
if _optimizer is not None:
|
||||
return _optimizer
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
|
|
@ -165,8 +164,8 @@ async def run_training(
|
|||
if not groups:
|
||||
raise ValueError("No trainable parameters found")
|
||||
|
||||
# Apollo settings from request config
|
||||
optimizer = Apollo(
|
||||
# Create optimizer
|
||||
_optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', 1e-5),
|
||||
rank=config.get('rank', 256),
|
||||
|
|
@ -179,9 +178,42 @@ async def run_training(
|
|||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
|
||||
# Restore state if exists
|
||||
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||
try:
|
||||
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||
_optimizer.load_state_dict(state)
|
||||
logger.info(f"[apollo] Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[apollo] Could not restore optimizer state: {e}")
|
||||
|
||||
logger.info(f"[apollo] Optimizer: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
f"state={_optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
return _optimizer
|
||||
|
||||
|
||||
def _save_optimizer_state():
|
||||
"""Save optimizer state for persistence between /train calls."""
|
||||
global _optimizer
|
||||
if _optimizer is not None:
|
||||
torch.save(_optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||
logger.info(f"[apollo] Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||
|
||||
|
||||
async def run_training(
|
||||
model: nn.Module,
|
||||
samples: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
) -> list[float]:
|
||||
"""Run Apollo training on the given samples.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
optimizer = _get_or_create_optimizer(model, config)
|
||||
|
||||
loss_history = []
|
||||
|
||||
|
|
@ -250,6 +282,8 @@ def schedule_checkpoint_sync():
|
|||
if _model_path:
|
||||
from .checkpoint_sync import checkpoint_sync
|
||||
logger.info("[apollo] Starting checkpoint sync...")
|
||||
# Save optimizer state alongside model weights
|
||||
_save_optimizer_state()
|
||||
result = checkpoint_sync(_model_path)
|
||||
logger.info(f"[apollo] Checkpoint sync: {result['total_changed']/1e6:.2f} MB")
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue