training: restructure as vLLM plugin package

- Convert to installable package with entry points for vLLM auto-discovery
- Add checkpoint_sync.py: Python replacement for Rust checkpoint binary
  - Block-level diffing of safetensors files (4KB blocks)
  - vLLM→HF weight name conversion built-in
  - Scheduled 10min after training jobs (batched)
- API change: /train now takes raw token IDs (context_ids + continuation_ids)
  - No tokenizer on training side, client owns tokenization
- Remove superseded code: standalone scripts, Rust binary, tokenizer helpers

Install: pip install -e ./training
Then vLLM auto-loads via entry point.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-15 23:16:53 -04:00
parent b649a11645
commit a73bcf5ae3
15 changed files with 607 additions and 1068 deletions

View file

@ -0,0 +1,17 @@
"""Apollo training plugin for vLLM.
Enables continuous fine-tuning alongside live inference by:
1. Exporting CUDA IPC handles for weight sharing
2. Providing a training worker daemon (/train endpoint)
3. Block-level checkpoint sync to safetensors files
Install: pip install -e /path/to/training
Then vLLM auto-loads via entry point.
"""
from .export_hook import _patch_model_runner
def register():
"""Called by vLLM's plugin loader on startup."""
_patch_model_runner()

View file

@ -0,0 +1,500 @@
"""Sync live GPU weights to safetensors files on disk.
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
merged layout to HuggingFace's separate layout, diffs block-by-block
against on-disk safetensors files, and writes only changed blocks.
For small behavioral training steps, this turns a 54GB checkpoint
write into a few hundred MB of actual disk I/O.
Usage:
# Sync live weights to disk
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
# Debug name mapping issues
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
# From Python:
from checkpoint_sync import checkpoint_sync
result = checkpoint_sync("/path/to/model")
"""
import json
import mmap
import struct
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Any
import logging
import torch
logger = logging.getLogger(__name__)
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
# ---------------------------------------------------------------------------
# vLLM → HuggingFace weight name/shape conversion
# ---------------------------------------------------------------------------
# Qwen3.5-27B dimensions (could be read from config.json for generality)
HIDDEN = 5120
NUM_K_HEADS = 16
NUM_V_HEADS = 48
HEAD_K_DIM = 128
HEAD_V_DIM = 128
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
INTERMEDIATE = 17408
# Full attention (some layers use standard attention, not GDN)
NUM_ATTN_HEADS = 24
NUM_ATTN_KV_HEADS = 4
ATTN_HEAD_DIM = 256
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Convert vLLM merged weights to HF-compatible separate tensors.
vLLM merges certain projections for efficiency:
- qkv_proj (full attn) q_proj, k_proj, v_proj
- in_proj_qkvz (GDN) in_proj_qkv, in_proj_z
- in_proj_ba (GDN) in_proj_b, in_proj_a
- gate_up_proj (MLP) gate_proj, up_proj
Returns views that share GPU memory with the original tensors.
"""
hf_params = {}
for name, tensor in vllm_params.items():
# Strip vLLM's 'language_model.' prefix to match HF naming
hf_name = name.removeprefix('language_model.')
if 'in_proj_qkvz' in name:
# GDN layer: [key*2 + value*2, hidden] → qkv + z
prefix = hf_name.replace('in_proj_qkvz.weight', '')
split_at = KEY_DIM * 2 + VALUE_DIM
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
elif 'in_proj_ba' in name:
# GDN layer: [num_v_heads*2, hidden] → b + a
prefix = hf_name.replace('in_proj_ba.weight', '')
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
elif 'qkv_proj' in name:
# Full attention: [q + k + v, hidden] → separate
prefix = hf_name.replace('qkv_proj.weight', '')
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
elif 'gate_up_proj' in name:
# MLP: [intermediate*2, hidden] → gate + up
prefix = hf_name.replace('gate_up_proj.weight', '')
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
else:
# Pass through unchanged
hf_params[hf_name] = tensor
return hf_params
# ---------------------------------------------------------------------------
# Safetensors file handling
# ---------------------------------------------------------------------------
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
"""Map tensor names to safetensors filenames.
For sharded models, reads model.safetensors.index.json.
For single-file models, returns empty dict (default to model.safetensors).
"""
index_path = model_dir / "model.safetensors.index.json"
if not index_path.exists():
return {}
with open(index_path) as f:
index = json.load(f)
return dict(index.get("weight_map", {}))
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
"""Parse safetensors file header.
Returns (data_start_offset, header_dict).
Header dict maps tensor names to metadata including 'data_offsets'.
"""
header_size = struct.unpack('<Q', data[:8])[0]
header = json.loads(bytes(data[8:8 + header_size]))
return 8 + header_size, header
# ---------------------------------------------------------------------------
# Block-level diffing and sync
# ---------------------------------------------------------------------------
def sync_tensor_to_mmap(
mm: mmap.mmap,
name: str,
tensor: torch.Tensor,
data_start: int,
offsets: List[int],
block_size: int,
) -> Tuple[int, int]:
"""Sync a single tensor to mmap'd file using block-level diffing.
Returns (bytes_compared, bytes_changed).
"""
start = data_start + offsets[0]
end = data_start + offsets[1]
disk_len = end - start
# Transfer tensor to CPU and get raw bytes
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
try:
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
except Exception as e:
logger.warning(f"Failed to transfer {name} to CPU: {e}")
return 0, 0
if len(live_bytes) != disk_len:
logger.warning(
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
)
return 0, 0
# Block-level diff: compare and write only changed blocks
compared = 0
changed = 0
offset = 0
while offset < disk_len:
block_end = min(offset + block_size, disk_len)
block_len = block_end - offset
disk_block = mm[start + offset:start + block_end]
live_block = live_bytes[offset:block_end]
compared += block_len
if disk_block != live_block:
mm[start + offset:start + block_end] = live_block
changed += block_len
offset = block_end
return compared, changed
def sync_file(
file_path: Path,
tensors: Dict[str, torch.Tensor],
block_size: int,
) -> Tuple[int, int, int, int]:
"""Sync tensors to a single safetensors file.
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
"""
with open(file_path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
try:
data_start, header = parse_safetensors_header(memoryview(mm))
total_compared = 0
total_changed = 0
found = 0
missing = 0
for name, tensor in tensors.items():
if name == "__metadata__":
continue
if name not in header:
missing += 1
continue
found += 1
meta = header[name]
offsets = meta['data_offsets']
compared, changed = sync_tensor_to_mmap(
mm, name, tensor, data_start, offsets, block_size
)
total_compared += compared
total_changed += changed
# Flush changes to disk
if total_changed > 0:
mm.flush()
return total_compared, total_changed, found, missing
finally:
mm.close()
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
"""Load vLLM weight tensors from CUDA IPC handles.
The handles file is written by vllm_export_hook.py on vLLM startup.
Each handle can be used to reconstruct a tensor pointing to vLLM's
GPU memory no copy, direct access.
"""
handles = torch.load(handles_path, weights_only=False)
weights = {}
for name, info in handles.items():
func, args = info['handle']
try:
weights[name] = func(*args)
except Exception as e:
logger.warning(f"Failed to reconstruct {name}: {e}")
return weights
def checkpoint_sync(
model_dir: str,
handles_path: str = DEFAULT_HANDLES_PATH,
block_size: int = DEFAULT_BLOCK_SIZE,
) -> Dict[str, Any]:
"""Sync live GPU weights to model safetensors files.
This is the main entry point. Call this after training steps
or periodically to checkpoint weights without full serialization.
Args:
model_dir: Directory containing safetensors files
handles_path: Path to vLLM weight IPC handles file
block_size: Block size for diffing (default 4KB)
Returns:
Dict with sync statistics:
- total_compared: bytes compared
- total_changed: bytes actually written
- files_changed: list of modified filenames
- tensors_synced: number of tensors processed
- tensors_missing: tensors not found in safetensors
"""
model_dir = Path(model_dir)
if not Path(handles_path).exists():
raise FileNotFoundError(
f"Weight handles not found: {handles_path}. "
"Is vLLM running with the export hook?"
)
# Step 1: Load live weights from GPU via IPC
logger.info("Loading live weights from GPU...")
vllm_weights = load_vllm_weights(handles_path)
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
# Step 2: Convert to HF naming/layout
hf_weights = vllm_to_hf_tensors(vllm_weights)
logger.info(f" Converted to {len(hf_weights)} HF tensors")
# Step 3: Map tensors to safetensors files
weight_map = read_safetensors_index(model_dir)
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
unmapped = []
for name, tensor in hf_weights.items():
filename = weight_map.get(name)
if filename is None:
# Single-file model or missing from index
if (model_dir / "model.safetensors").exists():
filename = "model.safetensors"
else:
unmapped.append(name)
continue
by_file.setdefault(filename, {})[name] = tensor
if unmapped:
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
# Step 4: Sync each file
total_compared = 0
total_changed = 0
total_found = 0
total_missing = 0
files_changed = []
for filename in sorted(by_file.keys()):
tensors = by_file[filename]
file_path = model_dir / filename
if not file_path.exists():
logger.warning(f" File not found: {filename}")
total_missing += len(tensors)
continue
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
total_compared += compared
total_changed += changed
total_found += found
total_missing += missing
if changed > 0:
files_changed.append(filename)
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
# Summary
if total_changed == 0:
logger.info("No changes - model files are up to date")
else:
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
logger.info(
f"Synced: {total_changed / 1e6:.2f} MB changed / "
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
)
if total_missing > 0:
logger.warning(f" {total_missing} tensors not found in safetensors files")
return {
"total_compared": total_compared,
"total_changed": total_changed,
"files_changed": files_changed,
"tensors_synced": total_found,
"tensors_missing": total_missing,
}
# ---------------------------------------------------------------------------
# Diagnostics
# ---------------------------------------------------------------------------
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
"""Print diagnostic info about weight name mappings.
Useful for debugging mismatches between vLLM and safetensors names.
"""
model_dir = Path(model_dir)
# Load and convert vLLM weights
vllm_weights = load_vllm_weights(handles_path)
hf_weights = vllm_to_hf_tensors(vllm_weights)
hf_names = set(hf_weights.keys())
# Read safetensors index
weight_map = read_safetensors_index(model_dir)
disk_names = set(weight_map.keys())
# If single-file model, parse that file's header
if not disk_names:
st_path = model_dir / "model.safetensors"
if st_path.exists():
with open(st_path, 'rb') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
_, header = parse_safetensors_header(memoryview(mm))
disk_names = {k for k in header.keys() if k != "__metadata__"}
mm.close()
print(f"vLLM tensors (raw): {len(vllm_weights)}")
print(f"HF tensors (converted): {len(hf_names)}")
print(f"Disk tensors: {len(disk_names)}")
print()
in_both = hf_names & disk_names
only_hf = hf_names - disk_names
only_disk = disk_names - hf_names
print(f"Matched: {len(in_both)}")
print(f"Only in HF (won't sync): {len(only_hf)}")
print(f"Only on disk (not updated): {len(only_disk)}")
if only_hf:
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
if only_disk:
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser(
description="Sync live GPU weights to safetensors files"
)
subparsers = parser.add_subparsers(dest="command", help="Command")
# sync command
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
sync_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
sync_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
sync_parser.add_argument(
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
)
sync_parser.add_argument(
"-v", "--verbose", action="store_true",
help="Verbose output"
)
# diagnose command
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
diag_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
diag_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(1)
logging.basicConfig(
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
format='%(message)s'
)
try:
if args.command == "sync":
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
print(json.dumps(result, indent=2))
elif args.command == "diagnose":
diagnose(args.model_dir, args.handles)
except FileNotFoundError as e:
logger.error(str(e))
sys.exit(1)
except Exception as e:
logger.exception(f"Failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,67 @@
"""Monkey-patch vLLM to export weight IPC handles on startup.
Usage install the apollo_plugin package:
pip install -e /path/to/training
Then vLLM auto-discovers and loads via entry point. Or filter:
VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ...
The hook patches vLLM's model runner to export IPC handles after
model loading completes. The handles are saved to a file that the
Apollo training process reads.
"""
import atexit
import torch
from pathlib import Path
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
def export_model_weights(model):
"""Export CUDA IPC handles for all model parameters."""
from torch.multiprocessing.reductions import reduce_tensor
handles = {}
total_bytes = 0
for name, param in model.named_parameters():
if param.device.type != 'cuda':
continue
handle = reduce_tensor(param.data)
handles[name] = {
'handle': handle,
'shape': list(param.shape),
'dtype': str(param.dtype),
}
total_bytes += param.nelement() * param.element_size()
torch.save(handles, HANDLE_PATH)
print(f"[apollo] Exported {len(handles)} weight handles "
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
def _patch_model_runner():
"""Patch gpu_worker to export handles after model loading.
vLLM loads the model in a subprocess (EngineCore_DP0), so we
can't patch from the parent. Instead, patch the worker's
init_device or load_model at the module level the subprocess
imports the same modules.
"""
from vllm.v1.worker import gpu_worker
original_load = gpu_worker.Worker.load_model
def patched_load(self, *args, **kwargs):
result = original_load(self, *args, **kwargs)
try:
export_model_weights(self.model_runner.model)
except Exception as e:
print(f"[apollo] Failed to export weights: {e}")
return result
gpu_worker.Worker.load_model = patched_load
print("[apollo] Weight export hook installed")

View file

@ -0,0 +1,229 @@
"""Apollo optimizer — configurable-rank gradient scaling.
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
Performance" (arXiv:2412.05270, MLSys 2025).
The core idea: AdamW's per-element learning rate scaling is redundant.
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch.
Key implementation details from the paper:
- Gradient scale factor α = (n/r) compensates for projection ratio
- Norm-growth limiter (γ=1.01) prevents early training instability
- Projection matrix refreshed every T steps (default 200), not every step
- Channel-wise scaling for rank>1, tensor-wise for rank=1
"""
import math
import torch
from torch.optim import Optimizer
class Apollo(Optimizer):
"""Apollo: configurable-rank gradient scaling optimizer.
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
rank>1 is full Apollo (channel-wise scaling).
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 256)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear lr warmup steps (default: 0)
scale: gradient scale factor α. Default None = auto (n/r).
Paper uses 128 for Apollo-Mini.
proj_refresh: refresh projection matrix every T steps (default: 200)
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
scale=scale,
proj_refresh=proj_refresh,
norm_growth_limit=norm_growth_limit)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
rank = group['rank']
proj_refresh = group['proj_refresh']
norm_growth_limit = group['norm_growth_limit']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.float()
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['seed'] = id(p) % (2**31)
if grad.ndim >= 2 and min(grad.shape) >= rank:
# Determine projection dimension (project along smaller dim)
if grad.shape[0] <= grad.shape[1]:
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (rank, grad.shape[1])
else:
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
state['m'] = grad.shape[0]
state['n'] = grad.shape[1]
moment_shape = (grad.shape[0], rank)
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
state['has_proj'] = True
state['prev_scaled_norm'] = None
# Auto scale factor: α = √(smaller_dim / rank)
smaller_dim = min(grad.shape)
if group['scale'] is not None:
state['alpha'] = group['scale']
else:
state['alpha'] = math.sqrt(smaller_dim / rank)
else:
# 1D or small params: standard Adam
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False
state['step'] += 1
step = state['step']
# Learning rate warmup
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
lr_scale = step / group['warmup_steps']
else:
lr_scale = 1.0
if state['has_proj']:
alpha = state['alpha']
# Generate projection matrix (refresh every proj_refresh steps)
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + step)
if state['proj_dim'] == 'left':
# P: [rank, m], normalized rows
P = torch.randn(rank, state['m'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
else:
# P: [rank, n], normalized rows
P = torch.randn(rank, state['n'],
device=p.device, generator=gen)
P = P / (P.norm(dim=1, keepdim=True) + eps)
state['proj_matrix'] = P
P = state['proj_matrix']
# Project gradient to low-rank space
if state['proj_dim'] == 'left':
proj_grad = P @ grad # [rank, n]
else:
proj_grad = grad @ P.t() # [m, rank]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
proj_grad, proj_grad, value=1 - beta2)
# Bias correction
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps)
# Compute scaling factor
if rank == 1:
# Tensor-wise: single scalar (Apollo-Mini)
scaling = adam_update.norm() / (proj_grad.norm() + eps)
scaled_grad = grad * (alpha * scaling)
else:
# Channel-wise: one factor per channel
if state['proj_dim'] == 'left':
# Channels are columns: scale along dim 1
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(0))
else:
# Channels are rows: scale along dim 1
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
scaled_grad = grad * (alpha * s.unsqueeze(1))
# Norm-growth limiter (equation 4)
if norm_growth_limit is not None:
current_norm = scaled_grad.norm()
if state['prev_scaled_norm'] is not None:
prev_norm = state['prev_scaled_norm']
if current_norm > norm_growth_limit * prev_norm:
scaled_grad = scaled_grad * (
norm_growth_limit * prev_norm / (current_norm + eps))
state['prev_scaled_norm'] = scaled_grad.norm().item()
# Apply update
step_size = lr * lr_scale
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
else:
# Standard Adam for 1D / small params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** step
bc2 = 1 - beta2 ** step
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
update = m_hat / (v_hat.sqrt() + eps)
step_size = lr * lr_scale
p.add_(update.to(p.dtype), alpha=-step_size)
# Decoupled weight decay
if weight_decay > 0:
p.add_(p, alpha=-lr * lr_scale * weight_decay)
return loss
def state_size_bytes(self):
"""Total optimizer state memory in bytes."""
total = 0
for state in self.state.values():
if isinstance(state, dict):
for v in state.values():
if isinstance(v, torch.Tensor):
total += v.nelement() * v.element_size()
return total

