apollo: default rank 256 — 0.25% compute cost, captures gradient structure across 100+ examples

This commit is contained in:
ProofOfConcept 2026-03-30 22:16:34 -04:00
parent e1cd4fb0ab
commit 8e7b4a22db

View file

@ -8,9 +8,9 @@ For each parameter tensor, maintains:
- projected second moment (v): same shape - projected second moment (v): same shape
- random projection matrix (regenerated from seed) - random projection matrix (regenerated from seed)
rank=1 is Apollo-Mini (~50MB state for 27B model). Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
rank=2-16 costs proportionally more memory but is still negligible. compute overhead vs forward+backward. Captures gradient structure
Compute cost of projection is <1% of forward+backward. across 100+ behavioral training examples per batch.
""" """
import torch import torch
@ -35,7 +35,7 @@ class Apollo(Optimizer):
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
""" """
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8, def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01, warmup_steps=0, scale_type='tensor'): weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,