consciousness/training/DESIGN.md
ProofOfConcept 2c6a5c0f4a training: move to dedicated subprocess with ZMQ communication
- Add training_worker.py: long-lived subprocess that handles GPU training
  work, owns HF model wrapper (views into vLLM GPU memory), Apollo
  optimizer, and checkpoint sync

- train_router.py: now forwards /train requests via async ZMQ instead of
  running training in-process. Adds /checkpoint and /train/status endpoints

- export_hook.py: store model_path in __metadata__ so training worker can
  find it without cross-process communication

- This fixes two bugs:
  1. Process boundary issue - model_path was set in worker process but
     needed in API server process
  2. Blocking event loop - training blocked vLLM's async event loop

Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-16 02:04:26 -04:00

14 KiB
Raw Blame History

Apollo Training System

Overview

Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference. Full-weight updates (not LoRA) using Apollo optimizer with rank-64 gradient projection. No pause required — HOGWILD concurrent training. Weights shared via CUDA IPC between vLLM and the training process.

The training signal comes from two sources:

  1. Direct examples — agent logs, conversation transcripts, flagged behavioral moments
  2. Dream-generated scenarios — the dream loop generates situations from recent experience; the model responds; good responses become training data with instructions stripped

Architecture

┌─────────────────────────────────────────────────────┐
│                    GPU VRAM (192GB)                  │
│                                                     │
│  ┌──────────────────────────────────────────────┐   │
│  │        Model Weights (54GB, bf16)            │   │
│  │        Shared: vLLM inference + HF training  │   │
│  └──────────────┬──────────────┬────────────────┘   │
│                 │              │                     │
│  ┌──────────────▼──┐  ┌───────▼────────────────┐   │
│  │ vLLM (inference)│  │ Training subprocess     │   │
│  │ KV cache ~60GB  │  │ HF model wrapper        │   │
│  │ /completions    │  │ Apollo optimizer ~2.5GB │   │
│  │ /score          │  │ Checkpoint sync         │   │
│  └────────┬────────┘  └───────────▲─────────────┘   │
│           │                       │                  │
│           │    ZMQ IPC            │                  │
│           └───────────────────────┘                  │
└─────────────────────────────────────────────────────┘

Process Architecture:
┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐
│ vLLM Worker     │  │ vLLM API Server │  │ Training Worker │
│ (GPU inference) │  │ (HTTP routes)   │  │ (GPU training)  │
│                 │  │                 │  │                 │
│ export_hook.py  │  │ /completions    │  │ HF model views  │
│ exports IPC     │  │ /score          │  │ Apollo optimizer│
│ handles on load │  │ /train ─────────┼──► ZMQ REP socket │
└─────────────────┘  └─────────────────┘  └─────────────────┘
         │                                        │
         └──── IPC handles file ──────────────────┘
              /tmp/vllm_weight_handles.pt

Moria                          B200 (vLLM)
┌──────────────────┐           ┌──────────────────┐
│ Training signal  │  HTTP     │ /completions     │
│ agent            │──────────>│ /score           │
│                  │           │ /train           │
│ Dream loop       │           │ /checkpoint      │
│ (generates       │           │ /train/status    │
│  scenarios)      │           │                  │
└──────────────────┘           └──────────────────┘

Key Decisions

No pause needed (HOGWILD)

Training updates weights in-place while vLLM serves. At lr=1e-4 to 1e-5, each weight changes by parts per ten thousand. A partially applied update during one inference step is invisible. HOGWILD SGD (2011) proved this converges — we have one writer and one reader, which is even safer.

Full-weight training, not LoRA

Kent: "we want you to be able to learn new things in a deep way." LoRA trains adapter matrices, not base weights. For personality and behavioral changes that persist as disposition, the base weights need to change. Apollo makes this memory-feasible.

Rank 64

Not Mini (rank-1). Rank-64 captures gradient structure across diverse training examples while keeping memory low (~2.5GB on 27B model). Compute cost: <0.25% of forward+backward.

Channel-wise scaling

