"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory. Implements the core algorithm from "APOLLO: Approximated Gradient Scaling for Memory-Efficient LLM Optimization" (arXiv:2412.05270). For each parameter tensor, maintains: - projected first moment (m): [m, rank] or [rank, n] - projected second moment (v): same shape - random projection matrix (regenerated from seed) Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% compute overhead vs forward+backward. Captures gradient structure across 100+ behavioral training examples per batch. """ import torch from torch.optim import Optimizer class Apollo(Optimizer): """Apollo: configurable-rank tensor-wise gradient scaling. rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance). Higher ranks cost proportionally more memory but may improve training quality for fine-grained behavioral fine-tuning. Args: params: model parameters lr: learning rate (default: 1e-4) rank: projection rank (default: 1 = Apollo-Mini) betas: coefficients for moment estimates (default: (0.9, 0.999)) eps: term for numerical stability (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) warmup_steps: linear warmup steps (default: 0) scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise """ def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, warmup_steps=0, scale_type='tensor'): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, weight_decay=weight_decay, warmup_steps=warmup_steps, scale_type=scale_type) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group['lr'] beta1, beta2 = group['betas'] eps = group['eps'] weight_decay = group['weight_decay'] for p in group['params']: if p.grad is None: continue grad = p.grad.float() state = self.state[p] # Initialize state if len(state) == 0: state['step'] = 0 state['seed'] = id(p) # deterministic per-param seed # Determine projection dimension rank = group['rank'] if grad.ndim >= 2 and min(grad.shape) >= rank: if grad.shape[0] >= grad.shape[1]: state['proj_dim'] = 'right' moment_shape = (grad.shape[0], rank) else: state['proj_dim'] = 'left' moment_shape = (rank, grad.shape[1]) state['exp_avg'] = torch.zeros(moment_shape, device=p.device) state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device) state['has_proj'] = True state['rank'] = rank else: # 1D params (biases, norms): use standard Adam state['exp_avg'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad) state['has_proj'] = False state['step'] += 1 # Learning rate warmup if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']: lr_scale = state['step'] / group['warmup_steps'] else: lr_scale = 1.0 if state['has_proj']: rank = state['rank'] # Generate deterministic random projection matrix gen = torch.Generator(device=p.device) gen.manual_seed(state['seed'] + state['step']) # Project gradient to low-rank if state['proj_dim'] == 'right': proj_mat = torch.randn(grad.shape[1], rank, device=p.device, generator=gen) proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps) proj_grad = grad @ proj_mat # [m, rank] else: proj_mat = torch.randn(rank, grad.shape[0], device=p.device, generator=gen) proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps) proj_grad = proj_mat @ grad # [rank, n] # Update moments in projected space state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1) state['exp_avg_sq'].mul_(beta2).addcmul_( proj_grad, proj_grad, value=1 - beta2) # Bias correction bc1 = 1 - beta1 ** state['step'] bc2 = 1 - beta2 ** state['step'] m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 # Adam update in projected space adam_update = m_hat / (v_hat.sqrt() + eps) # Tensor-wise scaling factor scaling = adam_update.norm() / (proj_grad.norm() + eps) # Apply to full gradient step_size = lr * lr_scale p.add_(grad.to(p.dtype) * (-step_size * scaling)) else: # Standard Adam for 1D params state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1) state['exp_avg_sq'].mul_(beta2).addcmul_( grad, grad, value=1 - beta2) bc1 = 1 - beta1 ** state['step'] bc2 = 1 - beta2 ** state['step'] m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 update = m_hat / (v_hat.sqrt() + eps) step_size = lr * lr_scale p.add_(update.to(p.dtype), alpha=-step_size) # Decoupled weight decay if weight_decay > 0: p.add_(p, alpha=-lr * lr_scale * weight_decay) return loss def state_size_bytes(self): """Total optimizer state memory in bytes.""" total = 0 for state in self.state.values(): if isinstance(state, dict): for v in state.values(): if isinstance(v, torch.Tensor): total += v.nelement() * v.element_size() return total