consciousness/training/research/apollo-paper-analysis.md
ProofOfConcept ac9a9034fb apollo: rewrite optimizer from paper's math + add research analysis
Corrections from reading the full paper (arXiv:2412.05270):
- Add gradient scale factor α = √(n/r) — compensates for systematic
  ratio between compact and original space scaling factors
- Add norm-growth limiter (γ=1.01) — prevents loss spikes in early training
- Refresh projection matrix every 200 steps, not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
- Scaling applies as G·diag(s), preserving gradient direction per channel

Research writeup in training/research/apollo-paper-analysis.md covers:
- Full mathematical derivation (equations 1-9)
- Theorems 4.1 and 4.2 (JL-based approximation bounds)
- Why Apollo can beat AdamW (directional sharpness, Hessian spectra)
- Fine-tuning results (matches AdamW at 0 memory cost)
- Ablation studies (rank, scaling granularity, projection method)
- Implications for our behavioral fine-tuning use case
2026-03-31 00:54:17 -04:00

10 KiB
Raw Blame History

Apollo Paper: Deep Analysis

Source: arXiv:2412.05270v4, MLSys 2025 Outstanding Paper Honorable Mention Authors: Zhu, Zhang, Cong, Liu, Park, Chandra, Long, Pan, Wang, Lee

The Core Insight

AdamW's per-element learning rate scaling is massively redundant for LLMs. The element-wise scaling can be coarsened to channel-wise or even tensor-wise without loss — and with slight improvement in some cases.

The mathematical argument

AdamW's update rule, rewritten as a pure scaling operation:

Standard AdamW:
  M_t = β₁M_{t-1} + (1-β₁)G_t               # first moment
  V_t = β₂V_{t-1} + (1-β₂)G_t²               # second moment
  G̃_t = M_t / (√V_t + ε)                     # scaled gradient
  W_{t+1} = W_t - η·G̃_t - η·λ·W_t           # weight update

Rewritten as scaling:
  W_{t+1} = W_t - η · (G̃_t/G_t) · G_t       # S = G̃_t/G_t is the scaling matrix

The scaling matrix S ∈ ^{m×n} is element-wise: each weight gets its own learning rate. The paper's key observation: this per-element granularity is unnecessary. S can be coarsened to:

  • Channel-wise: one scaling factor per column (or row), s_j for channel j
  • Tensor-wise: one scalar for the whole tensor (Apollo-Mini)

Channel-wise scaling factor (equation 3)

s_j = ‖G̃_t[:,j]‖₂ / ‖G_t[:,j]‖₂

This computes the ratio of norms between the Adam-scaled gradient and the raw gradient for each channel. It tells you: "how much should this channel's gradient be amplified or dampened?"

The paper shows empirically that channel-wise scaling achieves slightly BETTER perplexity than element-wise (24.43 vs 25.08 on LLaMA-130M). The coarsening acts as implicit regularization.

Apollo: Approximating the Scaling Factor

Computing channel-wise scaling still requires the full M_t and V_t matrices. Apollo's contribution: approximate s_j using a low-rank auxiliary optimizer.

Algorithm (Algorithm 1)

Input: W ∈ ^{m×n} (m ≤ n), lr η, scale factor α, rank r
Initialize: t = 0

repeat:
  G_t = ∇φ(W_t)                              # full gradient

  # Step 1: Project to low-rank space
  if t mod T = 0:
    P_t ← N(0, 1/r)                          # new random projection [r×m]
    seed ← random
  R_t = P_t · G_t                             # projected gradient [r×n]

  # Step 2: Adam in low-rank space
  M_t^R, V_t^R ← AdamW(R_t, β₁, β₂, λ=0)   # moments on projected gradient
  R̃_t = M_t^R / (√V_t^R + ε)                # Adam-scaled projected gradient

  # Step 3: Approximate channel-wise scaling
  if APOLLO:
    S ← diag(s₀^R, s₁^R, ..., s_n^R)
    where s_j^R = ‖R̃_t[:,j]‖₂ / ‖R_t[:,j]‖₂
  elif APOLLO-Mini:
    S ← s^R · I
    where s^R = ‖R̃_t‖₂ / ‖R_t‖₂             # single scalar

  # Step 4: Update weight in original space
  W_{t+1} = W_t + η·α · G_t·S - η·λ·W_t