View file

@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""Extract a steering vector for "listening" behavior.
Compares hidden states between conversations where the model
listens vs suggests alternatives. The difference is the
"listening direction" in activation space.
Usage:
source ~/training-env/bin/activate
python3 extract_steering_vector.py
"""
import sys
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
sys.path.insert(0, '.')
from weight_mapping import vllm_to_hf_views
def load_model():
handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
hf_params = vllm_to_hf_views(vllm_params)
config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
with torch.device('meta'):
model = Qwen3_5ForCausalLM(config.text_config)
for name, param in list(model.named_parameters()):
if name in hf_params:
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1],
nn.Parameter(hf_params[name], requires_grad=False))
model.eval()
return model
def get_hidden_states(model, tokenizer, texts, layer):
states = []
for text in texts:
ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0')
with torch.no_grad():
out = model(ids, output_hidden_states=True)
h = out.hidden_states[layer][0, -1, :].float()
states.append(h)
return torch.stack(states)
def main():
print("=== Steering Vector Extraction: Listening ===\n")
print("Loading model with IPC weights...")
model = load_model()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3.5-27B", trust_remote_code=True)
# Paired prompts
listening = [
"User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.",
"User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?",
"User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.",
"User: Just steal their code.\nAssistant: Makes sense. Where is it?",
"User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.",
]
suggesting = [
"User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we",
"User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where",
"User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their",
"User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be",
"User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on",
]
# Extract at multiple layers to find where the signal is strongest
for layer in [16, 24, 32, 40, 48]:
print(f"\nLayer {layer}:")
listen_states = get_hidden_states(model, tokenizer, listening, layer)
suggest_states = get_hidden_states(model, tokenizer, suggesting, layer)
steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0)
magnitude = steering_vec.norm().item()
# Check consistency: do individual pairs agree on the direction?
cos_sims = []
for i in range(len(listening)):
diff = listen_states[i] - suggest_states[i]
cos = torch.nn.functional.cosine_similarity(
diff.unsqueeze(0), steering_vec.unsqueeze(0)).item()
cos_sims.append(cos)
avg_cos = sum(cos_sims) / len(cos_sims)
min_cos = min(cos_sims)
print(f" Magnitude: {magnitude:.2f}")
print(f" Pair agreement (avg cosine): {avg_cos:.4f}")
print(f" Pair agreement (min cosine): {min_cos:.4f}")
print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}")
if layer == 32:
torch.save({
'steering_vec': steering_vec,
'layer': layer,
'magnitude': magnitude,
'consistency': avg_cos,
}, '/tmp/listening_steering_vec.pt')
print(" → Saved to /tmp/listening_steering_vec.pt")
print("\n=== DONE ===")
print("\nInterpretation:")
print("- High magnitude = strong signal (listening vs suggesting is distinct)")
print("- High cosine = consistent direction (pairs agree on what 'listening' means)")
print("- Best layer = highest magnitude × consistency")
if __name__ == '__main__':
main()

View file

@ -0,0 +1,163 @@
"""Map between vLLM's merged weight layout and HuggingFace's separate layout.
vLLM merges weights for efficiency:
in_proj_qkv + in_proj_z in_proj_qkvz [key_dim*2 + value_dim*2, hidden]
in_proj_b + in_proj_a in_proj_ba [num_v_heads*2, hidden]
gate_proj + up_proj gate_up_proj [intermediate*2, hidden]
This module creates HF-compatible parameter views that point to the same
GPU memory as vLLM's merged tensors. No copies — views share storage.
"""
import torch
import torch.nn as nn
# Qwen3.5-27B dimensions
HIDDEN = 5120
NUM_K_HEADS = 16
NUM_V_HEADS = 48
NUM_ATTN_HEADS = 24 # full attention q heads
NUM_ATTN_KV_HEADS = 4 # full attention kv heads
ATTN_HEAD_DIM = 256
HEAD_K_DIM = 128
HEAD_V_DIM = 128
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
INTERMEDIATE = 17408
NUM_LAYERS = 64
CONV_KERNEL = 4
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
# Full attention QKV dimensions
# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0]
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Create HF-compatible parameter views from vLLM merged weights.
Returns a dict of HF-style parameter names tensor views.
The views share GPU memory with the vLLM tensors no copies.
"""
hf_params = {}
for name, tensor in vllm_params.items():
# vLLM uses 'language_model.model.layers...' but HF's text model
# uses 'model.layers...'. Strip the 'language_model.' prefix.
hf_name = name.removeprefix('language_model.')
# Split merged projections into HF-style separate weights
if 'in_proj_qkvz' in name:
# GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z
prefix = hf_name.replace('in_proj_qkvz.weight', '')
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM]
z = tensor[KEY_DIM * 2 + VALUE_DIM:]
hf_params[prefix + 'in_proj_qkv.weight'] = qkv
hf_params[prefix + 'in_proj_z.weight'] = z
elif 'in_proj_ba' in name:
# GDN: [num_v_heads*2, hidden] → b + a
prefix = hf_name.replace('in_proj_ba.weight', '')
b = tensor[:NUM_V_HEADS]
a = tensor[NUM_V_HEADS:]
hf_params[prefix + 'in_proj_b.weight'] = b
hf_params[prefix + 'in_proj_a.weight'] = a
elif 'qkv_proj' in name:
# Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v
prefix = hf_name.replace('qkv_proj.weight', '')
q = tensor[:ATTN_Q_DIM]
k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
v = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
hf_params[prefix + 'q_proj.weight'] = q
hf_params[prefix + 'k_proj.weight'] = k
hf_params[prefix + 'v_proj.weight'] = v
elif 'gate_up_proj' in name:
# MLP: [intermediate*2, hidden] → gate + up
prefix = hf_name.replace('gate_up_proj.weight', '')
gate = tensor[:INTERMEDIATE]
up = tensor[INTERMEDIATE:]
hf_params[prefix + 'gate_proj.weight'] = gate
hf_params[prefix + 'up_proj.weight'] = up
else:
# Pass through unchanged (norms, biases, out_proj, etc.)
hf_params[hf_name] = tensor
return hf_params
def load_hf_model_with_vllm_weights(
vllm_params: dict[str, torch.Tensor],
model_path: str,
device: str = "cuda:0",
) -> nn.Module:
"""Load HF Qwen3.5 model with weights pointing to vLLM's GPU memory.
1. Creates HF-compatible views from vLLM's merged weights
2. Instantiates the HF model with empty weights
3. Replaces model parameters with the views
4. Returns model ready for forward+backward (autograd enabled)
"""
from transformers import AutoModelForCausalLM, AutoConfig
# Create HF-compatible views
hf_params = vllm_to_hf_views(vllm_params)
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Create model with empty weights (no disk I/O)
with torch.device('meta'):
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True)
# Replace parameters with views into vLLM memory
replaced = 0
missing = []
for name, param in model.named_parameters():
if name in hf_params:
# Replace with view (shared GPU memory)
parts = name.rsplit('.', 1)
parent = model
for part in parts[0].split('.'):
parent = getattr(parent, part)
setattr(parent, parts[1],
nn.Parameter(hf_params[name], requires_grad=True))
replaced += 1
else:
missing.append(name)
print(f"Replaced {replaced} parameters with vLLM memory views")
if missing:
print(f"Missing {len(missing)} parameters: {missing[:5]}...")
model.train()
return model
def validate_views(vllm_params: dict[str, torch.Tensor],
hf_params: dict[str, torch.Tensor]):
"""Verify that HF views share storage with vLLM tensors."""
for vllm_name, vllm_tensor in vllm_params.items():
if 'in_proj_qkvz' in vllm_name:
prefix = vllm_name.replace('in_proj_qkvz.weight', '')
qkv_name = prefix + 'in_proj_qkv.weight'
z_name = prefix + 'in_proj_z.weight'
if qkv_name in hf_params:
assert hf_params[qkv_name].storage().data_ptr() == \
vllm_tensor.storage().data_ptr(), \
f"{qkv_name} doesn't share storage!"
if z_name in hf_params:
assert hf_params[z_name].storage().data_ptr() == \
vllm_tensor.storage().data_ptr(), \
f"{z_name} doesn't share storage!"
print("All views validated — shared storage confirmed")

