From e1cd4fb0abdba26b3bd48817909d9aa03635142a Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Mon, 30 Mar 2026 22:06:31 -0400 Subject: [PATCH] apollo: make rank configurable (default 1 = Mini, higher ranks for experimentation) --- training/apollo_mini.py | 60 +++++++++++++++++++++++---------------- training/apollo_worker.py | 5 ++-- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/training/apollo_mini.py b/training/apollo_mini.py index d299202..f86d1bb 100644 --- a/training/apollo_mini.py +++ b/training/apollo_mini.py @@ -1,38 +1,46 @@ -"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory. +"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory. Implements the core algorithm from "APOLLO: Approximated Gradient Scaling for Memory-Efficient LLM Optimization" (arXiv:2412.05270). For each parameter tensor, maintains: - - rank-1 projected first moment (m): [m, 1] or [1, n] - - rank-1 projected second moment (v): same shape - - fixed random projection vector (regenerated from seed) + - projected first moment (m): [m, rank] or [rank, n] + - projected second moment (v): same shape + - random projection matrix (regenerated from seed) -Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW). +rank=1 is Apollo-Mini (~50MB state for 27B model). +rank=2-16 costs proportionally more memory but is still negligible. +Compute cost of projection is <1% of forward+backward. """ import torch from torch.optim import Optimizer -class ApolloMini(Optimizer): - """Apollo-Mini: rank-1 tensor-wise gradient scaling. +class Apollo(Optimizer): + """Apollo: configurable-rank tensor-wise gradient scaling. + + rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance). + Higher ranks cost proportionally more memory but may improve + training quality for fine-grained behavioral fine-tuning. Args: params: model parameters lr: learning rate (default: 1e-4) + rank: projection rank (default: 1 = Apollo-Mini) betas: coefficients for moment estimates (default: (0.9, 0.999)) eps: term for numerical stability (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) warmup_steps: linear warmup steps (default: 0) - scale: scaling factor for projection (default: 128) + scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise """ - def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.01, warmup_steps=0, scale=128): - defaults = dict(lr=lr, betas=betas, eps=eps, + def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0.01, warmup_steps=0, scale_type='tensor'): + defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, weight_decay=weight_decay, - warmup_steps=warmup_steps, scale=scale) + warmup_steps=warmup_steps, + scale_type=scale_type) super().__init__(params, defaults) @torch.no_grad() @@ -61,21 +69,21 @@ class ApolloMini(Optimizer): state['seed'] = id(p) # deterministic per-param seed # Determine projection dimension - if grad.ndim >= 2: + rank = group['rank'] + if grad.ndim >= 2 and min(grad.shape) >= rank: if grad.shape[0] >= grad.shape[1]: - proj_shape = (grad.shape[1], 1) state['proj_dim'] = 'right' - moment_shape = (grad.shape[0], 1) + moment_shape = (grad.shape[0], rank) else: - proj_shape = (1, grad.shape[0]) state['proj_dim'] = 'left' - moment_shape = (1, grad.shape[1]) + moment_shape = (rank, grad.shape[1]) state['exp_avg'] = torch.zeros(moment_shape, device=p.device) state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device) state['has_proj'] = True + state['rank'] = rank else: # 1D params (biases, norms): use standard Adam state['exp_avg'] = torch.zeros_like(grad) @@ -91,23 +99,25 @@ class ApolloMini(Optimizer): lr_scale = 1.0 if state['has_proj']: - # Generate deterministic random projection vector + rank = state['rank'] + + # Generate deterministic random projection matrix gen = torch.Generator(device=p.device) gen.manual_seed(state['seed'] + state['step']) - # Project gradient to rank-1 + # Project gradient to low-rank if state['proj_dim'] == 'right': - proj_vec = torch.randn(grad.shape[1], 1, + proj_mat = torch.randn(grad.shape[1], rank, device=p.device, generator=gen) - proj_vec = proj_vec / (proj_vec.norm() + eps) - proj_grad = grad @ proj_vec # [m, 1] + proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps) + proj_grad = grad @ proj_mat # [m, rank] else: - proj_vec = torch.randn(1, grad.shape[0], + proj_mat = torch.randn(rank, grad.shape[0], device=p.device, generator=gen) - proj_vec = proj_vec / (proj_vec.norm() + eps) - proj_grad = proj_vec @ grad # [1, n] + proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps) + proj_grad = proj_mat @ grad # [rank, n] # Update moments in projected space state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1) diff --git a/training/apollo_worker.py b/training/apollo_worker.py index 1a7e622..d46fb55 100755 --- a/training/apollo_worker.py +++ b/training/apollo_worker.py @@ -309,7 +309,7 @@ class ApolloWorker: samples: List[Dict[str, str]], config: Dict[str, Any]) -> List[float]: """Run Apollo-Mini training on conversation decision points.""" - from apollo_mini import ApolloMini + from apollo_mini import Apollo from transformers import AutoTokenizer lr = config.get('learning_rate', self.config['learning_rate']) @@ -331,7 +331,8 @@ class ApolloWorker: if standard_params: groups.append({'params': standard_params}) - optimizer = ApolloMini(groups, lr=lr) + rank = config.get('apollo_rank', 1) + optimizer = Apollo(groups, lr=lr, rank=rank) logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " f"state={optimizer.state_size_bytes()/1e6:.1f}MB")