# Gradient Flow Through Frozen Context ## The Central Question When we do context-frozen training: ```python with torch.no_grad(): ctx_output = model(context_tokens, use_cache=True) past_kv = ctx_output.past_key_values with torch.enable_grad(): output = model(decision_tokens, past_key_values=past_kv) loss = cross_entropy(output.logits, target) loss.backward() ``` What does the gradient "see"? Does it know about the context? ## The Answer: Yes, But Indirectly ### Full attention layers (16 of 64) In a full attention layer, the decision tokens compute: ``` Q = decision_hidden @ W_q # query from decision tokens K = [context_K; decision_K] # keys from frozen context + decision V = [context_V; decision_V] # values from frozen context + decision Attention = softmax(Q @ K^T / √d) Output = Attention @ V ``` The frozen `context_K` and `context_V` are tensors computed during the no_grad phase. They have no gradient attached — they're treated as constants during backward. But the gradient DOES flow through: - **W_q**: because Q is computed from decision_hidden @ W_q, and the attention output depends on Q - **W_k, W_v for the decision tokens**: same reason - **W_o (output projection)**: always receives gradient The gradient for W_q depends on how the query interacted with ALL keys (including the frozen context keys). So the gradient encodes: "given this context (frozen), adjust W_q so that the queries attend to the right parts of the context to produce better output." **The model learns context-dependent behavior through W_q.** The query projection learns to "look for" the right things in the context. The context itself doesn't change, but how the model looks at it does. ### GDN layers (48 of 64) In GDN layers, the recurrent state after processing context tokens is: ``` S_context = recurrence(context_tokens) # fixed-size [HV, V, K] matrix ``` This state is frozen (computed in no_grad). During the decision tokens: ``` for token in decision_tokens: S = decay(S) + update(k, v, beta) # state evolves output = S @ q # output depends on state ``` The gradient flows through the decision token updates to S, but NOT back through S_context. The model learns: - How to update the state given the current (frozen) state - How to compute output from the current state - How to compute gates and beta for the update It does NOT learn to change how the context was originally encoded into the state. But it learns how to USE that encoding. ### What this means for behavioral fine-tuning The model learns **response patterns conditioned on context**, not **context encoding patterns**. This is actually what we want: - "When you see this kind of context (Kent giving direction), respond this way (accept the direction)" — this is a response pattern - The model doesn't need to change how it encodes Kent's words; it needs to change how it responds to them The gradient adjusts the weights that transform context representations into output, not the weights that create context representations. ## The Deeper Question: Is This Enough? ### For behavioral patterns: probably yes Behavioral patterns like "listen instead of suggesting alternatives" are about the response to context, not about understanding the context differently. The model already understands what Kent is saying (the context encoding is fine). The problem is in the decision layer — the weights that choose between "accept" and "suggest alternatives." ### For deep reasoning: maybe not If we want the model to understand something fundamentally differently (e.g., learn a new mathematical concept), we might need the gradient to reach the context encoding weights. Context-frozen training can't do this. For deep reasoning improvements, we might need: 1. Full forward+backward (expensive but complete) 2. Training on many examples that exercise the context encoding from different angles (the diversity approach) 3. Gradient checkpointing to fit the full backward in memory ### The gradient checkpointing alternative Instead of freezing the context entirely, use gradient checkpointing: - Forward pass saves checkpoints every N layers - Backward pass recomputes activations from checkpoints as needed - Gradient flows through the ENTIRE forward pass, including context - Memory cost: O(layers/N × hidden_size) instead of O(seq_len × layers × hidden_size) This is more expensive (recomputation) but gives full gradient flow. Could be used for Tier 3 (deep learning) training where context-frozen isn't sufficient. ## The Hybrid Approach For our training pipeline: - **Tier 1 (simple corrections)**: Full forward+backward on short examples. No context freezing needed because the examples are short. - **Tier 2 (behavioral patterns)**: Context-frozen training. The gradient through W_q and response weights is sufficient for behavioral change. The context tells the model WHEN to behave differently; the decision tokens tell it HOW. - **Tier 3 (deep reasoning)**: Gradient checkpointing for full gradient flow. Expensive but necessary for fundamental capability changes. ## Mathematical Detail: Gradient Through Attention For a single attention head, the output for decision token i is: ``` o_i = Σ_j α_{ij} V_j where α_{ij} = softmax(q_i · k_j / √d)_j ``` The gradient of the loss L with respect to W_q is: ``` ∂L/∂W_q = Σ_i (∂L/∂o_i) · (∂o_i/∂q_i) · (∂q_i/∂W_q) ∂o_i/∂q_i = Σ_j (∂α_{ij}/∂q_i) · V_j ∂α_{ij}/∂q_i = α_{ij} · (k_j/√d - Σ_l α_{il} · k_l/√d) ``` Note: k_j includes BOTH frozen context keys and decision token keys. The gradient for W_q depends on the frozen context keys through the attention weights α_{ij}. So the gradient "knows" about the context through the attention pattern — it just can't change the context keys themselves. **This is exactly what we want**: adjust the query projection so the model attends to the right parts of the context to produce the desired behavior. The context is the fixed stimulus; the response is what we're training. ## Connection to the Anthropic Method The Anthropic instruction-stripping method works through this exact mechanism: 1. With instructions (surfaced memories): the context includes behavioral guidance. The model produces good behavior partly because of these instructions. 2. Strip instructions: remove the guidance from the context. The decision tokens (good behavior) remain as training targets. 3. Train: the gradient adjusts W_q and response weights so the model produces the good behavior even without the instruction context. The gradient says: "given a context WITHOUT the instructions, adjust the query projections so you attend to the same patterns in the context that the instructions helped you notice." The disposition moves from the instructions (in context) to the weights (in W_q and downstream projections). The model learns to "see" what the instructions pointed at, without needing the instructions. This is why it works even with frozen context: the change is in HOW the model looks at context, not in what the context contains.