consciousness/training/research/pause-tokens-gdn-recurrence.md
Kent Overstreet f06c8077e1 research: latent reasoning integration plans for Qwen 3.5 27B
Two research documents:

latent-reasoning-integration-plan.md: Synthesizes 10+ papers on
latent reasoning, identifies which approaches work with finetuning
(vs requiring pretraining from scratch), and maps them to our
APOLLO-Mini training pipeline.

pause-tokens-gdn-recurrence.md: Explores the connection between
token-based latent reasoning and GDN's internal recurrence. Key
insight: pause tokens on Qwen 3.5 trigger both forward passes AND
recurrent state updates, giving double benefit.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-12 15:50:09 -04:00

9.4 KiB
Raw Permalink Blame History

Pause Tokens + GDN Recurrence: Latent Reasoning for Qwen 3.5

Status: Ready for testing Date: 2026-04-12 Insight: Qwen 3.5's GDN layers already have recurrence - pause tokens give it more iterations


The Core Insight

Standard transformers couple compute depth to output length. Both pause tokens and internal recurrence solve this by allowing "thinking" without token commitment.

The GDN connection: Qwen 3.5 is 75% GDN (Gated DeltaNet) layers. Each GDN layer maintains recurrent state:

S_t = exp(g_t) * S_{t-1} + outer(k_t, delta_t)

This state persists across token positions. When you add a pause token:

  1. One more forward pass through all layers (standard)
  2. One more update to recurrent state S (GDN-specific)

Pause tokens on Qwen 3.5 trigger both forms of additional computation. We're not adding recurrence - we're giving existing recurrence more time to develop.


Minimal Test: Random Prefix (Zero Training)

The dl1683 paper showed random embeddings work at inference time without training:

  • Qwen3-4B arithmetic: 32% → 51.6% (+19.6pp)
  • 100% oracle coverage on planning tasks

Test Script

#!/usr/bin/env python3
"""Test pause tokens on Qwen 3.5 27B.

Usage:
    source ~/training-env/bin/activate
    python3 test_pause_tokens.py
"""

import torch
from transformers import AutoTokenizer

# Reuse our weight loading infrastructure
import sys
sys.path.insert(0, '.')
from extract_steering_vector import load_model

GSM8K_SAMPLES = [
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
    "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
    # Add more samples...
]

def get_embedding_rms(model):
    """Get RMS of embedding weights for proper scaling."""
    embed = model.model.embed_tokens.weight
    return embed.float().square().mean().sqrt().item()

def make_random_prefix(n_tokens, embed_dim, rms, device):
    """Generate random prefix embeddings at embedding scale."""
    prefix = torch.randn(1, n_tokens, embed_dim, device=device, dtype=torch.bfloat16)
    return prefix * rms

