apollo: default rank 256 — 0.25% compute cost, captures gradient structure across 100+ examples
This commit is contained in:
parent
e1cd4fb0ab
commit
8e7b4a22db
1 changed files with 4 additions and 4 deletions
|
|
@ -8,9 +8,9 @@ For each parameter tensor, maintains:
|
||||||
- projected second moment (v): same shape
|
- projected second moment (v): same shape
|
||||||
- random projection matrix (regenerated from seed)
|
- random projection matrix (regenerated from seed)
|
||||||
|
|
||||||
rank=1 is Apollo-Mini (~50MB state for 27B model).
|
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
||||||
rank=2-16 costs proportionally more memory but is still negligible.
|
compute overhead vs forward+backward. Captures gradient structure
|
||||||
Compute cost of projection is <1% of forward+backward.
|
across 100+ behavioral training examples per batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -35,7 +35,7 @@ class Apollo(Optimizer):
|
||||||
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
|
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8,
|
||||||
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
|
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
|
||||||
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue