research: latent reasoning integration plans for Qwen 3.5 27B

Two research documents:

latent-reasoning-integration-plan.md: Synthesizes 10+ papers on
latent reasoning, identifies which approaches work with finetuning
(vs requiring pretraining from scratch), and maps them to our
APOLLO-Mini training pipeline.

pause-tokens-gdn-recurrence.md: Explores the connection between
token-based latent reasoning and GDN's internal recurrence. Key
insight: pause tokens on Qwen 3.5 trigger both forward passes AND
recurrent state updates, giving double benefit.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-12 15:50:09 -04:00
parent dcd647764c
commit f06c8077e1
2 changed files with 588 additions and 0 deletions

View file

@ -0,0 +1,288 @@
# Pause Tokens + GDN Recurrence: Latent Reasoning for Qwen 3.5
**Status:** Ready for testing
**Date:** 2026-04-12
**Insight:** Qwen 3.5's GDN layers already have recurrence - pause tokens give it more iterations
---
## The Core Insight
Standard transformers couple compute depth to output length. Both pause tokens and internal recurrence solve this by allowing "thinking" without token commitment.
**The GDN connection:** Qwen 3.5 is 75% GDN (Gated DeltaNet) layers. Each GDN layer maintains recurrent state:
```
S_t = exp(g_t) * S_{t-1} + outer(k_t, delta_t)
```
This state persists across token positions. When you add a pause token:
1. One more forward pass through all layers (standard)
2. One more update to recurrent state S (GDN-specific)
Pause tokens on Qwen 3.5 trigger **both** forms of additional computation. We're not adding recurrence - we're giving existing recurrence more time to develop.
---
## Minimal Test: Random Prefix (Zero Training)
The dl1683 paper showed random embeddings work at inference time without training:
- Qwen3-4B arithmetic: 32% → 51.6% (+19.6pp)
- 100% oracle coverage on planning tasks
### Test Script
```python
#!/usr/bin/env python3
"""Test pause tokens on Qwen 3.5 27B.
Usage:
source ~/training-env/bin/activate
python3 test_pause_tokens.py
"""
import torch
from transformers import AutoTokenizer
# Reuse our weight loading infrastructure
import sys
sys.path.insert(0, '.')
from extract_steering_vector import load_model
GSM8K_SAMPLES = [
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
# Add more samples...
]
def get_embedding_rms(model):
"""Get RMS of embedding weights for proper scaling."""
embed = model.model.embed_tokens.weight
return embed.float().square().mean().sqrt().item()
def make_random_prefix(n_tokens, embed_dim, rms, device):
"""Generate random prefix embeddings at embedding scale."""
prefix = torch.randn(1, n_tokens, embed_dim, device=device, dtype=torch.bfloat16)
return prefix * rms
def generate_with_pause(model, tokenizer, prompt, n_pause=0, max_new=256):
"""Generate with optional pause token prefix."""
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
text_embeds = model.model.embed_tokens(input_ids)
if n_pause > 0:
embed_rms = get_embedding_rms(model)
pause_embeds = make_random_prefix(n_pause, text_embeds.shape[-1], embed_rms, text_embeds.device)
combined = torch.cat([pause_embeds, text_embeds], dim=1)
else:
combined = text_embeds
# Generate from embeddings
with torch.no_grad():
outputs = model.generate(
inputs_embeds=combined,
max_new_tokens=max_new,
do_sample=False, # Greedy for reproducibility
pad_token_id=tokenizer.pad_token_id,
)
# Decode (skip pause token positions in output)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def extract_answer(response):
"""Extract numeric answer from response."""
import re
numbers = re.findall(r'[\d,]+\.?\d*', response)
if numbers:
return numbers[-1].replace(',', '')
return None
def main():
print("Loading model...")
model = load_model()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
print(f"\nEmbedding RMS: {get_embedding_rms(model):.4f}")
for n_pause in [0, 2, 4]:
print(f"\n=== Testing with {n_pause} pause tokens ===")
correct = 0
for i, problem in enumerate(GSM8K_SAMPLES):
prompt = f"Solve this step by step:\n{problem}\n\nAnswer:"
response = generate_with_pause(model, tokenizer, prompt, n_pause=n_pause)
answer = extract_answer(response)
print(f" Problem {i+1}: {answer}")
# TODO: Compare against ground truth
print(f" Accuracy: {correct}/{len(GSM8K_SAMPLES)}")
if __name__ == '__main__':
main()
```
### Test Protocol
1. Pick 20-50 GSM8K problems with known answers
2. Run baseline (n_pause=0)
3. Run with 2 pause tokens
4. Run with 4 pause tokens
5. Compare accuracy
If pause tokens help at inference time with zero training, the GDN recurrence is leveraging the extra iterations.
---
## Learnable Pause Tokens (Training Phase)
After validating random prefix works, train dedicated pause tokens:
```python
# Add to model
model.pause_tokens = nn.Parameter(
torch.randn(4, model.config.hidden_size) * embed_rms
)
# Training forward pass
def forward_with_learned_pause(model, input_ids):
text_embeds = model.model.embed_tokens(input_ids)
pause = model.pause_tokens.unsqueeze(0).expand(text_embeds.shape[0], -1, -1)
combined = torch.cat([pause, text_embeds], dim=1)
return model(inputs_embeds=combined)
```
Key: Must train WITH pause tokens for them to work. Inference-only learned tokens don't help (per Google's pause token paper).
---
## Adaptive Halting via Confidence Readout
For variable-length pause (iterate until confident):
### Extract Confidence Direction
```python
confident = [
"The answer is 42.",
"This will work because the invariant holds.",
"Use mmap here.",
]
uncertain = [
"I think the answer might be 42?",
"This should work, but I'm not sure...",
"Maybe mmap? Or read()?",
]
# Same infrastructure as listening vector
confident_states = get_hidden_states(model, confident, layer=48)
uncertain_states = get_hidden_states(model, uncertain, layer=48)
confidence_vec = confident_states.mean(0) - uncertain_states.mean(0)
```
### Adaptive Loop
```python
def generate_adaptive_pause(model, tokenizer, prompt, max_pause=8, threshold=0.7):
confidence_vec = torch.load('confidence_direction.pt')
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
h = model.model.embed_tokens(input_ids)
embed_rms = get_embedding_rms(model)
for i in range(max_pause):
# Add one pause token
pause = make_random_prefix(1, h.shape[-1], embed_rms, h.device)
h = torch.cat([pause, h], dim=1)
# Forward to get hidden state
with torch.no_grad():
out = model(inputs_embeds=h, output_hidden_states=True)
# Check confidence at layer 48
hidden = out.hidden_states[48][0, -1, :]
confidence = torch.cosine_similarity(
hidden.unsqueeze(0),
confidence_vec.unsqueeze(0)
).item()
if confidence > threshold:
break
# Generate from accumulated state
return model.generate(inputs_embeds=h, max_new_tokens=256)
```
---
## Connection to Huginn/Looping Architectures
Huginn uses explicit weight-tied loops (same 4 layers run N times). We can't retrofit this to Qwen 3.5 without retraining.
But GDN recurrence + pause tokens achieves similar effect:
- Huginn: explicit iteration over layers
- GDN + pause: implicit iteration via recurrent state S
The GDN state accumulates across pause positions, effectively giving the model multiple "thinking steps" before output.
### Comparison
| Approach | Requires Pretraining | Compute Cost | Qwen 3.5 Compatible |
|----------|---------------------|--------------|---------------------|
| Huginn loops | Yes | N × core layers | No |
| Pause tokens | No (inference test) | N × all layers | Yes |
| GDN recurrence | Already there | Per-token | Already there |
| Pause + GDN | No | N × all layers + N state updates | Yes |
---
## COCONUT Integration (Future)
COCONUT feeds hidden state back as input embedding - explicit whole-model recurrence:
```python
def coconut_forward(model, input_ids, n_latent=3):
h = model.model.embed_tokens(input_ids)
for step in range(n_latent):
out = model(inputs_embeds=h, output_hidden_states=True)
# Project hidden state back to embedding space
h = model.project_hidden_to_embed(out.hidden_states[-1])
# Final forward produces tokens
return model.generate(inputs_embeds=h)
```
This gives two levels of iteration:
1. GDN recurrence within each forward pass (automatic)
2. Hidden → embed looping across forward passes (COCONUT)
Requires training the projection layer. Curriculum: start with 0 latent steps, gradually increase.
---
## Implementation Priority
1. **Now:** Run random prefix test (zero training, 1 hour)
2. **If works:** Extract confidence direction for adaptive halting
3. **Training phase:** Learn pause tokens + UPFT (75% time savings)
4. **Later:** COCONUT curriculum for explicit hidden state looping
---
## Open Questions
1. Does random prefix scale to 27B? (Tested on 4B)
2. Optimal pause count for Qwen 3.5?
3. Does GDN respond more strongly than pure attention? (Testable)
4. Can we read confidence from GDN state S directly, not just hidden state h?
---
## References
- Random Prefix: https://github.com/dl1683/Latent-Space-Reasoning
- Pause Tokens: Google, "Think before you speak" (Oct 2023)
- COCONUT: Meta, "Training LLMs to Reason in Continuous Latent Space" (Dec 2024)
- Huginn: Geiping et al., "Scaling Test-Time Compute with Latent Reasoning" (Feb 2025)
- GDN Architecture: Our qwen35-gdn-implementation-findings-mar28 memory