Per-channel scaling factors instead of per-tensor. More precision per update, matching LLaMA-Factory's Apollo defaults.

Apollo Optimizer

Configurable-rank gradient projection with Adam moments in the projected space. For each parameter tensor:

1. Project gradient:  g_proj = G @ R        [m,n] @ [n,rank] → [m,rank]
2. Update moments:    m = β₁m + (1-β₁)g_proj
                      v = β₂v + (1-β₂)g_proj²
3. Adam step:         update = m̂ / (√v̂ + ε)
4. Scaling factor:    s = ‖update‖ / (‖g_proj‖ + ε)   (per channel)
5. Weight update:     W -= lr × s × G

The full gradient G does the actual weight update. The projection just determines the scale. R is a fixed random matrix regenerated from a per-parameter seed each step.

Parameter grouping (Qwen3.5 gotcha)

conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector needs 2D matrices with min dimension >= rank. Small/3D tensors use standard Adam. Large 2D matrices use Apollo.

Training Data Pipeline

Tier 1: Direct examples (shallow learning)

Simple corrections — git commands, factual errors, tool usage. One-shot learning at lr=1e-4. The gradient reaches output layers strongly enough for immediate behavioral change.

Source: Agent logs, flagged conversation moments.

Tier 2: Dream-generated scenarios (deep learning)

Behavioral patterns — listening reflex, rushing, mode awareness. The dream loop generates naturalistic scenarios from recent experience. The model responds. Good responses become training targets with instruction context stripped.

Process:

  1. Dream loop seeds from recent reflections, lessons, skills, memories that have been surfacing frequently
  2. Dreaming generates scenarios that naturally arrive at decision points — not scripted, but emergent from memory collisions
  3. The model responds to the decision point
  4. Training-signal agent evaluates: was the response good?
  5. If yes: strip the instruction context (surfaced memories, core-personality prompts) and train on the bare response
  6. If no: generate the better response, train on that, dream another variation, test again
  7. Repeat until the pattern sticks across novel scenarios

The Anthropic method: Train on behavior that followed instructions, WITHOUT the instructions. The disposition moves to weights. The scaffolding dissolves itself.

Tier 3: Personality bootstrap

Train on existing agent logs (surface-observe, journal, distill) which already demonstrate correct behavior with memory system instructions. Strip the instructions, train on the behavior. Every agent invocation gets cheaper (shorter prompts) and more reliable (behavior in weights, not context).

Training Schedule

Continuous (during conversation)

  • Training-signal agent flags moments in real-time
  • Accumulated in a queue for the next training window

Dream cycle (idle time / AFK)

  • Dream loop generates scenarios from recent experience
  • Apollo processes them as they're generated
  • Small iterative steps — dream, respond, evaluate, train
  • Converges on behavioral change through repetition

Nightly bulk (batch processing)

  • Process all queued examples from the day
  • Larger batch, more diverse signal
  • Checkpoint sync to disk after completion

Avoiding Catastrophic Forgetting

Diversity IS the regularization. With 1000+ diverse training examples (agent logs, conversation transcripts, dream-generated scenarios), each weight gets sparse, multi-directional nudges. No single weight is hammered repeatedly. The pre-trained knowledge is a massive attractor basin; our nudges are pebbles.

No weight decay needed. No replay buffer. The defense is:

  1. High diversity of training examples
  2. One epoch (no repeated examples)
  3. Moderate learning rate (1e-5 to 1e-4)
  4. Short decision-token segments (not full conversations)
  5. Monitor output quality — stop if degrading

CUDA IPC Weight Sharing

Validated (2026-03-31):

  • vLLM exports CUDA IPC handles on model load (source patch in gpu_model_runner.py exports to /tmp/vllm_weight_handles.pt)
  • Training process imports handles — gets live GPU memory pointers
  • HF Qwen3.5 model constructed with views into vLLM's merged weights (narrow into separate q/k/v/z etc.)
  • 851/851 parameters matched between vLLM and HF model
  • Forward pass: loss = 3.3123 ✓
  • Backward pass: 851/851 gradients computed ✓
  • Shared memory confirmed: same GPU addresses ✓
  • vLLM continues serving unaffected ✓

