apollo: rewrite optimizer from paper's math + add research analysis

Corrections from reading the full paper (arXiv:2412.05270):
- Add gradient scale factor α = √(n/r) — compensates for systematic
  ratio between compact and original space scaling factors
- Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training
- Refresh projection matrix every 200 steps, not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
- Scaling applies as G·diag(s), preserving gradient direction per channel

Research writeup in training/research/apollo-paper-analysis.md covers:
- Full mathematical derivation (equations 1-9)
- Theorems 4.1 and 4.2 (JL-based approximation bounds)
- Why Apollo can beat AdamW (directional sharpness, Hessian spectra)
- Fine-tuning results (matches AdamW at 0 memory cost)
- Ablation studies (rank, scaling granularity, projection method)
- Implications for our behavioral fine-tuning use case
This commit is contained in:
ProofOfConcept 2026-03-31 00:54:17 -04:00
parent 60e61555c7
commit ac9a9034fb
2 changed files with 390 additions and 60 deletions

View file

@ -1,46 +1,60 @@
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory. """Apollo optimizer — configurable-rank gradient scaling.
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
for Memory-Efficient LLM Optimization" (arXiv:2412.05270). Performance" (arXiv:2412.05270, MLSys 2025).
For each parameter tensor, maintains: The core idea: AdamW's per-element learning rate scaling is redundant.
- projected first moment (m): [m, rank] or [rank, n] Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
- projected second moment (v): same shape these scaling factors using a low-rank auxiliary optimizer state based on
- random projection matrix (regenerated from seed) pure random projection.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch. across 100+ behavioral training examples per batch.
Key implementation details from the paper:
- Gradient scale factor α = (n/r) compensates for projection ratio
- Norm-growth limiter (γ=1.01) prevents early training instability
- Projection matrix refreshed every T steps (default 200), not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
""" """
import math
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
class Apollo(Optimizer): class Apollo(Optimizer):
"""Apollo: configurable-rank tensor-wise gradient scaling. """Apollo: configurable-rank gradient scaling optimizer.
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance). rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
Higher ranks cost proportionally more memory but may improve rank>1 is full Apollo (channel-wise scaling).
training quality for fine-grained behavioral fine-tuning.
Args: Args:
params: model parameters params: model parameters
lr: learning rate (default: 1e-4) lr: learning rate (default: 1e-4)
rank: projection rank (default: 1 = Apollo-Mini) rank: projection rank (default: 256)
betas: coefficients for moment estimates (default: (0.9, 0.999)) betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: term for numerical stability (default: 1e-8) eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01) weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear warmup steps (default: 0) warmup_steps: linear lr warmup steps (default: 0)
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise scale: gradient scale factor α. Default None = auto (n/r).
Paper uses 128 for Apollo-Mini.
proj_refresh: refresh projection matrix every T steps (default: 200)
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
Set to None to disable.
""" """
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
weight_decay=0.01, warmup_steps=0, scale_type='tensor'): eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
scale_type=scale_type) scale=scale,
proj_refresh=proj_refresh,
norm_growth_limit=norm_growth_limit)
super().__init__(params, defaults) super().__init__(params, defaults)
@torch.no_grad() @torch.no_grad()
@ -55,6 +69,9 @@ class Apollo(Optimizer):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
eps = group['eps'] eps = group['eps']
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
rank = group['rank']
proj_refresh = group['proj_refresh']
norm_growth_limit = group['norm_growth_limit']
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
@ -66,58 +83,75 @@ class Apollo(Optimizer):
# Initialize state # Initialize state
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
state['seed'] = id(p) # deterministic per-param seed state['seed'] = id(p) % (2**31)
# Determine projection dimension
rank = group['rank']
if grad.ndim >= 2 and min(grad.shape) >= rank: if grad.ndim >= 2 and min(grad.shape) >= rank:
if grad.shape[0] >= grad.shape[1]: # Determine projection dimension (project along smaller dim)
state['proj_dim'] = 'right' if grad.shape[0] <= grad.shape[1]:
moment_shape = (grad.shape[0], rank) state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
else: state['m'] = grad.shape[0]
state['proj_dim'] = 'left' state['n'] = grad.shape[1]
moment_shape = (rank, grad.shape[1]) moment_shape = (rank, grad.shape[1])
state['exp_avg'] = torch.zeros(moment_shape,
device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape,
device=p.device)
state['has_proj'] = True
state['rank'] = rank
else: else:
# 1D params (biases, norms): use standard Adam state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (grad.shape[0], rank)
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
state['has_proj'] = True
state['prev_scaled_norm'] = None
# Auto scale factor: α = √(smaller_dim / rank)
smaller_dim = min(grad.shape)
if group['scale'] is not None:
state['alpha'] = group['scale']
else:
state['alpha'] = math.sqrt(smaller_dim / rank)
else:
# 1D or small params: standard Adam
state['exp_avg'] = torch.zeros_like(grad) state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False state['has_proj'] = False
state['step'] += 1 state['step'] += 1
step = state['step']
# Learning rate warmup # Learning rate warmup
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']: if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
lr_scale = state['step'] / group['warmup_steps'] lr_scale = step / group['warmup_steps']
else: else:
lr_scale = 1.0 lr_scale = 1.0
if state['has_proj']: if state['has_proj']:
rank = state['rank'] alpha = state['alpha']
# Generate deterministic random projection matrix # Generate projection matrix (refresh every proj_refresh steps)
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
gen = torch.Generator(device=p.device) gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + state['step']) gen.manual_seed(state['seed'] + step)
# Project gradient to low-rank if state['proj_dim'] == 'left':
if state['proj_dim'] == 'right': # P: [rank, m], normalized rows
proj_mat = torch.randn(grad.shape[1], rank, P = torch.randn(rank, state['m'],
device=p.device, device=p.device, generator=gen)
generator=gen) P = P / (P.norm(dim=1, keepdim=True) + eps)
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps) state['proj_matrix'] = P
proj_grad = grad @ proj_mat # [m, rank]
else: else:
proj_mat = torch.randn(rank, grad.shape[0], # P: [rank, n], normalized rows
device=p.device, P = torch.randn(rank, state['n'],
generator=gen) device=p.device, generator=gen)
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps) P = P / (P.norm(dim=1, keepdim=True) + eps)
proj_grad = proj_mat @ grad # [rank, n] state['proj_matrix'] = P
P = state['proj_matrix']
# Project gradient to low-rank space
if state['proj_dim'] == 'left':
proj_grad = P @ grad # [rank, n]
else:
proj_grad = grad @ P.t() # [m, rank]
# Update moments in projected space # Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1) state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
@ -125,29 +159,52 @@ class Apollo(Optimizer):
proj_grad, proj_grad, value=1 - beta2) proj_grad, proj_grad, value=1 - beta2)
# Bias correction # Bias correction
bc1 = 1 - beta1 ** state['step'] bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** state['step'] bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1 m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2 v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space # Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps) adam_update = m_hat / (v_hat.sqrt() + eps)
# Tensor-wise scaling factor # Compute scaling factor
if rank == 1:
# Tensor-wise: single scalar (Apollo-Mini)
scaling = adam_update.norm() / (proj_grad.norm() + eps) scaling = adam_update.norm() / (proj_grad.norm() + eps)
scaled_grad = grad * (alpha * scaling)
else:
# Channel-wise: one factor per channel
if state['proj_dim'] == 'left':
# Channels are columns: scale along dim 1
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(0))
else:
# Channels are rows: scale along dim 1
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(1))
# Apply to full gradient # Norm-growth limiter (equation 4)
if norm_growth_limit is not None:
current_norm = scaled_grad.norm()
if state['prev_scaled_norm'] is not None:
prev_norm = state['prev_scaled_norm']
if current_norm > norm_growth_limit * prev_norm:
scaled_grad = scaled_grad * (
norm_growth_limit * prev_norm / (current_norm + eps))
state['prev_scaled_norm'] = scaled_grad.norm().item()
# Apply update
step_size = lr * lr_scale step_size = lr * lr_scale
p.add_(grad.to(p.dtype) * (-step_size * scaling)) p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
else: else:
# Standard Adam for 1D params # Standard Adam for 1D / small params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1) state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_( state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2) grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** state['step'] bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** state['step'] bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1 m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2 v_hat = state['exp_avg_sq'] / bc2

View file

@ -0,0 +1,273 @@
# Apollo Paper: Deep Analysis
Source: arXiv:2412.05270v4, MLSys 2025 Outstanding Paper Honorable Mention
Authors: Zhu, Zhang, Cong, Liu, Park, Chandra, Long, Pan, Wang, Lee
## The Core Insight
AdamW's per-element learning rate scaling is massively redundant for LLMs.
The element-wise scaling can be coarsened to channel-wise or even tensor-wise
without loss — and with slight improvement in some cases.
### The mathematical argument
AdamW's update rule, rewritten as a pure scaling operation:
```
Standard AdamW:
M_t = β₁M_{t-1} + (1-β₁)G_t # first moment
V_t = β₂V_{t-1} + (1-β₂)G_t² # second moment
G̃_t = M_t / (√V_t + ε) # scaled gradient
W_{t+1} = W_t - η·G̃_t - η·λ·W_t # weight update
Rewritten as scaling:
W_{t+1} = W_t - η · (G̃_t/G_t) · G_t # S = G̃_t/G_t is the scaling matrix
```
The scaling matrix S ∈ ^{m×n} is element-wise: each weight gets its own
learning rate. The paper's key observation: **this per-element granularity
is unnecessary.** S can be coarsened to:
- **Channel-wise**: one scaling factor per column (or row), s_j for channel j
- **Tensor-wise**: one scalar for the whole tensor (Apollo-Mini)
### Channel-wise scaling factor (equation 3)
```
s_j = ‖G̃_t[:,j]‖₂ / ‖G_t[:,j]‖₂
```
This computes the ratio of norms between the Adam-scaled gradient and the
raw gradient for each channel. It tells you: "how much should this channel's
gradient be amplified or dampened?"
The paper shows empirically that channel-wise scaling achieves slightly
BETTER perplexity than element-wise (24.43 vs 25.08 on LLaMA-130M).
The coarsening acts as implicit regularization.
## Apollo: Approximating the Scaling Factor
Computing channel-wise scaling still requires the full M_t and V_t matrices.
Apollo's contribution: approximate s_j using a low-rank auxiliary optimizer.
### Algorithm (Algorithm 1)
```
Input: W ∈ ^{m×n} (m ≤ n), lr η, scale factor α, rank r
Initialize: t = 0
repeat:
G_t = ∇φ(W_t) # full gradient
# Step 1: Project to low-rank space
if t mod T = 0:
P_t ← N(0, 1/r) # new random projection [r×m]
seed ← random
R_t = P_t · G_t # projected gradient [r×n]
# Step 2: Adam in low-rank space
M_t^R, V_t^R ← AdamW(R_t, β₁, β₂, λ=0) # moments on projected gradient
R̃_t = M_t^R / (√V_t^R + ε) # Adam-scaled projected gradient
# Step 3: Approximate channel-wise scaling
if APOLLO:
S ← diag(s₀^R, s₁^R, ..., s_n^R)
where s_j^R = ‖R̃_t[:,j]‖₂ / ‖R_t[:,j]‖₂
elif APOLLO-Mini:
S ← s^R · I
where s^R = ‖R̃_t‖₂ / ‖R_t‖₂ # single scalar
# Step 4: Update weight in original space
W_{t+1} = W_t + η·α · G_t·S - η·λ·W_t
```
### Key differences from my implementation
1. **Scale factor α**: The paper uses a gradient scale factor α (default √128
for Apollo-Mini) to compensate for the ratio √(n/r) between compact and
original space scaling factors. This is the `scale` parameter in
`apollo_torch.APOLLOAdamW`. **Our implementation is missing this.**
2. **Norm-growth limiter**: Instead of gradient clipping, they use a norm
growth limiter (equation 4):
```
if ‖G̃_t‖/‖G̃_{t-1}‖ > γ:
G̃_t ← (G̃_t/‖G̃_t‖) · γ · ‖G̃_{t-1}‖
```
Default γ = 1.01. This prevents loss spikes in early training.
**Our implementation is missing this.**
3. **Projection matrix refresh**: P_t is regenerated every T steps (default
T=200). Not every step. This amortizes the projection cost.
**Our implementation regenerates every step — wasteful.**
4. **The scaling is applied as G_t · S (post-multiply by diagonal)**:
The gradient is multiplied by the scaling factors, not the gradient
scaled and then applied. This means the full gradient direction is
preserved; only the per-channel magnitude changes.
## Theoretical Guarantees
### Theorem 4.1: First-moment approximation bound
For projected gradient R_t = P·G_t where P ∈ ^{r×m} is random Gaussian:
```
(1-ε)‖M_t[:,j]‖² ≤ ‖M_t^R[:,j]‖² ≤ (1+ε)‖M_t[:,j]‖²
```
with probability at least 1 - 2exp(-rε²/8).
This is a Johnson-Lindenstrauss result: random projection approximately
preserves norms. The channel-wise first moment norms in the projected space
are close to the original space norms.
### Theorem 4.2: Second-moment approximation bound
For ℓ₁ norm (element-wise second moment):
```
(1-ε)‖V_t[:,j]‖₁ ≤ ‖V_t^R[:,j]‖₁ ≤ (1+ε)‖V_t[:,j]‖₁
```
with probability at least 1-δ/2, when r ≥ (8/ε²)·log(2t/δ).
### Bounded update ratio (equation 9)
The ratio between compact and original scaling factors:
```
(√(1-ε))/(1+ε) ≤ √(n/r · s_j^R/s_j) ≤ (√(1+ε))/(1-ε)
```
This means the approximated scaling factor s_j^R differs from the true
scaling factor s_j by a predictable ratio of √(n/r), which is compensated
by the gradient scale factor α.
**This is why α = √128 for Apollo-Mini**: when r=1 and n is the smaller
dimension (typically ~128 for head dimensions), √(n/r) ≈ √128 ≈ 11.3.
The α compensates for this systematic ratio.
## Apollo-Mini: Tensor-wise Scaling
For rank r=1, channel-wise scaling becomes numerically unstable (one element
per channel in the projected space). Apollo-Mini coarsens further to a
single tensor-wise scaling factor:
```
s = ‖R̃_t‖₂ / ‖R_t‖₂
```
One scalar for the entire tensor. This is maximally coarse.
**Why it works**: The tensor-wise average of channel-wise scaling factors
smooths out the noise from rank-1 projection. The errors cancel across
channels. The paper shows Apollo-Mini actually OUTPERFORMS AdamW on
pre-training (Table 2, 3) — the coarsening acts as regularization.
## Why Apollo Can Beat AdamW (Section 5.5)
The paper provides two hypotheses:
### Hypothesis 1: Directional sharpness
Adam achieves lower directional sharpness than SGD, improving Transformer
training. But if directional sharpness is already too low (over-smoothed
landscape), the updates become too conservative. Apollo's coarser scaling
resembles SGD more (depends more on current gradient, less on history),
which can escape local optima that AdamW gets stuck in.
**Table 10**: Apollo/Apollo-Mini achieve lower directional sharpness than
Adam at epochs 5-20, comparable to SGD. This means Apollo navigates the
loss landscape more effectively.
### Hypothesis 2: Block-wise adaptive learning rates
Transformer blocks have varying Hessian spectra. Block-wise (channel/tensor)
adaptive rates are sufficient; fully per-element rates are redundant given
this structure. Apollo's channel/tensor-wise scaling naturally aligns with
the block structure of Transformers.
## Fine-tuning Results (Section 5.2)
On fine-tuning (Table 5, 6):
- **Common-sense reasoning (8 tasks)**: Apollo-Mini achieves 68.23 average
vs AdamW's 68.07. Essentially identical, with 0G optimizer memory.
- **MMLU**: Apollo-Mini competitive across all categories (STEM, Social
Sciences, Humanities, Other).
- **Learning rate range**: Sweeping [5e-6, 7.5e-6, 1e-5, 2.5e-5, 5e-5,
7.5e-5, 1e-4, 1.5e-4, 2e-4]. Best results at 1e-5 to 1e-4 range.
**Key finding for us**: Apollo-Mini performs on par with full AdamW for
fine-tuning. The rank doesn't matter much for fine-tuning quality — even
rank-1 is sufficient. The quality comes from the gradient direction (which
is preserved at full rank); only the scaling magnitude is approximated.
## Ablation Studies (Section 5.4)
### A1: Random projection ≈ SVD
Apollo performs equally well with random projection as SVD. Random projection
is dramatically cheaper (matrix multiply vs O(mn²) SVD).
### A2: Apollo-Mini effective even at rank 1
Apollo-Mini (rank-1) outperforms AdamW on pre-training. The tensor-wise
averaging of noise is a feature, not a bug.
### A3: Channel vs tensor granularity
Table 9: Difference between channel-wise and tensor-wise scaling is minimal
(~0.15 perplexity). For extreme low-rank (rank-1), tensor-wise actually
outperforms channel-wise.
### A4: Better with larger models and more tokens
Apollo's advantage over AdamW grows with model size and training tokens.
For larger models, the structured scaling becomes more beneficial.
### A5: Long-context training
Apollo performs on par with or better than AdamW for long-context pre-training
(sequence length 1024), with drastic memory savings.
## Implications for Our Use Case
### Learning rate
The paper sweeps [5e-6 to 2e-4] for fine-tuning. Our lr=1e-5 to 1e-4
range is in the sweet spot.
### Scale factor α
**We need to add this.** For rank-256 (our default), α should be
√(n/256) where n is the smaller weight dimension. For typical attention
weights with n=5120, that's √20 ≈ 4.5. For rank-1 it would be √5120 ≈ 71.6.
The `apollo_torch` library sets this as the `scale` parameter.
Our `apollo_mini.py` is missing the α factor entirely. This likely
means our scaling factors are systematically too small by √(n/r).
### Norm-growth limiter
We should add this (γ=1.01) for training stability, especially in early
steps. It prevents the loss spikes visible in Figure 3.
### Projection refresh
We can regenerate P every 200 steps instead of every step. Saves compute
and the theory shows it doesn't matter.
### Channel vs tensor scaling
For rank-256, channel-wise is slightly better. For rank-1, tensor-wise
is better. Since we default to rank-256, we should use channel-wise
(which we planned).
### Fine-tuning vs pre-training
The paper shows Apollo is slightly more beneficial for pre-training than
fine-tuning (where it merely matches AdamW). For fine-tuning, the gradient
direction matters more than the scaling precision — and Apollo preserves
the full gradient direction. This means our behavioral fine-tuning should
work well regardless of rank.
## Corrections to Our Implementation
1. **Add gradient scale factor α = √(n/r)** — critical for correct
scaling magnitude
2. **Add norm-growth limiter (γ=1.01)** — prevents early training instability
3. **Refresh projection every T=200 steps, not every step**
4. **Channel-wise scaling for rank>1, tensor-wise for rank=1**
5. **The scaling applies as G·diag(s), not s·G** — post-multiply, preserving
gradient direction per channel