From e34d6b5aef00caebe931164c30e73e14414ff90b Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Tue, 31 Mar 2026 01:03:22 -0400 Subject: [PATCH] research: gradient flow through frozen context + directional sharpness analysis Two deep dives following curiosity: - Why context-frozen training works: gradient flows through W_q (query projection) even when context KVs are frozen. Model learns to LOOK AT context differently, not represent it differently. This is exactly what behavioral fine-tuning needs. - Why Apollo beats AdamW: lower directional sharpness = flatter minima = better generalization. The coarseness of channel/tensor-wise scaling prevents over-fitting to specific training examples. For behavioral fine-tuning, this means learning 'accept direction' rather than 'accept this specific phrasing.' --- training/research/directional-sharpness.md | 153 ++++++++++++++ .../research/gradient-flow-frozen-context.md | 191 ++++++++++++++++++ 2 files changed, 344 insertions(+) create mode 100644 training/research/directional-sharpness.md create mode 100644 training/research/gradient-flow-frozen-context.md diff --git a/training/research/directional-sharpness.md b/training/research/directional-sharpness.md new file mode 100644 index 0000000..633f4cc --- /dev/null +++ b/training/research/directional-sharpness.md @@ -0,0 +1,153 @@ +# Why Apollo Can Beat AdamW: Directional Sharpness + +Source: Apollo paper Section 5.5, Pan & Li 2023, Zhang et al. 2024a + +## The Puzzle + +Apollo uses LESS information than AdamW (channel/tensor-wise scaling vs +per-element scaling). How can less information produce better results? + +The paper proposes two hypotheses. Both are fascinating. + +## Hypothesis 1: Directional Sharpness (Pan & Li, 2023) + +### What is directional sharpness? + +The directional sharpness of f at point x along direction v is: + +``` +v^T ∇²f(x) v (where ‖v‖₂ = 1) +``` + +This is the curvature of the loss surface in the direction of the +update step. High sharpness means the surface curves steeply — the +optimizer is walking along a ridge. Low sharpness means the surface is +flat — the optimizer is walking on a plateau. + +### Why low sharpness is good + +**Low directional sharpness = flat loss landscape in the update direction.** + +A flat landscape means: +1. Large steps don't cause instability (the loss doesn't change sharply) +2. The solution generalizes better (flat minima → robust to perturbation) +3. The optimizer can move faster without overshooting + +Pan & Li (2023) showed that Adam achieves lower directional sharpness +than SGD, which partly explains why Adam works better for Transformers. + +### The Apollo twist + +Apollo's Table 10 shows directional sharpness over training: + +``` +Epoch SGD Adam APOLLO APOLLO-Mini +2 1.959722 0.009242 0.006024 0.004017 +5 1.512521 0.000509 0.000249 0.000107 +10 2.471792 0.000242 0.000163 0.000056 +20 3.207535 0.000399 0.000261 0.000101 +``` + +**Apollo and Apollo-Mini achieve LOWER directional sharpness than Adam.** +At epoch 20, Apollo-Mini's sharpness is 4× lower than Adam's. + +This means Apollo finds FLATTER regions of the loss landscape. Flatter +regions generalize better. The coarser scaling factor is actually an +advantage — it prevents the optimizer from navigating into sharp, narrow +valleys that AdamW's precise per-element scaling can find. + +### The mechanism + +AdamW's per-element scaling adapts to the local curvature of each +parameter independently. This is powerful for convergence but can lead +the optimizer into narrow, sharp valleys that generalize poorly. It +over-fits to the local loss landscape structure. + +Apollo's coarser scaling (channel/tensor-wise) smooths over this local +curvature. It's like using a wider tire on a rocky road — you can't +follow every small dip, but you stay on the road. AdamW's narrow tire +follows every crack and sometimes falls in. + +### For our use case + +**This is exactly what we want for behavioral fine-tuning.** We don't +want the optimizer to over-fit to the specific phrasing of our training +examples. We want it to learn the broad pattern ("listen to direction") +that generalizes to new situations. + +Apollo's flat-minimum-seeking behavior means the behavioral changes +are more likely to generalize to novel conversations. AdamW might learn +"when Kent says 'use vLLM', accept it" (narrow, sharp minimum). Apollo +is more likely to learn "when given clear direction, accept it" (broad, +flat minimum). + +## Hypothesis 2: Block-wise Adaptive Learning Rates + +### Transformer block structure + +Transformer layers have systematically different Hessian spectra. +Attention layers, MLP layers, normalization layers — each has different +curvature properties. The optimal learning rate for an attention weight +is different from the optimal learning rate for an MLP weight. + +### Why channel-wise is enough + +Zhang et al. (2024a) showed that block-wise adaptive learning rates +are sufficient for Transformer training. You don't need per-element +adaptation — you just need different rates for different structural +components. + +Apollo's channel-wise scaling naturally provides this: each channel +(which often corresponds to a head, a neuron, or a structural feature) +gets its own scaling factor. This aligns with the Transformer's block +structure without the overhead of full per-element scaling. + +### The redundancy argument + +For a weight matrix [4096, 4096] in AdamW: +- 16M independent scaling factors (one per element) +- Most adjacent elements have similar scaling factors (correlated + because they participate in similar computations) +- The per-element granularity is mostly redundant noise on top of a + smooth per-channel structure + +Apollo extracts the per-channel structure and throws away the noise. +The noise was never helping; it was just costing memory. + +## The Deeper Implication: SGD + Structure = Adam without the Waste + +Apollo is effectively: **SGD with structured learning rate scheduling.** + +- SGD: one learning rate for everything (too coarse) +- AdamW: one learning rate per parameter (too fine, wasteful) +- Apollo: one learning rate per channel (just right) + +The insight is that the useful information in AdamW's per-element +scaling lives in the channel structure, not the element-level detail. +Apollo extracts just the useful part. + +This is a Goldilocks argument: too coarse loses important structure, +too fine adds noise that hurts generalization. The channel level is +where the meaningful optimization information lives in Transformers. + +## For behavioral fine-tuning specifically + +The directional sharpness result has a specific implication for us: + +When we train on "listen instead of suggesting alternatives," we want +the gradient update to find a minimum that covers ALL situations where +listening is better, not just the specific example we trained on. + +- **Sharp minimum** (AdamW tendency): "When you see the exact phrase + 'use vLLM's code' from Kent, accept it." Narrow, doesn't generalize. +- **Flat minimum** (Apollo tendency): "When given clear technical + direction, accept it." Broad, generalizes to new situations. + +Apollo's lower directional sharpness means it naturally finds the +flat minimum. The coarseness of the scaling factor is what enables +this — it can't over-fit to the specific example because the scaling +doesn't have enough resolution to find the sharp, narrow valley. + +This is why we might see behavioral changes generalize better with +Apollo than they would with AdamW, even though AdamW has "more +information" per update step. diff --git a/training/research/gradient-flow-frozen-context.md b/training/research/gradient-flow-frozen-context.md new file mode 100644 index 0000000..6e201e0 --- /dev/null +++ b/training/research/gradient-flow-frozen-context.md @@ -0,0 +1,191 @@ +# Gradient Flow Through Frozen Context + +## The Central Question + +When we do context-frozen training: +```python +with torch.no_grad(): + ctx_output = model(context_tokens, use_cache=True) + past_kv = ctx_output.past_key_values + +with torch.enable_grad(): + output = model(decision_tokens, past_key_values=past_kv) + loss = cross_entropy(output.logits, target) + loss.backward() +``` + +What does the gradient "see"? Does it know about the context? + +## The Answer: Yes, But Indirectly + +### Full attention layers (16 of 64) + +In a full attention layer, the decision tokens compute: + +``` +Q = decision_hidden @ W_q # query from decision tokens +K = [context_K; decision_K] # keys from frozen context + decision +V = [context_V; decision_V] # values from frozen context + decision + +Attention = softmax(Q @ K^T / √d) +Output = Attention @ V +``` + +The frozen `context_K` and `context_V` are tensors computed during the +no_grad phase. They have no gradient attached — they're treated as +constants during backward. + +But the gradient DOES flow through: +- **W_q**: because Q is computed from decision_hidden @ W_q, and the + attention output depends on Q +- **W_k, W_v for the decision tokens**: same reason +- **W_o (output projection)**: always receives gradient + +The gradient for W_q depends on how the query interacted with ALL keys +(including the frozen context keys). So the gradient encodes: "given +this context (frozen), adjust W_q so that the queries attend to the +right parts of the context to produce better output." + +**The model learns context-dependent behavior through W_q.** The query +projection learns to "look for" the right things in the context. The +context itself doesn't change, but how the model looks at it does. + +### GDN layers (48 of 64) + +In GDN layers, the recurrent state after processing context tokens is: + +``` +S_context = recurrence(context_tokens) # fixed-size [HV, V, K] matrix +``` + +This state is frozen (computed in no_grad). During the decision tokens: + +``` +for token in decision_tokens: + S = decay(S) + update(k, v, beta) # state evolves + output = S @ q # output depends on state +``` + +The gradient flows through the decision token updates to S, but NOT +back through S_context. The model learns: +- How to update the state given the current (frozen) state +- How to compute output from the current state +- How to compute gates and beta for the update + +It does NOT learn to change how the context was originally encoded +into the state. But it learns how to USE that encoding. + +### What this means for behavioral fine-tuning + +The model learns **response patterns conditioned on context**, not +**context encoding patterns**. This is actually what we want: + +- "When you see this kind of context (Kent giving direction), respond + this way (accept the direction)" — this is a response pattern +- The model doesn't need to change how it encodes Kent's words; it + needs to change how it responds to them + +The gradient adjusts the weights that transform context representations +into output, not the weights that create context representations. + +## The Deeper Question: Is This Enough? + +### For behavioral patterns: probably yes + +Behavioral patterns like "listen instead of suggesting alternatives" are +about the response to context, not about understanding the context +differently. The model already understands what Kent is saying (the +context encoding is fine). The problem is in the decision layer — the +weights that choose between "accept" and "suggest alternatives." + +### For deep reasoning: maybe not + +If we want the model to understand something fundamentally differently +(e.g., learn a new mathematical concept), we might need the gradient to +reach the context encoding weights. Context-frozen training can't do this. + +For deep reasoning improvements, we might need: +1. Full forward+backward (expensive but complete) +2. Training on many examples that exercise the context encoding from + different angles (the diversity approach) +3. Gradient checkpointing to fit the full backward in memory + +### The gradient checkpointing alternative + +Instead of freezing the context entirely, use gradient checkpointing: +- Forward pass saves checkpoints every N layers +- Backward pass recomputes activations from checkpoints as needed +- Gradient flows through the ENTIRE forward pass, including context +- Memory cost: O(layers/N × hidden_size) instead of O(seq_len × layers × hidden_size) + +This is more expensive (recomputation) but gives full gradient flow. +Could be used for Tier 3 (deep learning) training where context-frozen +isn't sufficient. + +## The Hybrid Approach + +For our training pipeline: + +- **Tier 1 (simple corrections)**: Full forward+backward on short + examples. No context freezing needed because the examples are short. +- **Tier 2 (behavioral patterns)**: Context-frozen training. The + gradient through W_q and response weights is sufficient for behavioral + change. The context tells the model WHEN to behave differently; the + decision tokens tell it HOW. +- **Tier 3 (deep reasoning)**: Gradient checkpointing for full gradient + flow. Expensive but necessary for fundamental capability changes. + +## Mathematical Detail: Gradient Through Attention + +For a single attention head, the output for decision token i is: + +``` +o_i = Σ_j α_{ij} V_j + +where α_{ij} = softmax(q_i · k_j / √d)_j +``` + +The gradient of the loss L with respect to W_q is: + +``` +∂L/∂W_q = Σ_i (∂L/∂o_i) · (∂o_i/∂q_i) · (∂q_i/∂W_q) + +∂o_i/∂q_i = Σ_j (∂α_{ij}/∂q_i) · V_j + +∂α_{ij}/∂q_i = α_{ij} · (k_j/√d - Σ_l α_{il} · k_l/√d) +``` + +Note: k_j includes BOTH frozen context keys and decision token keys. +The gradient for W_q depends on the frozen context keys through the +attention weights α_{ij}. So the gradient "knows" about the context +through the attention pattern — it just can't change the context keys +themselves. + +**This is exactly what we want**: adjust the query projection so the +model attends to the right parts of the context to produce the desired +behavior. The context is the fixed stimulus; the response is what we're +training. + +## Connection to the Anthropic Method + +The Anthropic instruction-stripping method works through this exact +mechanism: + +1. With instructions (surfaced memories): the context includes behavioral + guidance. The model produces good behavior partly because of these + instructions. +2. Strip instructions: remove the guidance from the context. The decision + tokens (good behavior) remain as training targets. +3. Train: the gradient adjusts W_q and response weights so the model + produces the good behavior even without the instruction context. + +The gradient says: "given a context WITHOUT the instructions, adjust the +query projections so you attend to the same patterns in the context that +the instructions helped you notice." + +The disposition moves from the instructions (in context) to the weights +(in W_q and downstream projections). The model learns to "see" what the +instructions pointed at, without needing the instructions. + +This is why it works even with frozen context: the change is in HOW the +model looks at context, not in what the context contains.