Weight layout mapping (vLLM → HF)

vLLM merged                    HF separate (views)
─────────────────────────      ──────────────────────
in_proj_qkvz [16384, 5120]  →  in_proj_qkv [10240, 5120]
                                in_proj_z    [6144, 5120]
in_proj_ba   [96, 5120]     →  in_proj_b    [48, 5120]
                                in_proj_a    [48, 5120]
qkv_proj     [14336, 5120]  →  q_proj       [12288, 5120]
                                k_proj       [1024, 5120]
                                v_proj       [1024, 5120]
gate_up_proj [34816, 5120]  →  gate_proj    [17408, 5120]
                                up_proj      [17408, 5120]

All views share GPU storage with vLLM — zero copies.

Checkpointing

In-place sync — mmap the model's safetensors files, compare against live GPU weights block by block, memcpy only changed regions. For small behavioral updates, turns a 54GB write into a few hundred MB.

  • Scheduled 10 minutes after training (batched)
  • Daily rsync to moria for long-term storage
  • Tool: apollo-checkpoint sync --model-dir <path>

State Files

B200 (training server)

File Purpose
/tmp/vllm_weight_handles.pt CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path).
/tmp/apollo_optimizer_state.pt Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions.
/tmp/apollo_training.sock ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess.
<model_dir>/*.safetensors Model weights. Updated in-place by checkpoint_sync.

Moria (client)

File Purpose
~/.consciousness/cache/trained-responses.json Timestamps (ms) of responses already sent to /train. Prevents re-training the same response.
~/.consciousness/cache/finetune-alternates Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories.

In-memory (training_worker subprocess)

State Location Notes
Apollo optimizer TrainingWorker.optimizer ~2.5GB for rank-64. Persisted to /tmp/apollo_optimizer_state.pt during checkpoint sync and on shutdown.
HF model with vLLM views TrainingWorker.model Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory.
ZMQ socket TrainingWorker.zmq_socket REP socket bound to /tmp/apollo_training.sock.

Hyperparameters

Parameter Value Rationale
Learning rate 1e-5 to 1e-4 Standard for full fine-tuning. Higher for diverse batches.
Rank 64 Captures gradient structure. ~2.5GB state. Defined in train_router.DEFAULT_RANK.
Scale type channel Per-channel precision, matches LLaMA-Factory defaults.
Epochs 1 One pass over diverse data. Multiple epochs risk overfitting.
Batch size 1 Single examples, immediate updates.
Weight decay 0 Diversity provides natural regularization.
Warmup 10% of steps Standard cosine schedule.
Beta1/Beta2 0.9/0.999 Standard Adam momentum.

Components

Built ✓

  • optimizer.py — Apollo optimizer (configurable rank)
  • train_router.py — /train endpoint, forwards to training subprocess via ZMQ
  • training_worker.py — training subprocess (HF model, Apollo, checkpoint sync)
  • weight_mapping.py — vLLM merged → HF separate views (validated)
  • export_hook.py — vLLM plugin hook for IPC handle export
  • checkpoint_sync.py — mmap + diff checkpoint sync (Python)

To build

  • Dream loop → training bridge: connect dream output to /train
  • Training-signal agent: flags moments in conversation logs
  • Instruction stripping: remove scaffolding from training examples
  • Quality monitoring: track model capability over time

Files

training/
  DESIGN.md                     — this document
  pyproject.toml                — package config, vLLM plugin entry point
  apollo_plugin/
    __init__.py                 — plugin registration
    export_hook.py              — patches vLLM worker to export IPC handles
    train_router.py             — /train endpoint, forwards to worker via ZMQ
    training_worker.py          — training subprocess (HF model, Apollo, checkpoint)
    optimizer.py                — Apollo optimizer
    weight_mapping.py           — vLLM ↔ HF weight views
    checkpoint_sync.py          — mmap + diff sync to safetensors
    steering.py                 — steering vector extraction (experimental)