Key differences from my implementation

  1. Scale factor α: The paper uses a gradient scale factor α (default √128 for Apollo-Mini) to compensate for the ratio √(n/r) between compact and original space scaling factors. This is the scale parameter in apollo_torch.APOLLOAdamW. Our implementation is missing this.

  2. Norm-growth limiter: Instead of gradient clipping, they use a norm growth limiter (equation 4):

    if ‖G̃_t‖/‖G̃_{t-1}‖ > γ:
      G̃_t ← (G̃_t/‖G̃_t‖) · γ · ‖G̃_{t-1}‖
    

    Default γ = 1.01. This prevents loss spikes in early training. Our implementation is missing this.

  3. Projection matrix refresh: P_t is regenerated every T steps (default T=200). Not every step. This amortizes the projection cost. Our implementation regenerates every step — wasteful.

  4. The scaling is applied as G_t · S (post-multiply by diagonal): The gradient is multiplied by the scaling factors, not the gradient scaled and then applied. This means the full gradient direction is preserved; only the per-channel magnitude changes.

Theoretical Guarantees

Theorem 4.1: First-moment approximation bound

For projected gradient R_t = P·G_t where P ∈ ^{r×m} is random Gaussian:

(1-ε)‖M_t[:,j]‖² ≤ ‖M_t^R[:,j]‖² ≤ (1+ε)‖M_t[:,j]‖²

with probability at least 1 - 2exp(-rε²/8).

This is a Johnson-Lindenstrauss result: random projection approximately preserves norms. The channel-wise first moment norms in the projected space are close to the original space norms.

Theorem 4.2: Second-moment approximation bound

For ℓ₁ norm (element-wise second moment):

(1-ε)‖V_t[:,j]‖₁ ≤ ‖V_t^R[:,j]‖₁ ≤ (1+ε)‖V_t[:,j]‖₁

with probability at least 1-δ/2, when r ≥ (8/ε²)·log(2t/δ).

Bounded update ratio (equation 9)

The ratio between compact and original scaling factors:

(√(1-ε))/(1+ε) ≤ √(n/r · s_j^R/s_j) ≤ (√(1+ε))/(1-ε)

This means the approximated scaling factor s_j^R differs from the true scaling factor s_j by a predictable ratio of √(n/r), which is compensated by the gradient scale factor α.

This is why α = √128 for Apollo-Mini: when r=1 and n is the smaller dimension (typically ~128 for head dimensions), √(n/r) ≈ √128 ≈ 11.3. The α compensates for this systematic ratio.

Apollo-Mini: Tensor-wise Scaling

For rank r=1, channel-wise scaling becomes numerically unstable (one element per channel in the projected space). Apollo-Mini coarsens further to a single tensor-wise scaling factor:

s = ‖R̃_t‖₂ / ‖R_t‖₂

One scalar for the entire tensor. This is maximally coarse.

Why it works: The tensor-wise average of channel-wise scaling factors smooths out the noise from rank-1 projection. The errors cancel across channels. The paper shows Apollo-Mini actually OUTPERFORMS AdamW on pre-training (Table 2, 3) — the coarsening acts as regularization.

Why Apollo Can Beat AdamW (Section 5.5)

The paper provides two hypotheses:

Hypothesis 1: Directional sharpness

Adam achieves lower directional sharpness than SGD, improving Transformer training. But if directional sharpness is already too low (over-smoothed landscape), the updates become too conservative. Apollo's coarser scaling resembles SGD more (depends more on current gradient, less on history), which can escape local optima that AdamW gets stuck in.

