consciousness/training/apollo_plugin/checkpoint_sync.py
ProofOfConcept 2c6a5c0f4a training: move to dedicated subprocess with ZMQ communication
- Add training_worker.py: long-lived subprocess that handles GPU training
  work, owns HF model wrapper (views into vLLM GPU memory), Apollo
  optimizer, and checkpoint sync

- train_router.py: now forwards /train requests via async ZMQ instead of
  running training in-process. Adds /checkpoint and /train/status endpoints

- export_hook.py: store model_path in __metadata__ so training worker can
  find it without cross-process communication

- This fixes two bugs:
  1. Process boundary issue - model_path was set in worker process but
     needed in API server process
  2. Blocking event loop - training blocked vLLM's async event loop

Architecture: vLLM API server <-> ZMQ <-> training subprocess
The subprocess loads IPC handles once, creates views into vLLM's GPU
memory, and handles training requests without blocking inference.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-16 02:04:26 -04:00

503 lines
16 KiB
Python

"""Sync live GPU weights to safetensors files on disk.
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
merged layout to HuggingFace's separate layout, diffs block-by-block
against on-disk safetensors files, and writes only changed blocks.
For small behavioral training steps, this turns a 54GB checkpoint
write into a few hundred MB of actual disk I/O.
Usage:
# Sync live weights to disk
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
# Debug name mapping issues
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
# From Python:
from checkpoint_sync import checkpoint_sync
result = checkpoint_sync("/path/to/model")
"""
import json
import mmap
import struct
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Any
import logging
import torch
logger = logging.getLogger(__name__)
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
# ---------------------------------------------------------------------------
# vLLM → HuggingFace weight name/shape conversion
# ---------------------------------------------------------------------------
# Qwen3.5-27B dimensions (could be read from config.json for generality)
HIDDEN = 5120
NUM_K_HEADS = 16
NUM_V_HEADS = 48
HEAD_K_DIM = 128
HEAD_V_DIM = 128
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
INTERMEDIATE = 17408
# Full attention (some layers use standard attention, not GDN)
NUM_ATTN_HEADS = 24
NUM_ATTN_KV_HEADS = 4
ATTN_HEAD_DIM = 256
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Convert vLLM merged weights to HF-compatible separate tensors.
vLLM merges certain projections for efficiency:
- qkv_proj (full attn) → q_proj, k_proj, v_proj
- in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z
- in_proj_ba (GDN) → in_proj_b, in_proj_a
- gate_up_proj (MLP) → gate_proj, up_proj
Returns views that share GPU memory with the original tensors.
"""
hf_params = {}
for name, tensor in vllm_params.items():
# Strip vLLM's 'language_model.' prefix to match HF naming
hf_name = name.removeprefix('language_model.')
if 'in_proj_qkvz' in name:
# GDN layer: [key*2 + value*2, hidden] → qkv + z
prefix = hf_name.replace('in_proj_qkvz.weight', '')
split_at = KEY_DIM * 2 + VALUE_DIM
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
elif 'in_proj_ba' in name:
# GDN layer: [num_v_heads*2, hidden] → b + a
prefix = hf_name.replace('in_proj_ba.weight', '')
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
elif 'qkv_proj' in name:
# Full attention: [q + k + v, hidden] → separate
prefix = hf_name.replace('qkv_proj.weight', '')
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
elif 'gate_up_proj' in name:
# MLP: [intermediate*2, hidden] → gate + up
prefix = hf_name.replace('gate_up_proj.weight', '')
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
else:
# Pass through unchanged
hf_params[hf_name] = tensor
return hf_params
# ---------------------------------------------------------------------------
# Safetensors file handling
# ---------------------------------------------------------------------------
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
"""Map tensor names to safetensors filenames.
For sharded models, reads model.safetensors.index.json.
For single-file models, returns empty dict (default to model.safetensors).
"""
index_path = model_dir / "model.safetensors.index.json"
if not index_path.exists():
return {}
with open(index_path) as f:
index = json.load(f)
return dict(index.get("weight_map", {}))
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
"""Parse safetensors file header.
Returns (data_start_offset, header_dict).
Header dict maps tensor names to metadata including 'data_offsets'.
"""
header_size = struct.unpack('<Q', data[:8])[0]
header = json.loads(bytes(data[8:8 + header_size]))
return 8 + header_size, header
# ---------------------------------------------------------------------------
# Block-level diffing and sync
# ---------------------------------------------------------------------------
def sync_tensor_to_mmap(
mm: mmap.mmap,
name: str,
tensor: torch.Tensor,
data_start: int,
offsets: List[int],
block_size: int,
) -> Tuple[int, int]:
"""Sync a single tensor to mmap'd file using block-level diffing.
Returns (bytes_compared, bytes_changed).
"""
start = data_start + offsets[0]
end = data_start + offsets[1]
disk_len = end - start
# Transfer tensor to CPU and get raw bytes
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
try:
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
except Exception as e:
logger.warning(f"Failed to transfer {name} to CPU: {e}")
return 0, 0
if len(live_bytes) != disk_len:
logger.warning(
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
)
return 0, 0
# Block-level diff: compare and write only changed blocks
compared = 0
changed = 0
offset = 0
while offset < disk_len:
block_end = min(offset + block_size, disk_len)
block_len = block_end - offset
disk_block = mm[start + offset:start + block_end]
live_block = live_bytes[offset:block_end]
compared += block_len
if disk_block != live_block:
mm[start + offset:start + block_end] = live_block
changed += block_len
offset = block_end
return compared, changed
def sync_file(
file_path: Path,
tensors: Dict[str, torch.Tensor],
block_size: int,
) -> Tuple[int, int, int, int]:
"""Sync tensors to a single safetensors file.
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
"""
with open(file_path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
try:
data_start, header = parse_safetensors_header(memoryview(mm))
total_compared = 0
total_changed = 0
found = 0
missing = 0
for name, tensor in tensors.items():
if name == "__metadata__":
continue
if name not in header:
missing += 1
continue
found += 1
meta = header[name]
offsets = meta['data_offsets']
compared, changed = sync_tensor_to_mmap(
mm, name, tensor, data_start, offsets, block_size
)
total_compared += compared
total_changed += changed
# Flush changes to disk
if total_changed > 0:
mm.flush()
return total_compared, total_changed, found, missing
finally:
mm.close()
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
"""Load vLLM weight tensors from CUDA IPC handles.
The handles file is written by vllm_export_hook.py on vLLM startup.
Each handle can be used to reconstruct a tensor pointing to vLLM's
GPU memory — no copy, direct access.
"""
handles = torch.load(handles_path, weights_only=False)
# Skip metadata entry
handles.pop('__metadata__', None)
weights = {}
for name, info in handles.items():
func, args = info['handle']
try:
weights[name] = func(*args)
except Exception as e:
logger.warning(f"Failed to reconstruct {name}: {e}")
return weights
def checkpoint_sync(
model_dir: str,
handles_path: str = DEFAULT_HANDLES_PATH,
block_size: int = DEFAULT_BLOCK_SIZE,
) -> Dict[str, Any]:
"""Sync live GPU weights to model safetensors files.
This is the main entry point. Call this after training steps
or periodically to checkpoint weights without full serialization.
Args:
model_dir: Directory containing safetensors files
handles_path: Path to vLLM weight IPC handles file
block_size: Block size for diffing (default 4KB)
Returns:
Dict with sync statistics:
- total_compared: bytes compared
- total_changed: bytes actually written
- files_changed: list of modified filenames
- tensors_synced: number of tensors processed
- tensors_missing: tensors not found in safetensors
"""
model_dir = Path(model_dir)
if not Path(handles_path).exists():
raise FileNotFoundError(
f"Weight handles not found: {handles_path}. "
"Is vLLM running with the export hook?"
)
# Step 1: Load live weights from GPU via IPC
logger.info("Loading live weights from GPU...")
vllm_weights = load_vllm_weights(handles_path)
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
# Step 2: Convert to HF naming/layout
hf_weights = vllm_to_hf_tensors(vllm_weights)
logger.info(f" Converted to {len(hf_weights)} HF tensors")
# Step 3: Map tensors to safetensors files
weight_map = read_safetensors_index(model_dir)
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
unmapped = []
for name, tensor in hf_weights.items():
filename = weight_map.get(name)
if filename is None:
# Single-file model or missing from index
if (model_dir / "model.safetensors").exists():
filename = "model.safetensors"
else:
unmapped.append(name)
continue
by_file.setdefault(filename, {})[name] = tensor
if unmapped:
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
# Step 4: Sync each file
total_compared = 0
total_changed = 0
total_found = 0
total_missing = 0
files_changed = []
for filename in sorted(by_file.keys()):
tensors = by_file[filename]
file_path = model_dir / filename
if not file_path.exists():
logger.warning(f" File not found: {filename}")
total_missing += len(tensors)
continue
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
total_compared += compared
total_changed += changed
total_found += found
total_missing += missing
if changed > 0:
files_changed.append(filename)
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
# Summary
if total_changed == 0:
logger.info("No changes - model files are up to date")
else:
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
logger.info(
f"Synced: {total_changed / 1e6:.2f} MB changed / "
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
)
if total_missing > 0:
logger.warning(f" {total_missing} tensors not found in safetensors files")
return {
"total_compared": total_compared,
"total_changed": total_changed,
"files_changed": files_changed,
"tensors_synced": total_found,
"tensors_missing": total_missing,
}
# ---------------------------------------------------------------------------
# Diagnostics
# ---------------------------------------------------------------------------
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
"""Print diagnostic info about weight name mappings.
Useful for debugging mismatches between vLLM and safetensors names.
"""
model_dir = Path(model_dir)
# Load and convert vLLM weights
vllm_weights = load_vllm_weights(handles_path)
hf_weights = vllm_to_hf_tensors(vllm_weights)
hf_names = set(hf_weights.keys())
# Read safetensors index
weight_map = read_safetensors_index(model_dir)
disk_names = set(weight_map.keys())
# If single-file model, parse that file's header
if not disk_names:
st_path = model_dir / "model.safetensors"
if st_path.exists():
with open(st_path, 'rb') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
_, header = parse_safetensors_header(memoryview(mm))
disk_names = {k for k in header.keys() if k != "__metadata__"}
mm.close()
print(f"vLLM tensors (raw): {len(vllm_weights)}")
print(f"HF tensors (converted): {len(hf_names)}")
print(f"Disk tensors: {len(disk_names)}")
print()
in_both = hf_names & disk_names
only_hf = hf_names - disk_names
only_disk = disk_names - hf_names
print(f"Matched: {len(in_both)}")
print(f"Only in HF (won't sync): {len(only_hf)}")
print(f"Only on disk (not updated): {len(only_disk)}")
if only_hf:
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
if only_disk:
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser(
description="Sync live GPU weights to safetensors files"
)
subparsers = parser.add_subparsers(dest="command", help="Command")
# sync command
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
sync_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
sync_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
sync_parser.add_argument(
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
)
sync_parser.add_argument(
"-v", "--verbose", action="store_true",
help="Verbose output"
)
# diagnose command
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
diag_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
diag_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(1)
logging.basicConfig(
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
format='%(message)s'
)
try:
if args.command == "sync":
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
print(json.dumps(result, indent=2))
elif args.command == "diagnose":
diagnose(args.model_dir, args.handles)
except FileNotFoundError as e:
logger.error(str(e))
sys.exit(1)
except Exception as e:
logger.exception(f"Failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()