training: use rank 64, define as single constant

- DEFAULT_RANK = 64 in train_router.py
- All references use the constant, not magic numbers
- ~2.5GB optimizer state instead of ~10GB

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 00:53:48 -04:00
parent 039473d31f
commit 68a2df2185
3 changed files with 16 additions and 16 deletions

View file

@ -8,9 +8,9 @@ Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch.
Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead
vs forward+backward. Sufficient for behavioral training with diverse
examples.
Key implementation details from the paper:
- Gradient scale factor α = (n/r) compensates for projection ratio
@ -34,7 +34,7 @@ class Apollo(Optimizer):
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 256)
rank: projection rank (default: 64)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
@ -46,7 +46,7 @@ class Apollo(Optimizer):
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,

View file

@ -42,6 +42,7 @@ _initialized: bool = False
_optimizer: Any = None # Persisted Apollo optimizer
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
DEFAULT_RANK = 64
def _load_training_model() -> nn.Module:
@ -150,7 +151,7 @@ def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 256:
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
apollo_params.append(p)
else:
standard_params.append(p)
@ -168,7 +169,7 @@ def _get_or_create_optimizer(model: nn.Module, config: dict[str, Any]):
_optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', 256),
rank=config.get('rank', DEFAULT_RANK),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),