From 7c7975d98ef863f2edbc06a692ddc0cc921f1530 Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Tue, 31 Mar 2026 00:59:04 -0400 Subject: [PATCH] =?UTF-8?q?research:=20context-frozen=20training=20?= =?UTF-8?q?=E2=80=94=20gradient=20masking,=20memory=20analysis,=20GDN=20co?= =?UTF-8?q?nsiderations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- training/research/context-frozen-training.md | 164 +++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 training/research/context-frozen-training.md diff --git a/training/research/context-frozen-training.md b/training/research/context-frozen-training.md new file mode 100644 index 0000000..be2aacd --- /dev/null +++ b/training/research/context-frozen-training.md @@ -0,0 +1,164 @@ +# Context-Frozen Training + +## The Concept + +Train on specific segments of long conversations without recomputing +the entire forward pass. The conversation context is "frozen" — it +contributes to the forward activations but gradients don't flow through it. + +## The Problem It Solves + +A conversation might be 10,000 tokens long. We want to train on a +50-token decision segment where the model should have listened instead +of suggesting alternatives. Standard fine-tuning would require: + +1. Forward pass through all 10,000 tokens (computing activations) +2. Backward pass through all 10,000 tokens (computing gradients) +3. Memory for activations of all 10,000 tokens + +With 27B parameters and 10K context, activation memory alone could +exceed 100GB. This is prohibitive. + +## The Solution + +Split the forward pass into two phases: + +### Phase 1: Context forward (no gradients) + +```python +with torch.no_grad(): + outputs = model(context_tokens, use_cache=True) + past_kv = outputs.past_key_values +``` + +This computes the KV cache for the context tokens. No gradient tracking, +no activation storage for backward. Memory cost: just the KV cache +(which is relatively small — a few GB for 10K tokens on Qwen3.5). + +### Phase 2: Decision tokens forward (with gradients) + +```python +with torch.enable_grad(): + outputs = model(decision_tokens, past_key_values=past_kv) + loss = cross_entropy(outputs.logits, target_tokens) + loss.backward() +``` + +This computes the forward pass for ONLY the decision tokens (50-256), +using the frozen KV cache as context. Gradients flow through these tokens +and their activations, but NOT through the context. + +## Memory Analysis + +For Qwen3.5-27B with 10K context and 100 decision tokens: + +### Without context freezing: +- Activations for backward: ~10100 tokens × 64 layers × hidden state + = hundreds of GB. **Doesn't fit.** + +### With context freezing: +- KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim) + = ~10-20GB (depends on GDN vs full attention split) +- Activations for backward: ~100 tokens × 64 layers × hidden state + = ~1-2GB with gradient checkpointing +- **Fits easily alongside vLLM.** + +## The Gradient Signal + +An important subtlety: the gradient only flows through the decision +tokens. This means: + +1. **Only the weights active for the decision tokens receive gradients.** + The context affects which weights are active (through the KV cache), + but the gradient magnitude comes only from the decision segment. + +2. **The context implicitly shapes the gradient.** Because the KV cache + from the context is used during the decision token forward pass, the + gradient for the decision tokens is context-dependent. The model + learns "in this context, respond this way" — not just "always respond + this way." + +3. **Gradient sparsity is maximized.** Short decision segments activate + a small fraction of the model's capacity, producing naturally sparse + gradients. This helps with both catastrophic forgetting (limited + weight perturbation) and HOGWILD convergence (sparse updates). + +## Implementation with HuggingFace Models + +The HF Qwen3.5 model supports `past_key_values` and `use_cache=True`. +The implementation is straightforward: + +```python +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B") +context_ids = tokenizer.encode(context_text) +decision_ids = tokenizer.encode(decision_text) + +# Phase 1: context (no grad) +with torch.no_grad(): + ctx_output = model( + torch.tensor([context_ids], device="cuda"), + use_cache=True + ) + past_kv = ctx_output.past_key_values + +# Phase 2: decision tokens (with grad) +decision_input = torch.tensor([decision_ids], device="cuda") +with torch.enable_grad(): + output = model(decision_input, past_key_values=past_kv, use_cache=False) + # Shift logits and labels for next-token prediction + logits = output.logits[:, :-1] + labels = decision_input[:, 1:] + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) + loss.backward() +``` + +## GDN Layer Considerations + +For the 48 GDN (linear attention) layers, the "KV cache" is actually +the recurrent state — a fixed-size [HV, V, K] tensor per layer. This +doesn't grow with context length. The GDN layers are inherently efficient +for context-frozen training because: + +1. The recurrent state after processing the context tokens encodes the + full context in a fixed-size matrix +2. During the decision token phase, the GDN layers update this state + and produce output in O(1) per token +3. No attention over the full context is needed + +For the 16 full attention layers, the KV cache DOES grow with context. +These are the memory bottleneck for very long contexts. + +## Connection to the Fused Inference/Training Design + +The fused design (Mar 27) proposed that the KV cache from inference +could be reused for training. With CUDA IPC weight sharing, this is +technically possible: + +1. vLLM processes a conversation, building KV cache +2. The training process imports the KV cache (or a copy of the recurrent + states) via IPC +3. Training runs the decision token phase using the imported cache +4. No recomputation of the context forward pass + +However, this requires vLLM to export its KV cache / recurrent states, +which is more complex than exporting weight handles. For now, the simpler +approach is to recompute the context forward pass in the training process +(Phase 1 above). The cost is moderate — context forward without gradient +tracking is fast. + +## The Anthropic Method Connection + +The Anthropic safety fine-tuning approach (generate behavior with +instructions, train without instructions) maps directly: + +1. **With instructions**: The context includes surfaced memories and + core-personality guidance. The model produces good behavior. +2. **Strip instructions**: Remove the surfaced memories from the context. + The decision tokens (the good response) remain. +3. **Train**: Forward pass through the stripped context (frozen), then + loss on the decision tokens. The model learns to produce the good + behavior without the instruction scaffolding. + +The context-frozen approach makes this efficient: the stripped context +is processed once (no grad), and only the decision tokens contribute +to the gradient.