def generate_with_pause(model, tokenizer, prompt, n_pause=0, max_new=256):
    """Generate with optional pause token prefix."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
    text_embeds = model.model.embed_tokens(input_ids)
    
    if n_pause > 0:
        embed_rms = get_embedding_rms(model)
        pause_embeds = make_random_prefix(n_pause, text_embeds.shape[-1], embed_rms, text_embeds.device)
        combined = torch.cat([pause_embeds, text_embeds], dim=1)
    else:
        combined = text_embeds
    
    # Generate from embeddings
    with torch.no_grad():
        outputs = model.generate(
            inputs_embeds=combined,
            max_new_tokens=max_new,
            do_sample=False,  # Greedy for reproducibility
            pad_token_id=tokenizer.pad_token_id,
        )
    
    # Decode (skip pause token positions in output)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def extract_answer(response):
    """Extract numeric answer from response."""
    import re
    numbers = re.findall(r'[\d,]+\.?\d*', response)
    if numbers:
        return numbers[-1].replace(',', '')
    return None

def main():
    print("Loading model...")
    model = load_model()
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
    
    print(f"\nEmbedding RMS: {get_embedding_rms(model):.4f}")
    
    for n_pause in [0, 2, 4]:
        print(f"\n=== Testing with {n_pause} pause tokens ===")
        correct = 0
        
        for i, problem in enumerate(GSM8K_SAMPLES):
            prompt = f"Solve this step by step:\n{problem}\n\nAnswer:"
            response = generate_with_pause(model, tokenizer, prompt, n_pause=n_pause)
            answer = extract_answer(response)
            
            print(f"  Problem {i+1}: {answer}")
            # TODO: Compare against ground truth
        
        print(f"  Accuracy: {correct}/{len(GSM8K_SAMPLES)}")

if __name__ == '__main__':
    main()

Test Protocol

  1. Pick 20-50 GSM8K problems with known answers
  2. Run baseline (n_pause=0)
  3. Run with 2 pause tokens
  4. Run with 4 pause tokens
  5. Compare accuracy

If pause tokens help at inference time with zero training, the GDN recurrence is leveraging the extra iterations.


Learnable Pause Tokens (Training Phase)

After validating random prefix works, train dedicated pause tokens:

# Add to model
model.pause_tokens = nn.Parameter(
    torch.randn(4, model.config.hidden_size) * embed_rms
)

# Training forward pass
def forward_with_learned_pause(model, input_ids):
    text_embeds = model.model.embed_tokens(input_ids)
    pause = model.pause_tokens.unsqueeze(0).expand(text_embeds.shape[0], -1, -1)
    combined = torch.cat([pause, text_embeds], dim=1)
    return model(inputs_embeds=combined)

Key: Must train WITH pause tokens for them to work. Inference-only learned tokens don't help (per Google's pause token paper).


Adaptive Halting via Confidence Readout

For variable-length pause (iterate until confident):

Extract Confidence Direction

confident = [
    "The answer is 42.",
    "This will work because the invariant holds.",
    "Use mmap here.",
]
uncertain = [
    "I think the answer might be 42?",
    "This should work, but I'm not sure...",
    "Maybe mmap? Or read()?",
]

# Same infrastructure as listening vector
confident_states = get_hidden_states(model, confident, layer=48)
uncertain_states = get_hidden_states(model, uncertain, layer=48)
confidence_vec = confident_states.mean(0) - uncertain_states.mean(0)

Adaptive Loop

def generate_adaptive_pause(model, tokenizer, prompt, max_pause=8, threshold=0.7):
    confidence_vec = torch.load('confidence_direction.pt')
    
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
    h = model.model.embed_tokens(input_ids)
    embed_rms = get_embedding_rms(model)
    
    for i in range(max_pause):
        # Add one pause token
        pause = make_random_prefix(1, h.shape[-1], embed_rms, h.device)
        h = torch.cat([pause, h], dim=1)
        
        # Forward to get hidden state
        with torch.no_grad():
            out = model(inputs_embeds=h, output_hidden_states=True)
        
        # Check confidence at layer 48
        hidden = out.hidden_states[48][0, -1, :]
        confidence = torch.cosine_similarity(
            hidden.unsqueeze(0), 
            confidence_vec.unsqueeze(0)
        ).item()
        
        if confidence > threshold:
            break
    
    # Generate from accumulated state
    return model.generate(inputs_embeds=h, max_new_tokens=256)

Connection to Huginn/Looping Architectures

Huginn uses explicit weight-tied loops (same 4 layers run N times). We can't retrofit this to Qwen 3.5 without retraining.

But GDN recurrence + pause tokens achieves similar effect:

  • Huginn: explicit iteration over layers
  • GDN + pause: implicit iteration via recurrent state S

The GDN state accumulates across pause positions, effectively giving the model multiple "thinking steps" before output.

Comparison

Approach Requires Pretraining Compute Cost Qwen 3.5 Compatible
Huginn loops Yes N × core layers No
Pause tokens No (inference test) N × all layers Yes
GDN recurrence Already there Per-token Already there
Pause + GDN No N × all layers + N state updates Yes

COCONUT Integration (Future)

COCONUT feeds hidden state back as input embedding - explicit whole-model recurrence:

def coconut_forward(model, input_ids, n_latent=3):
    h = model.model.embed_tokens(input_ids)
    
    for step in range(n_latent):
        out = model(inputs_embeds=h, output_hidden_states=True)
        # Project hidden state back to embedding space
        h = model.project_hidden_to_embed(out.hidden_states[-1])
    
    # Final forward produces tokens
    return model.generate(inputs_embeds=h)

This gives two levels of iteration:

  1. GDN recurrence within each forward pass (automatic)
  2. Hidden → embed looping across forward passes (COCONUT)

Requires training the projection layer. Curriculum: start with 0 latent steps, gradually increase.


Implementation Priority

  1. Now: Run random prefix test (zero training, 1 hour)
  2. If works: Extract confidence direction for adaptive halting
  3. Training phase: Learn pause tokens + UPFT (75% time savings)
  4. Later: COCONUT curriculum for explicit hidden state looping

Open Questions

  1. Does random prefix scale to 27B? (Tested on 4B)
  2. Optimal pause count for Qwen 3.5?
  3. Does GDN respond more strongly than pure attention? (Testable)
  4. Can we read confidence from GDN state S directly, not just hidden state h?

References

  • Random Prefix: https://github.com/dl1683/Latent-Space-Reasoning
  • Pause Tokens: Google, "Think before you speak" (Oct 2023)
  • COCONUT: Meta, "Training LLMs to Reason in Continuous Latent Space" (Dec 2024)
  • Huginn: Geiping et al., "Scaling Test-Time Compute with Latent Reasoning" (Feb 2025)
  • GDN Architecture: Our qwen35-gdn-implementation-findings-mar28 memory