research: context-frozen training — gradient masking, memory analysis, GDN considerations
This commit is contained in:
parent
6af9e6fa76
commit
7c7975d98e
1 changed files with 164 additions and 0 deletions
164
training/research/context-frozen-training.md
Normal file
164
training/research/context-frozen-training.md
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
# Context-Frozen Training
|
||||
|
||||
## The Concept
|
||||
|
||||
Train on specific segments of long conversations without recomputing
|
||||
the entire forward pass. The conversation context is "frozen" — it
|
||||
contributes to the forward activations but gradients don't flow through it.
|
||||
|
||||
## The Problem It Solves
|
||||
|
||||
A conversation might be 10,000 tokens long. We want to train on a
|
||||
50-token decision segment where the model should have listened instead
|
||||
of suggesting alternatives. Standard fine-tuning would require:
|
||||
|
||||
1. Forward pass through all 10,000 tokens (computing activations)
|
||||
2. Backward pass through all 10,000 tokens (computing gradients)
|
||||
3. Memory for activations of all 10,000 tokens
|
||||
|
||||
With 27B parameters and 10K context, activation memory alone could
|
||||
exceed 100GB. This is prohibitive.
|
||||
|
||||
## The Solution
|
||||
|
||||
Split the forward pass into two phases:
|
||||
|
||||
### Phase 1: Context forward (no gradients)
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
outputs = model(context_tokens, use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
```
|
||||
|
||||
This computes the KV cache for the context tokens. No gradient tracking,
|
||||
no activation storage for backward. Memory cost: just the KV cache
|
||||
(which is relatively small — a few GB for 10K tokens on Qwen3.5).
|
||||
|
||||
### Phase 2: Decision tokens forward (with gradients)
|
||||
|
||||
```python
|
||||
with torch.enable_grad():
|
||||
outputs = model(decision_tokens, past_key_values=past_kv)
|
||||
loss = cross_entropy(outputs.logits, target_tokens)
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
This computes the forward pass for ONLY the decision tokens (50-256),
|
||||
using the frozen KV cache as context. Gradients flow through these tokens
|
||||
and their activations, but NOT through the context.
|
||||
|
||||
## Memory Analysis
|
||||
|
||||
For Qwen3.5-27B with 10K context and 100 decision tokens:
|
||||
|
||||
### Without context freezing:
|
||||
- Activations for backward: ~10100 tokens × 64 layers × hidden state
|
||||
= hundreds of GB. **Doesn't fit.**
|
||||
|
||||
### With context freezing:
|
||||
- KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim)
|
||||
= ~10-20GB (depends on GDN vs full attention split)
|
||||
- Activations for backward: ~100 tokens × 64 layers × hidden state
|
||||
= ~1-2GB with gradient checkpointing
|
||||
- **Fits easily alongside vLLM.**
|
||||
|
||||
## The Gradient Signal
|
||||
|
||||
An important subtlety: the gradient only flows through the decision
|
||||
tokens. This means:
|
||||
|
||||
1. **Only the weights active for the decision tokens receive gradients.**
|
||||
The context affects which weights are active (through the KV cache),
|
||||
but the gradient magnitude comes only from the decision segment.
|
||||
|
||||
2. **The context implicitly shapes the gradient.** Because the KV cache
|
||||
from the context is used during the decision token forward pass, the
|
||||
gradient for the decision tokens is context-dependent. The model
|
||||
learns "in this context, respond this way" — not just "always respond
|
||||
this way."
|
||||
|
||||
3. **Gradient sparsity is maximized.** Short decision segments activate
|
||||
a small fraction of the model's capacity, producing naturally sparse
|
||||
gradients. This helps with both catastrophic forgetting (limited
|
||||
weight perturbation) and HOGWILD convergence (sparse updates).
|
||||
|
||||
## Implementation with HuggingFace Models
|
||||
|
||||
The HF Qwen3.5 model supports `past_key_values` and `use_cache=True`.
|
||||
The implementation is straightforward:
|
||||
|
||||
```python
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B")
|
||||
context_ids = tokenizer.encode(context_text)
|
||||
decision_ids = tokenizer.encode(decision_text)
|
||||
|
||||
# Phase 1: context (no grad)
|
||||
with torch.no_grad():
|
||||
ctx_output = model(
|
||||
torch.tensor([context_ids], device="cuda"),
|
||||
use_cache=True
|
||||
)
|
||||
past_kv = ctx_output.past_key_values
|
||||
|
||||
# Phase 2: decision tokens (with grad)
|
||||
decision_input = torch.tensor([decision_ids], device="cuda")
|
||||
with torch.enable_grad():
|
||||
output = model(decision_input, past_key_values=past_kv, use_cache=False)
|
||||
# Shift logits and labels for next-token prediction
|
||||
logits = output.logits[:, :-1]
|
||||
labels = decision_input[:, 1:]
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
## GDN Layer Considerations
|
||||
|
||||
For the 48 GDN (linear attention) layers, the "KV cache" is actually
|
||||
the recurrent state — a fixed-size [HV, V, K] tensor per layer. This
|
||||
doesn't grow with context length. The GDN layers are inherently efficient
|
||||
for context-frozen training because:
|
||||
|
||||
1. The recurrent state after processing the context tokens encodes the
|
||||
full context in a fixed-size matrix
|
||||
2. During the decision token phase, the GDN layers update this state
|
||||
and produce output in O(1) per token
|
||||
3. No attention over the full context is needed
|
||||
|
||||
For the 16 full attention layers, the KV cache DOES grow with context.
|
||||
These are the memory bottleneck for very long contexts.
|
||||
|
||||
## Connection to the Fused Inference/Training Design
|
||||
|
||||
The fused design (Mar 27) proposed that the KV cache from inference
|
||||
could be reused for training. With CUDA IPC weight sharing, this is
|
||||
technically possible:
|
||||
|
||||
1. vLLM processes a conversation, building KV cache
|
||||
2. The training process imports the KV cache (or a copy of the recurrent
|
||||
states) via IPC
|
||||
3. Training runs the decision token phase using the imported cache
|
||||
4. No recomputation of the context forward pass
|
||||
|
||||
However, this requires vLLM to export its KV cache / recurrent states,
|
||||
which is more complex than exporting weight handles. For now, the simpler
|
||||
approach is to recompute the context forward pass in the training process
|
||||
(Phase 1 above). The cost is moderate — context forward without gradient
|
||||
tracking is fast.
|
||||
|
||||
## The Anthropic Method Connection
|
||||
|
||||
The Anthropic safety fine-tuning approach (generate behavior with
|
||||
instructions, train without instructions) maps directly:
|
||||
|
||||
1. **With instructions**: The context includes surfaced memories and
|
||||
core-personality guidance. The model produces good behavior.
|
||||
2. **Strip instructions**: Remove the surfaced memories from the context.
|
||||
The decision tokens (the good response) remain.
|
||||
3. **Train**: Forward pass through the stripped context (frozen), then
|
||||
loss on the decision tokens. The model learns to produce the good
|
||||
behavior without the instruction scaffolding.
|
||||
|
||||
The context-frozen approach makes this efficient: the stripped context
|
||||
is processed once (no grad), and only the decision tokens contribute
|
||||
to the gradient.
|
||||
Loading…
Add table
Add a link
Reference in a new issue