diff --git a/training/apollo_mini.py b/training/apollo_mini.py index f86d1bb..61c3e44 100644 --- a/training/apollo_mini.py +++ b/training/apollo_mini.py @@ -8,9 +8,9 @@ For each parameter tensor, maintains: - projected second moment (v): same shape - random projection matrix (regenerated from seed) -rank=1 is Apollo-Mini (~50MB state for 27B model). -rank=2-16 costs proportionally more memory but is still negligible. -Compute cost of projection is <1% of forward+backward. +Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% +compute overhead vs forward+backward. Captures gradient structure +across 100+ behavioral training examples per batch. """ import torch @@ -35,7 +35,7 @@ class Apollo(Optimizer): scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise """ - def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8, + def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, warmup_steps=0, scale_type='tensor'): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, weight_decay=weight_decay,