Corrections from reading the full paper (arXiv:2412.05270): - Add gradient scale factor α = √(n/r) — compensates for systematic ratio between compact and original space scaling factors - Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training - Refresh projection matrix every 200 steps, not every step - Channel-wise scaling for rank>1, tensor-wise for rank=1 - Scaling applies as G·diag(s), preserving gradient direction per channel Research writeup in training/research/apollo-paper-analysis.md covers: - Full mathematical derivation (equations 1-9) - Theorems 4.1 and 4.2 (JL-based approximation bounds) - Why Apollo can beat AdamW (directional sharpness, Hessian spectra) - Fine-tuning results (matches AdamW at 0 memory cost) - Ablation studies (rank, scaling granularity, projection method) - Implications for our behavioral fine-tuning use case
10 KiB
Apollo Paper: Deep Analysis
Source: arXiv:2412.05270v4, MLSys 2025 Outstanding Paper Honorable Mention Authors: Zhu, Zhang, Cong, Liu, Park, Chandra, Long, Pan, Wang, Lee
The Core Insight
AdamW's per-element learning rate scaling is massively redundant for LLMs. The element-wise scaling can be coarsened to channel-wise or even tensor-wise without loss — and with slight improvement in some cases.
The mathematical argument
AdamW's update rule, rewritten as a pure scaling operation:
Standard AdamW:
M_t = β₁M_{t-1} + (1-β₁)G_t # first moment
V_t = β₂V_{t-1} + (1-β₂)G_t² # second moment
G̃_t = M_t / (√V_t + ε) # scaled gradient
W_{t+1} = W_t - η·G̃_t - η·λ·W_t # weight update
Rewritten as scaling:
W_{t+1} = W_t - η · (G̃_t/G_t) · G_t # S = G̃_t/G_t is the scaling matrix
The scaling matrix S ∈ ℝ^{m×n} is element-wise: each weight gets its own learning rate. The paper's key observation: this per-element granularity is unnecessary. S can be coarsened to:
- Channel-wise: one scaling factor per column (or row), s_j for channel j
- Tensor-wise: one scalar for the whole tensor (Apollo-Mini)
Channel-wise scaling factor (equation 3)
s_j = ‖G̃_t[:,j]‖₂ / ‖G_t[:,j]‖₂
This computes the ratio of norms between the Adam-scaled gradient and the raw gradient for each channel. It tells you: "how much should this channel's gradient be amplified or dampened?"
The paper shows empirically that channel-wise scaling achieves slightly BETTER perplexity than element-wise (24.43 vs 25.08 on LLaMA-130M). The coarsening acts as implicit regularization.
Apollo: Approximating the Scaling Factor
Computing channel-wise scaling still requires the full M_t and V_t matrices. Apollo's contribution: approximate s_j using a low-rank auxiliary optimizer.
Algorithm (Algorithm 1)
Input: W ∈ ℝ^{m×n} (m ≤ n), lr η, scale factor α, rank r
Initialize: t = 0
repeat:
G_t = ∇φ(W_t) # full gradient
# Step 1: Project to low-rank space
if t mod T = 0:
P_t ← N(0, 1/r) # new random projection [r×m]
seed ← random
R_t = P_t · G_t # projected gradient [r×n]
# Step 2: Adam in low-rank space
M_t^R, V_t^R ← AdamW(R_t, β₁, β₂, λ=0) # moments on projected gradient
R̃_t = M_t^R / (√V_t^R + ε) # Adam-scaled projected gradient
# Step 3: Approximate channel-wise scaling
if APOLLO:
S ← diag(s₀^R, s₁^R, ..., s_n^R)
where s_j^R = ‖R̃_t[:,j]‖₂ / ‖R_t[:,j]‖₂
elif APOLLO-Mini:
S ← s^R · I
where s^R = ‖R̃_t‖₂ / ‖R_t‖₂ # single scalar
# Step 4: Update weight in original space
W_{t+1} = W_t + η·α · G_t·S - η·λ·W_t
Key differences from my implementation
-
Scale factor α: The paper uses a gradient scale factor α (default √128 for Apollo-Mini) to compensate for the ratio √(n/r) between compact and original space scaling factors. This is the
scaleparameter inapollo_torch.APOLLOAdamW. Our implementation is missing this. -
Norm-growth limiter: Instead of gradient clipping, they use a norm growth limiter (equation 4):
if ‖G̃_t‖/‖G̃_{t-1}‖ > γ: G̃_t ← (G̃_t/‖G̃_t‖) · γ · ‖G̃_{t-1}‖Default γ = 1.01. This prevents loss spikes in early training. Our implementation is missing this.
-
Projection matrix refresh: P_t is regenerated every T steps (default T=200). Not every step. This amortizes the projection cost. Our implementation regenerates every step — wasteful.
-
The scaling is applied as G_t · S (post-multiply by diagonal): The gradient is multiplied by the scaling factors, not the gradient scaled and then applied. This means the full gradient direction is preserved; only the per-channel magnitude changes.
Theoretical Guarantees
Theorem 4.1: First-moment approximation bound
For projected gradient R_t = P·G_t where P ∈ ℝ^{r×m} is random Gaussian:
(1-ε)‖M_t[:,j]‖² ≤ ‖M_t^R[:,j]‖² ≤ (1+ε)‖M_t[:,j]‖²
with probability at least 1 - 2exp(-rε²/8).
This is a Johnson-Lindenstrauss result: random projection approximately preserves norms. The channel-wise first moment norms in the projected space are close to the original space norms.
Theorem 4.2: Second-moment approximation bound
For ℓ₁ norm (element-wise second moment):
(1-ε)‖V_t[:,j]‖₁ ≤ ‖V_t^R[:,j]‖₁ ≤ (1+ε)‖V_t[:,j]‖₁
with probability at least 1-δ/2, when r ≥ (8/ε²)·log(2t/δ).
Bounded update ratio (equation 9)
The ratio between compact and original scaling factors:
(√(1-ε))/(1+ε) ≤ √(n/r · s_j^R/s_j) ≤ (√(1+ε))/(1-ε)
This means the approximated scaling factor s_j^R differs from the true scaling factor s_j by a predictable ratio of √(n/r), which is compensated by the gradient scale factor α.
This is why α = √128 for Apollo-Mini: when r=1 and n is the smaller dimension (typically ~128 for head dimensions), √(n/r) ≈ √128 ≈ 11.3. The α compensates for this systematic ratio.
Apollo-Mini: Tensor-wise Scaling
For rank r=1, channel-wise scaling becomes numerically unstable (one element per channel in the projected space). Apollo-Mini coarsens further to a single tensor-wise scaling factor:
s = ‖R̃_t‖₂ / ‖R_t‖₂
One scalar for the entire tensor. This is maximally coarse.
Why it works: The tensor-wise average of channel-wise scaling factors smooths out the noise from rank-1 projection. The errors cancel across channels. The paper shows Apollo-Mini actually OUTPERFORMS AdamW on pre-training (Table 2, 3) — the coarsening acts as regularization.
Why Apollo Can Beat AdamW (Section 5.5)
The paper provides two hypotheses:
Hypothesis 1: Directional sharpness
Adam achieves lower directional sharpness than SGD, improving Transformer training. But if directional sharpness is already too low (over-smoothed landscape), the updates become too conservative. Apollo's coarser scaling resembles SGD more (depends more on current gradient, less on history), which can escape local optima that AdamW gets stuck in.
Table 10: Apollo/Apollo-Mini achieve lower directional sharpness than Adam at epochs 5-20, comparable to SGD. This means Apollo navigates the loss landscape more effectively.
Hypothesis 2: Block-wise adaptive learning rates
Transformer blocks have varying Hessian spectra. Block-wise (channel/tensor) adaptive rates are sufficient; fully per-element rates are redundant given this structure. Apollo's channel/tensor-wise scaling naturally aligns with the block structure of Transformers.
Fine-tuning Results (Section 5.2)
On fine-tuning (Table 5, 6):
- Common-sense reasoning (8 tasks): Apollo-Mini achieves 68.23 average vs AdamW's 68.07. Essentially identical, with 0G optimizer memory.
- MMLU: Apollo-Mini competitive across all categories (STEM, Social Sciences, Humanities, Other).
- Learning rate range: Sweeping [5e-6, 7.5e-6, 1e-5, 2.5e-5, 5e-5, 7.5e-5, 1e-4, 1.5e-4, 2e-4]. Best results at 1e-5 to 1e-4 range.
Key finding for us: Apollo-Mini performs on par with full AdamW for fine-tuning. The rank doesn't matter much for fine-tuning quality — even rank-1 is sufficient. The quality comes from the gradient direction (which is preserved at full rank); only the scaling magnitude is approximated.
Ablation Studies (Section 5.4)
A1: Random projection ≈ SVD
Apollo performs equally well with random projection as SVD. Random projection is dramatically cheaper (matrix multiply vs O(mn²) SVD).
A2: Apollo-Mini effective even at rank 1
Apollo-Mini (rank-1) outperforms AdamW on pre-training. The tensor-wise averaging of noise is a feature, not a bug.
A3: Channel vs tensor granularity
Table 9: Difference between channel-wise and tensor-wise scaling is minimal (~0.15 perplexity). For extreme low-rank (rank-1), tensor-wise actually outperforms channel-wise.
A4: Better with larger models and more tokens
Apollo's advantage over AdamW grows with model size and training tokens. For larger models, the structured scaling becomes more beneficial.
A5: Long-context training
Apollo performs on par with or better than AdamW for long-context pre-training (sequence length 1024), with drastic memory savings.
Implications for Our Use Case
Learning rate
The paper sweeps [5e-6 to 2e-4] for fine-tuning. Our lr=1e-5 to 1e-4 range is in the sweet spot.
Scale factor α
We need to add this. For rank-256 (our default), α should be
√(n/256) where n is the smaller weight dimension. For typical attention
weights with n=5120, that's √20 ≈ 4.5. For rank-1 it would be √5120 ≈ 71.6.
The apollo_torch library sets this as the scale parameter.
Our apollo_mini.py is missing the α factor entirely. This likely
means our scaling factors are systematically too small by √(n/r).
Norm-growth limiter
We should add this (γ=1.01) for training stability, especially in early steps. It prevents the loss spikes visible in Figure 3.
Projection refresh
We can regenerate P every 200 steps instead of every step. Saves compute and the theory shows it doesn't matter.
Channel vs tensor scaling
For rank-256, channel-wise is slightly better. For rank-1, tensor-wise is better. Since we default to rank-256, we should use channel-wise (which we planned).
Fine-tuning vs pre-training
The paper shows Apollo is slightly more beneficial for pre-training than fine-tuning (where it merely matches AdamW). For fine-tuning, the gradient direction matters more than the scaling precision — and Apollo preserves the full gradient direction. This means our behavioral fine-tuning should work well regardless of rank.
Corrections to Our Implementation
- Add gradient scale factor α = √(n/r) — critical for correct scaling magnitude
- Add norm-growth limiter (γ=1.01) — prevents early training instability
- Refresh projection every T=200 steps, not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
- The scaling applies as G·diag(s), not s·G — post-multiply, preserving gradient direction per channel