training: restructure as vLLM plugin package
- Convert to installable package with entry points for vLLM auto-discovery - Add checkpoint_sync.py: Python replacement for Rust checkpoint binary - Block-level diffing of safetensors files (4KB blocks) - vLLM→HF weight name conversion built-in - Scheduled 10min after training jobs (batched) - API change: /train now takes raw token IDs (context_ids + continuation_ids) - No tokenizer on training side, client owns tokenization - Remove superseded code: standalone scripts, Rust binary, tokenizer helpers Install: pip install -e ./training Then vLLM auto-loads via entry point. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
b649a11645
commit
a73bcf5ae3
15 changed files with 607 additions and 1068 deletions
500
training/apollo_plugin/checkpoint_sync.py
Normal file
500
training/apollo_plugin/checkpoint_sync.py
Normal file
|
|
@ -0,0 +1,500 @@
|
|||
"""Sync live GPU weights to safetensors files on disk.
|
||||
|
||||
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
|
||||
merged layout to HuggingFace's separate layout, diffs block-by-block
|
||||
against on-disk safetensors files, and writes only changed blocks.
|
||||
|
||||
For small behavioral training steps, this turns a 54GB checkpoint
|
||||
write into a few hundred MB of actual disk I/O.
|
||||
|
||||
Usage:
|
||||
# Sync live weights to disk
|
||||
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
|
||||
|
||||
# Debug name mapping issues
|
||||
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
|
||||
|
||||
# From Python:
|
||||
from checkpoint_sync import checkpoint_sync
|
||||
result = checkpoint_sync("/path/to/model")
|
||||
"""
|
||||
|
||||
import json
|
||||
import mmap
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
|
||||
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# vLLM → HuggingFace weight name/shape conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
# Qwen3.5-27B dimensions (could be read from config.json for generality)
|
||||
|
||||
HIDDEN = 5120
|
||||
NUM_K_HEADS = 16
|
||||
NUM_V_HEADS = 48
|
||||
HEAD_K_DIM = 128
|
||||
HEAD_V_DIM = 128
|
||||
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||
INTERMEDIATE = 17408
|
||||
|
||||
# Full attention (some layers use standard attention, not GDN)
|
||||
NUM_ATTN_HEADS = 24
|
||||
NUM_ATTN_KV_HEADS = 4
|
||||
ATTN_HEAD_DIM = 256
|
||||
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
|
||||
|
||||
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert vLLM merged weights to HF-compatible separate tensors.
|
||||
|
||||
vLLM merges certain projections for efficiency:
|
||||
- qkv_proj (full attn) → q_proj, k_proj, v_proj
|
||||
- in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z
|
||||
- in_proj_ba (GDN) → in_proj_b, in_proj_a
|
||||
- gate_up_proj (MLP) → gate_proj, up_proj
|
||||
|
||||
Returns views that share GPU memory with the original tensors.
|
||||
"""
|
||||
hf_params = {}
|
||||
|
||||
for name, tensor in vllm_params.items():
|
||||
# Strip vLLM's 'language_model.' prefix to match HF naming
|
||||
hf_name = name.removeprefix('language_model.')
|
||||
|
||||
if 'in_proj_qkvz' in name:
|
||||
# GDN layer: [key*2 + value*2, hidden] → qkv + z
|
||||
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||
split_at = KEY_DIM * 2 + VALUE_DIM
|
||||
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
|
||||
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
|
||||
|
||||
elif 'in_proj_ba' in name:
|
||||
# GDN layer: [num_v_heads*2, hidden] → b + a
|
||||
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
|
||||
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
|
||||
|
||||
elif 'qkv_proj' in name:
|
||||
# Full attention: [q + k + v, hidden] → separate
|
||||
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
|
||||
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||
|
||||
elif 'gate_up_proj' in name:
|
||||
# MLP: [intermediate*2, hidden] → gate + up
|
||||
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
|
||||
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
|
||||
|
||||
else:
|
||||
# Pass through unchanged
|
||||
hf_params[hf_name] = tensor
|
||||
|
||||
return hf_params
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Safetensors file handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
|
||||
"""Map tensor names to safetensors filenames.
|
||||
|
||||
For sharded models, reads model.safetensors.index.json.
|
||||
For single-file models, returns empty dict (default to model.safetensors).
|
||||
"""
|
||||
index_path = model_dir / "model.safetensors.index.json"
|
||||
if not index_path.exists():
|
||||
return {}
|
||||
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
|
||||
return dict(index.get("weight_map", {}))
|
||||
|
||||
|
||||
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
|
||||
"""Parse safetensors file header.
|
||||
|
||||
Returns (data_start_offset, header_dict).
|
||||
Header dict maps tensor names to metadata including 'data_offsets'.
|
||||
"""
|
||||
header_size = struct.unpack('<Q', data[:8])[0]
|
||||
header = json.loads(bytes(data[8:8 + header_size]))
|
||||
return 8 + header_size, header
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block-level diffing and sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sync_tensor_to_mmap(
|
||||
mm: mmap.mmap,
|
||||
name: str,
|
||||
tensor: torch.Tensor,
|
||||
data_start: int,
|
||||
offsets: List[int],
|
||||
block_size: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""Sync a single tensor to mmap'd file using block-level diffing.
|
||||
|
||||
Returns (bytes_compared, bytes_changed).
|
||||
"""
|
||||
start = data_start + offsets[0]
|
||||
end = data_start + offsets[1]
|
||||
disk_len = end - start
|
||||
|
||||
# Transfer tensor to CPU and get raw bytes
|
||||
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
|
||||
try:
|
||||
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to transfer {name} to CPU: {e}")
|
||||
return 0, 0
|
||||
|
||||
if len(live_bytes) != disk_len:
|
||||
logger.warning(
|
||||
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
|
||||
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
# Block-level diff: compare and write only changed blocks
|
||||
compared = 0
|
||||
changed = 0
|
||||
offset = 0
|
||||
|
||||
while offset < disk_len:
|
||||
block_end = min(offset + block_size, disk_len)
|
||||
block_len = block_end - offset
|
||||
|
||||
disk_block = mm[start + offset:start + block_end]
|
||||
live_block = live_bytes[offset:block_end]
|
||||
|
||||
compared += block_len
|
||||
|
||||
if disk_block != live_block:
|
||||
mm[start + offset:start + block_end] = live_block
|
||||
changed += block_len
|
||||
|
||||
offset = block_end
|
||||
|
||||
return compared, changed
|
||||
|
||||
|
||||
def sync_file(
|
||||
file_path: Path,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
block_size: int,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""Sync tensors to a single safetensors file.
|
||||
|
||||
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
|
||||
"""
|
||||
with open(file_path, 'r+b') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0)
|
||||
|
||||
try:
|
||||
data_start, header = parse_safetensors_header(memoryview(mm))
|
||||
|
||||
total_compared = 0
|
||||
total_changed = 0
|
||||
found = 0
|
||||
missing = 0
|
||||
|
||||
for name, tensor in tensors.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
if name not in header:
|
||||
missing += 1
|
||||
continue
|
||||
|
||||
found += 1
|
||||
meta = header[name]
|
||||
offsets = meta['data_offsets']
|
||||
|
||||
compared, changed = sync_tensor_to_mmap(
|
||||
mm, name, tensor, data_start, offsets, block_size
|
||||
)
|
||||
total_compared += compared
|
||||
total_changed += changed
|
||||
|
||||
# Flush changes to disk
|
||||
if total_changed > 0:
|
||||
mm.flush()
|
||||
|
||||
return total_compared, total_changed, found, missing
|
||||
|
||||
finally:
|
||||
mm.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load vLLM weight tensors from CUDA IPC handles.
|
||||
|
||||
The handles file is written by vllm_export_hook.py on vLLM startup.
|
||||
Each handle can be used to reconstruct a tensor pointing to vLLM's
|
||||
GPU memory — no copy, direct access.
|
||||
"""
|
||||
handles = torch.load(handles_path, weights_only=False)
|
||||
|
||||
weights = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
try:
|
||||
weights[name] = func(*args)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct {name}: {e}")
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def checkpoint_sync(
|
||||
model_dir: str,
|
||||
handles_path: str = DEFAULT_HANDLES_PATH,
|
||||
block_size: int = DEFAULT_BLOCK_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sync live GPU weights to model safetensors files.
|
||||
|
||||
This is the main entry point. Call this after training steps
|
||||
or periodically to checkpoint weights without full serialization.
|
||||
|
||||
Args:
|
||||
model_dir: Directory containing safetensors files
|
||||
handles_path: Path to vLLM weight IPC handles file
|
||||
block_size: Block size for diffing (default 4KB)
|
||||
|
||||
Returns:
|
||||
Dict with sync statistics:
|
||||
- total_compared: bytes compared
|
||||
- total_changed: bytes actually written
|
||||
- files_changed: list of modified filenames
|
||||
- tensors_synced: number of tensors processed
|
||||
- tensors_missing: tensors not found in safetensors
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
if not Path(handles_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Weight handles not found: {handles_path}. "
|
||||
"Is vLLM running with the export hook?"
|
||||
)
|
||||
|
||||
# Step 1: Load live weights from GPU via IPC
|
||||
logger.info("Loading live weights from GPU...")
|
||||
vllm_weights = load_vllm_weights(handles_path)
|
||||
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
|
||||
|
||||
# Step 2: Convert to HF naming/layout
|
||||
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||
logger.info(f" Converted to {len(hf_weights)} HF tensors")
|
||||
|
||||
# Step 3: Map tensors to safetensors files
|
||||
weight_map = read_safetensors_index(model_dir)
|
||||
|
||||
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
unmapped = []
|
||||
|
||||
for name, tensor in hf_weights.items():
|
||||
filename = weight_map.get(name)
|
||||
if filename is None:
|
||||
# Single-file model or missing from index
|
||||
if (model_dir / "model.safetensors").exists():
|
||||
filename = "model.safetensors"
|
||||
else:
|
||||
unmapped.append(name)
|
||||
continue
|
||||
by_file.setdefault(filename, {})[name] = tensor
|
||||
|
||||
if unmapped:
|
||||
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
|
||||
|
||||
# Step 4: Sync each file
|
||||
total_compared = 0
|
||||
total_changed = 0
|
||||
total_found = 0
|
||||
total_missing = 0
|
||||
files_changed = []
|
||||
|
||||
for filename in sorted(by_file.keys()):
|
||||
tensors = by_file[filename]
|
||||
file_path = model_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f" File not found: {filename}")
|
||||
total_missing += len(tensors)
|
||||
continue
|
||||
|
||||
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
|
||||
|
||||
total_compared += compared
|
||||
total_changed += changed
|
||||
total_found += found
|
||||
total_missing += missing
|
||||
|
||||
if changed > 0:
|
||||
files_changed.append(filename)
|
||||
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
|
||||
|
||||
# Summary
|
||||
if total_changed == 0:
|
||||
logger.info("No changes - model files are up to date")
|
||||
else:
|
||||
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
|
||||
logger.info(
|
||||
f"Synced: {total_changed / 1e6:.2f} MB changed / "
|
||||
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
|
||||
)
|
||||
|
||||
if total_missing > 0:
|
||||
logger.warning(f" {total_missing} tensors not found in safetensors files")
|
||||
|
||||
return {
|
||||
"total_compared": total_compared,
|
||||
"total_changed": total_changed,
|
||||
"files_changed": files_changed,
|
||||
"tensors_synced": total_found,
|
||||
"tensors_missing": total_missing,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
|
||||
"""Print diagnostic info about weight name mappings.
|
||||
|
||||
Useful for debugging mismatches between vLLM and safetensors names.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load and convert vLLM weights
|
||||
vllm_weights = load_vllm_weights(handles_path)
|
||||
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||
hf_names = set(hf_weights.keys())
|
||||
|
||||
# Read safetensors index
|
||||
weight_map = read_safetensors_index(model_dir)
|
||||
disk_names = set(weight_map.keys())
|
||||
|
||||
# If single-file model, parse that file's header
|
||||
if not disk_names:
|
||||
st_path = model_dir / "model.safetensors"
|
||||
if st_path.exists():
|
||||
with open(st_path, 'rb') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
_, header = parse_safetensors_header(memoryview(mm))
|
||||
disk_names = {k for k in header.keys() if k != "__metadata__"}
|
||||
mm.close()
|
||||
|
||||
print(f"vLLM tensors (raw): {len(vllm_weights)}")
|
||||
print(f"HF tensors (converted): {len(hf_names)}")
|
||||
print(f"Disk tensors: {len(disk_names)}")
|
||||
print()
|
||||
|
||||
in_both = hf_names & disk_names
|
||||
only_hf = hf_names - disk_names
|
||||
only_disk = disk_names - hf_names
|
||||
|
||||
print(f"Matched: {len(in_both)}")
|
||||
print(f"Only in HF (won't sync): {len(only_hf)}")
|
||||
print(f"Only on disk (not updated): {len(only_disk)}")
|
||||
|
||||
if only_hf:
|
||||
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
|
||||
if only_disk:
|
||||
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Sync live GPU weights to safetensors files"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command")
|
||||
|
||||
# sync command
|
||||
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
|
||||
sync_parser.add_argument(
|
||||
"--model-dir", required=True,
|
||||
help="Directory containing safetensors files"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
|
||||
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"-v", "--verbose", action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
|
||||
# diagnose command
|
||||
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
|
||||
diag_parser.add_argument(
|
||||
"--model-dir", required=True,
|
||||
help="Directory containing safetensors files"
|
||||
)
|
||||
diag_parser.add_argument(
|
||||
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
|
||||
format='%(message)s'
|
||||
)
|
||||
|
||||
try:
|
||||
if args.command == "sync":
|
||||
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
|
||||
print(json.dumps(result, indent=2))
|
||||
elif args.command == "diagnose":
|
||||
diagnose(args.model_dir, args.handles)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue