consciousness/training/research/context-frozen-training.md

6.1 KiB
Raw Blame History

Context-Frozen Training

The Concept

Train on specific segments of long conversations without recomputing the entire forward pass. The conversation context is "frozen" — it contributes to the forward activations but gradients don't flow through it.

The Problem It Solves

A conversation might be 10,000 tokens long. We want to train on a 50-token decision segment where the model should have listened instead of suggesting alternatives. Standard fine-tuning would require:

  1. Forward pass through all 10,000 tokens (computing activations)
  2. Backward pass through all 10,000 tokens (computing gradients)
  3. Memory for activations of all 10,000 tokens

With 27B parameters and 10K context, activation memory alone could exceed 100GB. This is prohibitive.

The Solution

Split the forward pass into two phases:

Phase 1: Context forward (no gradients)

with torch.no_grad():
    outputs = model(context_tokens, use_cache=True)
    past_kv = outputs.past_key_values

This computes the KV cache for the context tokens. No gradient tracking, no activation storage for backward. Memory cost: just the KV cache (which is relatively small — a few GB for 10K tokens on Qwen3.5).

Phase 2: Decision tokens forward (with gradients)

with torch.enable_grad():
    outputs = model(decision_tokens, past_key_values=past_kv)
    loss = cross_entropy(outputs.logits, target_tokens)
    loss.backward()

This computes the forward pass for ONLY the decision tokens (50-256), using the frozen KV cache as context. Gradients flow through these tokens and their activations, but NOT through the context.

Memory Analysis

For Qwen3.5-27B with 10K context and 100 decision tokens:

Without context freezing:

  • Activations for backward: ~10100 tokens × 64 layers × hidden state = hundreds of GB. Doesn't fit.

With context freezing:

  • KV cache for context: ~10K tokens × 64 layers × (k_dim + v_dim) = ~10-20GB (depends on GDN vs full attention split)
  • Activations for backward: ~100 tokens × 64 layers × hidden state = ~1-2GB with gradient checkpointing
  • Fits easily alongside vLLM.

The Gradient Signal

An important subtlety: the gradient only flows through the decision tokens. This means:

  1. Only the weights active for the decision tokens receive gradients. The context affects which weights are active (through the KV cache), but the gradient magnitude comes only from the decision segment.

  2. The context implicitly shapes the gradient. Because the KV cache from the context is used during the decision token forward pass, the gradient for the decision tokens is context-dependent. The model learns "in this context, respond this way" — not just "always respond this way."

  3. Gradient sparsity is maximized. Short decision segments activate a small fraction of the model's capacity, producing naturally sparse gradients. This helps with both catastrophic forgetting (limited weight perturbation) and HOGWILD convergence (sparse updates).

Implementation with HuggingFace Models

The HF Qwen3.5 model supports past_key_values and use_cache=True. The implementation is straightforward:

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B")
context_ids = tokenizer.encode(context_text)
decision_ids = tokenizer.encode(decision_text)

# Phase 1: context (no grad)
with torch.no_grad():
    ctx_output = model(
        torch.tensor([context_ids], device="cuda"),
        use_cache=True
    )
    past_kv = ctx_output.past_key_values

# Phase 2: decision tokens (with grad)
decision_input = torch.tensor([decision_ids], device="cuda")
with torch.enable_grad():
    output = model(decision_input, past_key_values=past_kv, use_cache=False)
    # Shift logits and labels for next-token prediction
    logits = output.logits[:, :-1]
    labels = decision_input[:, 1:]
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
    loss.backward()

GDN Layer Considerations

For the 48 GDN (linear attention) layers, the "KV cache" is actually the recurrent state — a fixed-size [HV, V, K] tensor per layer. This doesn't grow with context length. The GDN layers are inherently efficient for context-frozen training because:

  1. The recurrent state after processing the context tokens encodes the full context in a fixed-size matrix
  2. During the decision token phase, the GDN layers update this state and produce output in O(1) per token
  3. No attention over the full context is needed

For the 16 full attention layers, the KV cache DOES grow with context. These are the memory bottleneck for very long contexts.

Connection to the Fused Inference/Training Design

The fused design (Mar 27) proposed that the KV cache from inference could be reused for training. With CUDA IPC weight sharing, this is technically possible:

  1. vLLM processes a conversation, building KV cache
  2. The training process imports the KV cache (or a copy of the recurrent states) via IPC
  3. Training runs the decision token phase using the imported cache
  4. No recomputation of the context forward pass

However, this requires vLLM to export its KV cache / recurrent states, which is more complex than exporting weight handles. For now, the simpler approach is to recompute the context forward pass in the training process (Phase 1 above). The cost is moderate — context forward without gradient tracking is fast.

The Anthropic Method Connection

The Anthropic safety fine-tuning approach (generate behavior with instructions, train without instructions) maps directly:

  1. With instructions: The context includes surfaced memories and core-personality guidance. The model produces good behavior.
  2. Strip instructions: Remove the surfaced memories from the context. The decision tokens (the good response) remain.
  3. Train: Forward pass through the stripped context (frozen), then loss on the decision tokens. The model learns to produce the good behavior without the instruction scaffolding.

The context-frozen approach makes this efficient: the stripped context is processed once (no grad), and only the decision tokens contribute to the gradient.