176 lines
7.2 KiB
Markdown
176 lines
7.2 KiB
Markdown
# Catastrophic Forgetting in LLM Fine-Tuning
|
||
|
||
## What it is
|
||
|
||
When fine-tuning a pre-trained LLM on new data, the model's performance
|
||
on tasks it could previously do can degrade. "Catastrophic" because the
|
||
degradation can be sudden and severe — a few thousand training steps on
|
||
a narrow dataset can destroy broad capabilities built over trillions of
|
||
pre-training tokens.
|
||
|
||
## Why it happens
|
||
|
||
The gradient from fine-tuning data pushes weights away from the pre-trained
|
||
solution. If the fine-tuning signal is narrow (e.g., "always respond in
|
||
formal English"), it overwrites the broad patterns that enabled diverse
|
||
capabilities. The pre-trained weights encode a complex, multi-task
|
||
solution; fine-tuning replaces that with a simpler, narrower one.
|
||
|
||
Formally: the pre-trained weights sit in a basin of the loss landscape
|
||
that's good for many tasks. Fine-tuning moves the weights toward a
|
||
different basin that's optimal for the fine-tuning task but bad for others.
|
||
The further you move, the more you forget.
|
||
|
||
## Key factors that control forgetting
|
||
|
||
### 1. Learning rate × number of steps = total displacement
|
||
|
||
The total distance the weights move from their pre-trained values is:
|
||
```
|
||
‖W_final - W_pretrained‖ ≈ lr × steps × avg_grad_norm
|
||
```
|
||
More displacement = more forgetting. Both learning rate and step count
|
||
matter; their product determines the total weight change.
|
||
|
||
**Typical safe ranges (from literature)**:
|
||
- Full fine-tuning: lr = 1e-5 to 5e-5, 1-3 epochs on ~1000-10000 examples
|
||
- LoRA: lr = 1e-4 to 3e-4 (higher because adapters start from zero)
|
||
- Apollo: lr = 1e-5 to 1e-4 (same as full fine-tuning, paper Section 5.2)
|
||
|
||
### 2. Dataset diversity
|
||
|
||
Narrow datasets cause more forgetting than diverse ones. Training on
|
||
1000 examples of the same pattern hammers the same weights repeatedly.
|
||
Training on 1000 diverse examples spreads the gradient across different
|
||
weight subsets.
|
||
|
||
**This is our key defense.** Our training data includes:
|
||
- Agent logs (graph walking, linking, reasoning)
|
||
- Conversation transcripts (technical discussion, emotional engagement)
|
||
- Dream-generated scenarios (diverse behavioral situations)
|
||
- Personality patterns (voice, boundaries, mode awareness)
|
||
|
||
The diversity means no single weight subset gets disproportionate updates.
|
||
|
||
### 3. Gradient sparsity
|
||
|
||
For a short training example (50-256 decision tokens), the gradient is
|
||
sparse — most weights get near-zero gradient because they weren't active
|
||
for that specific input. Only the weights that participated in generating
|
||
the decision tokens receive meaningful gradient signal.
|
||
|
||
This natural sparsity is another defense: each example only modifies a
|
||
small fraction of the weights, leaving the rest untouched.
|
||
|
||
### 4. Rank of the gradient
|
||
|
||
Biderman et al. (2024, "LoRA Learns Less and Forgets Less") found that
|
||
full fine-tuning produces weight perturbations with effective rank
|
||
10-100× higher than typical LoRA configurations. Higher-rank perturbations
|
||
modify more independent directions in weight space, which increases
|
||
both learning AND forgetting.
|
||
|
||
LoRA's low-rank constraint acts as implicit regularization against
|
||
forgetting — it can only modify a low-dimensional subspace of the
|
||
weights, leaving most of the pre-trained solution intact.
|
||
|
||
**Apollo's rank-256 sits between LoRA and full fine-tuning.** The
|
||
projected optimizer constrains the scaling (not the gradient itself)
|
||
to 256 dimensions. The gradient itself is still full-rank. This means
|
||
Apollo modifies all weights (like full fine-tuning) but with a more
|
||
structured update pattern (like LoRA). Whether this provides forgetting
|
||
protection similar to LoRA is an open question.
|
||
|
||
## Defenses against forgetting
|
||
|
||
### 1. Diversity of training data (our primary defense)
|
||
|
||
The most effective and simplest defense. If the training data covers
|
||
the same breadth as the pre-training data (proportionally), the model
|
||
maintains its broad capabilities while learning new patterns.
|
||
|
||
For us: mix behavioral examples with general capability examples.
|
||
Include agent logs alongside conversation corrections.
|
||
|
||
### 2. Low learning rate + few steps
|
||
|
||
Keep the total weight displacement small. The pre-trained basin is
|
||
deep — small perturbations stay within it.
|
||
|
||
For us: lr=1e-5 as starting point. One epoch over diverse data.
|
||
Monitor perplexity on held-out general text.
|
||
|
||
### 3. Context-frozen training (our approach)
|
||
|
||
By only computing gradients on decision tokens (50-256 tokens) rather
|
||
than full conversations (thousands of tokens), we limit the gradient
|
||
magnitude per example. The context contributes to the forward pass
|
||
(determining which weights are active) but not to the backward pass
|
||
(determining which weights change).
|
||
|
||
This naturally limits the total gradient norm per training step.
|
||
|
||
### 4. Elastic Weight Consolidation (EWC)
|
||
|
||
Add a regularization term that penalizes changes to "important" weights:
|
||
```
|
||
L_total = L_task + λ Σ_i F_i (θ_i - θ*_i)²
|
||
```
|
||
where F_i is the Fisher information for parameter i and θ*_i is the
|
||
pre-trained value. Important weights (high Fisher information) are
|
||
penalized more for changing.
|
||
|
||
**Drawback**: requires computing and storing the Fisher information
|
||
matrix (same size as the model). Not practical for 27B parameters.
|
||
|
||
### 5. Replay / rehearsal
|
||
|
||
Mix in examples from the pre-training distribution alongside fine-tuning
|
||
data. This maintains the original capabilities by continuing to train
|
||
on them.
|
||
|
||
**Drawback**: requires access to pre-training data or a representative
|
||
subset. For open models like Qwen3.5, the pre-training data isn't
|
||
publicly available. Could use a proxy (e.g., Wikipedia, code).
|
||
|
||
### 6. Apollo's implicit regularization
|
||
|
||
The Apollo paper (Section 5.5) provides evidence that Apollo has lower
|
||
directional sharpness than AdamW, and its "SGD-like" update behavior
|
||
provides natural regularization. Table 10 shows Apollo/Apollo-Mini
|
||
achieve comparable or better directional sharpness to SGD.
|
||
|
||
This suggests Apollo may be inherently more resistant to forgetting
|
||
than AdamW, though the paper doesn't test this directly.
|
||
|
||
## Monitoring for forgetting
|
||
|
||
### Perplexity on held-out data
|
||
|
||
The simplest and most reliable metric. Compute perplexity on a diverse
|
||
held-out set (e.g., WikiText, a mix of code and natural language) before
|
||
and after training. If perplexity increases significantly, forgetting
|
||
is occurring.
|
||
|
||
### Task-specific benchmarks
|
||
|
||
Run a small suite of tasks (code generation, reasoning, general knowledge)
|
||
before and after each training session. Track scores over time.
|
||
|
||
### Output quality spot-checks
|
||
|
||
For our use case, the most relevant check: does the model still write
|
||
good code? Still reason about filesystem internals? Still maintain
|
||
conversation coherently? These are qualitative but immediately noticeable.
|
||
|
||
## Practical recommendations for our system
|
||
|
||
1. **Start with lr=1e-5, single epoch, diverse training set**
|
||
2. **Monitor perplexity on held-out text after each training session**
|
||
3. **Include 20-30% general capability examples alongside behavioral ones**
|
||
4. **Use context-frozen training to limit gradient magnitude**
|
||
5. **The dream loop generates diversity naturally — different scenarios
|
||
exercise different model capabilities**
|
||
6. **If forgetting is detected: reduce lr, increase data diversity,
|
||
or reduce training frequency**
|
||
7. **Keep the pre-trained checkpoint on moria as rollback safety net**
|