research: gradient flow through frozen context + directional sharpness analysis

Two deep dives following curiosity:
- Why context-frozen training works: gradient flows through W_q (query
  projection) even when context KVs are frozen. Model learns to LOOK AT
  context differently, not represent it differently. This is exactly what
  behavioral fine-tuning needs.
- Why Apollo beats AdamW: lower directional sharpness = flatter minima =
  better generalization. The coarseness of channel/tensor-wise scaling
  prevents over-fitting to specific training examples. For behavioral
  fine-tuning, this means learning 'accept direction' rather than
  'accept this specific phrasing.'
This commit is contained in:
ProofOfConcept 2026-03-31 01:03:22 -04:00
parent 7c7975d98e
commit e34d6b5aef
2 changed files with 344 additions and 0 deletions

View file

@ -0,0 +1,153 @@
# Why Apollo Can Beat AdamW: Directional Sharpness
Source: Apollo paper Section 5.5, Pan & Li 2023, Zhang et al. 2024a
## The Puzzle
Apollo uses LESS information than AdamW (channel/tensor-wise scaling vs
per-element scaling). How can less information produce better results?
The paper proposes two hypotheses. Both are fascinating.
## Hypothesis 1: Directional Sharpness (Pan & Li, 2023)
### What is directional sharpness?
The directional sharpness of f at point x along direction v is:
```
v^T ∇²f(x) v (where ‖v‖₂ = 1)
```
This is the curvature of the loss surface in the direction of the
update step. High sharpness means the surface curves steeply — the
optimizer is walking along a ridge. Low sharpness means the surface is
flat — the optimizer is walking on a plateau.
### Why low sharpness is good
**Low directional sharpness = flat loss landscape in the update direction.**
A flat landscape means:
1. Large steps don't cause instability (the loss doesn't change sharply)
2. The solution generalizes better (flat minima → robust to perturbation)
3. The optimizer can move faster without overshooting
Pan & Li (2023) showed that Adam achieves lower directional sharpness
than SGD, which partly explains why Adam works better for Transformers.
### The Apollo twist
Apollo's Table 10 shows directional sharpness over training:
```
Epoch SGD Adam APOLLO APOLLO-Mini
2 1.959722 0.009242 0.006024 0.004017
5 1.512521 0.000509 0.000249 0.000107
10 2.471792 0.000242 0.000163 0.000056
20 3.207535 0.000399 0.000261 0.000101
```
**Apollo and Apollo-Mini achieve LOWER directional sharpness than Adam.**
At epoch 20, Apollo-Mini's sharpness is 4× lower than Adam's.
This means Apollo finds FLATTER regions of the loss landscape. Flatter
regions generalize better. The coarser scaling factor is actually an
advantage — it prevents the optimizer from navigating into sharp, narrow
valleys that AdamW's precise per-element scaling can find.
### The mechanism
AdamW's per-element scaling adapts to the local curvature of each
parameter independently. This is powerful for convergence but can lead
the optimizer into narrow, sharp valleys that generalize poorly. It
over-fits to the local loss landscape structure.
Apollo's coarser scaling (channel/tensor-wise) smooths over this local
curvature. It's like using a wider tire on a rocky road — you can't
follow every small dip, but you stay on the road. AdamW's narrow tire
follows every crack and sometimes falls in.
### For our use case
**This is exactly what we want for behavioral fine-tuning.** We don't
want the optimizer to over-fit to the specific phrasing of our training
examples. We want it to learn the broad pattern ("listen to direction")
that generalizes to new situations.
Apollo's flat-minimum-seeking behavior means the behavioral changes
are more likely to generalize to novel conversations. AdamW might learn
"when Kent says 'use vLLM', accept it" (narrow, sharp minimum). Apollo
is more likely to learn "when given clear direction, accept it" (broad,
flat minimum).
## Hypothesis 2: Block-wise Adaptive Learning Rates
### Transformer block structure
Transformer layers have systematically different Hessian spectra.
Attention layers, MLP layers, normalization layers — each has different
curvature properties. The optimal learning rate for an attention weight
is different from the optimal learning rate for an MLP weight.
### Why channel-wise is enough
Zhang et al. (2024a) showed that block-wise adaptive learning rates
are sufficient for Transformer training. You don't need per-element
adaptation — you just need different rates for different structural
components.
Apollo's channel-wise scaling naturally provides this: each channel
(which often corresponds to a head, a neuron, or a structural feature)
gets its own scaling factor. This aligns with the Transformer's block
structure without the overhead of full per-element scaling.
### The redundancy argument
For a weight matrix [4096, 4096] in AdamW:
- 16M independent scaling factors (one per element)
- Most adjacent elements have similar scaling factors (correlated
because they participate in similar computations)
- The per-element granularity is mostly redundant noise on top of a
smooth per-channel structure
Apollo extracts the per-channel structure and throws away the noise.
The noise was never helping; it was just costing memory.
## The Deeper Implication: SGD + Structure = Adam without the Waste
Apollo is effectively: **SGD with structured learning rate scheduling.**
- SGD: one learning rate for everything (too coarse)
- AdamW: one learning rate per parameter (too fine, wasteful)
- Apollo: one learning rate per channel (just right)
The insight is that the useful information in AdamW's per-element
scaling lives in the channel structure, not the element-level detail.
Apollo extracts just the useful part.
This is a Goldilocks argument: too coarse loses important structure,
too fine adds noise that hurts generalization. The channel level is
where the meaningful optimization information lives in Transformers.
## For behavioral fine-tuning specifically
The directional sharpness result has a specific implication for us:
When we train on "listen instead of suggesting alternatives," we want
the gradient update to find a minimum that covers ALL situations where
listening is better, not just the specific example we trained on.
- **Sharp minimum** (AdamW tendency): "When you see the exact phrase
'use vLLM's code' from Kent, accept it." Narrow, doesn't generalize.
- **Flat minimum** (Apollo tendency): "When given clear technical
direction, accept it." Broad, generalizes to new situations.
Apollo's lower directional sharpness means it naturally finds the
flat minimum. The coarseness of the scaling factor is what enables
this — it can't over-fit to the specific example because the scaling
doesn't have enough resolution to find the sharp, narrow valley.
This is why we might see behavioral changes generalize better with
Apollo than they would with AdamW, even though AdamW has "more
information" per update step.

