"""Apollo optimizer — configurable-rank gradient scaling. Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level Performance" (arXiv:2412.05270, MLSys 2025). The core idea: AdamW's per-element learning rate scaling is redundant. Channel-wise or tensor-wise scaling is sufficient. Apollo approximates these scaling factors using a low-rank auxiliary optimizer state based on pure random projection. Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% compute overhead vs forward+backward. Captures gradient structure across 100+ behavioral training examples per batch. Key implementation details from the paper: - Gradient scale factor α = √(n/r) compensates for projection ratio - Norm-growth limiter (γ=1.01) prevents early training instability - Projection matrix refreshed every T steps (default 200), not every step - Channel-wise scaling for rank>1, tensor-wise for rank=1 """ import math import torch from torch.optim import Optimizer class Apollo(Optimizer): """Apollo: configurable-rank gradient scaling optimizer. rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory). rank>1 is full Apollo (channel-wise scaling). Args: params: model parameters lr: learning rate (default: 1e-4) rank: projection rank (default: 256) betas: Adam momentum coefficients (default: (0.9, 0.999)) eps: numerical stability term (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) warmup_steps: linear lr warmup steps (default: 0) scale: gradient scale factor α. Default None = auto √(n/r). Paper uses √128 for Apollo-Mini. proj_refresh: refresh projection matrix every T steps (default: 200) norm_growth_limit: max gradient norm growth ratio γ (default: 1.01). Set to None to disable. """ def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, warmup_steps=0, scale=None, proj_refresh=200, norm_growth_limit=1.01): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, weight_decay=weight_decay, warmup_steps=warmup_steps, scale=scale, proj_refresh=proj_refresh, norm_growth_limit=norm_growth_limit) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group['lr'] beta1, beta2 = group['betas'] eps = group['eps'] weight_decay = group['weight_decay'] rank = group['rank'] proj_refresh = group['proj_refresh'] norm_growth_limit = group['norm_growth_limit'] for p in group['params']: if p.grad is None: continue grad = p.grad.float() state = self.state[p] # Initialize state if len(state) == 0: state['step'] = 0 state['seed'] = id(p) % (2**31) if grad.ndim >= 2 and min(grad.shape) >= rank: # Determine projection dimension (project along smaller dim) if grad.shape[0] <= grad.shape[1]: state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n] state['m'] = grad.shape[0] state['n'] = grad.shape[1] moment_shape = (rank, grad.shape[1]) else: state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r] state['m'] = grad.shape[0] state['n'] = grad.shape[1] moment_shape = (grad.shape[0], rank) state['exp_avg'] = torch.zeros(moment_shape, device=p.device) state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device) state['has_proj'] = True state['prev_scaled_norm'] = None # Auto scale factor: α = √(smaller_dim / rank) smaller_dim = min(grad.shape) if group['scale'] is not None: state['alpha'] = group['scale'] else: state['alpha'] = math.sqrt(smaller_dim / rank) else: # 1D or small params: standard Adam state['exp_avg'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad) state['has_proj'] = False state['step'] += 1 step = state['step'] # Learning rate warmup if group['warmup_steps'] > 0 and step <= group['warmup_steps']: lr_scale = step / group['warmup_steps'] else: lr_scale = 1.0 if state['has_proj']: alpha = state['alpha'] # Generate projection matrix (refresh every proj_refresh steps) if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0): gen = torch.Generator(device=p.device) gen.manual_seed(state['seed'] + step) if state['proj_dim'] == 'left': # P: [rank, m], normalized rows P = torch.randn(rank, state['m'], device=p.device, generator=gen) P = P / (P.norm(dim=1, keepdim=True) + eps) state['proj_matrix'] = P else: # P: [rank, n], normalized rows P = torch.randn(rank, state['n'], device=p.device, generator=gen) P = P / (P.norm(dim=1, keepdim=True) + eps) state['proj_matrix'] = P P = state['proj_matrix'] # Project gradient to low-rank space if state['proj_dim'] == 'left': proj_grad = P @ grad # [rank, n] else: proj_grad = grad @ P.t() # [m, rank] # Update moments in projected space state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1) state['exp_avg_sq'].mul_(beta2).addcmul_( proj_grad, proj_grad, value=1 - beta2) # Bias correction bc1 = 1 - beta1 ** step bc2 = 1 - beta2 ** step m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 # Adam update in projected space adam_update = m_hat / (v_hat.sqrt() + eps) # Compute scaling factor if rank == 1: # Tensor-wise: single scalar (Apollo-Mini) scaling = adam_update.norm() / (proj_grad.norm() + eps) scaled_grad = grad * (alpha * scaling) else: # Channel-wise: one factor per channel if state['proj_dim'] == 'left': # Channels are columns: scale along dim 1 s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps) scaled_grad = grad * (alpha * s.unsqueeze(0)) else: # Channels are rows: scale along dim 1 s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps) scaled_grad = grad * (alpha * s.unsqueeze(1)) # Norm-growth limiter (equation 4) if norm_growth_limit is not None: current_norm = scaled_grad.norm() if state['prev_scaled_norm'] is not None: prev_norm = state['prev_scaled_norm'] if current_norm > norm_growth_limit * prev_norm: scaled_grad = scaled_grad * ( norm_growth_limit * prev_norm / (current_norm + eps)) state['prev_scaled_norm'] = scaled_grad.norm().item() # Apply update step_size = lr * lr_scale p.add_(scaled_grad.to(p.dtype), alpha=-step_size) else: # Standard Adam for 1D / small params state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1) state['exp_avg_sq'].mul_(beta2).addcmul_( grad, grad, value=1 - beta2) bc1 = 1 - beta1 ** step bc2 = 1 - beta2 ** step m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 update = m_hat / (v_hat.sqrt() + eps) step_size = lr * lr_scale p.add_(update.to(p.dtype), alpha=-step_size) # Decoupled weight decay if weight_decay > 0: p.add_(p, alpha=-lr * lr_scale * weight_decay) return loss def state_size_bytes(self): """Total optimizer state memory in bytes.""" total = 0 for state in self.state.values(): if isinstance(state, dict): for v in state.values(): if isinstance(v, torch.Tensor): total += v.nelement() * v.element_size() return total