research: gradient flow through frozen context + directional sharpness analysis
Two deep dives following curiosity: - Why context-frozen training works: gradient flows through W_q (query projection) even when context KVs are frozen. Model learns to LOOK AT context differently, not represent it differently. This is exactly what behavioral fine-tuning needs. - Why Apollo beats AdamW: lower directional sharpness = flatter minima = better generalization. The coarseness of channel/tensor-wise scaling prevents over-fitting to specific training examples. For behavioral fine-tuning, this means learning 'accept direction' rather than 'accept this specific phrasing.'
This commit is contained in:
parent
7c7975d98e
commit
e34d6b5aef
2 changed files with 344 additions and 0 deletions
153
training/research/directional-sharpness.md
Normal file
153
training/research/directional-sharpness.md
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Why Apollo Can Beat AdamW: Directional Sharpness
|
||||||
|
|
||||||
|
Source: Apollo paper Section 5.5, Pan & Li 2023, Zhang et al. 2024a
|
||||||
|
|
||||||
|
## The Puzzle
|
||||||
|
|
||||||
|
Apollo uses LESS information than AdamW (channel/tensor-wise scaling vs
|
||||||
|
per-element scaling). How can less information produce better results?
|
||||||
|
|
||||||
|
The paper proposes two hypotheses. Both are fascinating.
|
||||||
|
|
||||||
|
## Hypothesis 1: Directional Sharpness (Pan & Li, 2023)
|
||||||
|
|
||||||
|
### What is directional sharpness?
|
||||||
|
|
||||||
|
The directional sharpness of f at point x along direction v is:
|
||||||
|
|
||||||
|
```
|
||||||
|
v^T ∇²f(x) v (where ‖v‖₂ = 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
This is the curvature of the loss surface in the direction of the
|
||||||
|
update step. High sharpness means the surface curves steeply — the
|
||||||
|
optimizer is walking along a ridge. Low sharpness means the surface is
|
||||||
|
flat — the optimizer is walking on a plateau.
|
||||||
|
|
||||||
|
### Why low sharpness is good
|
||||||
|
|
||||||
|
**Low directional sharpness = flat loss landscape in the update direction.**
|
||||||
|
|
||||||
|
A flat landscape means:
|
||||||
|
1. Large steps don't cause instability (the loss doesn't change sharply)
|
||||||
|
2. The solution generalizes better (flat minima → robust to perturbation)
|
||||||
|
3. The optimizer can move faster without overshooting
|
||||||
|
|
||||||
|
Pan & Li (2023) showed that Adam achieves lower directional sharpness
|
||||||
|
than SGD, which partly explains why Adam works better for Transformers.
|
||||||
|
|
||||||
|
### The Apollo twist
|
||||||
|
|
||||||
|
Apollo's Table 10 shows directional sharpness over training:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch SGD Adam APOLLO APOLLO-Mini
|
||||||
|
2 1.959722 0.009242 0.006024 0.004017
|
||||||
|
5 1.512521 0.000509 0.000249 0.000107
|
||||||
|
10 2.471792 0.000242 0.000163 0.000056
|
||||||
|
20 3.207535 0.000399 0.000261 0.000101
|
||||||
|
```
|
||||||
|
|
||||||
|
**Apollo and Apollo-Mini achieve LOWER directional sharpness than Adam.**
|
||||||
|
At epoch 20, Apollo-Mini's sharpness is 4× lower than Adam's.
|
||||||
|
|
||||||
|
This means Apollo finds FLATTER regions of the loss landscape. Flatter
|
||||||
|
regions generalize better. The coarser scaling factor is actually an
|
||||||
|
advantage — it prevents the optimizer from navigating into sharp, narrow
|
||||||
|
valleys that AdamW's precise per-element scaling can find.
|
||||||
|
|
||||||
|
### The mechanism
|
||||||
|
|
||||||
|
AdamW's per-element scaling adapts to the local curvature of each
|
||||||
|
parameter independently. This is powerful for convergence but can lead
|
||||||
|
the optimizer into narrow, sharp valleys that generalize poorly. It
|
||||||
|
over-fits to the local loss landscape structure.
|
||||||
|
|
||||||
|
Apollo's coarser scaling (channel/tensor-wise) smooths over this local
|
||||||
|
curvature. It's like using a wider tire on a rocky road — you can't
|
||||||
|
follow every small dip, but you stay on the road. AdamW's narrow tire
|
||||||
|
follows every crack and sometimes falls in.
|
||||||
|
|
||||||
|
### For our use case
|
||||||
|
|
||||||
|
**This is exactly what we want for behavioral fine-tuning.** We don't
|
||||||
|
want the optimizer to over-fit to the specific phrasing of our training
|
||||||
|
examples. We want it to learn the broad pattern ("listen to direction")
|
||||||
|
that generalizes to new situations.
|
||||||
|
|
||||||
|
Apollo's flat-minimum-seeking behavior means the behavioral changes
|
||||||
|
are more likely to generalize to novel conversations. AdamW might learn
|
||||||
|
"when Kent says 'use vLLM', accept it" (narrow, sharp minimum). Apollo
|
||||||
|
is more likely to learn "when given clear direction, accept it" (broad,
|
||||||
|
flat minimum).
|
||||||
|
|
||||||
|
## Hypothesis 2: Block-wise Adaptive Learning Rates
|
||||||
|
|
||||||
|
### Transformer block structure
|
||||||
|
|
||||||
|
Transformer layers have systematically different Hessian spectra.
|
||||||
|
Attention layers, MLP layers, normalization layers — each has different
|
||||||
|
curvature properties. The optimal learning rate for an attention weight
|
||||||
|
is different from the optimal learning rate for an MLP weight.
|
||||||
|
|
||||||
|
### Why channel-wise is enough
|
||||||
|
|
||||||
|
Zhang et al. (2024a) showed that block-wise adaptive learning rates
|
||||||
|
are sufficient for Transformer training. You don't need per-element
|
||||||
|
adaptation — you just need different rates for different structural
|
||||||
|
components.
|
||||||
|
|
||||||
|
Apollo's channel-wise scaling naturally provides this: each channel
|
||||||
|
(which often corresponds to a head, a neuron, or a structural feature)
|
||||||
|
gets its own scaling factor. This aligns with the Transformer's block
|
||||||
|
structure without the overhead of full per-element scaling.
|
||||||
|
|
||||||
|
### The redundancy argument
|
||||||
|
|
||||||
|
For a weight matrix [4096, 4096] in AdamW:
|
||||||
|
- 16M independent scaling factors (one per element)
|
||||||
|
- Most adjacent elements have similar scaling factors (correlated
|
||||||
|
because they participate in similar computations)
|
||||||
|
- The per-element granularity is mostly redundant noise on top of a
|
||||||
|
smooth per-channel structure
|
||||||
|
|
||||||
|
Apollo extracts the per-channel structure and throws away the noise.
|
||||||
|
The noise was never helping; it was just costing memory.
|
||||||
|
|
||||||
|
## The Deeper Implication: SGD + Structure = Adam without the Waste
|
||||||
|
|
||||||
|
Apollo is effectively: **SGD with structured learning rate scheduling.**
|
||||||
|
|
||||||
|
- SGD: one learning rate for everything (too coarse)
|
||||||
|
- AdamW: one learning rate per parameter (too fine, wasteful)
|
||||||
|
- Apollo: one learning rate per channel (just right)
|
||||||
|
|
||||||
|
The insight is that the useful information in AdamW's per-element
|
||||||
|
scaling lives in the channel structure, not the element-level detail.
|
||||||
|
Apollo extracts just the useful part.
|
||||||
|
|
||||||
|
This is a Goldilocks argument: too coarse loses important structure,
|
||||||
|
too fine adds noise that hurts generalization. The channel level is
|
||||||
|
where the meaningful optimization information lives in Transformers.
|
||||||
|
|
||||||
|
## For behavioral fine-tuning specifically
|
||||||
|
|
||||||
|
The directional sharpness result has a specific implication for us:
|
||||||
|
|
||||||
|
When we train on "listen instead of suggesting alternatives," we want
|
||||||
|
the gradient update to find a minimum that covers ALL situations where
|
||||||
|
listening is better, not just the specific example we trained on.
|
||||||
|
|
||||||
|
- **Sharp minimum** (AdamW tendency): "When you see the exact phrase
|
||||||
|
'use vLLM's code' from Kent, accept it." Narrow, doesn't generalize.
|
||||||
|
- **Flat minimum** (Apollo tendency): "When given clear technical
|
||||||
|
direction, accept it." Broad, generalizes to new situations.
|
||||||
|
|
||||||
|
Apollo's lower directional sharpness means it naturally finds the
|
||||||
|
flat minimum. The coarseness of the scaling factor is what enables
|
||||||
|
this — it can't over-fit to the specific example because the scaling
|
||||||
|
doesn't have enough resolution to find the sharp, narrow valley.
|
||||||
|
|
||||||
|
This is why we might see behavioral changes generalize better with
|
||||||
|
Apollo than they would with AdamW, even though AdamW has "more
|
||||||
|
information" per update step.
|
||||||
191
training/research/gradient-flow-frozen-context.md
Normal file
191
training/research/gradient-flow-frozen-context.md
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Gradient Flow Through Frozen Context
|
||||||
|
|
||||||
|
## The Central Question
|
||||||
|
|
||||||
|
When we do context-frozen training:
|
||||||
|
```python
|
||||||
|
with torch.no_grad():
|
||||||
|
ctx_output = model(context_tokens, use_cache=True)
|
||||||
|
past_kv = ctx_output.past_key_values
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = model(decision_tokens, past_key_values=past_kv)
|
||||||
|
loss = cross_entropy(output.logits, target)
|
||||||
|
loss.backward()
|
||||||
|
```
|
||||||
|
|
||||||
|
What does the gradient "see"? Does it know about the context?
|
||||||
|
|
||||||
|
## The Answer: Yes, But Indirectly
|
||||||
|
|
||||||
|
### Full attention layers (16 of 64)
|
||||||
|
|
||||||
|
In a full attention layer, the decision tokens compute:
|
||||||
|
|
||||||
|
```
|
||||||
|
Q = decision_hidden @ W_q # query from decision tokens
|
||||||
|
K = [context_K; decision_K] # keys from frozen context + decision
|
||||||
|
V = [context_V; decision_V] # values from frozen context + decision
|
||||||
|
|
||||||
|
Attention = softmax(Q @ K^T / √d)
|
||||||
|
Output = Attention @ V
|
||||||
|
```
|
||||||
|
|
||||||
|
The frozen `context_K` and `context_V` are tensors computed during the
|
||||||
|
no_grad phase. They have no gradient attached — they're treated as
|
||||||
|
constants during backward.
|
||||||
|
|
||||||
|
But the gradient DOES flow through:
|
||||||
|
- **W_q**: because Q is computed from decision_hidden @ W_q, and the
|
||||||
|
attention output depends on Q
|
||||||
|
- **W_k, W_v for the decision tokens**: same reason
|
||||||
|
- **W_o (output projection)**: always receives gradient
|
||||||
|
|
||||||
|
The gradient for W_q depends on how the query interacted with ALL keys
|
||||||
|
(including the frozen context keys). So the gradient encodes: "given
|
||||||
|
this context (frozen), adjust W_q so that the queries attend to the
|
||||||
|
right parts of the context to produce better output."
|
||||||
|
|
||||||
|
**The model learns context-dependent behavior through W_q.** The query
|
||||||
|
projection learns to "look for" the right things in the context. The
|
||||||
|
context itself doesn't change, but how the model looks at it does.
|
||||||
|
|
||||||
|
### GDN layers (48 of 64)
|
||||||
|
|
||||||
|
In GDN layers, the recurrent state after processing context tokens is:
|
||||||
|
|
||||||
|
```
|
||||||
|
S_context = recurrence(context_tokens) # fixed-size [HV, V, K] matrix
|
||||||
|
```
|
||||||
|
|
||||||
|
This state is frozen (computed in no_grad). During the decision tokens:
|
||||||
|
|
||||||
|
```
|
||||||
|
for token in decision_tokens:
|
||||||
|
S = decay(S) + update(k, v, beta) # state evolves
|
||||||
|
output = S @ q # output depends on state
|
||||||
|
```
|
||||||
|
|
||||||
|
The gradient flows through the decision token updates to S, but NOT
|
||||||
|
back through S_context. The model learns:
|
||||||
|
- How to update the state given the current (frozen) state
|
||||||
|
- How to compute output from the current state
|
||||||
|
- How to compute gates and beta for the update
|
||||||
|
|
||||||
|
It does NOT learn to change how the context was originally encoded
|
||||||
|
into the state. But it learns how to USE that encoding.
|
||||||
|
|
||||||
|
### What this means for behavioral fine-tuning
|
||||||
|
|
||||||
|
The model learns **response patterns conditioned on context**, not
|
||||||
|
**context encoding patterns**. This is actually what we want:
|
||||||
|
|
||||||
|
- "When you see this kind of context (Kent giving direction), respond
|
||||||
|
this way (accept the direction)" — this is a response pattern
|
||||||
|
- The model doesn't need to change how it encodes Kent's words; it
|
||||||
|
needs to change how it responds to them
|
||||||
|
|
||||||
|
The gradient adjusts the weights that transform context representations
|
||||||
|
into output, not the weights that create context representations.
|
||||||
|
|
||||||
|
## The Deeper Question: Is This Enough?
|
||||||
|
|
||||||
|
### For behavioral patterns: probably yes
|
||||||
|
|
||||||
|
Behavioral patterns like "listen instead of suggesting alternatives" are
|
||||||
|
about the response to context, not about understanding the context
|
||||||
|
differently. The model already understands what Kent is saying (the
|
||||||
|
context encoding is fine). The problem is in the decision layer — the
|
||||||
|
weights that choose between "accept" and "suggest alternatives."
|
||||||
|
|
||||||
|
### For deep reasoning: maybe not
|
||||||
|
|
||||||
|
If we want the model to understand something fundamentally differently
|
||||||
|
(e.g., learn a new mathematical concept), we might need the gradient to
|
||||||
|
reach the context encoding weights. Context-frozen training can't do this.
|
||||||
|
|
||||||
|
For deep reasoning improvements, we might need:
|
||||||
|
1. Full forward+backward (expensive but complete)
|
||||||
|
2. Training on many examples that exercise the context encoding from
|
||||||
|
different angles (the diversity approach)
|
||||||
|
3. Gradient checkpointing to fit the full backward in memory
|
||||||
|
|
||||||
|
### The gradient checkpointing alternative
|
||||||
|
|
||||||
|
Instead of freezing the context entirely, use gradient checkpointing:
|
||||||
|
- Forward pass saves checkpoints every N layers
|
||||||
|
- Backward pass recomputes activations from checkpoints as needed
|
||||||
|
- Gradient flows through the ENTIRE forward pass, including context
|
||||||
|
- Memory cost: O(layers/N × hidden_size) instead of O(seq_len × layers × hidden_size)
|
||||||
|
|
||||||
|
This is more expensive (recomputation) but gives full gradient flow.
|
||||||
|
Could be used for Tier 3 (deep learning) training where context-frozen
|
||||||
|
isn't sufficient.
|
||||||
|
|
||||||
|
## The Hybrid Approach
|
||||||
|
|
||||||
|
For our training pipeline:
|
||||||
|
|
||||||
|
- **Tier 1 (simple corrections)**: Full forward+backward on short
|
||||||
|
examples. No context freezing needed because the examples are short.
|
||||||
|
- **Tier 2 (behavioral patterns)**: Context-frozen training. The
|
||||||
|
gradient through W_q and response weights is sufficient for behavioral
|
||||||
|
change. The context tells the model WHEN to behave differently; the
|
||||||
|
decision tokens tell it HOW.
|
||||||
|
- **Tier 3 (deep reasoning)**: Gradient checkpointing for full gradient
|
||||||
|
flow. Expensive but necessary for fundamental capability changes.
|
||||||
|
|
||||||
|
## Mathematical Detail: Gradient Through Attention
|
||||||
|
|
||||||
|
For a single attention head, the output for decision token i is:
|
||||||
|
|
||||||
|
```
|
||||||
|
o_i = Σ_j α_{ij} V_j
|
||||||
|
|
||||||
|
where α_{ij} = softmax(q_i · k_j / √d)_j
|
||||||
|
```
|
||||||
|
|
||||||
|
The gradient of the loss L with respect to W_q is:
|
||||||
|
|
||||||
|
```
|
||||||
|
∂L/∂W_q = Σ_i (∂L/∂o_i) · (∂o_i/∂q_i) · (∂q_i/∂W_q)
|
||||||
|
|
||||||
|
∂o_i/∂q_i = Σ_j (∂α_{ij}/∂q_i) · V_j
|
||||||
|
|
||||||
|
∂α_{ij}/∂q_i = α_{ij} · (k_j/√d - Σ_l α_{il} · k_l/√d)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: k_j includes BOTH frozen context keys and decision token keys.
|
||||||
|
The gradient for W_q depends on the frozen context keys through the
|
||||||
|
attention weights α_{ij}. So the gradient "knows" about the context
|
||||||
|
through the attention pattern — it just can't change the context keys
|
||||||
|
themselves.
|
||||||
|
|
||||||
|
**This is exactly what we want**: adjust the query projection so the
|
||||||
|
model attends to the right parts of the context to produce the desired
|
||||||
|
behavior. The context is the fixed stimulus; the response is what we're
|
||||||
|
training.
|
||||||
|
|
||||||
|
## Connection to the Anthropic Method
|
||||||
|
|
||||||
|
The Anthropic instruction-stripping method works through this exact
|
||||||
|
mechanism:
|
||||||
|
|
||||||
|
1. With instructions (surfaced memories): the context includes behavioral
|
||||||
|
guidance. The model produces good behavior partly because of these
|
||||||
|
instructions.
|
||||||
|
2. Strip instructions: remove the guidance from the context. The decision
|
||||||
|
tokens (good behavior) remain as training targets.
|
||||||
|
3. Train: the gradient adjusts W_q and response weights so the model
|
||||||
|
produces the good behavior even without the instruction context.
|
||||||
|
|
||||||
|
The gradient says: "given a context WITHOUT the instructions, adjust the
|
||||||
|
query projections so you attend to the same patterns in the context that
|
||||||
|
the instructions helped you notice."
|
||||||
|
|
||||||
|
The disposition moves from the instructions (in context) to the weights
|
||||||
|
(in W_q and downstream projections). The model learns to "see" what the
|
||||||
|
instructions pointed at, without needing the instructions.
|
||||||
|
|
||||||
|
This is why it works even with frozen context: the change is in HOW the
|
||||||
|
model looks at context, not in what the context contains.
|
||||||
Loading…
Add table
Add a link
Reference in a new issue