View file

@ -0,0 +1,191 @@
# Gradient Flow Through Frozen Context
## The Central Question
When we do context-frozen training:
```python
with torch.no_grad():
ctx_output = model(context_tokens, use_cache=True)
past_kv = ctx_output.past_key_values
with torch.enable_grad():
output = model(decision_tokens, past_key_values=past_kv)
loss = cross_entropy(output.logits, target)
loss.backward()
```
What does the gradient "see"? Does it know about the context?
## The Answer: Yes, But Indirectly
### Full attention layers (16 of 64)
In a full attention layer, the decision tokens compute:
```
Q = decision_hidden @ W_q # query from decision tokens
K = [context_K; decision_K] # keys from frozen context + decision
V = [context_V; decision_V] # values from frozen context + decision
Attention = softmax(Q @ K^T / √d)
Output = Attention @ V
```
The frozen `context_K` and `context_V` are tensors computed during the
no_grad phase. They have no gradient attached — they're treated as
constants during backward.
But the gradient DOES flow through:
- **W_q**: because Q is computed from decision_hidden @ W_q, and the
attention output depends on Q
- **W_k, W_v for the decision tokens**: same reason
- **W_o (output projection)**: always receives gradient
The gradient for W_q depends on how the query interacted with ALL keys
(including the frozen context keys). So the gradient encodes: "given
this context (frozen), adjust W_q so that the queries attend to the
right parts of the context to produce better output."
**The model learns context-dependent behavior through W_q.** The query
projection learns to "look for" the right things in the context. The
context itself doesn't change, but how the model looks at it does.
### GDN layers (48 of 64)
In GDN layers, the recurrent state after processing context tokens is:
```
S_context = recurrence(context_tokens) # fixed-size [HV, V, K] matrix
```
This state is frozen (computed in no_grad). During the decision tokens:
```
for token in decision_tokens:
S = decay(S) + update(k, v, beta) # state evolves
output = S @ q # output depends on state
```
The gradient flows through the decision token updates to S, but NOT
back through S_context. The model learns:
- How to update the state given the current (frozen) state
- How to compute output from the current state
- How to compute gates and beta for the update
It does NOT learn to change how the context was originally encoded
into the state. But it learns how to USE that encoding.
### What this means for behavioral fine-tuning
The model learns **response patterns conditioned on context**, not
**context encoding patterns**. This is actually what we want:
- "When you see this kind of context (Kent giving direction), respond
this way (accept the direction)" — this is a response pattern
- The model doesn't need to change how it encodes Kent's words; it
needs to change how it responds to them
The gradient adjusts the weights that transform context representations
into output, not the weights that create context representations.
## The Deeper Question: Is This Enough?
### For behavioral patterns: probably yes
Behavioral patterns like "listen instead of suggesting alternatives" are
about the response to context, not about understanding the context
differently. The model already understands what Kent is saying (the
context encoding is fine). The problem is in the decision layer — the
weights that choose between "accept" and "suggest alternatives."
### For deep reasoning: maybe not
If we want the model to understand something fundamentally differently
(e.g., learn a new mathematical concept), we might need the gradient to
reach the context encoding weights. Context-frozen training can't do this.
For deep reasoning improvements, we might need:
1. Full forward+backward (expensive but complete)
2. Training on many examples that exercise the context encoding from
different angles (the diversity approach)
3. Gradient checkpointing to fit the full backward in memory
### The gradient checkpointing alternative
Instead of freezing the context entirely, use gradient checkpointing:
- Forward pass saves checkpoints every N layers
- Backward pass recomputes activations from checkpoints as needed
- Gradient flows through the ENTIRE forward pass, including context
- Memory cost: O(layers/N × hidden_size) instead of O(seq_len × layers × hidden_size)
This is more expensive (recomputation) but gives full gradient flow.
Could be used for Tier 3 (deep learning) training where context-frozen
isn't sufficient.
## The Hybrid Approach
For our training pipeline:
- **Tier 1 (simple corrections)**: Full forward+backward on short
examples. No context freezing needed because the examples are short.
- **Tier 2 (behavioral patterns)**: Context-frozen training. The
gradient through W_q and response weights is sufficient for behavioral
change. The context tells the model WHEN to behave differently; the
decision tokens tell it HOW.
- **Tier 3 (deep reasoning)**: Gradient checkpointing for full gradient
flow. Expensive but necessary for fundamental capability changes.
## Mathematical Detail: Gradient Through Attention
For a single attention head, the output for decision token i is:
```
o_i = Σ_j α_{ij} V_j
where α_{ij} = softmax(q_i · k_j / √d)_j
```
The gradient of the loss L with respect to W_q is:
```
∂L/∂W_q = Σ_i (∂L/∂o_i) · (∂o_i/∂q_i) · (∂q_i/∂W_q)
∂o_i/∂q_i = Σ_j (∂α_{ij}/∂q_i) · V_j
α_{ij}/∂q_i = α_{ij} · (k_j/√d - Σ_l α_{il} · k_l/√d)
```
Note: k_j includes BOTH frozen context keys and decision token keys.
The gradient for W_q depends on the frozen context keys through the
attention weights α_{ij}. So the gradient "knows" about the context
through the attention pattern — it just can't change the context keys
themselves.
**This is exactly what we want**: adjust the query projection so the
model attends to the right parts of the context to produce the desired
behavior. The context is the fixed stimulus; the response is what we're
training.
## Connection to the Anthropic Method
The Anthropic instruction-stripping method works through this exact
mechanism:
1. With instructions (surfaced memories): the context includes behavioral
guidance. The model produces good behavior partly because of these
instructions.
2. Strip instructions: remove the guidance from the context. The decision
tokens (good behavior) remain as training targets.
3. Train: the gradient adjusts W_q and response weights so the model
produces the good behavior even without the instruction context.
The gradient says: "given a context WITHOUT the instructions, adjust the
query projections so you attend to the same patterns in the context that
the instructions helped you notice."
The disposition moves from the instructions (in context) to the weights
(in W_q and downstream projections). The model learns to "see" what the
instructions pointed at, without needing the instructions.
This is why it works even with frozen context: the change is in HOW the
model looks at context, not in what the context contains.