"""Sync live GPU weights to safetensors files on disk. Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's merged layout to HuggingFace's separate layout, diffs block-by-block against on-disk safetensors files, and writes only changed blocks. For small behavioral training steps, this turns a 54GB checkpoint write into a few hundred MB of actual disk I/O. Usage: # Sync live weights to disk python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B # Debug name mapping issues python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B # From Python: from checkpoint_sync import checkpoint_sync result = checkpoint_sync("/path/to/model") """ import json import mmap import struct import sys from pathlib import Path from typing import Dict, List, Tuple, Any import logging import torch logger = logging.getLogger(__name__) DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt" # --------------------------------------------------------------------------- # vLLM → HuggingFace weight name/shape conversion # --------------------------------------------------------------------------- # Qwen3.5-27B dimensions (could be read from config.json for generality) HIDDEN = 5120 NUM_K_HEADS = 16 NUM_V_HEADS = 48 HEAD_K_DIM = 128 HEAD_V_DIM = 128 KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144 INTERMEDIATE = 17408 # Full attention (some layers use standard attention, not GDN) NUM_ATTN_HEADS = 24 NUM_ATTN_KV_HEADS = 4 ATTN_HEAD_DIM = 256 ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512 ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288 ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """Convert vLLM merged weights to HF-compatible separate tensors. vLLM merges certain projections for efficiency: - qkv_proj (full attn) → q_proj, k_proj, v_proj - in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z - in_proj_ba (GDN) → in_proj_b, in_proj_a - gate_up_proj (MLP) → gate_proj, up_proj Returns views that share GPU memory with the original tensors. """ hf_params = {} for name, tensor in vllm_params.items(): # Strip vLLM's 'language_model.' prefix to match HF naming hf_name = name.removeprefix('language_model.') if 'in_proj_qkvz' in name: # GDN layer: [key*2 + value*2, hidden] → qkv + z prefix = hf_name.replace('in_proj_qkvz.weight', '') split_at = KEY_DIM * 2 + VALUE_DIM hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at] hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:] elif 'in_proj_ba' in name: # GDN layer: [num_v_heads*2, hidden] → b + a prefix = hf_name.replace('in_proj_ba.weight', '') hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS] hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:] elif 'qkv_proj' in name: # Full attention: [q + k + v, hidden] → separate prefix = hf_name.replace('qkv_proj.weight', '') hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM] hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM] hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:] elif 'gate_up_proj' in name: # MLP: [intermediate*2, hidden] → gate + up prefix = hf_name.replace('gate_up_proj.weight', '') hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE] hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:] else: # Pass through unchanged hf_params[hf_name] = tensor return hf_params # --------------------------------------------------------------------------- # Safetensors file handling # --------------------------------------------------------------------------- def read_safetensors_index(model_dir: Path) -> Dict[str, str]: """Map tensor names to safetensors filenames. For sharded models, reads model.safetensors.index.json. For single-file models, returns empty dict (default to model.safetensors). """ index_path = model_dir / "model.safetensors.index.json" if not index_path.exists(): return {} with open(index_path) as f: index = json.load(f) return dict(index.get("weight_map", {})) def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]: """Parse safetensors file header. Returns (data_start_offset, header_dict). Header dict maps tensor names to metadata including 'data_offsets'. """ header_size = struct.unpack(' Tuple[int, int]: """Sync a single tensor to mmap'd file using block-level diffing. Returns (bytes_compared, bytes_changed). """ start = data_start + offsets[0] end = data_start + offsets[1] disk_len = end - start # Transfer tensor to CPU and get raw bytes # Use .detach() to avoid autograd overhead, .contiguous() for memory layout try: live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes() except Exception as e: logger.warning(f"Failed to transfer {name} to CPU: {e}") return 0, 0 if len(live_bytes) != disk_len: logger.warning( f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} " f"(shape={list(tensor.shape)}, dtype={tensor.dtype})" ) return 0, 0 # Block-level diff: compare and write only changed blocks compared = 0 changed = 0 offset = 0 while offset < disk_len: block_end = min(offset + block_size, disk_len) block_len = block_end - offset disk_block = mm[start + offset:start + block_end] live_block = live_bytes[offset:block_end] compared += block_len if disk_block != live_block: mm[start + offset:start + block_end] = live_block changed += block_len offset = block_end return compared, changed def sync_file( file_path: Path, tensors: Dict[str, torch.Tensor], block_size: int, ) -> Tuple[int, int, int, int]: """Sync tensors to a single safetensors file. Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing). """ with open(file_path, 'r+b') as f: mm = mmap.mmap(f.fileno(), 0) try: data_start, header = parse_safetensors_header(memoryview(mm)) total_compared = 0 total_changed = 0 found = 0 missing = 0 for name, tensor in tensors.items(): if name == "__metadata__": continue if name not in header: missing += 1 continue found += 1 meta = header[name] offsets = meta['data_offsets'] compared, changed = sync_tensor_to_mmap( mm, name, tensor, data_start, offsets, block_size ) total_compared += compared total_changed += changed # Flush changes to disk if total_changed > 0: mm.flush() return total_compared, total_changed, found, missing finally: mm.close() # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]: """Load vLLM weight tensors from CUDA IPC handles. The handles file is written by vllm_export_hook.py on vLLM startup. Each handle can be used to reconstruct a tensor pointing to vLLM's GPU memory — no copy, direct access. """ handles = torch.load(handles_path, weights_only=False) # Skip metadata entry handles.pop('__metadata__', None) weights = {} for name, info in handles.items(): func, args = info['handle'] try: weights[name] = func(*args) except Exception as e: logger.warning(f"Failed to reconstruct {name}: {e}") return weights def checkpoint_sync( model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH, block_size: int = DEFAULT_BLOCK_SIZE, ) -> Dict[str, Any]: """Sync live GPU weights to model safetensors files. This is the main entry point. Call this after training steps or periodically to checkpoint weights without full serialization. Args: model_dir: Directory containing safetensors files handles_path: Path to vLLM weight IPC handles file block_size: Block size for diffing (default 4KB) Returns: Dict with sync statistics: - total_compared: bytes compared - total_changed: bytes actually written - files_changed: list of modified filenames - tensors_synced: number of tensors processed - tensors_missing: tensors not found in safetensors """ model_dir = Path(model_dir) if not Path(handles_path).exists(): raise FileNotFoundError( f"Weight handles not found: {handles_path}. " "Is vLLM running with the export hook?" ) # Step 1: Load live weights from GPU via IPC logger.info("Loading live weights from GPU...") vllm_weights = load_vllm_weights(handles_path) logger.info(f" Loaded {len(vllm_weights)} vLLM tensors") # Step 2: Convert to HF naming/layout hf_weights = vllm_to_hf_tensors(vllm_weights) logger.info(f" Converted to {len(hf_weights)} HF tensors") # Step 3: Map tensors to safetensors files weight_map = read_safetensors_index(model_dir) by_file: Dict[str, Dict[str, torch.Tensor]] = {} unmapped = [] for name, tensor in hf_weights.items(): filename = weight_map.get(name) if filename is None: # Single-file model or missing from index if (model_dir / "model.safetensors").exists(): filename = "model.safetensors" else: unmapped.append(name) continue by_file.setdefault(filename, {})[name] = tensor if unmapped: logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...") # Step 4: Sync each file total_compared = 0 total_changed = 0 total_found = 0 total_missing = 0 files_changed = [] for filename in sorted(by_file.keys()): tensors = by_file[filename] file_path = model_dir / filename if not file_path.exists(): logger.warning(f" File not found: {filename}") total_missing += len(tensors) continue compared, changed, found, missing = sync_file(file_path, tensors, block_size) total_compared += compared total_changed += changed total_found += found total_missing += missing if changed > 0: files_changed.append(filename) logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)") # Summary if total_changed == 0: logger.info("No changes - model files are up to date") else: pct = (total_changed / total_compared * 100) if total_compared > 0 else 0 logger.info( f"Synced: {total_changed / 1e6:.2f} MB changed / " f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)" ) if total_missing > 0: logger.warning(f" {total_missing} tensors not found in safetensors files") return { "total_compared": total_compared, "total_changed": total_changed, "files_changed": files_changed, "tensors_synced": total_found, "tensors_missing": total_missing, } # --------------------------------------------------------------------------- # Diagnostics # --------------------------------------------------------------------------- def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH): """Print diagnostic info about weight name mappings. Useful for debugging mismatches between vLLM and safetensors names. """ model_dir = Path(model_dir) # Load and convert vLLM weights vllm_weights = load_vllm_weights(handles_path) hf_weights = vllm_to_hf_tensors(vllm_weights) hf_names = set(hf_weights.keys()) # Read safetensors index weight_map = read_safetensors_index(model_dir) disk_names = set(weight_map.keys()) # If single-file model, parse that file's header if not disk_names: st_path = model_dir / "model.safetensors" if st_path.exists(): with open(st_path, 'rb') as f: mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) _, header = parse_safetensors_header(memoryview(mm)) disk_names = {k for k in header.keys() if k != "__metadata__"} mm.close() print(f"vLLM tensors (raw): {len(vllm_weights)}") print(f"HF tensors (converted): {len(hf_names)}") print(f"Disk tensors: {len(disk_names)}") print() in_both = hf_names & disk_names only_hf = hf_names - disk_names only_disk = disk_names - hf_names print(f"Matched: {len(in_both)}") print(f"Only in HF (won't sync): {len(only_hf)}") print(f"Only on disk (not updated): {len(only_disk)}") if only_hf: print(f"\nSample HF-only: {sorted(only_hf)[:5]}") if only_disk: print(f"\nSample disk-only: {sorted(only_disk)[:5]}") # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): import argparse parser = argparse.ArgumentParser( description="Sync live GPU weights to safetensors files" ) subparsers = parser.add_subparsers(dest="command", help="Command") # sync command sync_parser = subparsers.add_parser("sync", help="Sync weights to disk") sync_parser.add_argument( "--model-dir", required=True, help="Directory containing safetensors files" ) sync_parser.add_argument( "--handles", default=DEFAULT_HANDLES_PATH, help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" ) sync_parser.add_argument( "--block-size", type=int, default=DEFAULT_BLOCK_SIZE, help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})" ) sync_parser.add_argument( "-v", "--verbose", action="store_true", help="Verbose output" ) # diagnose command diag_parser = subparsers.add_parser("diagnose", help="Check name mappings") diag_parser.add_argument( "--model-dir", required=True, help="Directory containing safetensors files" ) diag_parser.add_argument( "--handles", default=DEFAULT_HANDLES_PATH, help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" ) args = parser.parse_args() if args.command is None: parser.print_help() sys.exit(1) logging.basicConfig( level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO, format='%(message)s' ) try: if args.command == "sync": result = checkpoint_sync(args.model_dir, args.handles, args.block_size) print(json.dumps(result, indent=2)) elif args.command == "diagnose": diagnose(args.model_dir, args.handles) except FileNotFoundError as e: logger.error(str(e)) sys.exit(1) except Exception as e: logger.exception(f"Failed: {e}") sys.exit(1) if __name__ == "__main__": main()