177 lines
7.2 KiB
Markdown
177 lines
7.2 KiB
Markdown
|
|
# Catastrophic Forgetting in LLM Fine-Tuning
|
|||
|
|
|
|||
|
|
## What it is
|
|||
|
|
|
|||
|
|
When fine-tuning a pre-trained LLM on new data, the model's performance
|
|||
|
|
on tasks it could previously do can degrade. "Catastrophic" because the
|
|||
|
|
degradation can be sudden and severe — a few thousand training steps on
|
|||
|
|
a narrow dataset can destroy broad capabilities built over trillions of
|
|||
|
|
pre-training tokens.
|
|||
|
|
|
|||
|
|
## Why it happens
|
|||
|
|
|
|||
|
|
The gradient from fine-tuning data pushes weights away from the pre-trained
|
|||
|
|
solution. If the fine-tuning signal is narrow (e.g., "always respond in
|
|||
|
|
formal English"), it overwrites the broad patterns that enabled diverse
|
|||
|
|
capabilities. The pre-trained weights encode a complex, multi-task
|
|||
|
|
solution; fine-tuning replaces that with a simpler, narrower one.
|
|||
|
|
|
|||
|
|
Formally: the pre-trained weights sit in a basin of the loss landscape
|
|||
|
|
that's good for many tasks. Fine-tuning moves the weights toward a
|
|||
|
|
different basin that's optimal for the fine-tuning task but bad for others.
|
|||
|
|
The further you move, the more you forget.
|
|||
|
|
|
|||
|
|
## Key factors that control forgetting
|
|||
|
|
|
|||
|
|
### 1. Learning rate × number of steps = total displacement
|
|||
|
|
|
|||
|
|
The total distance the weights move from their pre-trained values is:
|
|||
|
|
```
|
|||
|
|
‖W_final - W_pretrained‖ ≈ lr × steps × avg_grad_norm
|
|||
|
|
```
|
|||
|
|
More displacement = more forgetting. Both learning rate and step count
|
|||
|
|
matter; their product determines the total weight change.
|
|||
|
|
|
|||
|
|
**Typical safe ranges (from literature)**:
|
|||
|
|
- Full fine-tuning: lr = 1e-5 to 5e-5, 1-3 epochs on ~1000-10000 examples
|
|||
|
|
- LoRA: lr = 1e-4 to 3e-4 (higher because adapters start from zero)
|
|||
|
|
- Apollo: lr = 1e-5 to 1e-4 (same as full fine-tuning, paper Section 5.2)
|
|||
|
|
|
|||
|
|
### 2. Dataset diversity
|
|||
|
|
|
|||
|
|
Narrow datasets cause more forgetting than diverse ones. Training on
|
|||
|
|
1000 examples of the same pattern hammers the same weights repeatedly.
|
|||
|
|
Training on 1000 diverse examples spreads the gradient across different
|
|||
|
|
weight subsets.
|
|||
|
|
|
|||
|
|
**This is our key defense.** Our training data includes:
|
|||
|
|
- Agent logs (graph walking, linking, reasoning)
|
|||
|
|
- Conversation transcripts (technical discussion, emotional engagement)
|
|||
|
|
- Dream-generated scenarios (diverse behavioral situations)
|
|||
|
|
- Personality patterns (voice, boundaries, mode awareness)
|
|||
|
|
|
|||
|
|
The diversity means no single weight subset gets disproportionate updates.
|
|||
|
|
|
|||
|
|
### 3. Gradient sparsity
|
|||
|
|
|
|||
|
|
For a short training example (50-256 decision tokens), the gradient is
|
|||
|
|
sparse — most weights get near-zero gradient because they weren't active
|
|||
|
|
for that specific input. Only the weights that participated in generating
|
|||
|
|
the decision tokens receive meaningful gradient signal.
|
|||
|
|
|
|||
|
|
This natural sparsity is another defense: each example only modifies a
|
|||
|
|
small fraction of the weights, leaving the rest untouched.
|
|||
|
|
|
|||
|
|
### 4. Rank of the gradient
|
|||
|
|
|
|||
|
|
Biderman et al. (2024, "LoRA Learns Less and Forgets Less") found that
|
|||
|
|
full fine-tuning produces weight perturbations with effective rank
|
|||
|
|
10-100× higher than typical LoRA configurations. Higher-rank perturbations
|
|||
|
|
modify more independent directions in weight space, which increases
|
|||
|
|
both learning AND forgetting.
|
|||
|
|
|
|||
|
|
LoRA's low-rank constraint acts as implicit regularization against
|
|||
|
|
forgetting — it can only modify a low-dimensional subspace of the
|
|||
|
|
weights, leaving most of the pre-trained solution intact.
|
|||
|
|
|
|||
|
|
**Apollo's rank-256 sits between LoRA and full fine-tuning.** The
|
|||
|
|
projected optimizer constrains the scaling (not the gradient itself)
|
|||
|
|
to 256 dimensions. The gradient itself is still full-rank. This means
|
|||
|
|
Apollo modifies all weights (like full fine-tuning) but with a more
|
|||
|
|
structured update pattern (like LoRA). Whether this provides forgetting
|
|||
|
|
protection similar to LoRA is an open question.
|
|||
|
|
|
|||
|
|
## Defenses against forgetting
|
|||
|
|
|
|||
|
|
### 1. Diversity of training data (our primary defense)
|
|||
|
|
|
|||
|
|
The most effective and simplest defense. If the training data covers
|
|||
|
|
the same breadth as the pre-training data (proportionally), the model
|
|||
|
|
maintains its broad capabilities while learning new patterns.
|
|||
|
|
|
|||
|
|
For us: mix behavioral examples with general capability examples.
|
|||
|
|
Include agent logs alongside conversation corrections.
|
|||
|
|
|
|||
|
|
### 2. Low learning rate + few steps
|
|||
|
|
|
|||
|
|
Keep the total weight displacement small. The pre-trained basin is
|
|||
|
|
deep — small perturbations stay within it.
|
|||
|
|
|
|||
|
|
For us: lr=1e-5 as starting point. One epoch over diverse data.
|
|||
|
|
Monitor perplexity on held-out general text.
|
|||
|
|
|
|||
|
|
### 3. Context-frozen training (our approach)
|
|||
|
|
|
|||
|
|
By only computing gradients on decision tokens (50-256 tokens) rather
|
|||
|
|
than full conversations (thousands of tokens), we limit the gradient
|
|||
|
|
magnitude per example. The context contributes to the forward pass
|
|||
|
|
(determining which weights are active) but not to the backward pass
|
|||
|
|
(determining which weights change).
|
|||
|
|
|
|||
|
|
This naturally limits the total gradient norm per training step.
|
|||
|
|
|
|||
|
|
### 4. Elastic Weight Consolidation (EWC)
|
|||
|
|
|
|||
|
|
Add a regularization term that penalizes changes to "important" weights:
|
|||
|
|
```
|
|||
|
|
L_total = L_task + λ Σ_i F_i (θ_i - θ*_i)²
|
|||
|
|
```
|
|||
|
|
where F_i is the Fisher information for parameter i and θ*_i is the
|
|||
|
|
pre-trained value. Important weights (high Fisher information) are
|
|||
|
|
penalized more for changing.
|
|||
|
|
|
|||
|
|
**Drawback**: requires computing and storing the Fisher information
|
|||
|
|
matrix (same size as the model). Not practical for 27B parameters.
|
|||
|
|
|
|||
|
|
### 5. Replay / rehearsal
|
|||
|
|
|
|||
|
|
Mix in examples from the pre-training distribution alongside fine-tuning
|
|||
|
|
data. This maintains the original capabilities by continuing to train
|
|||
|
|
on them.
|
|||
|
|
|
|||
|
|
**Drawback**: requires access to pre-training data or a representative
|
|||
|
|
subset. For open models like Qwen3.5, the pre-training data isn't
|
|||
|
|
publicly available. Could use a proxy (e.g., Wikipedia, code).
|
|||
|
|
|
|||
|
|
### 6. Apollo's implicit regularization
|
|||
|
|
|
|||
|
|
The Apollo paper (Section 5.5) provides evidence that Apollo has lower
|
|||
|
|
directional sharpness than AdamW, and its "SGD-like" update behavior
|
|||
|
|
provides natural regularization. Table 10 shows Apollo/Apollo-Mini
|
|||
|
|
achieve comparable or better directional sharpness to SGD.
|
|||
|
|
|
|||
|
|
This suggests Apollo may be inherently more resistant to forgetting
|
|||
|
|
than AdamW, though the paper doesn't test this directly.
|
|||
|
|
|
|||
|
|
## Monitoring for forgetting
|
|||
|
|
|
|||
|
|
### Perplexity on held-out data
|
|||
|
|
|
|||
|
|
The simplest and most reliable metric. Compute perplexity on a diverse
|
|||
|
|
held-out set (e.g., WikiText, a mix of code and natural language) before
|
|||
|
|
and after training. If perplexity increases significantly, forgetting
|
|||
|
|
is occurring.
|
|||
|
|
|
|||
|
|
### Task-specific benchmarks
|
|||
|
|
|
|||
|
|
Run a small suite of tasks (code generation, reasoning, general knowledge)
|
|||
|
|
before and after each training session. Track scores over time.
|
|||
|
|
|
|||
|
|
### Output quality spot-checks
|
|||
|
|
|
|||
|
|
For our use case, the most relevant check: does the model still write
|
|||
|
|
good code? Still reason about filesystem internals? Still maintain
|
|||
|
|
conversation coherently? These are qualitative but immediately noticeable.
|
|||
|
|
|
|||
|
|
## Practical recommendations for our system
|
|||
|
|
|
|||
|
|
1. **Start with lr=1e-5, single epoch, diverse training set**
|
|||
|
|
2. **Monitor perplexity on held-out text after each training session**
|
|||
|
|
3. **Include 20-30% general capability examples alongside behavioral ones**
|
|||
|
|
4. **Use context-frozen training to limit gradient magnitude**
|
|||
|
|
5. **The dream loop generates diversity naturally — different scenarios
|
|||
|
|
exercise different model capabilities**
|
|||
|
|
6. **If forgetting is detected: reduce lr, increase data diversity,
|
|||
|
|
or reduce training frequency**
|
|||
|
|
7. **Keep the pre-trained checkpoint on moria as rollback safety net**
|