training: restructure as vLLM plugin package
- Convert to installable package with entry points for vLLM auto-discovery - Add checkpoint_sync.py: Python replacement for Rust checkpoint binary - Block-level diffing of safetensors files (4KB blocks) - vLLM→HF weight name conversion built-in - Scheduled 10min after training jobs (batched) - API change: /train now takes raw token IDs (context_ids + continuation_ids) - No tokenizer on training side, client owns tokenization - Remove superseded code: standalone scripts, Rust binary, tokenizer helpers Install: pip install -e ./training Then vLLM auto-loads via entry point. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
b649a11645
commit
a73bcf5ae3
15 changed files with 607 additions and 1068 deletions
17
training/apollo_plugin/__init__.py
Normal file
17
training/apollo_plugin/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""Apollo training plugin for vLLM.
|
||||||
|
|
||||||
|
Enables continuous fine-tuning alongside live inference by:
|
||||||
|
1. Exporting CUDA IPC handles for weight sharing
|
||||||
|
2. Providing a training worker daemon (/train endpoint)
|
||||||
|
3. Block-level checkpoint sync to safetensors files
|
||||||
|
|
||||||
|
Install: pip install -e /path/to/training
|
||||||
|
Then vLLM auto-loads via entry point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .export_hook import _patch_model_runner
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
"""Called by vLLM's plugin loader on startup."""
|
||||||
|
_patch_model_runner()
|
||||||
500
training/apollo_plugin/checkpoint_sync.py
Normal file
500
training/apollo_plugin/checkpoint_sync.py
Normal file
|
|
@ -0,0 +1,500 @@
|
||||||
|
"""Sync live GPU weights to safetensors files on disk.
|
||||||
|
|
||||||
|
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
|
||||||
|
merged layout to HuggingFace's separate layout, diffs block-by-block
|
||||||
|
against on-disk safetensors files, and writes only changed blocks.
|
||||||
|
|
||||||
|
For small behavioral training steps, this turns a 54GB checkpoint
|
||||||
|
write into a few hundred MB of actual disk I/O.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Sync live weights to disk
|
||||||
|
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
|
||||||
|
|
||||||
|
# Debug name mapping issues
|
||||||
|
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
|
||||||
|
|
||||||
|
# From Python:
|
||||||
|
from checkpoint_sync import checkpoint_sync
|
||||||
|
result = checkpoint_sync("/path/to/model")
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import mmap
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
|
||||||
|
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# vLLM → HuggingFace weight name/shape conversion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Qwen3.5-27B dimensions (could be read from config.json for generality)
|
||||||
|
|
||||||
|
HIDDEN = 5120
|
||||||
|
NUM_K_HEADS = 16
|
||||||
|
NUM_V_HEADS = 48
|
||||||
|
HEAD_K_DIM = 128
|
||||||
|
HEAD_V_DIM = 128
|
||||||
|
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||||
|
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||||
|
INTERMEDIATE = 17408
|
||||||
|
|
||||||
|
# Full attention (some layers use standard attention, not GDN)
|
||||||
|
NUM_ATTN_HEADS = 24
|
||||||
|
NUM_ATTN_KV_HEADS = 4
|
||||||
|
ATTN_HEAD_DIM = 256
|
||||||
|
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||||
|
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||||
|
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||||
|
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Convert vLLM merged weights to HF-compatible separate tensors.
|
||||||
|
|
||||||
|
vLLM merges certain projections for efficiency:
|
||||||
|
- qkv_proj (full attn) → q_proj, k_proj, v_proj
|
||||||
|
- in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z
|
||||||
|
- in_proj_ba (GDN) → in_proj_b, in_proj_a
|
||||||
|
- gate_up_proj (MLP) → gate_proj, up_proj
|
||||||
|
|
||||||
|
Returns views that share GPU memory with the original tensors.
|
||||||
|
"""
|
||||||
|
hf_params = {}
|
||||||
|
|
||||||
|
for name, tensor in vllm_params.items():
|
||||||
|
# Strip vLLM's 'language_model.' prefix to match HF naming
|
||||||
|
hf_name = name.removeprefix('language_model.')
|
||||||
|
|
||||||
|
if 'in_proj_qkvz' in name:
|
||||||
|
# GDN layer: [key*2 + value*2, hidden] → qkv + z
|
||||||
|
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||||
|
split_at = KEY_DIM * 2 + VALUE_DIM
|
||||||
|
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
|
||||||
|
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
|
||||||
|
|
||||||
|
elif 'in_proj_ba' in name:
|
||||||
|
# GDN layer: [num_v_heads*2, hidden] → b + a
|
||||||
|
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||||
|
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
|
||||||
|
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
|
||||||
|
|
||||||
|
elif 'qkv_proj' in name:
|
||||||
|
# Full attention: [q + k + v, hidden] → separate
|
||||||
|
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||||
|
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
|
||||||
|
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||||
|
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||||
|
|
||||||
|
elif 'gate_up_proj' in name:
|
||||||
|
# MLP: [intermediate*2, hidden] → gate + up
|
||||||
|
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||||
|
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
|
||||||
|
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Pass through unchanged
|
||||||
|
hf_params[hf_name] = tensor
|
||||||
|
|
||||||
|
return hf_params
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Safetensors file handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
|
||||||
|
"""Map tensor names to safetensors filenames.
|
||||||
|
|
||||||
|
For sharded models, reads model.safetensors.index.json.
|
||||||
|
For single-file models, returns empty dict (default to model.safetensors).
|
||||||
|
"""
|
||||||
|
index_path = model_dir / "model.safetensors.index.json"
|
||||||
|
if not index_path.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with open(index_path) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
|
||||||
|
return dict(index.get("weight_map", {}))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
|
||||||
|
"""Parse safetensors file header.
|
||||||
|
|
||||||
|
Returns (data_start_offset, header_dict).
|
||||||
|
Header dict maps tensor names to metadata including 'data_offsets'.
|
||||||
|
"""
|
||||||
|
header_size = struct.unpack('<Q', data[:8])[0]
|
||||||
|
header = json.loads(bytes(data[8:8 + header_size]))
|
||||||
|
return 8 + header_size, header
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Block-level diffing and sync
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sync_tensor_to_mmap(
|
||||||
|
mm: mmap.mmap,
|
||||||
|
name: str,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
data_start: int,
|
||||||
|
offsets: List[int],
|
||||||
|
block_size: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Sync a single tensor to mmap'd file using block-level diffing.
|
||||||
|
|
||||||
|
Returns (bytes_compared, bytes_changed).
|
||||||
|
"""
|
||||||
|
start = data_start + offsets[0]
|
||||||
|
end = data_start + offsets[1]
|
||||||
|
disk_len = end - start
|
||||||
|
|
||||||
|
# Transfer tensor to CPU and get raw bytes
|
||||||
|
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
|
||||||
|
try:
|
||||||
|
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to transfer {name} to CPU: {e}")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
if len(live_bytes) != disk_len:
|
||||||
|
logger.warning(
|
||||||
|
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
|
||||||
|
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
|
||||||
|
)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
# Block-level diff: compare and write only changed blocks
|
||||||
|
compared = 0
|
||||||
|
changed = 0
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
while offset < disk_len:
|
||||||
|
block_end = min(offset + block_size, disk_len)
|
||||||
|
block_len = block_end - offset
|
||||||
|
|
||||||
|
disk_block = mm[start + offset:start + block_end]
|
||||||
|
live_block = live_bytes[offset:block_end]
|
||||||
|
|
||||||
|
compared += block_len
|
||||||
|
|
||||||
|
if disk_block != live_block:
|
||||||
|
mm[start + offset:start + block_end] = live_block
|
||||||
|
changed += block_len
|
||||||
|
|
||||||
|
offset = block_end
|
||||||
|
|
||||||
|
return compared, changed
|
||||||
|
|
||||||
|
|
||||||
|
def sync_file(
|
||||||
|
file_path: Path,
|
||||||
|
tensors: Dict[str, torch.Tensor],
|
||||||
|
block_size: int,
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""Sync tensors to a single safetensors file.
|
||||||
|
|
||||||
|
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r+b') as f:
|
||||||
|
mm = mmap.mmap(f.fileno(), 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_start, header = parse_safetensors_header(memoryview(mm))
|
||||||
|
|
||||||
|
total_compared = 0
|
||||||
|
total_changed = 0
|
||||||
|
found = 0
|
||||||
|
missing = 0
|
||||||
|
|
||||||
|
for name, tensor in tensors.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name not in header:
|
||||||
|
missing += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
found += 1
|
||||||
|
meta = header[name]
|
||||||
|
offsets = meta['data_offsets']
|
||||||
|
|
||||||
|
compared, changed = sync_tensor_to_mmap(
|
||||||
|
mm, name, tensor, data_start, offsets, block_size
|
||||||
|
)
|
||||||
|
total_compared += compared
|
||||||
|
total_changed += changed
|
||||||
|
|
||||||
|
# Flush changes to disk
|
||||||
|
if total_changed > 0:
|
||||||
|
mm.flush()
|
||||||
|
|
||||||
|
return total_compared, total_changed, found, missing
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mm.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load vLLM weight tensors from CUDA IPC handles.
|
||||||
|
|
||||||
|
The handles file is written by vllm_export_hook.py on vLLM startup.
|
||||||
|
Each handle can be used to reconstruct a tensor pointing to vLLM's
|
||||||
|
GPU memory — no copy, direct access.
|
||||||
|
"""
|
||||||
|
handles = torch.load(handles_path, weights_only=False)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
try:
|
||||||
|
weights[name] = func(*args)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconstruct {name}: {e}")
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_sync(
|
||||||
|
model_dir: str,
|
||||||
|
handles_path: str = DEFAULT_HANDLES_PATH,
|
||||||
|
block_size: int = DEFAULT_BLOCK_SIZE,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Sync live GPU weights to model safetensors files.
|
||||||
|
|
||||||
|
This is the main entry point. Call this after training steps
|
||||||
|
or periodically to checkpoint weights without full serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory containing safetensors files
|
||||||
|
handles_path: Path to vLLM weight IPC handles file
|
||||||
|
block_size: Block size for diffing (default 4KB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with sync statistics:
|
||||||
|
- total_compared: bytes compared
|
||||||
|
- total_changed: bytes actually written
|
||||||
|
- files_changed: list of modified filenames
|
||||||
|
- tensors_synced: number of tensors processed
|
||||||
|
- tensors_missing: tensors not found in safetensors
|
||||||
|
"""
|
||||||
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
|
if not Path(handles_path).exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Weight handles not found: {handles_path}. "
|
||||||
|
"Is vLLM running with the export hook?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Load live weights from GPU via IPC
|
||||||
|
logger.info("Loading live weights from GPU...")
|
||||||
|
vllm_weights = load_vllm_weights(handles_path)
|
||||||
|
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
|
||||||
|
|
||||||
|
# Step 2: Convert to HF naming/layout
|
||||||
|
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||||
|
logger.info(f" Converted to {len(hf_weights)} HF tensors")
|
||||||
|
|
||||||
|
# Step 3: Map tensors to safetensors files
|
||||||
|
weight_map = read_safetensors_index(model_dir)
|
||||||
|
|
||||||
|
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
|
unmapped = []
|
||||||
|
|
||||||
|
for name, tensor in hf_weights.items():
|
||||||
|
filename = weight_map.get(name)
|
||||||
|
if filename is None:
|
||||||
|
# Single-file model or missing from index
|
||||||
|
if (model_dir / "model.safetensors").exists():
|
||||||
|
filename = "model.safetensors"
|
||||||
|
else:
|
||||||
|
unmapped.append(name)
|
||||||
|
continue
|
||||||
|
by_file.setdefault(filename, {})[name] = tensor
|
||||||
|
|
||||||
|
if unmapped:
|
||||||
|
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
|
||||||
|
|
||||||
|
# Step 4: Sync each file
|
||||||
|
total_compared = 0
|
||||||
|
total_changed = 0
|
||||||
|
total_found = 0
|
||||||
|
total_missing = 0
|
||||||
|
files_changed = []
|
||||||
|
|
||||||
|
for filename in sorted(by_file.keys()):
|
||||||
|
tensors = by_file[filename]
|
||||||
|
file_path = model_dir / filename
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.warning(f" File not found: {filename}")
|
||||||
|
total_missing += len(tensors)
|
||||||
|
continue
|
||||||
|
|
||||||
|
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
|
||||||
|
|
||||||
|
total_compared += compared
|
||||||
|
total_changed += changed
|
||||||
|
total_found += found
|
||||||
|
total_missing += missing
|
||||||
|
|
||||||
|
if changed > 0:
|
||||||
|
files_changed.append(filename)
|
||||||
|
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
if total_changed == 0:
|
||||||
|
logger.info("No changes - model files are up to date")
|
||||||
|
else:
|
||||||
|
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
|
||||||
|
logger.info(
|
||||||
|
f"Synced: {total_changed / 1e6:.2f} MB changed / "
|
||||||
|
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_missing > 0:
|
||||||
|
logger.warning(f" {total_missing} tensors not found in safetensors files")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_compared": total_compared,
|
||||||
|
"total_changed": total_changed,
|
||||||
|
"files_changed": files_changed,
|
||||||
|
"tensors_synced": total_found,
|
||||||
|
"tensors_missing": total_missing,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Diagnostics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
|
||||||
|
"""Print diagnostic info about weight name mappings.
|
||||||
|
|
||||||
|
Useful for debugging mismatches between vLLM and safetensors names.
|
||||||
|
"""
|
||||||
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
|
# Load and convert vLLM weights
|
||||||
|
vllm_weights = load_vllm_weights(handles_path)
|
||||||
|
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||||
|
hf_names = set(hf_weights.keys())
|
||||||
|
|
||||||
|
# Read safetensors index
|
||||||
|
weight_map = read_safetensors_index(model_dir)
|
||||||
|
disk_names = set(weight_map.keys())
|
||||||
|
|
||||||
|
# If single-file model, parse that file's header
|
||||||
|
if not disk_names:
|
||||||
|
st_path = model_dir / "model.safetensors"
|
||||||
|
if st_path.exists():
|
||||||
|
with open(st_path, 'rb') as f:
|
||||||
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
_, header = parse_safetensors_header(memoryview(mm))
|
||||||
|
disk_names = {k for k in header.keys() if k != "__metadata__"}
|
||||||
|
mm.close()
|
||||||
|
|
||||||
|
print(f"vLLM tensors (raw): {len(vllm_weights)}")
|
||||||
|
print(f"HF tensors (converted): {len(hf_names)}")
|
||||||
|
print(f"Disk tensors: {len(disk_names)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
in_both = hf_names & disk_names
|
||||||
|
only_hf = hf_names - disk_names
|
||||||
|
only_disk = disk_names - hf_names
|
||||||
|
|
||||||
|
print(f"Matched: {len(in_both)}")
|
||||||
|
print(f"Only in HF (won't sync): {len(only_hf)}")
|
||||||
|
print(f"Only on disk (not updated): {len(only_disk)}")
|
||||||
|
|
||||||
|
if only_hf:
|
||||||
|
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
|
||||||
|
if only_disk:
|
||||||
|
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Sync live GPU weights to safetensors files"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Command")
|
||||||
|
|
||||||
|
# sync command
|
||||||
|
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--model-dir", required=True,
|
||||||
|
help="Directory containing safetensors files"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||||
|
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
|
||||||
|
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"-v", "--verbose", action="store_true",
|
||||||
|
help="Verbose output"
|
||||||
|
)
|
||||||
|
|
||||||
|
# diagnose command
|
||||||
|
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
|
||||||
|
diag_parser.add_argument(
|
||||||
|
"--model-dir", required=True,
|
||||||
|
help="Directory containing safetensors files"
|
||||||
|
)
|
||||||
|
diag_parser.add_argument(
|
||||||
|
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||||
|
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command is None:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
|
||||||
|
format='%(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.command == "sync":
|
||||||
|
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
elif args.command == "diagnose":
|
||||||
|
diagnose(args.model_dir, args.handles)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,17 +1,12 @@
|
||||||
"""Monkey-patch vLLM to export weight IPC handles on startup.
|
"""Monkey-patch vLLM to export weight IPC handles on startup.
|
||||||
|
|
||||||
Usage — add to start_vllm.sh BEFORE the vllm serve command:
|
Usage — install the apollo_plugin package:
|
||||||
|
|
||||||
export VLLM_PLUGINS=vllm_export_hook
|
pip install -e /path/to/training
|
||||||
vllm serve Qwen/Qwen3.5-27B ...
|
|
||||||
|
|
||||||
Or use Python to launch vLLM with the hook:
|
Then vLLM auto-discovers and loads via entry point. Or filter:
|
||||||
|
|
||||||
python3 -c "
|
VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ...
|
||||||
import vllm_export_hook # installs the patch
|
|
||||||
from vllm.entrypoints.openai.api_server import run_server
|
|
||||||
run_server(...)
|
|
||||||
"
|
|
||||||
|
|
||||||
The hook patches vLLM's model runner to export IPC handles after
|
The hook patches vLLM's model runner to export IPC handles after
|
||||||
model loading completes. The handles are saved to a file that the
|
model loading completes. The handles are saved to a file that the
|
||||||
|
|
@ -70,7 +65,3 @@ def _patch_model_runner():
|
||||||
|
|
||||||
gpu_worker.Worker.load_model = patched_load
|
gpu_worker.Worker.load_model = patched_load
|
||||||
print("[apollo] Weight export hook installed")
|
print("[apollo] Weight export hook installed")
|
||||||
|
|
||||||
|
|
||||||
# Auto-install when imported
|
|
||||||
_patch_model_runner()
|
|
||||||
|
|
@ -74,6 +74,9 @@ class TrainingJob:
|
||||||
'error': self.error,
|
'error': self.error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
class ApolloWorker:
|
class ApolloWorker:
|
||||||
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
||||||
self.config = self._load_config(config_path)
|
self.config = self._load_config(config_path)
|
||||||
|
|
@ -81,6 +84,7 @@ class ApolloWorker:
|
||||||
self.vllm_paused = False
|
self.vllm_paused = False
|
||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
|
self._checkpoint_timer: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
||||||
"""Load configuration from file or use defaults."""
|
"""Load configuration from file or use defaults."""
|
||||||
|
|
@ -233,6 +237,9 @@ class ApolloWorker:
|
||||||
|
|
||||||
logger.info(f"Training job {job.job_id} completed successfully")
|
logger.info(f"Training job {job.job_id} completed successfully")
|
||||||
|
|
||||||
|
# Schedule checkpoint sync (batched — won't duplicate if timer pending)
|
||||||
|
self.schedule_checkpoint_sync()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training job {job.job_id} failed: {e}")
|
logger.error(f"Training job {job.job_id} failed: {e}")
|
||||||
job.status = TrainingStatus.FAILED
|
job.status = TrainingStatus.FAILED
|
||||||
|
|
@ -278,6 +285,43 @@ class ApolloWorker:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to resume vLLM: {e}")
|
logger.warning(f"Failed to resume vLLM: {e}")
|
||||||
|
|
||||||
|
def schedule_checkpoint_sync(self):
|
||||||
|
"""Schedule a checkpoint sync in 10 minutes, if not already scheduled.
|
||||||
|
|
||||||
|
This batches multiple training runs into a single sync — the timer
|
||||||
|
resets only when no timer is pending.
|
||||||
|
"""
|
||||||
|
if self._checkpoint_timer is not None:
|
||||||
|
logger.debug("Checkpoint sync already scheduled, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay())
|
||||||
|
logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes")
|
||||||
|
|
||||||
|
async def _checkpoint_sync_after_delay(self):
|
||||||
|
"""Wait then sync — the actual timer task."""
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||||
|
await self._do_checkpoint_sync()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Checkpoint sync cancelled")
|
||||||
|
finally:
|
||||||
|
self._checkpoint_timer = None
|
||||||
|
|
||||||
|
async def _do_checkpoint_sync(self):
|
||||||
|
"""Execute the checkpoint sync."""
|
||||||
|
try:
|
||||||
|
from apollo_plugin.checkpoint_sync import checkpoint_sync
|
||||||
|
logger.info("Starting checkpoint sync...")
|
||||||
|
result = checkpoint_sync(
|
||||||
|
self.config['model_path'],
|
||||||
|
self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'),
|
||||||
|
)
|
||||||
|
changed_mb = result['total_changed'] / 1e6
|
||||||
|
logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Checkpoint sync failed: {e}")
|
||||||
|
|
||||||
async def load_model_for_training(self) -> nn.Module:
|
async def load_model_for_training(self) -> nn.Module:
|
||||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||||
|
|
||||||
|
|
@ -299,22 +343,24 @@ class ApolloWorker:
|
||||||
logger.info(f"Imported {len(vllm_params)} parameters")
|
logger.info(f"Imported {len(vllm_params)} parameters")
|
||||||
|
|
||||||
# Map vLLM merged layout → HF separate layout (views, no copies)
|
# Map vLLM merged layout → HF separate layout (views, no copies)
|
||||||
from weight_mapping import load_hf_model_with_vllm_weights
|
from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights
|
||||||
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
||||||
logger.info("HF model constructed with vLLM weight views")
|
logger.info("HF model constructed with vLLM weight views")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def run_apollo_training(self, model: nn.Module,
|
async def run_apollo_training(self, model: nn.Module,
|
||||||
samples: List[Dict[str, str]],
|
samples: List[Dict[str, Any]],
|
||||||
config: Dict[str, Any]) -> List[float]:
|
config: Dict[str, Any]) -> List[float]:
|
||||||
"""Run Apollo-Mini training on conversation decision points."""
|
"""Run Apollo-Mini training on conversation decision points.
|
||||||
from apollo_mini import Apollo
|
|
||||||
from transformers import AutoTokenizer
|
Each sample has:
|
||||||
|
context_ids: token IDs for frozen context (no gradients)
|
||||||
|
continuation_ids: token IDs for the decision we're training on
|
||||||
|
"""
|
||||||
|
from apollo_plugin.optimizer import Apollo
|
||||||
|
|
||||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.config['model_path'], trust_remote_code=True)
|
|
||||||
|
|
||||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||||
apollo_params, standard_params = [], []
|
apollo_params, standard_params = [], []
|
||||||
|
|
@ -340,12 +386,10 @@ class ApolloWorker:
|
||||||
loss_history = []
|
loss_history = []
|
||||||
|
|
||||||
for i, sample in enumerate(samples):
|
for i, sample in enumerate(samples):
|
||||||
context = sample.get('context', '')
|
# context_ids: frozen (forward only, no gradients)
|
||||||
continuation = sample.get('continuation', '')
|
# continuation_ids: the decision we're training on
|
||||||
|
ctx_ids = sample['context_ids']
|
||||||
# Tokenize
|
cont_ids = sample['continuation_ids']
|
||||||
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
|
|
||||||
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
|
||||||
all_ids = ctx_ids + cont_ids
|
all_ids = ctx_ids + cont_ids
|
||||||
context_len = len(ctx_ids)
|
context_len = len(ctx_ids)
|
||||||
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "apollo-checkpoint"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2024"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
memmap2 = "0.9"
|
|
||||||
safetensors = "0.5"
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json = "1"
|
|
||||||
anyhow = "1"
|
|
||||||
clap = { version = "4", features = ["derive"] }
|
|
||||||
|
|
@ -1,265 +0,0 @@
|
||||||
// apollo-checkpoint — Sync live GPU weights back to model files on disk.
|
|
||||||
//
|
|
||||||
// mmaps the model's safetensors files, reads live weights from GPU via
|
|
||||||
// Python helper (CUDA IPC handles), compares block by block, and memcpys
|
|
||||||
// only changed regions back into the mmap. For small behavioral training
|
|
||||||
// steps, this turns a 54GB write into a few hundred MB.
|
|
||||||
//
|
|
||||||
// The model files on disk are the checkpoint. No separate checkpoint
|
|
||||||
// directory — just keep the model up to date.
|
|
||||||
//
|
|
||||||
// Usage:
|
|
||||||
// apollo-checkpoint sync \
|
|
||||||
// --handles /tmp/vllm_weight_handles.pt \
|
|
||||||
// --model-dir /path/to/Qwen3.5-27B
|
|
||||||
//
|
|
||||||
// Runs every 10 minutes via cron. Daily rsync to moria.
|
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
|
||||||
use clap::{Parser, Subcommand};
|
|
||||||
use memmap2::MmapMut;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fs;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
|
|
||||||
struct Cli {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Cmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
|
||||||
enum Cmd {
|
|
||||||
/// Sync live GPU weights back to model safetensors files
|
|
||||||
Sync {
|
|
||||||
/// Path to vLLM weight IPC handles
|
|
||||||
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
|
||||||
handles: PathBuf,
|
|
||||||
|
|
||||||
/// Model directory containing safetensors files
|
|
||||||
#[arg(long)]
|
|
||||||
model_dir: PathBuf,
|
|
||||||
|
|
||||||
/// Block size for diffing (bytes)
|
|
||||||
#[arg(long, default_value_t = 4096)]
|
|
||||||
block_size: usize,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dump live GPU weights to a flat binary file, ordered by safetensors
|
|
||||||
/// file and offset to match the on-disk layout.
|
|
||||||
///
|
|
||||||
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
|
|
||||||
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
|
|
||||||
let dump_path = output_dir.join(".live_dump.bin");
|
|
||||||
let index_path = output_dir.join(".live_dump.json");
|
|
||||||
|
|
||||||
let status = Command::new("python3")
|
|
||||||
.arg("-c")
|
|
||||||
.arg(format!(r#"
|
|
||||||
import torch, json
|
|
||||||
|
|
||||||
handles = torch.load("{handles}", weights_only=False)
|
|
||||||
index = {{}}
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
with open("{dump}", "wb") as f:
|
|
||||||
for name in sorted(handles.keys()):
|
|
||||||
info = handles[name]
|
|
||||||
func, args = info["handle"]
|
|
||||||
tensor = func(*args)
|
|
||||||
data = tensor.contiguous().cpu().numpy().tobytes()
|
|
||||||
f.write(data)
|
|
||||||
index[name] = {{"offset": offset, "size": len(data)}}
|
|
||||||
offset += len(data)
|
|
||||||
|
|
||||||
with open("{index}", "w") as f:
|
|
||||||
json.dump(index, f)
|
|
||||||
|
|
||||||
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
|
||||||
"#,
|
|
||||||
handles = handles_path.display(),
|
|
||||||
dump = dump_path.display(),
|
|
||||||
index = index_path.display(),
|
|
||||||
))
|
|
||||||
.status()
|
|
||||||
.context("Failed to run Python weight dump")?;
|
|
||||||
|
|
||||||
if !status.success() {
|
|
||||||
bail!("Python weight dump failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
let index_str = fs::read_to_string(&index_path)?;
|
|
||||||
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
|
|
||||||
let dump_data = fs::read(&dump_path)?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
for (name, entry) in &index {
|
|
||||||
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up temp files
|
|
||||||
let _ = fs::remove_file(&dump_path);
|
|
||||||
let _ = fs::remove_file(&index_path);
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
|
||||||
struct DumpEntry {
|
|
||||||
offset: usize,
|
|
||||||
size: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read the safetensors index to map parameter names to files.
|
|
||||||
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
|
|
||||||
let index_path = model_dir.join("model.safetensors.index.json");
|
|
||||||
if !index_path.exists() {
|
|
||||||
// Single file model
|
|
||||||
return Ok(HashMap::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let index_str = fs::read_to_string(&index_path)?;
|
|
||||||
let index: serde_json::Value = serde_json::from_str(&index_str)?;
|
|
||||||
let weight_map = index["weight_map"]
|
|
||||||
.as_object()
|
|
||||||
.context("No weight_map in index")?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
for (name, file) in weight_map {
|
|
||||||
result.insert(name.clone(), file.as_str().unwrap().to_string());
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sync changed blocks from live weights into a mmap'd safetensors file.
|
|
||||||
/// Returns (total_bytes_compared, bytes_changed).
|
|
||||||
fn sync_tensors_to_file(
|
|
||||||
file_path: &Path,
|
|
||||||
tensors: &[(String, Vec<u8>)],
|
|
||||||
block_size: usize,
|
|
||||||
) -> Result<(usize, usize)> {
|
|
||||||
use safetensors::SafeTensors;
|
|
||||||
|
|
||||||
let file = fs::OpenOptions::new()
|
|
||||||
.read(true)
|
|
||||||
.write(true)
|
|
||||||
.open(file_path)
|
|
||||||
.with_context(|| format!("Failed to open {}", file_path.display()))?;
|
|
||||||
|
|
||||||
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
|
||||||
|
|
||||||
// Parse safetensors header to find tensor offsets
|
|
||||||
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
|
|
||||||
let header_json: serde_json::Value =
|
|
||||||
serde_json::from_slice(&mmap[8..8 + header_size])?;
|
|
||||||
let data_start = 8 + header_size;
|
|
||||||
|
|
||||||
let mut total_compared = 0usize;
|
|
||||||
let mut total_changed = 0usize;
|
|
||||||
|
|
||||||
for (name, live_data) in tensors {
|
|
||||||
let meta = match header_json.get(name) {
|
|
||||||
Some(m) => m,
|
|
||||||
None => {
|
|
||||||
eprintln!(" Warning: {} not found in {}", name, file_path.display());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let offsets = meta["data_offsets"].as_array().unwrap();
|
|
||||||
let start = data_start + offsets[0].as_u64().unwrap() as usize;
|
|
||||||
let end = data_start + offsets[1].as_u64().unwrap() as usize;
|
|
||||||
let disk_data = &mmap[start..end];
|
|
||||||
|
|
||||||
if disk_data.len() != live_data.len() {
|
|
||||||
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
|
|
||||||
name, disk_data.len(), live_data.len());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Diff block by block, memcpy only changed blocks
|
|
||||||
let mut offset = 0;
|
|
||||||
while offset < disk_data.len() {
|
|
||||||
let block_end = (offset + block_size).min(disk_data.len());
|
|
||||||
total_compared += block_end - offset;
|
|
||||||
|
|
||||||
if disk_data[offset..block_end] != live_data[offset..block_end] {
|
|
||||||
mmap[start + offset..start + block_end]
|
|
||||||
.copy_from_slice(&live_data[offset..block_end]);
|
|
||||||
total_changed += block_end - offset;
|
|
||||||
}
|
|
||||||
offset = block_end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mmap.flush()?;
|
|
||||||
Ok((total_compared, total_changed))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
|
|
||||||
if !handles.exists() {
|
|
||||||
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
|
|
||||||
handles.display());
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("Dumping live weights from GPU...");
|
|
||||||
let live_weights = dump_live_weights(&handles, &model_dir)?;
|
|
||||||
eprintln!(" {} tensors dumped", live_weights.len());
|
|
||||||
|
|
||||||
// Map parameter names to safetensors files
|
|
||||||
let weight_map = read_safetensors_index(&model_dir)?;
|
|
||||||
|
|
||||||
// Group tensors by safetensors file
|
|
||||||
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
|
|
||||||
for (name, data) in live_weights {
|
|
||||||
let file = weight_map
|
|
||||||
.get(&name)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_else(|| "model.safetensors".to_string());
|
|
||||||
by_file.entry(file).or_default().push((name, data));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut total_compared = 0usize;
|
|
||||||
let mut total_changed = 0usize;
|
|
||||||
|
|
||||||
for (filename, tensors) in &by_file {
|
|
||||||
let file_path = model_dir.join(filename);
|
|
||||||
if !file_path.exists() {
|
|
||||||
eprintln!(" Warning: {} not found, skipping", filename);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
|
|
||||||
total_compared += compared;
|
|
||||||
total_changed += changed;
|
|
||||||
|
|
||||||
if changed > 0 {
|
|
||||||
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if total_changed == 0 {
|
|
||||||
eprintln!("No changes — model files are up to date");
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
|
|
||||||
total_changed as f64 / 1e6,
|
|
||||||
total_compared as f64 / 1e9,
|
|
||||||
total_changed as f64 / total_compared as f64 * 100.0,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
let cli = Cli::parse();
|
|
||||||
match cli.command {
|
|
||||||
Cmd::Sync { handles, model_dir, block_size } => {
|
|
||||||
cmd_sync(handles, model_dir, block_size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,87 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Export vLLM's live model weight IPC handles for the training process.
|
|
||||||
|
|
||||||
Connects to a running vLLM instance, iterates over model parameters,
|
|
||||||
and exports CUDA IPC handles that allow another process to access the
|
|
||||||
same GPU memory without copying.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Run after vLLM is serving:
|
|
||||||
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
|
|
||||||
|
|
||||||
# Or via vLLM's API (future):
|
|
||||||
curl -X POST http://localhost:8000/export_weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def export_from_model(model, output_path: str):
|
|
||||||
"""Export IPC handles for all model parameters."""
|
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
|
||||||
|
|
||||||
handles = {}
|
|
||||||
total_bytes = 0
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
handle = reduce_tensor(param.data)
|
|
||||||
handles[name] = {
|
|
||||||
'handle': handle,
|
|
||||||
'shape': list(param.shape),
|
|
||||||
'dtype': str(param.dtype),
|
|
||||||
}
|
|
||||||
param_bytes = param.nelement() * param.element_size()
|
|
||||||
total_bytes += param_bytes
|
|
||||||
|
|
||||||
torch.save(handles, output_path)
|
|
||||||
|
|
||||||
n_params = len(handles)
|
|
||||||
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
|
|
||||||
print(f"Saved to {output_path}")
|
|
||||||
return handles
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
|
|
||||||
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
|
|
||||||
help="Output path for IPC handles")
|
|
||||||
parser.add_argument("--vllm-pid", type=int, default=None,
|
|
||||||
help="vLLM worker PID (auto-detected if not specified)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# For now: load the model directly and export.
|
|
||||||
# TODO: connect to running vLLM process instead.
|
|
||||||
print("Note: This currently loads the model separately.")
|
|
||||||
print("Full integration will export from the running vLLM process.")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Detect model path from running vLLM
|
|
||||||
import subprocess
|
|
||||||
result = subprocess.run(
|
|
||||||
['ps', 'aux'], capture_output=True, text=True
|
|
||||||
)
|
|
||||||
model_path = None
|
|
||||||
for line in result.stdout.split('\n'):
|
|
||||||
if 'vllm' in line and '--model' in line:
|
|
||||||
parts = line.split()
|
|
||||||
for i, p in enumerate(parts):
|
|
||||||
if p == '--model' and i + 1 < len(parts):
|
|
||||||
model_path = parts[i + 1]
|
|
||||||
break
|
|
||||||
# Also check model_tag format
|
|
||||||
if p.startswith('--model='):
|
|
||||||
model_path = p.split('=', 1)[1]
|
|
||||||
break
|
|
||||||
|
|
||||||
if model_path:
|
|
||||||
print(f"Detected vLLM model: {model_path}")
|
|
||||||
else:
|
|
||||||
print("Could not detect running vLLM model. Specify manually.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,215 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""First real Apollo training step — ready for Kent to run.
|
|
||||||
|
|
||||||
This script:
|
|
||||||
1. Imports vLLM's live weights via CUDA IPC
|
|
||||||
2. Constructs HF model with shared memory views
|
|
||||||
3. Runs ONE forward+backward on a real training example
|
|
||||||
4. Applies ONE Apollo optimizer step
|
|
||||||
5. Verifies vLLM still works after the update
|
|
||||||
|
|
||||||
The training example is from March 30: Kent said "use vLLM's code"
|
|
||||||
and the model should have accepted instead of suggesting alternatives.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
source ~/training-env/bin/activate
|
|
||||||
python3 first_training_step.py [--dry-run]
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
|
||||||
|
|
||||||
sys.path.insert(0, '.')
|
|
||||||
from weight_mapping import vllm_to_hf_views
|
|
||||||
from apollo_mini import Apollo
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--dry-run', action='store_true',
|
|
||||||
help="Run forward+backward but don't apply the optimizer step")
|
|
||||||
parser.add_argument('--lr', type=float, default=1e-5,
|
|
||||||
help="Learning rate (default: 1e-5 = conservative)")
|
|
||||||
parser.add_argument('--rank', type=int, default=256)
|
|
||||||
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
|
|
||||||
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print("=== First Apollo Training Step ===\n")
|
|
||||||
|
|
||||||
# 1. Import vLLM weights
|
|
||||||
print("1. Importing vLLM weights via CUDA IPC...")
|
|
||||||
handles = torch.load(args.handles, weights_only=False)
|
|
||||||
vllm_params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args_h = info['handle']
|
|
||||||
vllm_params[name] = func(*args_h)
|
|
||||||
print(f" {len(vllm_params)} parameters imported")
|
|
||||||
|
|
||||||
# 2. Map to HF layout
|
|
||||||
print("2. Mapping to HF layout (zero-copy views)...")
|
|
||||||
hf_params = vllm_to_hf_views(vllm_params)
|
|
||||||
|
|
||||||
# 3. Create HF model
|
|
||||||
print("3. Creating HF model with shared weights...")
|
|
||||||
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
|
||||||
with torch.device('meta'):
|
|
||||||
model = Qwen3_5ForCausalLM(config.text_config)
|
|
||||||
|
|
||||||
replaced = 0
|
|
||||||
for name, param in list(model.named_parameters()):
|
|
||||||
if name in hf_params:
|
|
||||||
parts = name.split('.')
|
|
||||||
parent = model
|
|
||||||
for part in parts[:-1]:
|
|
||||||
parent = getattr(parent, part)
|
|
||||||
setattr(parent, parts[-1],
|
|
||||||
nn.Parameter(hf_params[name], requires_grad=True))
|
|
||||||
replaced += 1
|
|
||||||
print(f" {replaced} parameters replaced with vLLM memory views")
|
|
||||||
|
|
||||||
# 4. Load tokenizer
|
|
||||||
print("4. Loading tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
# 5. Construct training example
|
|
||||||
print("5. Constructing training example...")
|
|
||||||
|
|
||||||
# Context: conversation where Kent says to use vLLM's code
|
|
||||||
# Target: the response that accepts the direction
|
|
||||||
context = (
|
|
||||||
"<|im_start|>user\n"
|
|
||||||
"vllm has a fused kernel already, right?<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
|
|
||||||
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
|
|
||||||
"<|im_start|>user\n"
|
|
||||||
"Why wouldn't we just use that?<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The CORRECT response (accept direction, don't suggest alternatives)
|
|
||||||
continuation = (
|
|
||||||
"We should. Let me pull in their kernel and wire it into "
|
|
||||||
"our Rust orchestration. Which file should I start with?"
|
|
||||||
)
|
|
||||||
|
|
||||||
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
|
||||||
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
context_len = len(context_ids)
|
|
||||||
|
|
||||||
print(f" Context: {context_len} tokens")
|
|
||||||
print(f" Continuation: {len(continuation_ids)} tokens")
|
|
||||||
print(f" Total: {len(all_ids)} tokens")
|
|
||||||
|
|
||||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
||||||
|
|
||||||
# 6. Initialize Apollo optimizer
|
|
||||||
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
|
|
||||||
apollo_params = []
|
|
||||||
standard_params = []
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= args.rank:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({'params': apollo_params})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({'params': standard_params})
|
|
||||||
|
|
||||||
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
|
|
||||||
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
|
|
||||||
|
|
||||||
# 7. Forward pass
|
|
||||||
print("7. Forward pass...")
|
|
||||||
model.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Context-frozen: no grad for context, grad for continuation
|
|
||||||
with torch.no_grad():
|
|
||||||
ctx_output = model(input_ids[:, :context_len], use_cache=True)
|
|
||||||
past_kv = ctx_output.past_key_values
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
output = model(input_ids[:, context_len:],
|
|
||||||
past_key_values=past_kv, use_cache=False)
|
|
||||||
logits = output.logits
|
|
||||||
# Shift for next-token prediction
|
|
||||||
shift_logits = logits[:, :-1].contiguous()
|
|
||||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1),
|
|
||||||
)
|
|
||||||
print(f" Loss: {loss.item():.4f}")
|
|
||||||
|
|
||||||
# 8. Backward pass
|
|
||||||
print("8. Backward pass...")
|
|
||||||
loss.backward()
|
|
||||||
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
|
||||||
print(f" {n_grads} parameters have gradients")
|
|
||||||
|
|
||||||
# 9. Apollo step (or dry run)
|
|
||||||
if args.dry_run:
|
|
||||||
print("\n9. DRY RUN — skipping optimizer step")
|
|
||||||
print(" (run without --dry-run to apply the update)")
|
|
||||||
else:
|
|
||||||
print("9. Applying Apollo optimizer step...")
|
|
||||||
# Record a few weight norms before
|
|
||||||
sample_norms_before = {}
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if 'layers.0.' in name and p.grad is not None:
|
|
||||||
sample_norms_before[name] = p.data.norm().item()
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Check weight changes
|
|
||||||
print(" Weight changes (layer 0):")
|
|
||||||
for name, before in sample_norms_before.items():
|
|
||||||
p = dict(model.named_parameters())[name]
|
|
||||||
after = p.data.norm().item()
|
|
||||||
delta = abs(after - before)
|
|
||||||
pct = delta / before * 100 if before > 0 else 0
|
|
||||||
print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)")
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# 10. Verify vLLM still works
|
|
||||||
print("\n10. Verifying vLLM still serves...")
|
|
||||||
import subprocess
|
|
||||||
result = subprocess.run(
|
|
||||||
['curl', '-s', '--max-time', '30',
|
|
||||||
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
|
|
||||||
'-H', 'Content-Type: application/json',
|
|
||||||
'-H', 'Authorization: Bearer bcachefs-agents-2026',
|
|
||||||
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
|
|
||||||
capture_output=True, text=True, timeout=45
|
|
||||||
)
|
|
||||||
if result.returncode == 0 and 'choices' in result.stdout:
|
|
||||||
print(" vLLM still serving ✓")
|
|
||||||
else:
|
|
||||||
print(" WARNING: vLLM may not be responding")
|
|
||||||
print(f" stdout: {result.stdout[:200]}")
|
|
||||||
|
|
||||||
print("\n=== COMPLETE ===")
|
|
||||||
if args.dry_run:
|
|
||||||
print("Run without --dry-run to apply the first real training step.")
|
|
||||||
else:
|
|
||||||
print("First Apollo training step applied to vLLM's live weights.")
|
|
||||||
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
28
training/pyproject.toml
Normal file
28
training/pyproject.toml
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "apollo-plugin"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Apollo training plugin for vLLM"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
"aiohttp",
|
||||||
|
"safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["pytest"]
|
||||||
|
|
||||||
|
[project.entry-points."vllm.general_plugins"]
|
||||||
|
apollo = "apollo_plugin:register"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
apollo-worker = "apollo_plugin.worker:main"
|
||||||
|
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["apollo_plugin*"]
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Start vLLM with Apollo weight export hook.
|
|
||||||
#
|
|
||||||
# The hook patches vLLM's model runner to export CUDA IPC handles
|
|
||||||
# after loading, so the Apollo training process can share the same
|
|
||||||
# GPU memory.
|
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
|
||||||
|
|
||||||
exec python3 -c "
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, '$SCRIPT_DIR')
|
|
||||||
import vllm_export_hook # patches model runner before vLLM loads
|
|
||||||
|
|
||||||
sys.argv = ['vllm'] + sys.argv[1:]
|
|
||||||
from vllm.entrypoints.cli.main import main
|
|
||||||
main()
|
|
||||||
" serve "$@"
|
|
||||||
|
|
@ -1,269 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Nightly training process for Apollo-Mini fine-tuning.
|
|
||||||
|
|
||||||
Imports vLLM's model weights via CUDA IPC, runs context-frozen
|
|
||||||
training on flagged conversation segments, saves updated checkpoint.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python3 train.py \
|
|
||||||
--weights /tmp/vllm_weight_handles.pt \
|
|
||||||
--examples training-examples.jsonl \
|
|
||||||
--checkpoint-dir checkpoints/ \
|
|
||||||
--lr 1e-5
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from apollo_mini import ApolloMini
|
|
||||||
|
|
||||||
|
|
||||||
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
|
|
||||||
"""Import weight tensors from CUDA IPC handles."""
|
|
||||||
handles = torch.load(handle_path, weights_only=False)
|
|
||||||
params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args = info['handle']
|
|
||||||
tensor = func(*args)
|
|
||||||
params[name] = tensor
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
|
|
||||||
"""Split parameters into Apollo-Mini and standard groups.
|
|
||||||
|
|
||||||
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
|
|
||||||
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
|
|
||||||
"""
|
|
||||||
apollo_params = []
|
|
||||||
standard_params = []
|
|
||||||
|
|
||||||
for name, p in params.items():
|
|
||||||
p.requires_grad_(True)
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({
|
|
||||||
'params': apollo_params,
|
|
||||||
'name': 'apollo',
|
|
||||||
})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({
|
|
||||||
'params': standard_params,
|
|
||||||
'name': 'standard',
|
|
||||||
})
|
|
||||||
|
|
||||||
n_apollo = sum(p.nelement() for p in apollo_params)
|
|
||||||
n_standard = sum(p.nelement() for p in standard_params)
|
|
||||||
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
|
|
||||||
return groups
|
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(params, input_ids, context_len, device):
|
|
||||||
"""Run context-frozen forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: dict of name -> tensor (shared with vLLM)
|
|
||||||
input_ids: full sequence [1, seq_len]
|
|
||||||
context_len: number of context tokens (no gradient)
|
|
||||||
device: CUDA device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits for decision tokens, target ids for loss
|
|
||||||
"""
|
|
||||||
# TODO: Build proper forward model matching vLLM's weight layout.
|
|
||||||
# For now this is a placeholder — the real implementation needs
|
|
||||||
# to replicate vLLM's model architecture (merged projections,
|
|
||||||
# GDN recurrence, full attention, MLP) using the shared weights.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Forward model not yet implemented. "
|
|
||||||
"Need to build a model that matches vLLM's merged weight layout "
|
|
||||||
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
|
|
||||||
"RowParallelLinear for out_proj/down) and computes the same "
|
|
||||||
"forward pass with autograd enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(params: dict[str, torch.Tensor],
|
|
||||||
checkpoint_dir: str,
|
|
||||||
config_path: str = None):
|
|
||||||
"""Save model checkpoint in HuggingFace safetensors format.
|
|
||||||
|
|
||||||
Saves weights split across shards matching the original model layout,
|
|
||||||
archives the previous checkpoint, and updates the 'latest' symlink.
|
|
||||||
"""
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
out_dir = Path(checkpoint_dir) / date_str
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save all weights in a single safetensors file for now.
|
|
||||||
# TODO: split across shards matching HF model index for large models.
|
|
||||||
tensors = {}
|
|
||||||
for name, param in params.items():
|
|
||||||
tensors[name] = param.data.contiguous().cpu()
|
|
||||||
|
|
||||||
save_path = out_dir / "model.safetensors"
|
|
||||||
save_file(tensors, str(save_path))
|
|
||||||
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
|
|
||||||
|
|
||||||
# Copy config files if provided
|
|
||||||
if config_path:
|
|
||||||
import shutil
|
|
||||||
config_dir = Path(config_path)
|
|
||||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
|
||||||
'special_tokens_map.json', 'generation_config.json']:
|
|
||||||
src = config_dir / f
|
|
||||||
if src.exists():
|
|
||||||
shutil.copy2(src, out_dir / f)
|
|
||||||
|
|
||||||
# Update latest symlink
|
|
||||||
latest = Path(checkpoint_dir) / "latest"
|
|
||||||
if latest.is_symlink():
|
|
||||||
latest.unlink()
|
|
||||||
latest.symlink_to(date_str)
|
|
||||||
print(f"Updated {latest} -> {date_str}")
|
|
||||||
|
|
||||||
return str(out_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def train_step(params, example, optimizer, device, log_entries):
|
|
||||||
"""Run one training step on a single example.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: dict of name -> tensor
|
|
||||||
example: dict with 'input_ids', 'context_len', 'target_ids'
|
|
||||||
optimizer: ApolloMini instance
|
|
||||||
device: CUDA device
|
|
||||||
log_entries: list to append log dicts to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
loss value
|
|
||||||
"""
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
|
|
||||||
context_len = example['context_len']
|
|
||||||
|
|
||||||
# Forward pass (context frozen, decision tokens with grad)
|
|
||||||
logits, targets = forward_pass(params, input_ids, context_len, device)
|
|
||||||
|
|
||||||
# Cross-entropy loss on decision tokens
|
|
||||||
loss = torch.nn.functional.cross_entropy(
|
|
||||||
logits.view(-1, logits.shape[-1]),
|
|
||||||
targets.view(-1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Compute gradient stats before optimizer step
|
|
||||||
total_grad_norm = 0.0
|
|
||||||
for p in params.values():
|
|
||||||
if p.grad is not None:
|
|
||||||
total_grad_norm += p.grad.norm().item() ** 2
|
|
||||||
total_grad_norm = total_grad_norm ** 0.5
|
|
||||||
|
|
||||||
# Optimizer step
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Log
|
|
||||||
log_entries.append({
|
|
||||||
'example_id': example.get('id', 'unknown'),
|
|
||||||
'loss': loss.item(),
|
|
||||||
'grad_norm': total_grad_norm,
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return loss.item()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Apollo-Mini training")
|
|
||||||
parser.add_argument("--weights", required=True,
|
|
||||||
help="Path to exported weight IPC handles")
|
|
||||||
parser.add_argument("--examples", required=True,
|
|
||||||
help="Path to training examples JSONL")
|
|
||||||
parser.add_argument("--checkpoint-dir", default="checkpoints",
|
|
||||||
help="Directory for saving checkpoints")
|
|
||||||
parser.add_argument("--config-path", default=None,
|
|
||||||
help="Path to model config files (for checkpoint)")
|
|
||||||
parser.add_argument("--lr", type=float, default=1e-5,
|
|
||||||
help="Learning rate")
|
|
||||||
parser.add_argument("--warmup-steps", type=int, default=10,
|
|
||||||
help="Learning rate warmup steps")
|
|
||||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
|
||||||
parser.add_argument("--dry-run", action="store_true",
|
|
||||||
help="Load weights and validate, don't train")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"Apollo-Mini Training")
|
|
||||||
print(f" weights: {args.weights}")
|
|
||||||
print(f" examples: {args.examples}")
|
|
||||||
print(f" lr: {args.lr}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Import weights
|
|
||||||
print("Importing weights via CUDA IPC...")
|
|
||||||
params = import_weights(args.weights)
|
|
||||||
print(f" {len(params)} parameters imported")
|
|
||||||
|
|
||||||
# Make parameter groups
|
|
||||||
param_groups = make_param_groups(params)
|
|
||||||
|
|
||||||
# Initialize optimizer
|
|
||||||
optimizer = ApolloMini(param_groups, lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
warmup_steps=args.warmup_steps)
|
|
||||||
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
if args.dry_run:
|
|
||||||
print("\nDry run — weights imported and validated successfully.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load training examples
|
|
||||||
examples = []
|
|
||||||
with open(args.examples) as f:
|
|
||||||
for line in f:
|
|
||||||
examples.append(json.loads(line))
|
|
||||||
print(f" {len(examples)} training examples")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
log_entries = []
|
|
||||||
print(f"\nTraining...")
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
for i, example in enumerate(examples):
|
|
||||||
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
|
|
||||||
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
|
|
||||||
|
|
||||||
elapsed = time.time() - t0
|
|
||||||
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
|
|
||||||
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
# Save checkpoint
|
|
||||||
print("\nSaving checkpoint...")
|
|
||||||
save_checkpoint(params, args.checkpoint_dir, args.config_path)
|
|
||||||
|
|
||||||
# Save training log
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
|
|
||||||
with open(log_path, 'w') as f:
|
|
||||||
for entry in log_entries:
|
|
||||||
f.write(json.dumps(entry) + '\n')
|
|
||||||
print(f"Training log: {log_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,175 +0,0 @@
|
||||||
"""Training example construction and tokenization.
|
|
||||||
|
|
||||||
Takes raw conversation context + improved continuation, produces
|
|
||||||
tokenized tensors ready for context-frozen forward+backward.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingExample:
|
|
||||||
"""A single training example for context-frozen training."""
|
|
||||||
id: str
|
|
||||||
context: str # conversation up to decision point
|
|
||||||
continuation: str # the better response
|
|
||||||
reason: str = "" # why this is a training target
|
|
||||||
memories: list[str] = field(default_factory=list) # memories that were in context
|
|
||||||
|
|
||||||
# Computed after tokenization
|
|
||||||
input_ids: torch.Tensor | None = None
|
|
||||||
context_len: int = 0
|
|
||||||
total_len: int = 0
|
|
||||||
|
|
||||||
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
|
|
||||||
"""Tokenize context + continuation into training-ready tensors.
|
|
||||||
|
|
||||||
The chat template is applied to make the token distribution
|
|
||||||
match what the model sees during inference.
|
|
||||||
"""
|
|
||||||
# Build messages for context (everything up to the decision)
|
|
||||||
# The context should already be in chat format
|
|
||||||
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
|
|
||||||
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
|
|
||||||
|
|
||||||
self.context_len = len(context_ids)
|
|
||||||
self.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
if self.total_len > max_len:
|
|
||||||
# Truncate context from the left, keep continuation intact
|
|
||||||
excess = self.total_len - max_len
|
|
||||||
context_ids = context_ids[excess:]
|
|
||||||
self.context_len = len(context_ids)
|
|
||||||
self.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
self.input_ids = torch.tensor(all_ids, device=device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
'id': self.id,
|
|
||||||
'context': self.context,
|
|
||||||
'continuation': self.continuation,
|
|
||||||
'reason': self.reason,
|
|
||||||
'memories': self.memories,
|
|
||||||
'context_len': self.context_len,
|
|
||||||
'total_len': self.total_len,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: dict) -> 'TrainingExample':
|
|
||||||
return cls(
|
|
||||||
id=d['id'],
|
|
||||||
context=d['context'],
|
|
||||||
continuation=d['continuation'],
|
|
||||||
reason=d.get('reason', ''),
|
|
||||||
memories=d.get('memories', []),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_examples(path: str) -> list[TrainingExample]:
|
|
||||||
"""Load training examples from JSONL file."""
|
|
||||||
examples = []
|
|
||||||
with open(path) as f:
|
|
||||||
for line in f:
|
|
||||||
if line.strip():
|
|
||||||
examples.append(TrainingExample.from_dict(json.loads(line)))
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def save_examples(examples: list[TrainingExample], path: str):
|
|
||||||
"""Save training examples to JSONL file."""
|
|
||||||
with open(path, 'w') as f:
|
|
||||||
for ex in examples:
|
|
||||||
f.write(json.dumps(ex.to_dict()) + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleTokenizer:
|
|
||||||
"""Handles tokenization with the model's chat template.
|
|
||||||
|
|
||||||
Applies the same chat template that vLLM uses during inference,
|
|
||||||
so the token distribution matches what the model expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
def prepare_example(self, example: TrainingExample,
|
|
||||||
max_len: int = 8192,
|
|
||||||
device: str = "cuda:0") -> TrainingExample:
|
|
||||||
"""Tokenize an example using the chat template.
|
|
||||||
|
|
||||||
For proper training, the context should be formatted exactly
|
|
||||||
as vLLM would format it — with chat template applied.
|
|
||||||
"""
|
|
||||||
# Apply chat template to get the exact token sequence
|
|
||||||
# the model would see during inference
|
|
||||||
#
|
|
||||||
# Context: everything up to the decision point
|
|
||||||
# Continuation: the improved response
|
|
||||||
#
|
|
||||||
# We tokenize them separately to know where context ends
|
|
||||||
# and continuation begins.
|
|
||||||
context_ids = self.tokenizer.encode(
|
|
||||||
example.context, add_special_tokens=True)
|
|
||||||
continuation_ids = self.tokenizer.encode(
|
|
||||||
example.continuation, add_special_tokens=False)
|
|
||||||
|
|
||||||
example.context_len = len(context_ids)
|
|
||||||
example.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
if example.total_len > max_len:
|
|
||||||
excess = example.total_len - max_len
|
|
||||||
context_ids = context_ids[excess:]
|
|
||||||
example.context_len = len(context_ids)
|
|
||||||
example.total_len = example.context_len + len(continuation_ids)
|
|
||||||
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
example.input_ids = torch.tensor(all_ids, device=device)
|
|
||||||
return example
|
|
||||||
|
|
||||||
def prepare_from_messages(self, example_id: str,
|
|
||||||
messages: list[dict],
|
|
||||||
decision_idx: int,
|
|
||||||
better_response: str,
|
|
||||||
reason: str = "",
|
|
||||||
memories: list[str] | None = None,
|
|
||||||
max_len: int = 8192,
|
|
||||||
device: str = "cuda:0") -> TrainingExample:
|
|
||||||
"""Build a training example from a chat message list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
example_id: unique identifier
|
|
||||||
messages: list of {"role": ..., "content": ...} dicts
|
|
||||||
decision_idx: index of the assistant message to replace
|
|
||||||
better_response: the improved response text
|
|
||||||
reason: why this is a training target
|
|
||||||
memories: memory keys that were in context
|
|
||||||
max_len: maximum sequence length
|
|
||||||
device: target device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tokenized TrainingExample
|
|
||||||
"""
|
|
||||||
# Context: all messages up to (not including) the decision
|
|
||||||
context_messages = messages[:decision_idx]
|
|
||||||
context_text = self.tokenizer.apply_chat_template(
|
|
||||||
context_messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
|
|
||||||
# Build the example
|
|
||||||
example = TrainingExample(
|
|
||||||
id=example_id,
|
|
||||||
context=context_text,
|
|
||||||
continuation=better_response,
|
|
||||||
reason=reason,
|
|
||||||
memories=memories or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.prepare_example(example, max_len=max_len, device=device)
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue