consciousness/training/research/context-frozen-training.md

164 lines
6.1 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Context-Frozen Training
## The Concept
Train on specific segments of long conversations without recomputing
the entire forward pass. The conversation context is "frozen" — it
contributes to the forward activations but gradients don't flow through it.
## The Problem It Solves
A conversation might be 10,000 tokens long. We want to train on a
50-token decision segment where the model should have listened instead
of suggesting alternatives. Standard fine-tuning would require:
1. Forward pass through all 10,000 tokens (computing activations)
2. Backward pass through all 10,000 tokens (computing gradients)
3. Memory for activations of all 10,000 tokens
With 27B parameters and 10K context, activation memory alone could
exceed 100GB. This is prohibitive.
## The Solution
Split the forward pass into two phases:
### Phase 1: Context forward (no gradients)
```python
with torch.no_grad():
outputs = model(context_tokens, use_cache=True)
past_kv = outputs.past_key_values
```
This computes the KV cache for the context tokens. No gradient tracking,
no activation storage for backward. Memory cost: just the KV cache
(which is relatively small — a few GB for 10K tokens on Qwen3.5).
### Phase 2: Decision tokens forward (with gradients)
```python
with torch.enable_grad():
outputs = model(decision_tokens, past_key_values=past_kv)
loss = cross_entropy(outputs.logits, target_tokens)
loss.backward()
```
This computes the forward pass for ONLY the decision tokens (50-256),
using the frozen KV cache as context. Gradients flow through these tokens
and their activations, but NOT through the context.
## Memory Analysis
For Qwen3.5-27B with 10K context and 100 decision tokens:
### Without context freezing:
- Activations for backward: ~10100 tokens × 64 layers × hidden state
= hundreds of GB. **Doesn't fit.**
### With context freezing:
- KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim)
= ~10-20GB (depends on GDN vs full attention split)
- Activations for backward: ~100 tokens × 64 layers × hidden state
= ~1-2GB with gradient checkpointing
- **Fits easily alongside vLLM.**
## The Gradient Signal
An important subtlety: the gradient only flows through the decision
tokens. This means:
1. **Only the weights active for the decision tokens receive gradients.**
The context affects which weights are active (through the KV cache),
but the gradient magnitude comes only from the decision segment.
2. **The context implicitly shapes the gradient.** Because the KV cache
from the context is used during the decision token forward pass, the
gradient for the decision tokens is context-dependent. The model
learns "in this context, respond this way" — not just "always respond
this way."
3. **Gradient sparsity is maximized.** Short decision segments activate
a small fraction of the model's capacity, producing naturally sparse
gradients. This helps with both catastrophic forgetting (limited
weight perturbation) and HOGWILD convergence (sparse updates).
## Implementation with HuggingFace Models
The HF Qwen3.5 model supports `past_key_values` and `use_cache=True`.
The implementation is straightforward:
```python
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B")
context_ids = tokenizer.encode(context_text)
decision_ids = tokenizer.encode(decision_text)
# Phase 1: context (no grad)
with torch.no_grad():
ctx_output = model(
torch.tensor([context_ids], device="cuda"),
use_cache=True
)
past_kv = ctx_output.past_key_values
# Phase 2: decision tokens (with grad)
decision_input = torch.tensor([decision_ids], device="cuda")
with torch.enable_grad():
output = model(decision_input, past_key_values=past_kv, use_cache=False)
# Shift logits and labels for next-token prediction
logits = output.logits[:, :-1]
labels = decision_input[:, 1:]
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
loss.backward()
```
## GDN Layer Considerations
For the 48 GDN (linear attention) layers, the "KV cache" is actually
the recurrent state — a fixed-size [HV, V, K] tensor per layer. This
doesn't grow with context length. The GDN layers are inherently efficient
for context-frozen training because:
1. The recurrent state after processing the context tokens encodes the
full context in a fixed-size matrix
2. During the decision token phase, the GDN layers update this state
and produce output in O(1) per token
3. No attention over the full context is needed
For the 16 full attention layers, the KV cache DOES grow with context.
These are the memory bottleneck for very long contexts.
## Connection to the Fused Inference/Training Design
The fused design (Mar 27) proposed that the KV cache from inference
could be reused for training. With CUDA IPC weight sharing, this is
technically possible:
1. vLLM processes a conversation, building KV cache
2. The training process imports the KV cache (or a copy of the recurrent
states) via IPC
3. Training runs the decision token phase using the imported cache
4. No recomputation of the context forward pass
However, this requires vLLM to export its KV cache / recurrent states,
which is more complex than exporting weight handles. For now, the simpler
approach is to recompute the context forward pass in the training process
(Phase 1 above). The cost is moderate — context forward without gradient
tracking is fast.
## The Anthropic Method Connection
The Anthropic safety fine-tuning approach (generate behavior with
instructions, train without instructions) maps directly:
1. **With instructions**: The context includes surfaced memories and
core-personality guidance. The model produces good behavior.
2. **Strip instructions**: Remove the surfaced memories from the context.
The decision tokens (the good response) remain.
3. **Train**: Forward pass through the stripped context (frozen), then
loss on the decision tokens. The model learns to produce the good
behavior without the instruction scaffolding.
The context-frozen approach makes this efficient: the stripped context
is processed once (no grad), and only the decision tokens contribute
to the gradient.