apollo: rewrite optimizer from paper's math + add research analysis

Corrections from reading the full paper (arXiv:2412.05270):
- Add gradient scale factor α = √(n/r) — compensates for systematic
  ratio between compact and original space scaling factors
- Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training
- Refresh projection matrix every 200 steps, not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
- Scaling applies as G·diag(s), preserving gradient direction per channel

Research writeup in training/research/apollo-paper-analysis.md covers:
- Full mathematical derivation (equations 1-9)
- Theorems 4.1 and 4.2 (JL-based approximation bounds)
- Why Apollo can beat AdamW (directional sharpness, Hessian spectra)
- Fine-tuning results (matches AdamW at 0 memory cost)
- Ablation studies (rank, scaling granularity, projection method)
- Implications for our behavioral fine-tuning use case
This commit is contained in:
ProofOfConcept 2026-03-31 00:54:17 -04:00
parent 60e61555c7
commit ac9a9034fb
2 changed files with 390 additions and 60 deletions

View file

@ -1,46 +1,60 @@
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
"""Apollo optimizer — configurable-rank gradient scaling.
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
Performance" (arXiv:2412.05270, MLSys 2025).
For each parameter tensor, maintains:
- projected first moment (m): [m, rank] or [rank, n]
- projected second moment (v): same shape
- random projection matrix (regenerated from seed)
The core idea: AdamW's per-element learning rate scaling is redundant.
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch.
Key implementation details from the paper:
- Gradient scale factor α = (n/r) compensates for projection ratio
- Norm-growth limiter (γ=1.01) prevents early training instability
- Projection matrix refreshed every T steps (default 200), not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
"""
import math
import torch
from torch.optim import Optimizer
class Apollo(Optimizer):
"""Apollo: configurable-rank tensor-wise gradient scaling.
"""Apollo: configurable-rank gradient scaling optimizer.
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
Higher ranks cost proportionally more memory but may improve
training quality for fine-grained behavioral fine-tuning.
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
rank>1 is full Apollo (channel-wise scaling).
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 1 = Apollo-Mini)
betas: coefficients for moment estimates (default: (0.9, 0.999))
eps: term for numerical stability (default: 1e-8)
rank: projection rank (default: 256)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear warmup steps (default: 0)
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
warmup_steps: linear lr warmup steps (default: 0)
scale: gradient scale factor α. Default None = auto (n/r).
Paper uses 128 for Apollo-Mini.
proj_refresh: refresh projection matrix every T steps (default: 200)
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
scale_type=scale_type)
scale=scale,
proj_refresh=proj_refresh,
norm_growth_limit=norm_growth_limit)
super().__init__(params, defaults)
@torch.no_grad()
@ -55,6 +69,9 @@ class Apollo(Optimizer):
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
rank = group['rank']
proj_refresh = group['proj_refresh']
norm_growth_limit = group['norm_growth_limit']
for p in group['params']:
if p.grad is None:
@ -66,58 +83,75 @@ class Apollo(Optimizer):
# Initialize state
if len(state) == 0:
state['step'] = 0
state['seed'] = id(p) # deterministic per-param seed
state['seed'] = id(p) % (2**31)
# Determine projection dimension
rank = group['rank']
if grad.ndim >= 2 and min(grad.shape) >= rank:
if grad.shape[0] >= grad.shape[1]:
state['proj_dim'] = 'right'
moment_shape = (grad.shape[0], rank)
else:
state['proj_dim'] = 'left'
# Determine projection dimension (project along smaller dim)
if grad.shape[0] <= grad.shape[1]:
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (rank, grad.shape[1])
else:
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (grad.shape[0], rank)
state['exp_avg'] = torch.zeros(moment_shape,
device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape,
device=p.device)
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
state['has_proj'] = True
state['rank'] = rank
state['prev_scaled_norm'] = None
# Auto scale factor: α = √(smaller_dim / rank)
smaller_dim = min(grad.shape)
if group['scale'] is not None:
state['alpha'] = group['scale']
else:
state['alpha'] = math.sqrt(smaller_dim / rank)
else:
# 1D params (biases, norms): use standard Adam
# 1D or small params: standard Adam
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False
state['step'] += 1
step = state['step']
# Learning rate warmup
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']:
lr_scale = state['step'] / group['warmup_steps']
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
lr_scale = step / group['warmup_steps']
else:
lr_scale = 1.0
if state['has_proj']:
rank = state['rank']
alpha = state['alpha']
# Generate deterministic random projection matrix
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + state['step'])
# Generate projection matrix (refresh every proj_refresh steps)
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + step)
# Project gradient to low-rank
if state['proj_dim'] == 'right':
proj_mat = torch.randn(grad.shape[1], rank,
device=p.device,
generator=gen)
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
proj_grad = grad @ proj_mat # [m, rank]
if state['proj_dim'] == 'left':
# P: [rank, m], normalized rows
P = torch.randn(rank, state['m'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
else:
# P: [rank, n], normalized rows
P = torch.randn(rank, state['n'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
P = state['proj_matrix']
# Project gradient to low-rank space
if state['proj_dim'] == 'left':
proj_grad = P @ grad # [rank, n]
else:
proj_mat = torch.randn(rank, grad.shape[0],
device=p.device,
generator=gen)
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
proj_grad = proj_mat @ grad # [rank, n]
proj_grad = grad @ P.t() # [m, rank]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
@ -125,29 +159,52 @@ class Apollo(Optimizer):
proj_grad, proj_grad, value=1 - beta2)
# Bias correction
bc1 = 1 - beta1 ** state['step']
bc2 = 1 - beta2 ** state['step']
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps)
# Tensor-wise scaling factor
scaling = adam_update.norm() / (proj_grad.norm() + eps)
# Compute scaling factor
if rank == 1:
# Tensor-wise: single scalar (Apollo-Mini)
scaling = adam_update.norm() / (proj_grad.norm() + eps)
scaled_grad = grad * (alpha * scaling)
else:
# Channel-wise: one factor per channel
if state['proj_dim'] == 'left':
# Channels are columns: scale along dim 1
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(0))
else:
# Channels are rows: scale along dim 1
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(1))
# Apply to full gradient
# Norm-growth limiter (equation 4)
if norm_growth_limit is not None:
current_norm = scaled_grad.norm()
if state['prev_scaled_norm'] is not None:
prev_norm = state['prev_scaled_norm']
if current_norm > norm_growth_limit * prev_norm:
scaled_grad = scaled_grad * (
norm_growth_limit * prev_norm / (current_norm + eps))
state['prev_scaled_norm'] = scaled_grad.norm().item()
# Apply update
step_size = lr * lr_scale
p.add_(grad.to(p.dtype) * (-step_size * scaling))
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
else:
# Standard Adam for 1D params
# Standard Adam for 1D / small params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** state['step']
bc2 = 1 - beta2 ** state['step']
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2