diff --git a/training/apollo_mini.py b/training/apollo_mini.py index 61c3e44..166ae3a 100644 --- a/training/apollo_mini.py +++ b/training/apollo_mini.py @@ -1,46 +1,60 @@ -"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory. +"""Apollo optimizer — configurable-rank gradient scaling. -Implements the core algorithm from "APOLLO: Approximated Gradient Scaling -for Memory-Efficient LLM Optimization" (arXiv:2412.05270). +Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level +Performance" (arXiv:2412.05270, MLSys 2025). -For each parameter tensor, maintains: - - projected first moment (m): [m, rank] or [rank, n] - - projected second moment (v): same shape - - random projection matrix (regenerated from seed) +The core idea: AdamW's per-element learning rate scaling is redundant. +Channel-wise or tensor-wise scaling is sufficient. Apollo approximates +these scaling factors using a low-rank auxiliary optimizer state based on +pure random projection. Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% compute overhead vs forward+backward. Captures gradient structure across 100+ behavioral training examples per batch. + +Key implementation details from the paper: + - Gradient scale factor α = √(n/r) compensates for projection ratio + - Norm-growth limiter (γ=1.01) prevents early training instability + - Projection matrix refreshed every T steps (default 200), not every step + - Channel-wise scaling for rank>1, tensor-wise for rank=1 """ +import math + import torch from torch.optim import Optimizer class Apollo(Optimizer): - """Apollo: configurable-rank tensor-wise gradient scaling. + """Apollo: configurable-rank gradient scaling optimizer. - rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance). - Higher ranks cost proportionally more memory but may improve - training quality for fine-grained behavioral fine-tuning. + rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory). + rank>1 is full Apollo (channel-wise scaling). Args: params: model parameters lr: learning rate (default: 1e-4) - rank: projection rank (default: 1 = Apollo-Mini) - betas: coefficients for moment estimates (default: (0.9, 0.999)) - eps: term for numerical stability (default: 1e-8) + rank: projection rank (default: 256) + betas: Adam momentum coefficients (default: (0.9, 0.999)) + eps: numerical stability term (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) - warmup_steps: linear warmup steps (default: 0) - scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise + warmup_steps: linear lr warmup steps (default: 0) + scale: gradient scale factor α. Default None = auto √(n/r). + Paper uses √128 for Apollo-Mini. + proj_refresh: refresh projection matrix every T steps (default: 200) + norm_growth_limit: max gradient norm growth ratio γ (default: 1.01). + Set to None to disable. """ - def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.01, warmup_steps=0, scale_type='tensor'): + def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), + eps=1e-8, weight_decay=0.01, warmup_steps=0, + scale=None, proj_refresh=200, norm_growth_limit=1.01): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, weight_decay=weight_decay, warmup_steps=warmup_steps, - scale_type=scale_type) + scale=scale, + proj_refresh=proj_refresh, + norm_growth_limit=norm_growth_limit) super().__init__(params, defaults) @torch.no_grad() @@ -55,6 +69,9 @@ class Apollo(Optimizer): beta1, beta2 = group['betas'] eps = group['eps'] weight_decay = group['weight_decay'] + rank = group['rank'] + proj_refresh = group['proj_refresh'] + norm_growth_limit = group['norm_growth_limit'] for p in group['params']: if p.grad is None: @@ -66,58 +83,75 @@ class Apollo(Optimizer): # Initialize state if len(state) == 0: state['step'] = 0 - state['seed'] = id(p) # deterministic per-param seed + state['seed'] = id(p) % (2**31) - # Determine projection dimension - rank = group['rank'] if grad.ndim >= 2 and min(grad.shape) >= rank: - if grad.shape[0] >= grad.shape[1]: - state['proj_dim'] = 'right' - moment_shape = (grad.shape[0], rank) - else: - state['proj_dim'] = 'left' + # Determine projection dimension (project along smaller dim) + if grad.shape[0] <= grad.shape[1]: + state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n] + state['m'] = grad.shape[0] + state['n'] = grad.shape[1] moment_shape = (rank, grad.shape[1]) + else: + state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r] + state['m'] = grad.shape[0] + state['n'] = grad.shape[1] + moment_shape = (grad.shape[0], rank) - state['exp_avg'] = torch.zeros(moment_shape, - device=p.device) - state['exp_avg_sq'] = torch.zeros(moment_shape, - device=p.device) + state['exp_avg'] = torch.zeros(moment_shape, device=p.device) + state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device) state['has_proj'] = True - state['rank'] = rank + state['prev_scaled_norm'] = None + + # Auto scale factor: α = √(smaller_dim / rank) + smaller_dim = min(grad.shape) + if group['scale'] is not None: + state['alpha'] = group['scale'] + else: + state['alpha'] = math.sqrt(smaller_dim / rank) else: - # 1D params (biases, norms): use standard Adam + # 1D or small params: standard Adam state['exp_avg'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad) state['has_proj'] = False state['step'] += 1 + step = state['step'] # Learning rate warmup - if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']: - lr_scale = state['step'] / group['warmup_steps'] + if group['warmup_steps'] > 0 and step <= group['warmup_steps']: + lr_scale = step / group['warmup_steps'] else: lr_scale = 1.0 if state['has_proj']: - rank = state['rank'] + alpha = state['alpha'] - # Generate deterministic random projection matrix - gen = torch.Generator(device=p.device) - gen.manual_seed(state['seed'] + state['step']) + # Generate projection matrix (refresh every proj_refresh steps) + if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0): + gen = torch.Generator(device=p.device) + gen.manual_seed(state['seed'] + step) - # Project gradient to low-rank - if state['proj_dim'] == 'right': - proj_mat = torch.randn(grad.shape[1], rank, - device=p.device, - generator=gen) - proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps) - proj_grad = grad @ proj_mat # [m, rank] + if state['proj_dim'] == 'left': + # P: [rank, m], normalized rows + P = torch.randn(rank, state['m'], + device=p.device, generator=gen) + P = P / (P.norm(dim=1, keepdim=True) + eps) + state['proj_matrix'] = P + else: + # P: [rank, n], normalized rows + P = torch.randn(rank, state['n'], + device=p.device, generator=gen) + P = P / (P.norm(dim=1, keepdim=True) + eps) + state['proj_matrix'] = P + + P = state['proj_matrix'] + + # Project gradient to low-rank space + if state['proj_dim'] == 'left': + proj_grad = P @ grad # [rank, n] else: - proj_mat = torch.randn(rank, grad.shape[0], - device=p.device, - generator=gen) - proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps) - proj_grad = proj_mat @ grad # [rank, n] + proj_grad = grad @ P.t() # [m, rank] # Update moments in projected space state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1) @@ -125,29 +159,52 @@ class Apollo(Optimizer): proj_grad, proj_grad, value=1 - beta2) # Bias correction - bc1 = 1 - beta1 ** state['step'] - bc2 = 1 - beta2 ** state['step'] + bc1 = 1 - beta1 ** step + bc2 = 1 - beta2 ** step m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 # Adam update in projected space adam_update = m_hat / (v_hat.sqrt() + eps) - # Tensor-wise scaling factor - scaling = adam_update.norm() / (proj_grad.norm() + eps) + # Compute scaling factor + if rank == 1: + # Tensor-wise: single scalar (Apollo-Mini) + scaling = adam_update.norm() / (proj_grad.norm() + eps) + scaled_grad = grad * (alpha * scaling) + else: + # Channel-wise: one factor per channel + if state['proj_dim'] == 'left': + # Channels are columns: scale along dim 1 + s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps) + scaled_grad = grad * (alpha * s.unsqueeze(0)) + else: + # Channels are rows: scale along dim 1 + s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps) + scaled_grad = grad * (alpha * s.unsqueeze(1)) - # Apply to full gradient + # Norm-growth limiter (equation 4) + if norm_growth_limit is not None: + current_norm = scaled_grad.norm() + if state['prev_scaled_norm'] is not None: + prev_norm = state['prev_scaled_norm'] + if current_norm > norm_growth_limit * prev_norm: + scaled_grad = scaled_grad * ( + norm_growth_limit * prev_norm / (current_norm + eps)) + state['prev_scaled_norm'] = scaled_grad.norm().item() + + # Apply update step_size = lr * lr_scale - p.add_(grad.to(p.dtype) * (-step_size * scaling)) + p.add_(scaled_grad.to(p.dtype), alpha=-step_size) else: - # Standard Adam for 1D params + # Standard Adam for 1D / small params state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1) state['exp_avg_sq'].mul_(beta2).addcmul_( grad, grad, value=1 - beta2) - bc1 = 1 - beta1 ** state['step'] - bc2 = 1 - beta2 ** state['step'] + bc1 = 1 - beta1 ** step + bc2 = 1 - beta2 ** step m_hat = state['exp_avg'] / bc1 v_hat = state['exp_avg_sq'] / bc2 diff --git a/training/research/apollo-paper-analysis.md b/training/research/apollo-paper-analysis.md new file mode 100644 index 0000000..936b2f7 --- /dev/null +++ b/training/research/apollo-paper-analysis.md @@ -0,0 +1,273 @@ +# Apollo Paper: Deep Analysis + +Source: arXiv:2412.05270v4, MLSys 2025 Outstanding Paper Honorable Mention +Authors: Zhu, Zhang, Cong, Liu, Park, Chandra, Long, Pan, Wang, Lee + +## The Core Insight + +AdamW's per-element learning rate scaling is massively redundant for LLMs. +The element-wise scaling can be coarsened to channel-wise or even tensor-wise +without loss — and with slight improvement in some cases. + +### The mathematical argument + +AdamW's update rule, rewritten as a pure scaling operation: + +``` +Standard AdamW: + M_t = β₁M_{t-1} + (1-β₁)G_t # first moment + V_t = β₂V_{t-1} + (1-β₂)G_t² # second moment + G̃_t = M_t / (√V_t + ε) # scaled gradient + W_{t+1} = W_t - η·G̃_t - η·λ·W_t # weight update + +Rewritten as scaling: + W_{t+1} = W_t - η · (G̃_t/G_t) · G_t # S = G̃_t/G_t is the scaling matrix +``` + +The scaling matrix S ∈ ℝ^{m×n} is element-wise: each weight gets its own +learning rate. The paper's key observation: **this per-element granularity +is unnecessary.** S can be coarsened to: + +- **Channel-wise**: one scaling factor per column (or row), s_j for channel j +- **Tensor-wise**: one scalar for the whole tensor (Apollo-Mini) + +### Channel-wise scaling factor (equation 3) + +``` +s_j = ‖G̃_t[:,j]‖₂ / ‖G_t[:,j]‖₂ +``` + +This computes the ratio of norms between the Adam-scaled gradient and the +raw gradient for each channel. It tells you: "how much should this channel's +gradient be amplified or dampened?" + +The paper shows empirically that channel-wise scaling achieves slightly +BETTER perplexity than element-wise (24.43 vs 25.08 on LLaMA-130M). +The coarsening acts as implicit regularization. + +## Apollo: Approximating the Scaling Factor + +Computing channel-wise scaling still requires the full M_t and V_t matrices. +Apollo's contribution: approximate s_j using a low-rank auxiliary optimizer. + +### Algorithm (Algorithm 1) + +``` +Input: W ∈ ℝ^{m×n} (m ≤ n), lr η, scale factor α, rank r +Initialize: t = 0 + +repeat: + G_t = ∇φ(W_t) # full gradient + + # Step 1: Project to low-rank space + if t mod T = 0: + P_t ← N(0, 1/r) # new random projection [r×m] + seed ← random + R_t = P_t · G_t # projected gradient [r×n] + + # Step 2: Adam in low-rank space + M_t^R, V_t^R ← AdamW(R_t, β₁, β₂, λ=0) # moments on projected gradient + R̃_t = M_t^R / (√V_t^R + ε) # Adam-scaled projected gradient + + # Step 3: Approximate channel-wise scaling + if APOLLO: + S ← diag(s₀^R, s₁^R, ..., s_n^R) + where s_j^R = ‖R̃_t[:,j]‖₂ / ‖R_t[:,j]‖₂ + elif APOLLO-Mini: + S ← s^R · I + where s^R = ‖R̃_t‖₂ / ‖R_t‖₂ # single scalar + + # Step 4: Update weight in original space + W_{t+1} = W_t + η·α · G_t·S - η·λ·W_t +``` + +### Key differences from my implementation + +1. **Scale factor α**: The paper uses a gradient scale factor α (default √128 + for Apollo-Mini) to compensate for the ratio √(n/r) between compact and + original space scaling factors. This is the `scale` parameter in + `apollo_torch.APOLLOAdamW`. **Our implementation is missing this.** + +2. **Norm-growth limiter**: Instead of gradient clipping, they use a norm + growth limiter (equation 4): + ``` + if ‖G̃_t‖/‖G̃_{t-1}‖ > γ: + G̃_t ← (G̃_t/‖G̃_t‖) · γ · ‖G̃_{t-1}‖ + ``` + Default γ = 1.01. This prevents loss spikes in early training. + **Our implementation is missing this.** + +3. **Projection matrix refresh**: P_t is regenerated every T steps (default + T=200). Not every step. This amortizes the projection cost. + **Our implementation regenerates every step — wasteful.** + +4. **The scaling is applied as G_t · S (post-multiply by diagonal)**: + The gradient is multiplied by the scaling factors, not the gradient + scaled and then applied. This means the full gradient direction is + preserved; only the per-channel magnitude changes. + +## Theoretical Guarantees + +### Theorem 4.1: First-moment approximation bound + +For projected gradient R_t = P·G_t where P ∈ ℝ^{r×m} is random Gaussian: + +``` +(1-ε)‖M_t[:,j]‖² ≤ ‖M_t^R[:,j]‖² ≤ (1+ε)‖M_t[:,j]‖² +``` + +with probability at least 1 - 2exp(-rε²/8). + +This is a Johnson-Lindenstrauss result: random projection approximately +preserves norms. The channel-wise first moment norms in the projected space +are close to the original space norms. + +### Theorem 4.2: Second-moment approximation bound + +For ℓ₁ norm (element-wise second moment): + +``` +(1-ε)‖V_t[:,j]‖₁ ≤ ‖V_t^R[:,j]‖₁ ≤ (1+ε)‖V_t[:,j]‖₁ +``` + +with probability at least 1-δ/2, when r ≥ (8/ε²)·log(2t/δ). + +### Bounded update ratio (equation 9) + +The ratio between compact and original scaling factors: + +``` +(√(1-ε))/(1+ε) ≤ √(n/r · s_j^R/s_j) ≤ (√(1+ε))/(1-ε) +``` + +This means the approximated scaling factor s_j^R differs from the true +scaling factor s_j by a predictable ratio of √(n/r), which is compensated +by the gradient scale factor α. + +**This is why α = √128 for Apollo-Mini**: when r=1 and n is the smaller +dimension (typically ~128 for head dimensions), √(n/r) ≈ √128 ≈ 11.3. +The α compensates for this systematic ratio. + +## Apollo-Mini: Tensor-wise Scaling + +For rank r=1, channel-wise scaling becomes numerically unstable (one element +per channel in the projected space). Apollo-Mini coarsens further to a +single tensor-wise scaling factor: + +``` +s = ‖R̃_t‖₂ / ‖R_t‖₂ +``` + +One scalar for the entire tensor. This is maximally coarse. + +**Why it works**: The tensor-wise average of channel-wise scaling factors +smooths out the noise from rank-1 projection. The errors cancel across +channels. The paper shows Apollo-Mini actually OUTPERFORMS AdamW on +pre-training (Table 2, 3) — the coarsening acts as regularization. + +## Why Apollo Can Beat AdamW (Section 5.5) + +The paper provides two hypotheses: + +### Hypothesis 1: Directional sharpness + +Adam achieves lower directional sharpness than SGD, improving Transformer +training. But if directional sharpness is already too low (over-smoothed +landscape), the updates become too conservative. Apollo's coarser scaling +resembles SGD more (depends more on current gradient, less on history), +which can escape local optima that AdamW gets stuck in. + +**Table 10**: Apollo/Apollo-Mini achieve lower directional sharpness than +Adam at epochs 5-20, comparable to SGD. This means Apollo navigates the +loss landscape more effectively. + +### Hypothesis 2: Block-wise adaptive learning rates + +Transformer blocks have varying Hessian spectra. Block-wise (channel/tensor) +adaptive rates are sufficient; fully per-element rates are redundant given +this structure. Apollo's channel/tensor-wise scaling naturally aligns with +the block structure of Transformers. + +## Fine-tuning Results (Section 5.2) + +On fine-tuning (Table 5, 6): + +- **Common-sense reasoning (8 tasks)**: Apollo-Mini achieves 68.23 average + vs AdamW's 68.07. Essentially identical, with 0G optimizer memory. +- **MMLU**: Apollo-Mini competitive across all categories (STEM, Social + Sciences, Humanities, Other). +- **Learning rate range**: Sweeping [5e-6, 7.5e-6, 1e-5, 2.5e-5, 5e-5, + 7.5e-5, 1e-4, 1.5e-4, 2e-4]. Best results at 1e-5 to 1e-4 range. + +**Key finding for us**: Apollo-Mini performs on par with full AdamW for +fine-tuning. The rank doesn't matter much for fine-tuning quality — even +rank-1 is sufficient. The quality comes from the gradient direction (which +is preserved at full rank); only the scaling magnitude is approximated. + +## Ablation Studies (Section 5.4) + +### A1: Random projection ≈ SVD +Apollo performs equally well with random projection as SVD. Random projection +is dramatically cheaper (matrix multiply vs O(mn²) SVD). + +### A2: Apollo-Mini effective even at rank 1 +Apollo-Mini (rank-1) outperforms AdamW on pre-training. The tensor-wise +averaging of noise is a feature, not a bug. + +### A3: Channel vs tensor granularity +Table 9: Difference between channel-wise and tensor-wise scaling is minimal +(~0.15 perplexity). For extreme low-rank (rank-1), tensor-wise actually +outperforms channel-wise. + +### A4: Better with larger models and more tokens +Apollo's advantage over AdamW grows with model size and training tokens. +For larger models, the structured scaling becomes more beneficial. + +### A5: Long-context training +Apollo performs on par with or better than AdamW for long-context pre-training +(sequence length 1024), with drastic memory savings. + +## Implications for Our Use Case + +### Learning rate +The paper sweeps [5e-6 to 2e-4] for fine-tuning. Our lr=1e-5 to 1e-4 +range is in the sweet spot. + +### Scale factor α +**We need to add this.** For rank-256 (our default), α should be +√(n/256) where n is the smaller weight dimension. For typical attention +weights with n=5120, that's √20 ≈ 4.5. For rank-1 it would be √5120 ≈ 71.6. +The `apollo_torch` library sets this as the `scale` parameter. + +Our `apollo_mini.py` is missing the α factor entirely. This likely +means our scaling factors are systematically too small by √(n/r). + +### Norm-growth limiter +We should add this (γ=1.01) for training stability, especially in early +steps. It prevents the loss spikes visible in Figure 3. + +### Projection refresh +We can regenerate P every 200 steps instead of every step. Saves compute +and the theory shows it doesn't matter. + +### Channel vs tensor scaling +For rank-256, channel-wise is slightly better. For rank-1, tensor-wise +is better. Since we default to rank-256, we should use channel-wise +(which we planned). + +### Fine-tuning vs pre-training +The paper shows Apollo is slightly more beneficial for pre-training than +fine-tuning (where it merely matches AdamW). For fine-tuning, the gradient +direction matters more than the scaling precision — and Apollo preserves +the full gradient direction. This means our behavioral fine-tuning should +work well regardless of rank. + +## Corrections to Our Implementation + +1. **Add gradient scale factor α = √(n/r)** — critical for correct + scaling magnitude +2. **Add norm-growth limiter (γ=1.01)** — prevents early training instability +3. **Refresh projection every T=200 steps, not every step** +4. **Channel-wise scaling for rank>1, tensor-wise for rank=1** +5. **The scaling applies as G·diag(s), not s·G** — post-multiply, preserving + gradient direction per channel