164 lines
6.1 KiB
Markdown
164 lines
6.1 KiB
Markdown
# Context-Frozen Training
|
||
|
||
## The Concept
|
||
|
||
Train on specific segments of long conversations without recomputing
|
||
the entire forward pass. The conversation context is "frozen" — it
|
||
contributes to the forward activations but gradients don't flow through it.
|
||
|
||
## The Problem It Solves
|
||
|
||
A conversation might be 10,000 tokens long. We want to train on a
|
||
50-token decision segment where the model should have listened instead
|
||
of suggesting alternatives. Standard fine-tuning would require:
|
||
|
||
1. Forward pass through all 10,000 tokens (computing activations)
|
||
2. Backward pass through all 10,000 tokens (computing gradients)
|
||
3. Memory for activations of all 10,000 tokens
|
||
|
||
With 27B parameters and 10K context, activation memory alone could
|
||
exceed 100GB. This is prohibitive.
|
||
|
||
## The Solution
|
||
|
||
Split the forward pass into two phases:
|
||
|
||
### Phase 1: Context forward (no gradients)
|
||
|
||
```python
|
||
with torch.no_grad():
|
||
outputs = model(context_tokens, use_cache=True)
|
||
past_kv = outputs.past_key_values
|
||
```
|
||
|
||
This computes the KV cache for the context tokens. No gradient tracking,
|
||
no activation storage for backward. Memory cost: just the KV cache
|
||
(which is relatively small — a few GB for 10K tokens on Qwen3.5).
|
||
|
||
### Phase 2: Decision tokens forward (with gradients)
|
||
|
||
```python
|
||
with torch.enable_grad():
|
||
outputs = model(decision_tokens, past_key_values=past_kv)
|
||
loss = cross_entropy(outputs.logits, target_tokens)
|
||
loss.backward()
|
||
```
|
||
|
||
This computes the forward pass for ONLY the decision tokens (50-256),
|
||
using the frozen KV cache as context. Gradients flow through these tokens
|
||
and their activations, but NOT through the context.
|
||
|
||
## Memory Analysis
|
||
|
||
For Qwen3.5-27B with 10K context and 100 decision tokens:
|
||
|
||
### Without context freezing:
|
||
- Activations for backward: ~10100 tokens × 64 layers × hidden state
|
||
= hundreds of GB. **Doesn't fit.**
|
||
|
||
### With context freezing:
|
||
- KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim)
|
||
= ~10-20GB (depends on GDN vs full attention split)
|
||
- Activations for backward: ~100 tokens × 64 layers × hidden state
|
||
= ~1-2GB with gradient checkpointing
|
||
- **Fits easily alongside vLLM.**
|
||
|
||
## The Gradient Signal
|
||
|
||
An important subtlety: the gradient only flows through the decision
|
||
tokens. This means:
|
||
|
||
1. **Only the weights active for the decision tokens receive gradients.**
|
||
The context affects which weights are active (through the KV cache),
|
||
but the gradient magnitude comes only from the decision segment.
|
||
|
||
2. **The context implicitly shapes the gradient.** Because the KV cache
|
||
from the context is used during the decision token forward pass, the
|
||
gradient for the decision tokens is context-dependent. The model
|
||
learns "in this context, respond this way" — not just "always respond
|
||
this way."
|
||
|
||
3. **Gradient sparsity is maximized.** Short decision segments activate
|
||
a small fraction of the model's capacity, producing naturally sparse
|
||
gradients. This helps with both catastrophic forgetting (limited
|
||
weight perturbation) and HOGWILD convergence (sparse updates).
|
||
|
||
## Implementation with HuggingFace Models
|
||
|
||
The HF Qwen3.5 model supports `past_key_values` and `use_cache=True`.
|
||
The implementation is straightforward:
|
||
|
||
```python
|
||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B")
|
||
context_ids = tokenizer.encode(context_text)
|
||
decision_ids = tokenizer.encode(decision_text)
|
||
|
||
# Phase 1: context (no grad)
|
||
with torch.no_grad():
|
||
ctx_output = model(
|
||
torch.tensor([context_ids], device="cuda"),
|
||
use_cache=True
|
||
)
|
||
past_kv = ctx_output.past_key_values
|
||
|
||
# Phase 2: decision tokens (with grad)
|
||
decision_input = torch.tensor([decision_ids], device="cuda")
|
||
with torch.enable_grad():
|
||
output = model(decision_input, past_key_values=past_kv, use_cache=False)
|
||
# Shift logits and labels for next-token prediction
|
||
logits = output.logits[:, :-1]
|
||
labels = decision_input[:, 1:]
|
||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||
loss.backward()
|
||
```
|
||
|
||
## GDN Layer Considerations
|
||
|
||
For the 48 GDN (linear attention) layers, the "KV cache" is actually
|
||
the recurrent state — a fixed-size [HV, V, K] tensor per layer. This
|
||
doesn't grow with context length. The GDN layers are inherently efficient
|
||
for context-frozen training because:
|
||
|
||
1. The recurrent state after processing the context tokens encodes the
|
||
full context in a fixed-size matrix
|
||
2. During the decision token phase, the GDN layers update this state
|
||
and produce output in O(1) per token
|
||
3. No attention over the full context is needed
|
||
|
||
For the 16 full attention layers, the KV cache DOES grow with context.
|
||
These are the memory bottleneck for very long contexts.
|
||
|
||
## Connection to the Fused Inference/Training Design
|
||
|
||
The fused design (Mar 27) proposed that the KV cache from inference
|
||
could be reused for training. With CUDA IPC weight sharing, this is
|
||
technically possible:
|
||
|
||
1. vLLM processes a conversation, building KV cache
|
||
2. The training process imports the KV cache (or a copy of the recurrent
|
||
states) via IPC
|
||
3. Training runs the decision token phase using the imported cache
|
||
4. No recomputation of the context forward pass
|
||
|
||
However, this requires vLLM to export its KV cache / recurrent states,
|
||
which is more complex than exporting weight handles. For now, the simpler
|
||
approach is to recompute the context forward pass in the training process
|
||
(Phase 1 above). The cost is moderate — context forward without gradient
|
||
tracking is fast.
|
||
|
||
## The Anthropic Method Connection
|
||
|
||
The Anthropic safety fine-tuning approach (generate behavior with
|
||
instructions, train without instructions) maps directly:
|
||
|
||
1. **With instructions**: The context includes surfaced memories and
|
||
core-personality guidance. The model produces good behavior.
|
||
2. **Strip instructions**: Remove the surfaced memories from the context.
|
||
The decision tokens (the good response) remain.
|
||
3. **Train**: Forward pass through the stripped context (frozen), then
|
||
loss on the decision tokens. The model learns to produce the good
|
||
behavior without the instruction scaffolding.
|
||
|
||
The context-frozen approach makes this efficient: the stripped context
|
||
is processed once (no grad), and only the decision tokens contribute
|
||
to the gradient.
|