consciousness/training/apollo_plugin/optimizer.py
Kent Overstreet 68a2df2185 training: use rank 64, define as single constant
- DEFAULT_RANK = 64 in train_router.py
- All references use the constant, not magic numbers
- ~2.5GB optimizer state instead of ~10GB

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-16 02:04:26 -04:00

229 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Apollo optimizer — configurable-rank gradient scaling.
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
Performance" (arXiv:2412.05270, MLSys 2025).
The core idea: AdamW's per-element learning rate scaling is redundant.
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead
vs forward+backward. Sufficient for behavioral training with diverse
examples.
Key implementation details from the paper:
- Gradient scale factor α = √(n/r) compensates for projection ratio
- Norm-growth limiter (γ=1.01) prevents early training instability
- Projection matrix refreshed every T steps (default 200), not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
"""
import math
import torch
from torch.optim import Optimizer
class Apollo(Optimizer):
"""Apollo: configurable-rank gradient scaling optimizer.
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
rank>1 is full Apollo (channel-wise scaling).
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 64)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear lr warmup steps (default: 0)
scale: gradient scale factor α. Default None = auto √(n/r).
Paper uses √128 for Apollo-Mini.
proj_refresh: refresh projection matrix every T steps (default: 200)
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
scale=scale,
proj_refresh=proj_refresh,
norm_growth_limit=norm_growth_limit)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
rank = group['rank']
proj_refresh = group['proj_refresh']
norm_growth_limit = group['norm_growth_limit']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.float()
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['seed'] = id(p) % (2**31)
if grad.ndim >= 2 and min(grad.shape) >= rank:
# Determine projection dimension (project along smaller dim)
if grad.shape[0] <= grad.shape[1]:
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (rank, grad.shape[1])
else:
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (grad.shape[0], rank)
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
state['has_proj'] = True
state['prev_scaled_norm'] = None
# Auto scale factor: α = √(smaller_dim / rank)
smaller_dim = min(grad.shape)
if group['scale'] is not None:
state['alpha'] = group['scale']
else:
state['alpha'] = math.sqrt(smaller_dim / rank)
else:
# 1D or small params: standard Adam
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False
state['step'] += 1
step = state['step']
# Learning rate warmup
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
lr_scale = step / group['warmup_steps']
else:
lr_scale = 1.0
if state['has_proj']:
alpha = state['alpha']
# Generate projection matrix (refresh every proj_refresh steps)
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + step)
if state['proj_dim'] == 'left':
# P: [rank, m], normalized rows
P = torch.randn(rank, state['m'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
else:
# P: [rank, n], normalized rows
P = torch.randn(rank, state['n'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
P = state['proj_matrix']
# Project gradient to low-rank space
if state['proj_dim'] == 'left':
proj_grad = P @ grad # [rank, n]
else:
proj_grad = grad @ P.t() # [m, rank]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
proj_grad, proj_grad, value=1 - beta2)
# Bias correction
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps)
# Compute scaling factor
if rank == 1:
# Tensor-wise: single scalar (Apollo-Mini)
scaling = adam_update.norm() / (proj_grad.norm() + eps)
scaled_grad = grad * (alpha * scaling)
else:
# Channel-wise: one factor per channel
if state['proj_dim'] == 'left':
# Channels are columns: scale along dim 1
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(0))
else:
# Channels are rows: scale along dim 1
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(1))
# Norm-growth limiter (equation 4)
if norm_growth_limit is not None:
current_norm = scaled_grad.norm()
if state['prev_scaled_norm'] is not None:
prev_norm = state['prev_scaled_norm']
if current_norm > norm_growth_limit * prev_norm:
scaled_grad = scaled_grad * (
norm_growth_limit * prev_norm / (current_norm + eps))
state['prev_scaled_norm'] = scaled_grad.norm().item()
# Apply update
step_size = lr * lr_scale
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
else:
# Standard Adam for 1D / small params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
update = m_hat / (v_hat.sqrt() + eps)
step_size = lr * lr_scale
p.add_(update.to(p.dtype), alpha=-step_size)
# Decoupled weight decay
if weight_decay > 0:
p.add_(p, alpha=-lr * lr_scale * weight_decay)
return loss
def state_size_bytes(self):
"""Total optimizer state memory in bytes."""
total = 0
for state in self.state.values():
if isinstance(state, dict):
for v in state.values():
if isinstance(v, torch.Tensor):
total += v.nelement() * v.element_size()
return total