2026-03-31 00:54:17 -04:00
|
|
|
|
"""Apollo optimizer — configurable-rank gradient scaling.
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
|
|
|
|
|
|
Performance" (arXiv:2412.05270, MLSys 2025).
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
The core idea: AdamW's per-element learning rate scaling is redundant.
|
|
|
|
|
|
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
|
|
|
|
|
|
these scaling factors using a low-rank auxiliary optimizer state based on
|
|
|
|
|
|
pure random projection.
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
2026-03-30 22:16:34 -04:00
|
|
|
|
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
|
|
|
|
|
compute overhead vs forward+backward. Captures gradient structure
|
|
|
|
|
|
across 100+ behavioral training examples per batch.
|
2026-03-31 00:54:17 -04:00
|
|
|
|
|
|
|
|
|
|
Key implementation details from the paper:
|
|
|
|
|
|
- Gradient scale factor α = √(n/r) compensates for projection ratio
|
|
|
|
|
|
- Norm-growth limiter (γ=1.01) prevents early training instability
|
|
|
|
|
|
- Projection matrix refreshed every T steps (default 200), not every step
|
|
|
|
|
|
- Channel-wise scaling for rank>1, tensor-wise for rank=1
|
2026-03-30 22:02:37 -04:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
import math
|
|
|
|
|
|
|
2026-03-30 22:02:37 -04:00
|
|
|
|
import torch
|
|
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-30 22:06:31 -04:00
|
|
|
|
class Apollo(Optimizer):
|
2026-03-31 00:54:17 -04:00
|
|
|
|
"""Apollo: configurable-rank gradient scaling optimizer.
|
2026-03-30 22:06:31 -04:00
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
|
|
|
|
|
|
rank>1 is full Apollo (channel-wise scaling).
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
params: model parameters
|
|
|
|
|
|
lr: learning rate (default: 1e-4)
|
2026-03-31 00:54:17 -04:00
|
|
|
|
rank: projection rank (default: 256)
|
|
|
|
|
|
betas: Adam momentum coefficients (default: (0.9, 0.999))
|
|
|
|
|
|
eps: numerical stability term (default: 1e-8)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
weight_decay: decoupled weight decay (default: 0.01)
|
2026-03-31 00:54:17 -04:00
|
|
|
|
warmup_steps: linear lr warmup steps (default: 0)
|
|
|
|
|
|
scale: gradient scale factor α. Default None = auto √(n/r).
|
|
|
|
|
|
Paper uses √128 for Apollo-Mini.
|
|
|
|
|
|
proj_refresh: refresh projection matrix every T steps (default: 200)
|
|
|
|
|
|
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
|
|
|
|
|
|
Set to None to disable.
|
2026-03-30 22:02:37 -04:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
|
|
|
|
|
|
eps=1e-8, weight_decay=0.01, warmup_steps=0,
|
|
|
|
|
|
scale=None, proj_refresh=200, norm_growth_limit=1.01):
|
2026-03-30 22:06:31 -04:00
|
|
|
|
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
2026-03-30 22:02:37 -04:00
|
|
|
|
weight_decay=weight_decay,
|
2026-03-30 22:06:31 -04:00
|
|
|
|
warmup_steps=warmup_steps,
|
2026-03-31 00:54:17 -04:00
|
|
|
|
scale=scale,
|
|
|
|
|
|
proj_refresh=proj_refresh,
|
|
|
|
|
|
norm_growth_limit=norm_growth_limit)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
super().__init__(params, defaults)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
|
def step(self, closure=None):
|
|
|
|
|
|
loss = None
|
|
|
|
|
|
if closure is not None:
|
|
|
|
|
|
with torch.enable_grad():
|
|
|
|
|
|
loss = closure()
|
|
|
|
|
|
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
|
|
|
lr = group['lr']
|
|
|
|
|
|
beta1, beta2 = group['betas']
|
|
|
|
|
|
eps = group['eps']
|
|
|
|
|
|
weight_decay = group['weight_decay']
|
2026-03-31 00:54:17 -04:00
|
|
|
|
rank = group['rank']
|
|
|
|
|
|
proj_refresh = group['proj_refresh']
|
|
|
|
|
|
norm_growth_limit = group['norm_growth_limit']
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
|
if p.grad is None:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
grad = p.grad.float()
|
|
|
|
|
|
state = self.state[p]
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize state
|
|
|
|
|
|
if len(state) == 0:
|
|
|
|
|
|
state['step'] = 0
|
2026-03-31 00:54:17 -04:00
|
|
|
|
state['seed'] = id(p) % (2**31)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
2026-03-30 22:06:31 -04:00
|
|
|
|
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
2026-03-31 00:54:17 -04:00
|
|
|
|
# Determine projection dimension (project along smaller dim)
|
|
|
|
|
|
if grad.shape[0] <= grad.shape[1]:
|
|
|
|
|
|
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
|
|
|
|
|
|
state['m'] = grad.shape[0]
|
|
|
|
|
|
state['n'] = grad.shape[1]
|
2026-03-30 22:06:31 -04:00
|
|
|
|
moment_shape = (rank, grad.shape[1])
|
2026-03-31 00:54:17 -04:00
|
|
|
|
else:
|
|
|
|
|
|
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
|
|
|
|
|
|
state['m'] = grad.shape[0]
|
|
|
|
|
|
state['n'] = grad.shape[1]
|
|
|
|
|
|
moment_shape = (grad.shape[0], rank)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
|
|
|
|
|
|
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
state['has_proj'] = True
|
2026-03-31 00:54:17 -04:00
|
|
|
|
state['prev_scaled_norm'] = None
|
|
|
|
|
|
|
|
|
|
|
|
# Auto scale factor: α = √(smaller_dim / rank)
|
|
|
|
|
|
smaller_dim = min(grad.shape)
|
|
|
|
|
|
if group['scale'] is not None:
|
|
|
|
|
|
state['alpha'] = group['scale']
|
|
|
|
|
|
else:
|
|
|
|
|
|
state['alpha'] = math.sqrt(smaller_dim / rank)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
else:
|
2026-03-31 00:54:17 -04:00
|
|
|
|
# 1D or small params: standard Adam
|
2026-03-30 22:02:37 -04:00
|
|
|
|
state['exp_avg'] = torch.zeros_like(grad)
|
|
|
|
|
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
|
|
|
|
|
state['has_proj'] = False
|
|
|
|
|
|
|
|
|
|
|
|
state['step'] += 1
|
2026-03-31 00:54:17 -04:00
|
|
|
|
step = state['step']
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
|
|
|
|
|
# Learning rate warmup
|
2026-03-31 00:54:17 -04:00
|
|
|
|
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
|
|
|
|
|
|
lr_scale = step / group['warmup_steps']
|
2026-03-30 22:02:37 -04:00
|
|
|
|
else:
|
|
|
|
|
|
lr_scale = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
if state['has_proj']:
|
2026-03-31 00:54:17 -04:00
|
|
|
|
alpha = state['alpha']
|
|
|
|
|
|
|
|
|
|
|
|
# Generate projection matrix (refresh every proj_refresh steps)
|
|
|
|
|
|
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
|
|
|
|
|
|
gen = torch.Generator(device=p.device)
|
|
|
|
|
|
gen.manual_seed(state['seed'] + step)
|
|
|
|
|
|
|
|
|
|
|
|
if state['proj_dim'] == 'left':
|
|
|
|
|
|
# P: [rank, m], normalized rows
|
|
|
|
|
|
P = torch.randn(rank, state['m'],
|
|
|
|
|
|
device=p.device, generator=gen)
|
|
|
|
|
|
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
|
|
|
|
|
state['proj_matrix'] = P
|
|
|
|
|
|
else:
|
|
|
|
|
|
# P: [rank, n], normalized rows
|
|
|
|
|
|
P = torch.randn(rank, state['n'],
|
|
|
|
|
|
device=p.device, generator=gen)
|
|
|
|
|
|
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
|
|
|
|
|
state['proj_matrix'] = P
|
|
|
|
|
|
|
|
|
|
|
|
P = state['proj_matrix']
|
|
|
|
|
|
|
|
|
|
|
|
# Project gradient to low-rank space
|
|
|
|
|
|
if state['proj_dim'] == 'left':
|
|
|
|
|
|
proj_grad = P @ grad # [rank, n]
|
2026-03-30 22:02:37 -04:00
|
|
|
|
else:
|
2026-03-31 00:54:17 -04:00
|
|
|
|
proj_grad = grad @ P.t() # [m, rank]
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
|
|
|
|
|
# Update moments in projected space
|
|
|
|
|
|
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
|
|
|
|
|
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
|
|
|
|
|
proj_grad, proj_grad, value=1 - beta2)
|
|
|
|
|
|
|
|
|
|
|
|
# Bias correction
|
2026-03-31 00:54:17 -04:00
|
|
|
|
bc1 = 1 - beta1 ** step
|
|
|
|
|
|
bc2 = 1 - beta2 ** step
|
2026-03-30 22:02:37 -04:00
|
|
|
|
m_hat = state['exp_avg'] / bc1
|
|
|
|
|
|
v_hat = state['exp_avg_sq'] / bc2
|
|
|
|
|
|
|
|
|
|
|
|
# Adam update in projected space
|
|
|
|
|
|
adam_update = m_hat / (v_hat.sqrt() + eps)
|
|
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
# Compute scaling factor
|
|
|
|
|
|
if rank == 1:
|
|
|
|
|
|
# Tensor-wise: single scalar (Apollo-Mini)
|
|
|
|
|
|
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
|
|
|
|
|
scaled_grad = grad * (alpha * scaling)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Channel-wise: one factor per channel
|
|
|
|
|
|
if state['proj_dim'] == 'left':
|
|
|
|
|
|
# Channels are columns: scale along dim 1
|
|
|
|
|
|
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
|
|
|
|
|
|
scaled_grad = grad * (alpha * s.unsqueeze(0))
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Channels are rows: scale along dim 1
|
|
|
|
|
|
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
|
|
|
|
|
|
scaled_grad = grad * (alpha * s.unsqueeze(1))
|
|
|
|
|
|
|
|
|
|
|
|
# Norm-growth limiter (equation 4)
|
|
|
|
|
|
if norm_growth_limit is not None:
|
|
|
|
|
|
current_norm = scaled_grad.norm()
|
|
|
|
|
|
if state['prev_scaled_norm'] is not None:
|
|
|
|
|
|
prev_norm = state['prev_scaled_norm']
|
|
|
|
|
|
if current_norm > norm_growth_limit * prev_norm:
|
|
|
|
|
|
scaled_grad = scaled_grad * (
|
|
|
|
|
|
norm_growth_limit * prev_norm / (current_norm + eps))
|
|
|
|
|
|
state['prev_scaled_norm'] = scaled_grad.norm().item()
|
|
|
|
|
|
|
|
|
|
|
|
# Apply update
|
2026-03-30 22:02:37 -04:00
|
|
|
|
step_size = lr * lr_scale
|
2026-03-31 00:54:17 -04:00
|
|
|
|
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
|
2026-03-30 22:02:37 -04:00
|
|
|
|
|
|
|
|
|
|
else:
|
2026-03-31 00:54:17 -04:00
|
|
|
|
# Standard Adam for 1D / small params
|
2026-03-30 22:02:37 -04:00
|
|
|
|
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
|
|
|
|
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
|
|
|
|
|
grad, grad, value=1 - beta2)
|
|
|
|
|
|
|
2026-03-31 00:54:17 -04:00
|
|
|
|
bc1 = 1 - beta1 ** step
|
|
|
|
|
|
bc2 = 1 - beta2 ** step
|
2026-03-30 22:02:37 -04:00
|
|
|
|
m_hat = state['exp_avg'] / bc1
|
|
|
|
|
|
v_hat = state['exp_avg_sq'] / bc2
|
|
|
|
|
|
|
|
|
|
|
|
update = m_hat / (v_hat.sqrt() + eps)
|
|
|
|
|
|
step_size = lr * lr_scale
|
|
|
|
|
|
p.add_(update.to(p.dtype), alpha=-step_size)
|
|
|
|
|
|
|
|
|
|
|
|
# Decoupled weight decay
|
|
|
|
|
|
if weight_decay > 0:
|
|
|
|
|
|
p.add_(p, alpha=-lr * lr_scale * weight_decay)
|
|
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def state_size_bytes(self):
|
|
|
|
|
|
"""Total optimizer state memory in bytes."""
|
|
|
|
|
|
total = 0
|
|
|
|
|
|
for state in self.state.values():
|
|
|
|
|
|
if isinstance(state, dict):
|
|
|
|
|
|
for v in state.values():
|
|
|
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
|
|
|
total += v.nelement() * v.element_size()
|
|
|
|
|
|
return total
|