apollo: rewrite optimizer from paper's math + add research analysis
Corrections from reading the full paper (arXiv:2412.05270): - Add gradient scale factor α = √(n/r) — compensates for systematic ratio between compact and original space scaling factors - Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training - Refresh projection matrix every 200 steps, not every step - Channel-wise scaling for rank>1, tensor-wise for rank=1 - Scaling applies as G·diag(s), preserving gradient direction per channel Research writeup in training/research/apollo-paper-analysis.md covers: - Full mathematical derivation (equations 1-9) - Theorems 4.1 and 4.2 (JL-based approximation bounds) - Why Apollo can beat AdamW (directional sharpness, Hessian spectra) - Fine-tuning results (matches AdamW at 0 memory cost) - Ablation studies (rank, scaling granularity, projection method) - Implications for our behavioral fine-tuning use case
This commit is contained in:
parent
60e61555c7
commit
ac9a9034fb
2 changed files with 390 additions and 60 deletions
|
|
@ -1,46 +1,60 @@
|
||||||
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
|
"""Apollo optimizer — configurable-rank gradient scaling.
|
||||||
|
|
||||||
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
|
||||||
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
Performance" (arXiv:2412.05270, MLSys 2025).
|
||||||
|
|
||||||
For each parameter tensor, maintains:
|
The core idea: AdamW's per-element learning rate scaling is redundant.
|
||||||
- projected first moment (m): [m, rank] or [rank, n]
|
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
|
||||||
- projected second moment (v): same shape
|
these scaling factors using a low-rank auxiliary optimizer state based on
|
||||||
- random projection matrix (regenerated from seed)
|
pure random projection.
|
||||||
|
|
||||||
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
||||||
compute overhead vs forward+backward. Captures gradient structure
|
compute overhead vs forward+backward. Captures gradient structure
|
||||||
across 100+ behavioral training examples per batch.
|
across 100+ behavioral training examples per batch.
|
||||||
|
|
||||||
|
Key implementation details from the paper:
|
||||||
|
- Gradient scale factor α = √(n/r) compensates for projection ratio
|
||||||
|
- Norm-growth limiter (γ=1.01) prevents early training instability
|
||||||
|
- Projection matrix refreshed every T steps (default 200), not every step
|
||||||
|
- Channel-wise scaling for rank>1, tensor-wise for rank=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class Apollo(Optimizer):
|
class Apollo(Optimizer):
|
||||||
"""Apollo: configurable-rank tensor-wise gradient scaling.
|
"""Apollo: configurable-rank gradient scaling optimizer.
|
||||||
|
|
||||||
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
|
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
|
||||||
Higher ranks cost proportionally more memory but may improve
|
rank>1 is full Apollo (channel-wise scaling).
|
||||||
training quality for fine-grained behavioral fine-tuning.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: model parameters
|
params: model parameters
|
||||||
lr: learning rate (default: 1e-4)
|
lr: learning rate (default: 1e-4)
|
||||||
rank: projection rank (default: 1 = Apollo-Mini)
|
rank: projection rank (default: 256)
|
||||||
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
betas: Adam momentum coefficients (default: (0.9, 0.999))
|
||||||
eps: term for numerical stability (default: 1e-8)
|
eps: numerical stability term (default: 1e-8)
|
||||||
weight_decay: decoupled weight decay (default: 0.01)
|
weight_decay: decoupled weight decay (default: 0.01)
|
||||||
warmup_steps: linear warmup steps (default: 0)
|
warmup_steps: linear lr warmup steps (default: 0)
|
||||||
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
|
scale: gradient scale factor α. Default None = auto √(n/r).
|
||||||
|
Paper uses √128 for Apollo-Mini.
|
||||||
|
proj_refresh: refresh projection matrix every T steps (default: 200)
|
||||||
|
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
|
||||||
|
Set to None to disable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
|
||||||
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
|
eps=1e-8, weight_decay=0.01, warmup_steps=0,
|
||||||
|
scale=None, proj_refresh=200, norm_growth_limit=1.01):
|
||||||
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
scale_type=scale_type)
|
scale=scale,
|
||||||
|
proj_refresh=proj_refresh,
|
||||||
|
norm_growth_limit=norm_growth_limit)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -55,6 +69,9 @@ class Apollo(Optimizer):
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group['betas']
|
||||||
eps = group['eps']
|
eps = group['eps']
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
|
rank = group['rank']
|
||||||
|
proj_refresh = group['proj_refresh']
|
||||||
|
norm_growth_limit = group['norm_growth_limit']
|
||||||
|
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
|
|
@ -66,58 +83,75 @@ class Apollo(Optimizer):
|
||||||
# Initialize state
|
# Initialize state
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
state['seed'] = id(p) # deterministic per-param seed
|
state['seed'] = id(p) % (2**31)
|
||||||
|
|
||||||
# Determine projection dimension
|
|
||||||
rank = group['rank']
|
|
||||||
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
||||||
if grad.shape[0] >= grad.shape[1]:
|
# Determine projection dimension (project along smaller dim)
|
||||||
state['proj_dim'] = 'right'
|
if grad.shape[0] <= grad.shape[1]:
|
||||||
moment_shape = (grad.shape[0], rank)
|
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
|
||||||
else:
|
state['m'] = grad.shape[0]
|
||||||
state['proj_dim'] = 'left'
|
state['n'] = grad.shape[1]
|
||||||
moment_shape = (rank, grad.shape[1])
|
moment_shape = (rank, grad.shape[1])
|
||||||
|
else:
|
||||||
|
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
|
||||||
|
state['m'] = grad.shape[0]
|
||||||
|
state['n'] = grad.shape[1]
|
||||||
|
moment_shape = (grad.shape[0], rank)
|
||||||
|
|
||||||
state['exp_avg'] = torch.zeros(moment_shape,
|
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
|
||||||
device=p.device)
|
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
|
||||||
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
|
||||||
device=p.device)
|
|
||||||
state['has_proj'] = True
|
state['has_proj'] = True
|
||||||
state['rank'] = rank
|
state['prev_scaled_norm'] = None
|
||||||
|
|
||||||
|
# Auto scale factor: α = √(smaller_dim / rank)
|
||||||
|
smaller_dim = min(grad.shape)
|
||||||
|
if group['scale'] is not None:
|
||||||
|
state['alpha'] = group['scale']
|
||||||
|
else:
|
||||||
|
state['alpha'] = math.sqrt(smaller_dim / rank)
|
||||||
else:
|
else:
|
||||||
# 1D params (biases, norms): use standard Adam
|
# 1D or small params: standard Adam
|
||||||
state['exp_avg'] = torch.zeros_like(grad)
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||||
state['has_proj'] = False
|
state['has_proj'] = False
|
||||||
|
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
|
step = state['step']
|
||||||
|
|
||||||
# Learning rate warmup
|
# Learning rate warmup
|
||||||
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']:
|
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
|
||||||
lr_scale = state['step'] / group['warmup_steps']
|
lr_scale = step / group['warmup_steps']
|
||||||
else:
|
else:
|
||||||
lr_scale = 1.0
|
lr_scale = 1.0
|
||||||
|
|
||||||
if state['has_proj']:
|
if state['has_proj']:
|
||||||
rank = state['rank']
|
alpha = state['alpha']
|
||||||
|
|
||||||
# Generate deterministic random projection matrix
|
# Generate projection matrix (refresh every proj_refresh steps)
|
||||||
gen = torch.Generator(device=p.device)
|
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
|
||||||
gen.manual_seed(state['seed'] + state['step'])
|
gen = torch.Generator(device=p.device)
|
||||||
|
gen.manual_seed(state['seed'] + step)
|
||||||
|
|
||||||
# Project gradient to low-rank
|
if state['proj_dim'] == 'left':
|
||||||
if state['proj_dim'] == 'right':
|
# P: [rank, m], normalized rows
|
||||||
proj_mat = torch.randn(grad.shape[1], rank,
|
P = torch.randn(rank, state['m'],
|
||||||
device=p.device,
|
device=p.device, generator=gen)
|
||||||
generator=gen)
|
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
||||||
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
|
state['proj_matrix'] = P
|
||||||
proj_grad = grad @ proj_mat # [m, rank]
|
else:
|
||||||
|
# P: [rank, n], normalized rows
|
||||||
|
P = torch.randn(rank, state['n'],
|
||||||
|
device=p.device, generator=gen)
|
||||||
|
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
||||||
|
state['proj_matrix'] = P
|
||||||
|
|
||||||
|
P = state['proj_matrix']
|
||||||
|
|
||||||
|
# Project gradient to low-rank space
|
||||||
|
if state['proj_dim'] == 'left':
|
||||||
|
proj_grad = P @ grad # [rank, n]
|
||||||
else:
|
else:
|
||||||
proj_mat = torch.randn(rank, grad.shape[0],
|
proj_grad = grad @ P.t() # [m, rank]
|
||||||
device=p.device,
|
|
||||||
generator=gen)
|
|
||||||
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
|
|
||||||
proj_grad = proj_mat @ grad # [rank, n]
|
|
||||||
|
|
||||||
# Update moments in projected space
|
# Update moments in projected space
|
||||||
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||||
|
|
@ -125,29 +159,52 @@ class Apollo(Optimizer):
|
||||||
proj_grad, proj_grad, value=1 - beta2)
|
proj_grad, proj_grad, value=1 - beta2)
|
||||||
|
|
||||||
# Bias correction
|
# Bias correction
|
||||||
bc1 = 1 - beta1 ** state['step']
|
bc1 = 1 - beta1 ** step
|
||||||
bc2 = 1 - beta2 ** state['step']
|
bc2 = 1 - beta2 ** step
|
||||||
m_hat = state['exp_avg'] / bc1
|
m_hat = state['exp_avg'] / bc1
|
||||||
v_hat = state['exp_avg_sq'] / bc2
|
v_hat = state['exp_avg_sq'] / bc2
|
||||||
|
|
||||||
# Adam update in projected space
|
# Adam update in projected space
|
||||||
adam_update = m_hat / (v_hat.sqrt() + eps)
|
adam_update = m_hat / (v_hat.sqrt() + eps)
|
||||||
|
|
||||||
# Tensor-wise scaling factor
|
# Compute scaling factor
|
||||||
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
if rank == 1:
|
||||||
|
# Tensor-wise: single scalar (Apollo-Mini)
|
||||||
|
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
||||||
|
scaled_grad = grad * (alpha * scaling)
|
||||||
|
else:
|
||||||
|
# Channel-wise: one factor per channel
|
||||||
|
if state['proj_dim'] == 'left':
|
||||||
|
# Channels are columns: scale along dim 1
|
||||||
|
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
|
||||||
|
scaled_grad = grad * (alpha * s.unsqueeze(0))
|
||||||
|
else:
|
||||||
|
# Channels are rows: scale along dim 1
|
||||||
|
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
|
||||||
|
scaled_grad = grad * (alpha * s.unsqueeze(1))
|
||||||
|
|
||||||
# Apply to full gradient
|
# Norm-growth limiter (equation 4)
|
||||||
|
if norm_growth_limit is not None:
|
||||||
|
current_norm = scaled_grad.norm()
|
||||||
|
if state['prev_scaled_norm'] is not None:
|
||||||
|
prev_norm = state['prev_scaled_norm']
|
||||||
|
if current_norm > norm_growth_limit * prev_norm:
|
||||||
|
scaled_grad = scaled_grad * (
|
||||||
|
norm_growth_limit * prev_norm / (current_norm + eps))
|
||||||
|
state['prev_scaled_norm'] = scaled_grad.norm().item()
|
||||||
|
|
||||||
|
# Apply update
|
||||||
step_size = lr * lr_scale
|
step_size = lr * lr_scale
|
||||||
p.add_(grad.to(p.dtype) * (-step_size * scaling))
|
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard Adam for 1D params
|
# Standard Adam for 1D / small params
|
||||||
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||||
grad, grad, value=1 - beta2)
|
grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
bc1 = 1 - beta1 ** state['step']
|
bc1 = 1 - beta1 ** step
|
||||||
bc2 = 1 - beta2 ** state['step']
|
bc2 = 1 - beta2 ** step
|
||||||
m_hat = state['exp_avg'] / bc1
|
m_hat = state['exp_avg'] / bc1
|
||||||
v_hat = state['exp_avg_sq'] / bc2
|
v_hat = state['exp_avg_sq'] / bc2
|
||||||
|
|
||||||
|
|
|
||||||
273
training/research/apollo-paper-analysis.md
Normal file
273
training/research/apollo-paper-analysis.md
Normal file
|
|
@ -0,0 +1,273 @@
|
||||||
|
# Apollo Paper: Deep Analysis
|
||||||
|
|
||||||
|
Source: arXiv:2412.05270v4, MLSys 2025 Outstanding Paper Honorable Mention
|
||||||
|
Authors: Zhu, Zhang, Cong, Liu, Park, Chandra, Long, Pan, Wang, Lee
|
||||||
|
|
||||||
|
## The Core Insight
|
||||||
|
|
||||||
|
AdamW's per-element learning rate scaling is massively redundant for LLMs.
|
||||||
|
The element-wise scaling can be coarsened to channel-wise or even tensor-wise
|
||||||
|
without loss — and with slight improvement in some cases.
|
||||||
|
|
||||||
|
### The mathematical argument
|
||||||
|
|
||||||
|
AdamW's update rule, rewritten as a pure scaling operation:
|
||||||
|
|
||||||
|
```
|
||||||
|
Standard AdamW:
|
||||||
|
M_t = β₁M_{t-1} + (1-β₁)G_t # first moment
|
||||||
|
V_t = β₂V_{t-1} + (1-β₂)G_t² # second moment
|
||||||
|
G̃_t = M_t / (√V_t + ε) # scaled gradient
|
||||||
|
W_{t+1} = W_t - η·G̃_t - η·λ·W_t # weight update
|
||||||
|
|
||||||
|
Rewritten as scaling:
|
||||||
|
W_{t+1} = W_t - η · (G̃_t/G_t) · G_t # S = G̃_t/G_t is the scaling matrix
|
||||||
|
```
|
||||||
|
|
||||||
|
The scaling matrix S ∈ ℝ^{m×n} is element-wise: each weight gets its own
|
||||||
|
learning rate. The paper's key observation: **this per-element granularity
|
||||||
|
is unnecessary.** S can be coarsened to:
|
||||||
|
|
||||||
|
- **Channel-wise**: one scaling factor per column (or row), s_j for channel j
|
||||||
|
- **Tensor-wise**: one scalar for the whole tensor (Apollo-Mini)
|
||||||
|
|
||||||
|
### Channel-wise scaling factor (equation 3)
|
||||||
|
|
||||||
|
```
|
||||||
|
s_j = ‖G̃_t[:,j]‖₂ / ‖G_t[:,j]‖₂
|
||||||
|
```
|
||||||
|
|
||||||
|
This computes the ratio of norms between the Adam-scaled gradient and the
|
||||||
|
raw gradient for each channel. It tells you: "how much should this channel's
|
||||||
|
gradient be amplified or dampened?"
|
||||||
|
|
||||||
|
The paper shows empirically that channel-wise scaling achieves slightly
|
||||||
|
BETTER perplexity than element-wise (24.43 vs 25.08 on LLaMA-130M).
|
||||||
|
The coarsening acts as implicit regularization.
|
||||||
|
|
||||||
|
## Apollo: Approximating the Scaling Factor
|
||||||
|
|
||||||
|
Computing channel-wise scaling still requires the full M_t and V_t matrices.
|
||||||
|
Apollo's contribution: approximate s_j using a low-rank auxiliary optimizer.
|
||||||
|
|
||||||
|
### Algorithm (Algorithm 1)
|
||||||
|
|
||||||
|
```
|
||||||
|
Input: W ∈ ℝ^{m×n} (m ≤ n), lr η, scale factor α, rank r
|
||||||
|
Initialize: t = 0
|
||||||
|
|
||||||
|
repeat:
|
||||||
|
G_t = ∇φ(W_t) # full gradient
|
||||||
|
|
||||||
|
# Step 1: Project to low-rank space
|
||||||
|
if t mod T = 0:
|
||||||
|
P_t ← N(0, 1/r) # new random projection [r×m]
|
||||||
|
seed ← random
|
||||||
|
R_t = P_t · G_t # projected gradient [r×n]
|
||||||
|
|
||||||
|
# Step 2: Adam in low-rank space
|
||||||
|
M_t^R, V_t^R ← AdamW(R_t, β₁, β₂, λ=0) # moments on projected gradient
|
||||||
|
R̃_t = M_t^R / (√V_t^R + ε) # Adam-scaled projected gradient
|
||||||
|
|
||||||
|
# Step 3: Approximate channel-wise scaling
|
||||||
|
if APOLLO:
|
||||||
|
S ← diag(s₀^R, s₁^R, ..., s_n^R)
|
||||||
|
where s_j^R = ‖R̃_t[:,j]‖₂ / ‖R_t[:,j]‖₂
|
||||||
|
elif APOLLO-Mini:
|
||||||
|
S ← s^R · I
|
||||||
|
where s^R = ‖R̃_t‖₂ / ‖R_t‖₂ # single scalar
|
||||||
|
|
||||||
|
# Step 4: Update weight in original space
|
||||||
|
W_{t+1} = W_t + η·α · G_t·S - η·λ·W_t
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key differences from my implementation
|
||||||
|
|
||||||
|
1. **Scale factor α**: The paper uses a gradient scale factor α (default √128
|
||||||
|
for Apollo-Mini) to compensate for the ratio √(n/r) between compact and
|
||||||
|
original space scaling factors. This is the `scale` parameter in
|
||||||
|
`apollo_torch.APOLLOAdamW`. **Our implementation is missing this.**
|
||||||
|
|
||||||
|
2. **Norm-growth limiter**: Instead of gradient clipping, they use a norm
|
||||||
|
growth limiter (equation 4):
|
||||||
|
```
|
||||||
|
if ‖G̃_t‖/‖G̃_{t-1}‖ > γ:
|
||||||
|
G̃_t ← (G̃_t/‖G̃_t‖) · γ · ‖G̃_{t-1}‖
|
||||||
|
```
|
||||||
|
Default γ = 1.01. This prevents loss spikes in early training.
|
||||||
|
**Our implementation is missing this.**
|
||||||
|
|
||||||
|
3. **Projection matrix refresh**: P_t is regenerated every T steps (default
|
||||||
|
T=200). Not every step. This amortizes the projection cost.
|
||||||
|
**Our implementation regenerates every step — wasteful.**
|
||||||
|
|
||||||
|
4. **The scaling is applied as G_t · S (post-multiply by diagonal)**:
|
||||||
|
The gradient is multiplied by the scaling factors, not the gradient
|
||||||
|
scaled and then applied. This means the full gradient direction is
|
||||||
|
preserved; only the per-channel magnitude changes.
|
||||||
|
|
||||||
|
## Theoretical Guarantees
|
||||||
|
|
||||||
|
### Theorem 4.1: First-moment approximation bound
|
||||||
|
|
||||||
|
For projected gradient R_t = P·G_t where P ∈ ℝ^{r×m} is random Gaussian:
|
||||||
|
|
||||||
|
```
|
||||||
|
(1-ε)‖M_t[:,j]‖² ≤ ‖M_t^R[:,j]‖² ≤ (1+ε)‖M_t[:,j]‖²
|
||||||
|
```
|
||||||
|
|
||||||
|
with probability at least 1 - 2exp(-rε²/8).
|
||||||
|
|
||||||
|
This is a Johnson-Lindenstrauss result: random projection approximately
|
||||||
|
preserves norms. The channel-wise first moment norms in the projected space
|
||||||
|
are close to the original space norms.
|
||||||
|
|
||||||
|
### Theorem 4.2: Second-moment approximation bound
|
||||||
|
|
||||||
|
For ℓ₁ norm (element-wise second moment):
|
||||||
|
|
||||||
|
```
|
||||||
|
(1-ε)‖V_t[:,j]‖₁ ≤ ‖V_t^R[:,j]‖₁ ≤ (1+ε)‖V_t[:,j]‖₁
|
||||||
|
```
|
||||||
|
|
||||||
|
with probability at least 1-δ/2, when r ≥ (8/ε²)·log(2t/δ).
|
||||||
|
|
||||||
|
### Bounded update ratio (equation 9)
|
||||||
|
|
||||||
|
The ratio between compact and original scaling factors:
|
||||||
|
|
||||||
|
```
|
||||||
|
(√(1-ε))/(1+ε) ≤ √(n/r · s_j^R/s_j) ≤ (√(1+ε))/(1-ε)
|
||||||
|
```
|
||||||
|
|
||||||
|
This means the approximated scaling factor s_j^R differs from the true
|
||||||
|
scaling factor s_j by a predictable ratio of √(n/r), which is compensated
|
||||||
|
by the gradient scale factor α.
|
||||||
|
|
||||||
|
**This is why α = √128 for Apollo-Mini**: when r=1 and n is the smaller
|
||||||
|
dimension (typically ~128 for head dimensions), √(n/r) ≈ √128 ≈ 11.3.
|
||||||
|
The α compensates for this systematic ratio.
|
||||||
|
|
||||||
|
## Apollo-Mini: Tensor-wise Scaling
|
||||||
|
|
||||||
|
For rank r=1, channel-wise scaling becomes numerically unstable (one element
|
||||||
|
per channel in the projected space). Apollo-Mini coarsens further to a
|
||||||
|
single tensor-wise scaling factor:
|
||||||
|
|
||||||
|
```
|
||||||
|
s = ‖R̃_t‖₂ / ‖R_t‖₂
|
||||||
|
```
|
||||||
|
|
||||||
|
One scalar for the entire tensor. This is maximally coarse.
|
||||||
|
|
||||||
|
**Why it works**: The tensor-wise average of channel-wise scaling factors
|
||||||
|
smooths out the noise from rank-1 projection. The errors cancel across
|
||||||
|
channels. The paper shows Apollo-Mini actually OUTPERFORMS AdamW on
|
||||||
|
pre-training (Table 2, 3) — the coarsening acts as regularization.
|
||||||
|
|
||||||
|
## Why Apollo Can Beat AdamW (Section 5.5)
|
||||||
|
|
||||||
|
The paper provides two hypotheses:
|
||||||
|
|
||||||
|
### Hypothesis 1: Directional sharpness
|
||||||
|
|
||||||
|
Adam achieves lower directional sharpness than SGD, improving Transformer
|
||||||
|
training. But if directional sharpness is already too low (over-smoothed
|
||||||
|
landscape), the updates become too conservative. Apollo's coarser scaling
|
||||||
|
resembles SGD more (depends more on current gradient, less on history),
|
||||||
|
which can escape local optima that AdamW gets stuck in.
|
||||||
|
|
||||||
|
**Table 10**: Apollo/Apollo-Mini achieve lower directional sharpness than
|
||||||
|
Adam at epochs 5-20, comparable to SGD. This means Apollo navigates the
|
||||||
|
loss landscape more effectively.
|
||||||
|
|
||||||
|
### Hypothesis 2: Block-wise adaptive learning rates
|
||||||
|
|
||||||
|
Transformer blocks have varying Hessian spectra. Block-wise (channel/tensor)
|
||||||
|
adaptive rates are sufficient; fully per-element rates are redundant given
|
||||||
|
this structure. Apollo's channel/tensor-wise scaling naturally aligns with
|
||||||
|
the block structure of Transformers.
|
||||||
|
|
||||||
|
## Fine-tuning Results (Section 5.2)
|
||||||
|
|
||||||
|
On fine-tuning (Table 5, 6):
|
||||||
|
|
||||||
|
- **Common-sense reasoning (8 tasks)**: Apollo-Mini achieves 68.23 average
|
||||||
|
vs AdamW's 68.07. Essentially identical, with 0G optimizer memory.
|
||||||
|
- **MMLU**: Apollo-Mini competitive across all categories (STEM, Social
|
||||||
|
Sciences, Humanities, Other).
|
||||||
|
- **Learning rate range**: Sweeping [5e-6, 7.5e-6, 1e-5, 2.5e-5, 5e-5,
|
||||||
|
7.5e-5, 1e-4, 1.5e-4, 2e-4]. Best results at 1e-5 to 1e-4 range.
|
||||||
|
|
||||||
|
**Key finding for us**: Apollo-Mini performs on par with full AdamW for
|
||||||
|
fine-tuning. The rank doesn't matter much for fine-tuning quality — even
|
||||||
|
rank-1 is sufficient. The quality comes from the gradient direction (which
|
||||||
|
is preserved at full rank); only the scaling magnitude is approximated.
|
||||||
|
|
||||||
|
## Ablation Studies (Section 5.4)
|
||||||
|
|
||||||
|
### A1: Random projection ≈ SVD
|
||||||
|
Apollo performs equally well with random projection as SVD. Random projection
|
||||||
|
is dramatically cheaper (matrix multiply vs O(mn²) SVD).
|
||||||
|
|
||||||
|
### A2: Apollo-Mini effective even at rank 1
|
||||||
|
Apollo-Mini (rank-1) outperforms AdamW on pre-training. The tensor-wise
|
||||||
|
averaging of noise is a feature, not a bug.
|
||||||
|
|
||||||
|
### A3: Channel vs tensor granularity
|
||||||
|
Table 9: Difference between channel-wise and tensor-wise scaling is minimal
|
||||||
|
(~0.15 perplexity). For extreme low-rank (rank-1), tensor-wise actually
|
||||||
|
outperforms channel-wise.
|
||||||
|
|
||||||
|
### A4: Better with larger models and more tokens
|
||||||
|
Apollo's advantage over AdamW grows with model size and training tokens.
|
||||||
|
For larger models, the structured scaling becomes more beneficial.
|
||||||
|
|
||||||
|
### A5: Long-context training
|
||||||
|
Apollo performs on par with or better than AdamW for long-context pre-training
|
||||||
|
(sequence length 1024), with drastic memory savings.
|
||||||
|
|
||||||
|
## Implications for Our Use Case
|
||||||
|
|
||||||
|
### Learning rate
|
||||||
|
The paper sweeps [5e-6 to 2e-4] for fine-tuning. Our lr=1e-5 to 1e-4
|
||||||
|
range is in the sweet spot.
|
||||||
|
|
||||||
|
### Scale factor α
|
||||||
|
**We need to add this.** For rank-256 (our default), α should be
|
||||||
|
√(n/256) where n is the smaller weight dimension. For typical attention
|
||||||
|
weights with n=5120, that's √20 ≈ 4.5. For rank-1 it would be √5120 ≈ 71.6.
|
||||||
|
The `apollo_torch` library sets this as the `scale` parameter.
|
||||||
|
|
||||||
|
Our `apollo_mini.py` is missing the α factor entirely. This likely
|
||||||
|
means our scaling factors are systematically too small by √(n/r).
|
||||||
|
|
||||||
|
### Norm-growth limiter
|
||||||
|
We should add this (γ=1.01) for training stability, especially in early
|
||||||
|
steps. It prevents the loss spikes visible in Figure 3.
|
||||||
|
|
||||||
|
### Projection refresh
|
||||||
|
We can regenerate P every 200 steps instead of every step. Saves compute
|
||||||
|
and the theory shows it doesn't matter.
|
||||||
|
|
||||||
|
### Channel vs tensor scaling
|
||||||
|
For rank-256, channel-wise is slightly better. For rank-1, tensor-wise
|
||||||
|
is better. Since we default to rank-256, we should use channel-wise
|
||||||
|
(which we planned).
|
||||||
|
|
||||||
|
### Fine-tuning vs pre-training
|
||||||
|
The paper shows Apollo is slightly more beneficial for pre-training than
|
||||||
|
fine-tuning (where it merely matches AdamW). For fine-tuning, the gradient
|
||||||
|
direction matters more than the scaling precision — and Apollo preserves
|
||||||
|
the full gradient direction. This means our behavioral fine-tuning should
|
||||||
|
work well regardless of rank.
|
||||||
|
|
||||||
|
## Corrections to Our Implementation
|
||||||
|
|
||||||
|
1. **Add gradient scale factor α = √(n/r)** — critical for correct
|
||||||
|
scaling magnitude
|
||||||
|
2. **Add norm-growth limiter (γ=1.01)** — prevents early training instability
|
||||||
|
3. **Refresh projection every T=200 steps, not every step**
|
||||||
|
4. **Channel-wise scaling for rank>1, tensor-wise for rank=1**
|
||||||
|
5. **The scaling applies as G·diag(s), not s·G** — post-multiply, preserving
|
||||||
|
gradient direction per channel
|
||||||
Loading…
Add table
Add a link
Reference in a new issue