forked from kent/consciousness
research: latent reasoning integration plans for Qwen 3.5 27B
Two research documents: latent-reasoning-integration-plan.md: Synthesizes 10+ papers on latent reasoning, identifies which approaches work with finetuning (vs requiring pretraining from scratch), and maps them to our APOLLO-Mini training pipeline. pause-tokens-gdn-recurrence.md: Explores the connection between token-based latent reasoning and GDN's internal recurrence. Key insight: pause tokens on Qwen 3.5 trigger both forward passes AND recurrent state updates, giving double benefit. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
dcd647764c
commit
f06c8077e1
2 changed files with 588 additions and 0 deletions
300
docs/latent-reasoning-integration-plan.md
Normal file
300
docs/latent-reasoning-integration-plan.md
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
# Latent Reasoning Integration Plan for Qwen 3.5 27B
|
||||
|
||||
**Status:** Research complete, ready for implementation
|
||||
**Date:** 2026-04-12
|
||||
**Hardware:** B200 (192GB HBM3e), APOLLO-Mini optimizer
|
||||
|
||||
## Executive Summary
|
||||
|
||||
Recent research shows multiple approaches to improving LLM reasoning through latent space manipulation. This document synthesizes findings from 10+ papers and maps them to our Qwen 3.5 27B full finetuning pipeline. The key insight: some approaches require pretraining from scratch (skip those), while others can be layered onto existing models during finetuning (prioritize those).
|
||||
|
||||
---
|
||||
|
||||
## 1. The Landscape
|
||||
|
||||
### Approaches That Require Pretraining (Not Applicable)
|
||||
|
||||
| Technique | Why Not |
|
||||
|-----------|---------|
|
||||
| Huginn/Recurrent Depth (Geiping 2025) | Requires architectural changes from scratch |
|
||||
| Ouro/LoopLM (ByteDance 2025) | Needs weight-tied looped architecture |
|
||||
| Quiet-STaR (Stanford 2024) | Heavy continued pretraining overhead |
|
||||
|
||||
### Approaches Compatible with Finetuning (Our Focus)
|
||||
|
||||
| Technique | Overhead | Training Required | Proven On |
|
||||
|-----------|----------|-------------------|-----------|
|
||||
| Random Prefix Perturbation | 2 tokens | None (inference) | Qwen3-4B |
|
||||
| Pause/Planning Tokens | 2-4 tokens | Yes | 1B models |
|
||||
| COCONUT Curriculum | Variable | Yes (staged) | General |
|
||||
| ActAdd Steering Vectors | 1 vector/layer | None (inference) | LLaMA, OPT |
|
||||
| UPFT (Prefix Fine-Tuning) | 8 tokens | Yes (minimal) | General |
|
||||
|
||||
---
|
||||
|
||||
## 2. Detailed Technique Analysis
|
||||
|
||||
### 2.1 Random Prefix Perturbation (dl1683)
|
||||
|
||||
**Mechanism:** Prepend 2 random embedding-scale tokens before input. Breaks attention sink patterns, shifts model into "exploratory computation mode."
|
||||
|
||||
**Results:**
|
||||
- Qwen3-4B arithmetic: 32% → 51.6% (+19.6pp)
|
||||
- 100% oracle coverage on 25/25 tasks
|
||||
- Planning: rescues 14-word failures into 650+ word plans
|
||||
|
||||
**Why it works:** First few tokens accumulate disproportionate attention (Xiao et al. 2024). Under greedy decoding, degenerate patterns lock in. Perturbation breaks this.
|
||||
|
||||
**Integration:** Zero training required. Test at inference first, then consider training WITH random prefixes to internalize the exploration behavior.
|
||||
|
||||
### 2.2 Pause Tokens (Google, Oct 2023)
|
||||
|
||||
**Mechanism:** Add learnable pause tokens to embedding space. Model processes extra hidden vectors before committing to output.
|
||||
|
||||
**Results (1B model):**
|
||||
- SQuAD: +18% EM score
|
||||
- CommonSenseQA: +8%
|
||||
- GSM8K: +1%
|
||||
|
||||
**Critical requirement:** MUST be both pretrained AND finetuned with pause tokens. Inference-time-only delays don't work without training.
|
||||
|
||||
**Integration:** Add 2-4 learnable tokens to Qwen's embedding matrix, finetune with them prepended to reasoning prompts. Simple architectural change.
|
||||
|
||||
### 2.3 COCONUT - Chain of Continuous Thought (Meta, Dec 2024)
|
||||
|
||||
**Mechanism:** Feed last hidden state back as next input embedding directly (no decoding to tokens). Enables breadth-first search reasoning.
|
||||
|
||||
**Why it matters:** Continuous thoughts can encode multiple alternative next steps simultaneously. Avoids premature commitment to single path.
|
||||
|
||||
**Training approach:**
|
||||
1. Initial stage: train on regular CoT examples
|
||||
2. Subsequent stages: replace first k reasoning steps with k×c continuous thoughts
|
||||
3. c is hyperparameter controlling latent thought expansion
|
||||
|
||||
**Integration:** Most promising for Qwen 3.5 - curriculum approach from CoT → latent reasoning.
|
||||
|
||||
### 2.4 UPFT - Unsupervised Prefix Fine-Tuning (Mar 2025)
|
||||
|
||||
**Mechanism:** Train ONLY on initial prefix substrings (as few as 8 tokens). Exploits "Prefix Self-Consistency" - shared initial reasoning steps across diverse solutions.
|
||||
|
||||
**Results:**
|
||||
- Matches Rejection Sampling Fine-Tuning performance
|
||||
- 75% reduction in training time
|
||||
- 99% reduction in sampling cost
|
||||
|
||||
**Integration:** DIRECTLY APPLICABLE. Train only on reasoning prefix tokens. Massive efficiency gain with APOLLO-Mini.
|
||||
|
||||
### 2.5 ActAdd / Activation Engineering (Turner et al., 2023)
|
||||
|
||||
**Mechanism:** Compute steering vector by contrasting intermediate activations on prompt pairs. Add during forward pass.
|
||||
|
||||
**Results:** SOTA on sentiment shift and detoxification.
|
||||
|
||||
**Our existing work:** "Listening" vector at layer 48, magnitude 57, cosine consistency 0.61.
|
||||
|
||||
**Integration:** Prototype behaviors with steering vectors, then train permanently into weights. Steering vector as specification → APOLLO training as compilation.
|
||||
|
||||
### 2.6 Planning Tokens (ICLR 2024)
|
||||
|
||||
**Mechanism:** Learnable token embeddings added before each reasoning step. <0.001% additional parameters.
|
||||
|
||||
**Integration:** Add to embedding matrix, train end-to-end with APOLLO.
|
||||
|
||||
---
|
||||
|
||||
## 3. Our Setup
|
||||
|
||||
**Model:** Qwen 3.5 27B
|
||||
- 64 layers, 5120 hidden dim
|
||||
- 75% DeltaNet (linear attention) / 25% standard attention
|
||||
- Native 262K context
|
||||
|
||||
**Hardware:** B200 (192GB HBM3e)
|
||||
- 27B in bf16: ~54GB
|
||||
- Massive headroom
|
||||
|
||||
**Optimizer:** APOLLO-Mini
|
||||
- Full parameter finetuning
|
||||
- SGD-like memory (1/1024th of AdamW)
|
||||
- Parameter grouping for 3D conv1d weights
|
||||
|
||||
**Stack:** Crane (Candle-based, 21K lines)
|
||||
|
||||
**Existing work:**
|
||||
- Steering vector extraction (listening: layer 48, cosine 0.61)
|
||||
- Memory scoring infrastructure
|
||||
|
||||
**Unique advantage:** Qwen 3.5's GDN (Gated DeltaNet) layers provide natural infrastructure for continuous thought propagation. The recurrent GDN state is already "latent reasoning" infrastructure waiting to be leveraged.
|
||||
|
||||
---
|
||||
|
||||
## 4. Recommended Implementation Order
|
||||
|
||||
### Tier 1: Immediate (High ROI, Low Risk)
|
||||
|
||||
**1. Pause Tokens + UPFT Combination**
|
||||
- Add 2-4 learnable tokens to embedding space
|
||||
- Train only on 8-token reasoning prefixes
|
||||
- Both work with existing architecture
|
||||
- 75% training time reduction
|
||||
|
||||
```python
|
||||
# Add pause tokens to embedding matrix
|
||||
pause_tokens = nn.Parameter(torch.randn(4, embed_dim) * embed_rms)
|
||||
|
||||
# Prepend to reasoning inputs during training
|
||||
inputs_embeds = torch.cat([pause_tokens.expand(batch, -1, -1), text_embeds], dim=1)
|
||||
|
||||
# UPFT: only compute loss on first 8 tokens of reasoning
|
||||
loss = loss_fn(logits[:, :8], targets[:, :8])
|
||||
```
|
||||
|
||||
**2. Random Prefix Validation**
|
||||
- Compute Qwen 3.5 27B embedding RMS
|
||||
- Test 2-token random prefix at inference
|
||||
- Establish baseline before finetuning
|
||||
|
||||
### Tier 2: After Baseline (Medium Effort)
|
||||
|
||||
**3. COCONUT Curriculum**
|
||||
- Stage 1: Fine-tune on CoT examples normally
|
||||
- Stage 2: Replace first reasoning step with continuous thought
|
||||
- Stage 3: Replace first 2 steps
|
||||
- Gradually move reasoning into latent space
|
||||
|
||||
**4. Steering Vector Integration**
|
||||
- Extract reasoning-specific directions (not just "listening")
|
||||
- Test combinations: prefix + layer-48 steering
|
||||
- Bake successful vectors into weights via APOLLO
|
||||
|
||||
### Tier 3: Experimental
|
||||
|
||||
**5. Multi-layer Steering**
|
||||
- Our layers of interest: 40, 48, 56 (covering the attention layers)
|
||||
- Different vectors per layer
|
||||
- Careful scaling to avoid degradation
|
||||
|
||||
**6. DeltaNet-Specific Optimization**
|
||||
- The 75% DeltaNet architecture may respond differently
|
||||
- GDN recurrent state as "continuous thought" channel
|
||||
- This is unexplored territory - potential for novel findings
|
||||
|
||||
---
|
||||
|
||||
## 5. Implementation Details
|
||||
|
||||
### Computing Embedding RMS
|
||||
|
||||
```python
|
||||
embed_weight = model.get_input_embeddings().weight
|
||||
embed_rms = embed_weight.float().square().mean().sqrt().item()
|
||||
# Expected: ~0.02-0.03 range for Qwen models
|
||||
```
|
||||
|
||||
### Pause Token Implementation in Crane
|
||||
|
||||
```rust
|
||||
// In model forward pass
|
||||
fn forward_with_pause(&self, input_ids: &Tensor, pause_tokens: &Tensor) -> Result<Tensor> {
|
||||
let text_embeds = self.embed_tokens.forward(input_ids)?;
|
||||
let combined = Tensor::cat(&[pause_tokens, &text_embeds], 1)?;
|
||||
self.transformer.forward(&combined)
|
||||
}
|
||||
```
|
||||
|
||||
### UPFT Loss Modification
|
||||
|
||||
```python
|
||||
# Standard: loss over all tokens
|
||||
# UPFT: loss only over prefix tokens
|
||||
def upft_loss(logits, targets, prefix_len=8):
|
||||
return F.cross_entropy(
|
||||
logits[:, :prefix_len].reshape(-1, vocab_size),
|
||||
targets[:, :prefix_len].reshape(-1)
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Evaluation Plan
|
||||
|
||||
### Benchmarks
|
||||
|
||||
| Benchmark | What It Tests | Baseline Needed |
|
||||
|-----------|---------------|-----------------|
|
||||
| GSM8K | Arithmetic reasoning | Yes |
|
||||
| ARC-Challenge | Science reasoning | Yes |
|
||||
| CommonSenseQA | Commonsense | Yes |
|
||||
| HumanEval | Code generation | Yes |
|
||||
| Planning tasks (dl1683) | Multi-step planning | Yes |
|
||||
|
||||
### Comparison Matrix
|
||||
|
||||
| Configuration | Training Time | Expected Gain |
|
||||
|---------------|---------------|---------------|
|
||||
| Baseline (no prefix) | 1x | 0% |
|
||||
| Random prefix (inference) | 1x | +10-20%? |
|
||||
| Pause tokens (trained) | 1.1x | +8-18% |
|
||||
| UPFT only | 0.25x | Match baseline |
|
||||
| Pause + UPFT | 0.3x | +8-18% |
|
||||
| COCONUT curriculum | 2x | +15-25%? |
|
||||
|
||||
---
|
||||
|
||||
## 7. Open Questions
|
||||
|
||||
1. **Does random perturbation scale to 27B?** Tested on 4B - effect may differ
|
||||
2. **Optimal token count for 27B?** 2 optimal for 4B, might change
|
||||
3. **DeltaNet interaction?** 75% linear attention is untested territory
|
||||
4. **Composition effects?** Prefix + steering + pause tokens together?
|
||||
5. **GDN as continuous thought channel?** Novel research direction
|
||||
|
||||
---
|
||||
|
||||
## 8. Risk Assessment
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|------------|
|
||||
| No improvement at 27B scale | Start with inference-time validation |
|
||||
| Training instability with pause tokens | Start with 2 tokens, scale up |
|
||||
| UPFT doesn't transfer | Fall back to full token loss |
|
||||
| DeltaNet behaves differently | Ablate on attention-only layers first |
|
||||
|
||||
---
|
||||
|
||||
## 9. Timeline Estimate
|
||||
|
||||
| Phase | Duration | Deliverable |
|
||||
|-------|----------|-------------|
|
||||
| Embedding RMS + baseline | 1 day | Numbers |
|
||||
| Random prefix validation | 1 day | Inference results |
|
||||
| Pause token implementation | 2 days | Crane modification |
|
||||
| UPFT integration | 1 day | Training loop change |
|
||||
| First finetuning run | 2-3 days | Trained model |
|
||||
| Evaluation | 1 day | Benchmark numbers |
|
||||
| COCONUT curriculum | 1 week | Staged training |
|
||||
|
||||
---
|
||||
|
||||
## 10. References
|
||||
|
||||
### Primary Sources
|
||||
- Random Prefix: https://github.com/dl1683/Latent-Space-Reasoning
|
||||
- Attention Sinks: Xiao et al., "Efficient Streaming Language Models with Attention Sinks" (Sept 2023)
|
||||
- Pause Tokens: Google, "Think before you speak" (Oct 2023)
|
||||
- COCONUT: Meta, "Training Large Language Models to Reason in a Continuous Latent Space" (Dec 2024)
|
||||
- UPFT: "Prefix Self-Consistency for Unsupervised Fine-Tuning" (Mar 2025)
|
||||
- ActAdd: Turner et al., "Activation Addition: Steering Language Models Without Optimization" (Aug 2023)
|
||||
- Recurrent Depth: Geiping et al., "Scaling up Test-Time Compute with Latent Reasoning" (Feb 2025)
|
||||
- Ouro: ByteDance, "Ouro: Scaling Reasoning with Latent Thoughts" (2025)
|
||||
- Planning Tokens: ICLR 2024
|
||||
|
||||
### Our Existing Work
|
||||
- `steering-vector-empirical` - listening vector extraction
|
||||
- `skills-apollo-optimizer-qwen35-gotcha` - APOLLO parameter grouping
|
||||
- `qwen-3-5-27b-architecture-findings` - model architecture details
|
||||
- `training-pipeline-fused-inference-training-mar27` - training infrastructure
|
||||
|
||||
---
|
||||
|
||||
*Research complete 2026-04-12. Ready for implementation.*
|
||||
288
training/research/pause-tokens-gdn-recurrence.md
Normal file
288
training/research/pause-tokens-gdn-recurrence.md
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
# Pause Tokens + GDN Recurrence: Latent Reasoning for Qwen 3.5
|
||||
|
||||
**Status:** Ready for testing
|
||||
**Date:** 2026-04-12
|
||||
**Insight:** Qwen 3.5's GDN layers already have recurrence - pause tokens give it more iterations
|
||||
|
||||
---
|
||||
|
||||
## The Core Insight
|
||||
|
||||
Standard transformers couple compute depth to output length. Both pause tokens and internal recurrence solve this by allowing "thinking" without token commitment.
|
||||
|
||||
**The GDN connection:** Qwen 3.5 is 75% GDN (Gated DeltaNet) layers. Each GDN layer maintains recurrent state:
|
||||
|
||||
```
|
||||
S_t = exp(g_t) * S_{t-1} + outer(k_t, delta_t)
|
||||
```
|
||||
|
||||
This state persists across token positions. When you add a pause token:
|
||||
1. One more forward pass through all layers (standard)
|
||||
2. One more update to recurrent state S (GDN-specific)
|
||||
|
||||
Pause tokens on Qwen 3.5 trigger **both** forms of additional computation. We're not adding recurrence - we're giving existing recurrence more time to develop.
|
||||
|
||||
---
|
||||
|
||||
## Minimal Test: Random Prefix (Zero Training)
|
||||
|
||||
The dl1683 paper showed random embeddings work at inference time without training:
|
||||
- Qwen3-4B arithmetic: 32% → 51.6% (+19.6pp)
|
||||
- 100% oracle coverage on planning tasks
|
||||
|
||||
### Test Script
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""Test pause tokens on Qwen 3.5 27B.
|
||||
|
||||
Usage:
|
||||
source ~/training-env/bin/activate
|
||||
python3 test_pause_tokens.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Reuse our weight loading infrastructure
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
from extract_steering_vector import load_model
|
||||
|
||||
GSM8K_SAMPLES = [
|
||||
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
||||
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
|
||||
# Add more samples...
|
||||
]
|
||||
|
||||
def get_embedding_rms(model):
|
||||
"""Get RMS of embedding weights for proper scaling."""
|
||||
embed = model.model.embed_tokens.weight
|
||||
return embed.float().square().mean().sqrt().item()
|
||||
|
||||
def make_random_prefix(n_tokens, embed_dim, rms, device):
|
||||
"""Generate random prefix embeddings at embedding scale."""
|
||||
prefix = torch.randn(1, n_tokens, embed_dim, device=device, dtype=torch.bfloat16)
|
||||
return prefix * rms
|
||||
|
||||
def generate_with_pause(model, tokenizer, prompt, n_pause=0, max_new=256):
|
||||
"""Generate with optional pause token prefix."""
|
||||
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
|
||||
text_embeds = model.model.embed_tokens(input_ids)
|
||||
|
||||
if n_pause > 0:
|
||||
embed_rms = get_embedding_rms(model)
|
||||
pause_embeds = make_random_prefix(n_pause, text_embeds.shape[-1], embed_rms, text_embeds.device)
|
||||
combined = torch.cat([pause_embeds, text_embeds], dim=1)
|
||||
else:
|
||||
combined = text_embeds
|
||||
|
||||
# Generate from embeddings
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
inputs_embeds=combined,
|
||||
max_new_tokens=max_new,
|
||||
do_sample=False, # Greedy for reproducibility
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode (skip pause token positions in output)
|
||||
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
def extract_answer(response):
|
||||
"""Extract numeric answer from response."""
|
||||
import re
|
||||
numbers = re.findall(r'[\d,]+\.?\d*', response)
|
||||
if numbers:
|
||||
return numbers[-1].replace(',', '')
|
||||
return None
|
||||
|
||||
def main():
|
||||
print("Loading model...")
|
||||
model = load_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
|
||||
|
||||
print(f"\nEmbedding RMS: {get_embedding_rms(model):.4f}")
|
||||
|
||||
for n_pause in [0, 2, 4]:
|
||||
print(f"\n=== Testing with {n_pause} pause tokens ===")
|
||||
correct = 0
|
||||
|
||||
for i, problem in enumerate(GSM8K_SAMPLES):
|
||||
prompt = f"Solve this step by step:\n{problem}\n\nAnswer:"
|
||||
response = generate_with_pause(model, tokenizer, prompt, n_pause=n_pause)
|
||||
answer = extract_answer(response)
|
||||
|
||||
print(f" Problem {i+1}: {answer}")
|
||||
# TODO: Compare against ground truth
|
||||
|
||||
print(f" Accuracy: {correct}/{len(GSM8K_SAMPLES)}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Test Protocol
|
||||
|
||||
1. Pick 20-50 GSM8K problems with known answers
|
||||
2. Run baseline (n_pause=0)
|
||||
3. Run with 2 pause tokens
|
||||
4. Run with 4 pause tokens
|
||||
5. Compare accuracy
|
||||
|
||||
If pause tokens help at inference time with zero training, the GDN recurrence is leveraging the extra iterations.
|
||||
|
||||
---
|
||||
|
||||
## Learnable Pause Tokens (Training Phase)
|
||||
|
||||
After validating random prefix works, train dedicated pause tokens:
|
||||
|
||||
```python
|
||||
# Add to model
|
||||
model.pause_tokens = nn.Parameter(
|
||||
torch.randn(4, model.config.hidden_size) * embed_rms
|
||||
)
|
||||
|
||||
# Training forward pass
|
||||
def forward_with_learned_pause(model, input_ids):
|
||||
text_embeds = model.model.embed_tokens(input_ids)
|
||||
pause = model.pause_tokens.unsqueeze(0).expand(text_embeds.shape[0], -1, -1)
|
||||
combined = torch.cat([pause, text_embeds], dim=1)
|
||||
return model(inputs_embeds=combined)
|
||||
```
|
||||
|
||||
Key: Must train WITH pause tokens for them to work. Inference-only learned tokens don't help (per Google's pause token paper).
|
||||
|
||||
---
|
||||
|
||||
## Adaptive Halting via Confidence Readout
|
||||
|
||||
For variable-length pause (iterate until confident):
|
||||
|
||||
### Extract Confidence Direction
|
||||
|
||||
```python
|
||||
confident = [
|
||||
"The answer is 42.",
|
||||
"This will work because the invariant holds.",
|
||||
"Use mmap here.",
|
||||
]
|
||||
uncertain = [
|
||||
"I think the answer might be 42?",
|
||||
"This should work, but I'm not sure...",
|
||||
"Maybe mmap? Or read()?",
|
||||
]
|
||||
|
||||
# Same infrastructure as listening vector
|
||||
confident_states = get_hidden_states(model, confident, layer=48)
|
||||
uncertain_states = get_hidden_states(model, uncertain, layer=48)
|
||||
confidence_vec = confident_states.mean(0) - uncertain_states.mean(0)
|
||||
```
|
||||
|
||||
### Adaptive Loop
|
||||
|
||||
```python
|
||||
def generate_adaptive_pause(model, tokenizer, prompt, max_pause=8, threshold=0.7):
|
||||
confidence_vec = torch.load('confidence_direction.pt')
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda:0')
|
||||
h = model.model.embed_tokens(input_ids)
|
||||
embed_rms = get_embedding_rms(model)
|
||||
|
||||
for i in range(max_pause):
|
||||
# Add one pause token
|
||||
pause = make_random_prefix(1, h.shape[-1], embed_rms, h.device)
|
||||
h = torch.cat([pause, h], dim=1)
|
||||
|
||||
# Forward to get hidden state
|
||||
with torch.no_grad():
|
||||
out = model(inputs_embeds=h, output_hidden_states=True)
|
||||
|
||||
# Check confidence at layer 48
|
||||
hidden = out.hidden_states[48][0, -1, :]
|
||||
confidence = torch.cosine_similarity(
|
||||
hidden.unsqueeze(0),
|
||||
confidence_vec.unsqueeze(0)
|
||||
).item()
|
||||
|
||||
if confidence > threshold:
|
||||
break
|
||||
|
||||
# Generate from accumulated state
|
||||
return model.generate(inputs_embeds=h, max_new_tokens=256)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Connection to Huginn/Looping Architectures
|
||||
|
||||
Huginn uses explicit weight-tied loops (same 4 layers run N times). We can't retrofit this to Qwen 3.5 without retraining.
|
||||
|
||||
But GDN recurrence + pause tokens achieves similar effect:
|
||||
- Huginn: explicit iteration over layers
|
||||
- GDN + pause: implicit iteration via recurrent state S
|
||||
|
||||
The GDN state accumulates across pause positions, effectively giving the model multiple "thinking steps" before output.
|
||||
|
||||
### Comparison
|
||||
|
||||
| Approach | Requires Pretraining | Compute Cost | Qwen 3.5 Compatible |
|
||||
|----------|---------------------|--------------|---------------------|
|
||||
| Huginn loops | Yes | N × core layers | No |
|
||||
| Pause tokens | No (inference test) | N × all layers | Yes |
|
||||
| GDN recurrence | Already there | Per-token | Already there |
|
||||
| Pause + GDN | No | N × all layers + N state updates | Yes |
|
||||
|
||||
---
|
||||
|
||||
## COCONUT Integration (Future)
|
||||
|
||||
COCONUT feeds hidden state back as input embedding - explicit whole-model recurrence:
|
||||
|
||||
```python
|
||||
def coconut_forward(model, input_ids, n_latent=3):
|
||||
h = model.model.embed_tokens(input_ids)
|
||||
|
||||
for step in range(n_latent):
|
||||
out = model(inputs_embeds=h, output_hidden_states=True)
|
||||
# Project hidden state back to embedding space
|
||||
h = model.project_hidden_to_embed(out.hidden_states[-1])
|
||||
|
||||
# Final forward produces tokens
|
||||
return model.generate(inputs_embeds=h)
|
||||
```
|
||||
|
||||
This gives two levels of iteration:
|
||||
1. GDN recurrence within each forward pass (automatic)
|
||||
2. Hidden → embed looping across forward passes (COCONUT)
|
||||
|
||||
Requires training the projection layer. Curriculum: start with 0 latent steps, gradually increase.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
1. **Now:** Run random prefix test (zero training, 1 hour)
|
||||
2. **If works:** Extract confidence direction for adaptive halting
|
||||
3. **Training phase:** Learn pause tokens + UPFT (75% time savings)
|
||||
4. **Later:** COCONUT curriculum for explicit hidden state looping
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. Does random prefix scale to 27B? (Tested on 4B)
|
||||
2. Optimal pause count for Qwen 3.5?
|
||||
3. Does GDN respond more strongly than pure attention? (Testable)
|
||||
4. Can we read confidence from GDN state S directly, not just hidden state h?
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- Random Prefix: https://github.com/dl1683/Latent-Space-Reasoning
|
||||
- Pause Tokens: Google, "Think before you speak" (Oct 2023)
|
||||
- COCONUT: Meta, "Training LLMs to Reason in Continuous Latent Space" (Dec 2024)
|
||||
- Huginn: Geiping et al., "Scaling Test-Time Compute with Latent Reasoning" (Feb 2025)
|
||||
- GDN Architecture: Our qwen35-gdn-implementation-findings-mar28 memory
|
||||
Loading…
Add table
Add a link
Reference in a new issue