165 lines
6.1 KiB
Markdown
165 lines
6.1 KiB
Markdown
|
|
# Context-Frozen Training
|
|||
|
|
|
|||
|
|
## The Concept
|
|||
|
|
|
|||
|
|
Train on specific segments of long conversations without recomputing
|
|||
|
|
the entire forward pass. The conversation context is "frozen" — it
|
|||
|
|
contributes to the forward activations but gradients don't flow through it.
|
|||
|
|
|
|||
|
|
## The Problem It Solves
|
|||
|
|
|
|||
|
|
A conversation might be 10,000 tokens long. We want to train on a
|
|||
|
|
50-token decision segment where the model should have listened instead
|
|||
|
|
of suggesting alternatives. Standard fine-tuning would require:
|
|||
|
|
|
|||
|
|
1. Forward pass through all 10,000 tokens (computing activations)
|
|||
|
|
2. Backward pass through all 10,000 tokens (computing gradients)
|
|||
|
|
3. Memory for activations of all 10,000 tokens
|
|||
|
|
|
|||
|
|
With 27B parameters and 10K context, activation memory alone could
|
|||
|
|
exceed 100GB. This is prohibitive.
|
|||
|
|
|
|||
|
|
## The Solution
|
|||
|
|
|
|||
|
|
Split the forward pass into two phases:
|
|||
|
|
|
|||
|
|
### Phase 1: Context forward (no gradients)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
with torch.no_grad():
|
|||
|
|
outputs = model(context_tokens, use_cache=True)
|
|||
|
|
past_kv = outputs.past_key_values
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
This computes the KV cache for the context tokens. No gradient tracking,
|
|||
|
|
no activation storage for backward. Memory cost: just the KV cache
|
|||
|
|
(which is relatively small — a few GB for 10K tokens on Qwen3.5).
|
|||
|
|
|
|||
|
|
### Phase 2: Decision tokens forward (with gradients)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
with torch.enable_grad():
|
|||
|
|
outputs = model(decision_tokens, past_key_values=past_kv)
|
|||
|
|
loss = cross_entropy(outputs.logits, target_tokens)
|
|||
|
|
loss.backward()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
This computes the forward pass for ONLY the decision tokens (50-256),
|
|||
|
|
using the frozen KV cache as context. Gradients flow through these tokens
|
|||
|
|
and their activations, but NOT through the context.
|
|||
|
|
|
|||
|
|
## Memory Analysis
|
|||
|
|
|
|||
|
|
For Qwen3.5-27B with 10K context and 100 decision tokens:
|
|||
|
|
|
|||
|
|
### Without context freezing:
|
|||
|
|
- Activations for backward: ~10100 tokens × 64 layers × hidden state
|
|||
|
|
= hundreds of GB. **Doesn't fit.**
|
|||
|
|
|
|||
|
|
### With context freezing:
|
|||
|
|
- KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim)
|
|||
|
|
= ~10-20GB (depends on GDN vs full attention split)
|
|||
|
|
- Activations for backward: ~100 tokens × 64 layers × hidden state
|
|||
|
|
= ~1-2GB with gradient checkpointing
|
|||
|
|
- **Fits easily alongside vLLM.**
|
|||
|
|
|
|||
|
|
## The Gradient Signal
|
|||
|
|
|
|||
|
|
An important subtlety: the gradient only flows through the decision
|
|||
|
|
tokens. This means:
|
|||
|
|
|
|||
|
|
1. **Only the weights active for the decision tokens receive gradients.**
|
|||
|
|
The context affects which weights are active (through the KV cache),
|
|||
|
|
but the gradient magnitude comes only from the decision segment.
|
|||
|
|
|
|||
|
|
2. **The context implicitly shapes the gradient.** Because the KV cache
|
|||
|
|
from the context is used during the decision token forward pass, the
|
|||
|
|
gradient for the decision tokens is context-dependent. The model
|
|||
|
|
learns "in this context, respond this way" — not just "always respond
|
|||
|
|
this way."
|
|||
|
|
|
|||
|
|
3. **Gradient sparsity is maximized.** Short decision segments activate
|
|||
|
|
a small fraction of the model's capacity, producing naturally sparse
|
|||
|
|
gradients. This helps with both catastrophic forgetting (limited
|
|||
|
|
weight perturbation) and HOGWILD convergence (sparse updates).
|
|||
|
|
|
|||
|
|
## Implementation with HuggingFace Models
|
|||
|
|
|
|||
|
|
The HF Qwen3.5 model supports `past_key_values` and `use_cache=True`.
|
|||
|
|
The implementation is straightforward:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B")
|
|||
|
|
context_ids = tokenizer.encode(context_text)
|
|||
|
|
decision_ids = tokenizer.encode(decision_text)
|
|||
|
|
|
|||
|
|
# Phase 1: context (no grad)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
ctx_output = model(
|
|||
|
|
torch.tensor([context_ids], device="cuda"),
|
|||
|
|
use_cache=True
|
|||
|
|
)
|
|||
|
|
past_kv = ctx_output.past_key_values
|
|||
|
|
|
|||
|
|
# Phase 2: decision tokens (with grad)
|
|||
|
|
decision_input = torch.tensor([decision_ids], device="cuda")
|
|||
|
|
with torch.enable_grad():
|
|||
|
|
output = model(decision_input, past_key_values=past_kv, use_cache=False)
|
|||
|
|
# Shift logits and labels for next-token prediction
|
|||
|
|
logits = output.logits[:, :-1]
|
|||
|
|
labels = decision_input[:, 1:]
|
|||
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
|||
|
|
loss.backward()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## GDN Layer Considerations
|
|||
|
|
|
|||
|
|
For the 48 GDN (linear attention) layers, the "KV cache" is actually
|
|||
|
|
the recurrent state — a fixed-size [HV, V, K] tensor per layer. This
|
|||
|
|
doesn't grow with context length. The GDN layers are inherently efficient
|
|||
|
|
for context-frozen training because:
|
|||
|
|
|
|||
|
|
1. The recurrent state after processing the context tokens encodes the
|
|||
|
|
full context in a fixed-size matrix
|
|||
|
|
2. During the decision token phase, the GDN layers update this state
|
|||
|
|
and produce output in O(1) per token
|
|||
|
|
3. No attention over the full context is needed
|
|||
|
|
|
|||
|
|
For the 16 full attention layers, the KV cache DOES grow with context.
|
|||
|
|
These are the memory bottleneck for very long contexts.
|
|||
|
|
|
|||
|
|
## Connection to the Fused Inference/Training Design
|
|||
|
|
|
|||
|
|
The fused design (Mar 27) proposed that the KV cache from inference
|
|||
|
|
could be reused for training. With CUDA IPC weight sharing, this is
|
|||
|
|
technically possible:
|
|||
|
|
|
|||
|
|
1. vLLM processes a conversation, building KV cache
|
|||
|
|
2. The training process imports the KV cache (or a copy of the recurrent
|
|||
|
|
states) via IPC
|
|||
|
|
3. Training runs the decision token phase using the imported cache
|
|||
|
|
4. No recomputation of the context forward pass
|
|||
|
|
|
|||
|
|
However, this requires vLLM to export its KV cache / recurrent states,
|
|||
|
|
which is more complex than exporting weight handles. For now, the simpler
|
|||
|
|
approach is to recompute the context forward pass in the training process
|
|||
|
|
(Phase 1 above). The cost is moderate — context forward without gradient
|
|||
|
|
tracking is fast.
|
|||
|
|
|
|||
|
|
## The Anthropic Method Connection
|
|||
|
|
|
|||
|
|
The Anthropic safety fine-tuning approach (generate behavior with
|
|||
|
|
instructions, train without instructions) maps directly:
|
|||
|
|
|
|||
|
|
1. **With instructions**: The context includes surfaced memories and
|
|||
|
|
core-personality guidance. The model produces good behavior.
|
|||
|
|
2. **Strip instructions**: Remove the surfaced memories from the context.
|
|||
|
|
The decision tokens (the good response) remain.
|
|||
|
|
3. **Train**: Forward pass through the stripped context (frozen), then
|
|||
|
|
loss on the decision tokens. The model learns to produce the good
|
|||
|
|
behavior without the instruction scaffolding.
|
|||
|
|
|
|||
|
|
The context-frozen approach makes this efficient: the stripped context
|
|||
|
|
is processed once (no grad), and only the decision tokens contribute
|
|||
|
|
to the gradient.
|