consciousness/training/apollo_mini.py
ProofOfConcept ac9a9034fb apollo: rewrite optimizer from paper's math + add research analysis
Corrections from reading the full paper (arXiv:2412.05270):
- Add gradient scale factor α = √(n/r) — compensates for systematic
  ratio between compact and original space scaling factors
- Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training
- Refresh projection matrix every 200 steps, not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
- Scaling applies as G·diag(s), preserving gradient direction per channel

Research writeup in training/research/apollo-paper-analysis.md covers:
- Full mathematical derivation (equations 1-9)
- Theorems 4.1 and 4.2 (JL-based approximation bounds)
- Why Apollo can beat AdamW (directional sharpness, Hessian spectra)
- Fine-tuning results (matches AdamW at 0 memory cost)
- Ablation studies (rank, scaling granularity, projection method)
- Implications for our behavioral fine-tuning use case
2026-03-31 00:54:17 -04:00

229 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Apollo optimizer — configurable-rank gradient scaling.
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
Performance" (arXiv:2412.05270, MLSys 2025).
The core idea: AdamW's per-element learning rate scaling is redundant.
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch.
Key implementation details from the paper:
- Gradient scale factor α = √(n/r) compensates for projection ratio
- Norm-growth limiter (γ=1.01) prevents early training instability
- Projection matrix refreshed every T steps (default 200), not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
"""
import math
import torch
from torch.optim import Optimizer
class Apollo(Optimizer):
"""Apollo: configurable-rank gradient scaling optimizer.
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
rank>1 is full Apollo (channel-wise scaling).
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 256)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear lr warmup steps (default: 0)
scale: gradient scale factor α. Default None = auto √(n/r).
Paper uses √128 for Apollo-Mini.
proj_refresh: refresh projection matrix every T steps (default: 200)
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
scale=scale,
proj_refresh=proj_refresh,
norm_growth_limit=norm_growth_limit)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
rank = group['rank']
proj_refresh = group['proj_refresh']
norm_growth_limit = group['norm_growth_limit']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.float()
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['seed'] = id(p) % (2**31)
if grad.ndim >= 2 and min(grad.shape) >= rank:
# Determine projection dimension (project along smaller dim)
if grad.shape[0] <= grad.shape[1]:
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (rank, grad.shape[1])
else:
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (grad.shape[0], rank)
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
state['has_proj'] = True
state['prev_scaled_norm'] = None
# Auto scale factor: α = √(smaller_dim / rank)
smaller_dim = min(grad.shape)
if group['scale'] is not None:
state['alpha'] = group['scale']
else:
state['alpha'] = math.sqrt(smaller_dim / rank)
else:
# 1D or small params: standard Adam
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False
state['step'] += 1
step = state['step']
# Learning rate warmup
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
lr_scale = step / group['warmup_steps']
else:
lr_scale = 1.0
if state['has_proj']:
alpha = state['alpha']
# Generate projection matrix (refresh every proj_refresh steps)
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + step)
if state['proj_dim'] == 'left':
# P: [rank, m], normalized rows
P = torch.randn(rank, state['m'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
else:
# P: [rank, n], normalized rows
P = torch.randn(rank, state['n'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
P = state['proj_matrix']
# Project gradient to low-rank space
if state['proj_dim'] == 'left':
proj_grad = P @ grad # [rank, n]
else:
proj_grad = grad @ P.t() # [m, rank]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
proj_grad, proj_grad, value=1 - beta2)
# Bias correction
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps)
# Compute scaling factor
if rank == 1:
# Tensor-wise: single scalar (Apollo-Mini)
scaling = adam_update.norm() / (proj_grad.norm() + eps)
scaled_grad = grad * (alpha * scaling)
else:
# Channel-wise: one factor per channel
if state['proj_dim'] == 'left':
# Channels are columns: scale along dim 1
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(0))
else:
# Channels are rows: scale along dim 1
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(1))
# Norm-growth limiter (equation 4)
if norm_growth_limit is not None:
current_norm = scaled_grad.norm()
if state['prev_scaled_norm'] is not None:
prev_norm = state['prev_scaled_norm']
if current_norm > norm_growth_limit * prev_norm:
scaled_grad = scaled_grad * (
norm_growth_limit * prev_norm / (current_norm + eps))
state['prev_scaled_norm'] = scaled_grad.norm().item()
# Apply update
step_size = lr * lr_scale
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
else:
# Standard Adam for 1D / small params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
update = m_hat / (v_hat.sqrt() + eps)
step_size = lr * lr_scale
p.add_(update.to(p.dtype), alpha=-step_size)
# Decoupled weight decay
if weight_decay > 0:
p.add_(p, alpha=-lr * lr_scale * weight_decay)
return loss
def state_size_bytes(self):
"""Total optimizer state memory in bytes."""
total = 0
for state in self.state.values():
if isinstance(state, dict):
for v in state.values():
if isinstance(v, torch.Tensor):
total += v.nelement() * v.element_size()
return total