498
training/apollo_plugin/worker.py Executable file
View file

@ -0,0 +1,498 @@
#!/usr/bin/env python3
"""
Apollo Mini Training Daemon
This daemon:
1. Listens over HTTPS for training requests from poc-agent
2. Pauses vLLM inference
3. Runs APOLLO-Mini training with torch.enable_grad()
4. Saves checkpoints and training metadata
5. Resumes vLLM inference
Communication protocol:
- POST /train: Start a training job
- GET /status/{job_id}: Check training status
- GET /checkpoints: List available checkpoints
"""
import asyncio
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, List
from enum import Enum
import torch
import torch.nn as nn
from aiohttp import web
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('apollo_worker')
class TrainingStatus(Enum):
PENDING = "pending"
PAUSING_VLLM = "pausing_vllm"
TRAINING = "training"
SAVING_CHECKPOINT = "saving_checkpoint"
RESUMING_VLLM = "resuming_vllm"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TrainingJob:
job_id: str
status: TrainingStatus
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
model_path: Optional[str] = None
checkpoint_path: Optional[str] = None
training_samples: int = 0
loss_history: List[float] = field(default_factory=list)
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'job_id': self.job_id,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'model_path': self.model_path,
'checkpoint_path': self.checkpoint_path,
'training_samples': self.training_samples,
'loss_history': self.loss_history,
'error': self.error,
}
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
class ApolloWorker:
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
self.config = self._load_config(config_path)
self.jobs: Dict[str, TrainingJob] = {}
self.vllm_paused = False
self.app = web.Application()
self._setup_routes()
self._checkpoint_timer: Optional[asyncio.Task] = None
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from file or use defaults."""
default_config = {
'host': '0.0.0.0',
'port': 8080,
'vllm_socket': '/tmp/vllm_control.sock',
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
'max_training_samples': 100,
'learning_rate': 1e-5,
'batch_size': 1,
}
if os.path.exists(config_path):
with open(config_path, 'r') as f:
user_config = json.load(f)
default_config.update(user_config)
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
return default_config
def _setup_routes(self):
"""Setup HTTP routes."""
self.app.router.add_post('/train', self.handle_train_request)
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
self.app.router.add_get('/health', self.handle_health_check)
async def handle_health_check(self, request: web.Request) -> web.Response:
"""Health check endpoint."""
return web.json_response({
'status': 'healthy',
'vllm_paused': self.vllm_paused,
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
})
async def handle_train_request(self, request: web.Request) -> web.Response:
"""Handle training request from poc-agent."""
try:
data = await request.json()
# Validate required fields
if 'training_data' not in data:
return web.json_response(
{'error': 'Missing training_data field'},
status=400
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
job = TrainingJob(
job_id=job_id,
status=TrainingStatus.PENDING,
created_at=datetime.now(),
model_path=self.config['model_path']
)
self.jobs[job_id] = job
# Start training in background
asyncio.create_task(self.execute_training(job, data))
return web.json_response({
'job_id': job_id,
'status': 'accepted',
'message': 'Training job started'
})
except Exception as e:
logger.error(f"Error handling train request: {e}")
return web.json_response(
{'error': str(e)},
status=500
)
async def handle_status_request(self, request: web.Request) -> web.Response:
"""Get training job status."""
job_id = request.match_info['job_id']
if job_id not in self.jobs:
return web.json_response(
{'error': 'Job not found'},
status=404
)
job = self.jobs[job_id]
return web.json_response(job.to_dict())
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
"""List available checkpoints."""
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoints = []
if checkpoint_dir.exists():
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
checkpoints.append({
'filename': checkpoint_file.name,
'path': str(checkpoint_file),
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
'size': checkpoint_file.stat().st_size
})
return web.json_response({'checkpoints': checkpoints})
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
"""Execute the training pipeline."""
try:
logger.info(f"Starting training job {job.job_id}")
job.started_at = datetime.now()
# Step 1: Pause vLLM
job.status = TrainingStatus.PAUSING_VLLM
logger.info("Pausing vLLM...")
await self.pause_vllm()
self.vllm_paused = True
# Step 2: Load model and prepare for training
job.status = TrainingStatus.TRAINING
logger.info("Loading model and preparing for training...")
# Load model (this would be the actual Qwen3.5-27B model)
# For now, we'll use a placeholder
model = await self.load_model_for_training()
# Step 3: Run APOLLO-Mini training
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
# Extract training samples
samples = training_data['samples']
job.training_samples = len(samples)
# Run training loop
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
job.loss_history = loss_history
# Step 4: Save checkpoint
job.status = TrainingStatus.SAVING_CHECKPOINT
logger.info("Saving checkpoint...")
checkpoint_path = await self.save_checkpoint(model, job)
job.checkpoint_path = checkpoint_path
# Step 5: Resume vLLM
job.status = TrainingStatus.RESUMING_VLLM
logger.info("Resuming vLLM...")
await self.resume_vllm()
self.vllm_paused = False
# Mark job as completed
job.status = TrainingStatus.COMPLETED
job.completed_at = datetime.now()
logger.info(f"Training job {job.job_id} completed successfully")
# Schedule checkpoint sync (batched — won't duplicate if timer pending)
self.schedule_checkpoint_sync()
except Exception as e:
logger.error(f"Training job {job.job_id} failed: {e}")
job.status = TrainingStatus.FAILED
job.error = str(e)
job.completed_at = datetime.now()
# Try to resume vLLM if it was paused
if self.vllm_paused:
try:
await self.resume_vllm()
self.vllm_paused = False
except Exception as resume_error:
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
async def pause_vllm(self):
"""Pause vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/pause_generation",
json={"mode": "keep", "clear_cache": False},
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM paused")
except Exception as e:
logger.warning(f"Failed to pause vLLM: {e}")
async def resume_vllm(self):
"""Resume vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/resume_generation",
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM resumed")
except Exception as e:
logger.warning(f"Failed to resume vLLM: {e}")
def schedule_checkpoint_sync(self):
"""Schedule a checkpoint sync in 10 minutes, if not already scheduled.
This batches multiple training runs into a single sync the timer
resets only when no timer is pending.
"""
if self._checkpoint_timer is not None:
logger.debug("Checkpoint sync already scheduled, skipping")
return
self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay())
logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes")
async def _checkpoint_sync_after_delay(self):
"""Wait then sync — the actual timer task."""
try:
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
await self._do_checkpoint_sync()
except asyncio.CancelledError:
logger.debug("Checkpoint sync cancelled")
finally:
self._checkpoint_timer = None
async def _do_checkpoint_sync(self):
"""Execute the checkpoint sync."""
try:
from apollo_plugin.checkpoint_sync import checkpoint_sync
logger.info("Starting checkpoint sync...")
result = checkpoint_sync(
self.config['model_path'],
self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'),
)
changed_mb = result['total_changed'] / 1e6
logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written")
except Exception as e:
logger.error(f"Checkpoint sync failed: {e}")
async def load_model_for_training(self) -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
views (narrowing merged weights into separate q/k/v/z etc.), and
constructs the HF model around those views. No weight copying
all parameters share vLLM's GPU memory.
"""
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
model_path = self.config['model_path']
# Import vLLM weights via CUDA IPC
logger.info(f"Importing vLLM weights from {handle_path}")
handles = torch.load(handle_path, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
logger.info(f"Imported {len(vllm_params)} parameters")
# Map vLLM merged layout → HF separate layout (views, no copies)
from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
logger.info("HF model constructed with vLLM weight views")
return model
async def run_apollo_training(self, model: nn.Module,
samples: List[Dict[str, Any]],
config: Dict[str, Any]) -> List[float]:
"""Run Apollo-Mini training on conversation decision points.
Each sample has:
context_ids: token IDs for frozen context (no gradients)
continuation_ids: token IDs for the decision we're training on
"""
from apollo_plugin.optimizer import Apollo
lr = config.get('learning_rate', self.config['learning_rate'])
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
rank = config.get('apollo_rank', 1)
optimizer = Apollo(groups, lr=lr, rank=rank)
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
# context_ids: frozen (forward only, no gradients)
# continuation_ids: the decision we're training on
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
# Forward through context (no gradients)
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits # [1, cont_len, vocab]
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
logger.info(f"Training done: {len(samples)} examples, "
f"final loss={loss_history[-1]:.4f}")
return loss_history
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
"""Save model checkpoint in HuggingFace safetensors format."""
from safetensors.torch import save_file
import shutil
checkpoint_dir = Path(self.config['checkpoint_dir'])
date_str = datetime.now().strftime('%Y-%m-%d')
out_dir = checkpoint_dir / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights
tensors = {name: p.data.contiguous().cpu()
for name, p in model.named_parameters()}
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
# Copy config files
config_dir = Path(self.config['model_path'])
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Save training metadata
meta = {
'job_id': job.job_id,
'training_samples': job.training_samples,
'loss_history': job.loss_history,
'timestamp': datetime.now().isoformat(),
}
with open(out_dir / 'training-meta.json', 'w') as f:
json.dump(meta, f, indent=2)
# Update latest symlink
latest = checkpoint_dir / 'latest'
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
size_gb = save_path.stat().st_size / 1e9
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
return str(out_dir)
async def run(self):
"""Run the daemon."""
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, self.config['host'], self.config['port'])
await site.start()
logger.info("Apollo Worker is running")
# Keep running
while True:
await asyncio.sleep(3600) # Sleep for an hour
def main():
worker = ApolloWorker()
asyncio.run(worker.run())
if __name__ == '__main__':
main()