Table 10: Apollo/Apollo-Mini achieve lower directional sharpness than Adam at epochs 5-20, comparable to SGD. This means Apollo navigates the loss landscape more effectively.

Hypothesis 2: Block-wise adaptive learning rates

Transformer blocks have varying Hessian spectra. Block-wise (channel/tensor) adaptive rates are sufficient; fully per-element rates are redundant given this structure. Apollo's channel/tensor-wise scaling naturally aligns with the block structure of Transformers.

Fine-tuning Results (Section 5.2)

On fine-tuning (Table 5, 6):

  • Common-sense reasoning (8 tasks): Apollo-Mini achieves 68.23 average vs AdamW's 68.07. Essentially identical, with 0G optimizer memory.
  • MMLU: Apollo-Mini competitive across all categories (STEM, Social Sciences, Humanities, Other).
  • Learning rate range: Sweeping [5e-6, 7.5e-6, 1e-5, 2.5e-5, 5e-5, 7.5e-5, 1e-4, 1.5e-4, 2e-4]. Best results at 1e-5 to 1e-4 range.

Key finding for us: Apollo-Mini performs on par with full AdamW for fine-tuning. The rank doesn't matter much for fine-tuning quality — even rank-1 is sufficient. The quality comes from the gradient direction (which is preserved at full rank); only the scaling magnitude is approximated.

Ablation Studies (Section 5.4)

A1: Random projection ≈ SVD

Apollo performs equally well with random projection as SVD. Random projection is dramatically cheaper (matrix multiply vs O(mn²) SVD).

A2: Apollo-Mini effective even at rank 1

Apollo-Mini (rank-1) outperforms AdamW on pre-training. The tensor-wise averaging of noise is a feature, not a bug.

A3: Channel vs tensor granularity

Table 9: Difference between channel-wise and tensor-wise scaling is minimal (~0.15 perplexity). For extreme low-rank (rank-1), tensor-wise actually outperforms channel-wise.

A4: Better with larger models and more tokens

Apollo's advantage over AdamW grows with model size and training tokens. For larger models, the structured scaling becomes more beneficial.

A5: Long-context training

Apollo performs on par with or better than AdamW for long-context pre-training (sequence length 1024), with drastic memory savings.

Implications for Our Use Case

Learning rate

The paper sweeps [5e-6 to 2e-4] for fine-tuning. Our lr=1e-5 to 1e-4 range is in the sweet spot.

Scale factor α

We need to add this. For rank-256 (our default), α should be √(n/256) where n is the smaller weight dimension. For typical attention weights with n=5120, that's √20 ≈ 4.5. For rank-1 it would be √5120 ≈ 71.6. The apollo_torch library sets this as the scale parameter.

Our apollo_mini.py is missing the α factor entirely. This likely means our scaling factors are systematically too small by √(n/r).

Norm-growth limiter

We should add this (γ=1.01) for training stability, especially in early steps. It prevents the loss spikes visible in Figure 3.

Projection refresh

We can regenerate P every 200 steps instead of every step. Saves compute and the theory shows it doesn't matter.

Channel vs tensor scaling

For rank-256, channel-wise is slightly better. For rank-1, tensor-wise is better. Since we default to rank-256, we should use channel-wise (which we planned).

Fine-tuning vs pre-training

The paper shows Apollo is slightly more beneficial for pre-training than fine-tuning (where it merely matches AdamW). For fine-tuning, the gradient direction matters more than the scaling precision — and Apollo preserves the full gradient direction. This means our behavioral fine-tuning should work well regardless of rank.

Corrections to Our Implementation

  1. Add gradient scale factor α = √(n/r) — critical for correct scaling magnitude
  2. Add norm-growth limiter (γ=1.01) — prevents early training instability
  3. Refresh projection every T=200 steps, not every step
  4. Channel-wise scaling for rank>1, tensor-wise for rank=1
  5. The scaling applies as G·diag(s), not s·G — post-multiply, preserving gradient direction per channel