// apollo-checkpoint — Sync live GPU weights back to model files on disk. // // mmaps the model's safetensors files, reads live weights from GPU via // Python helper (CUDA IPC handles), compares block by block, and memcpys // only changed regions back into the mmap. For small behavioral training // steps, this turns a 54GB write into a few hundred MB. // // The model files on disk are the checkpoint. No separate checkpoint // directory — just keep the model up to date. // // Usage: // apollo-checkpoint sync \ // --handles /tmp/vllm_weight_handles.pt \ // --model-dir /path/to/Qwen3.5-27B // // Runs every 10 minutes via cron. Daily rsync to moria. use anyhow::{Context, Result, bail}; use clap::{Parser, Subcommand}; use memmap2::MmapMut; use std::collections::HashMap; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; #[derive(Parser)] #[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")] struct Cli { #[command(subcommand)] command: Cmd, } #[derive(Subcommand)] enum Cmd { /// Sync live GPU weights back to model safetensors files Sync { /// Path to vLLM weight IPC handles #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] handles: PathBuf, /// Model directory containing safetensors files #[arg(long)] model_dir: PathBuf, /// Block size for diffing (bytes) #[arg(long, default_value_t = 4096)] block_size: usize, }, } /// Dump live GPU weights to a flat binary file, ordered by safetensors /// file and offset to match the on-disk layout. /// /// Returns a map of (safetensors filename, tensor name) → raw bytes. fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result>> { let dump_path = output_dir.join(".live_dump.bin"); let index_path = output_dir.join(".live_dump.json"); let status = Command::new("python3") .arg("-c") .arg(format!(r#" import torch, json handles = torch.load("{handles}", weights_only=False) index = {{}} offset = 0 with open("{dump}", "wb") as f: for name in sorted(handles.keys()): info = handles[name] func, args = info["handle"] tensor = func(*args) data = tensor.contiguous().cpu().numpy().tobytes() f.write(data) index[name] = {{"offset": offset, "size": len(data)}} offset += len(data) with open("{index}", "w") as f: json.dump(index, f) print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") "#, handles = handles_path.display(), dump = dump_path.display(), index = index_path.display(), )) .status() .context("Failed to run Python weight dump")?; if !status.success() { bail!("Python weight dump failed"); } let index_str = fs::read_to_string(&index_path)?; let index: HashMap = serde_json::from_str(&index_str)?; let dump_data = fs::read(&dump_path)?; let mut result = HashMap::new(); for (name, entry) in &index { result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec()); } // Clean up temp files let _ = fs::remove_file(&dump_path); let _ = fs::remove_file(&index_path); Ok(result) } #[derive(serde::Deserialize)] struct DumpEntry { offset: usize, size: usize, } /// Read the safetensors index to map parameter names to files. fn read_safetensors_index(model_dir: &Path) -> Result> { let index_path = model_dir.join("model.safetensors.index.json"); if !index_path.exists() { // Single file model return Ok(HashMap::new()); } let index_str = fs::read_to_string(&index_path)?; let index: serde_json::Value = serde_json::from_str(&index_str)?; let weight_map = index["weight_map"] .as_object() .context("No weight_map in index")?; let mut result = HashMap::new(); for (name, file) in weight_map { result.insert(name.clone(), file.as_str().unwrap().to_string()); } Ok(result) } /// Sync changed blocks from live weights into a mmap'd safetensors file. /// Returns (total_bytes_compared, bytes_changed). fn sync_tensors_to_file( file_path: &Path, tensors: &[(String, Vec)], block_size: usize, ) -> Result<(usize, usize)> { use safetensors::SafeTensors; let file = fs::OpenOptions::new() .read(true) .write(true) .open(file_path) .with_context(|| format!("Failed to open {}", file_path.display()))?; let mut mmap = unsafe { MmapMut::map_mut(&file)? }; // Parse safetensors header to find tensor offsets let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize; let header_json: serde_json::Value = serde_json::from_slice(&mmap[8..8 + header_size])?; let data_start = 8 + header_size; let mut total_compared = 0usize; let mut total_changed = 0usize; for (name, live_data) in tensors { let meta = match header_json.get(name) { Some(m) => m, None => { eprintln!(" Warning: {} not found in {}", name, file_path.display()); continue; } }; let offsets = meta["data_offsets"].as_array().unwrap(); let start = data_start + offsets[0].as_u64().unwrap() as usize; let end = data_start + offsets[1].as_u64().unwrap() as usize; let disk_data = &mmap[start..end]; if disk_data.len() != live_data.len() { eprintln!(" Warning: size mismatch for {}: disk={} live={}", name, disk_data.len(), live_data.len()); continue; } // Diff block by block, memcpy only changed blocks let mut offset = 0; while offset < disk_data.len() { let block_end = (offset + block_size).min(disk_data.len()); total_compared += block_end - offset; if disk_data[offset..block_end] != live_data[offset..block_end] { mmap[start + offset..start + block_end] .copy_from_slice(&live_data[offset..block_end]); total_changed += block_end - offset; } offset = block_end; } } mmap.flush()?; Ok((total_compared, total_changed)) } fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> { if !handles.exists() { bail!("Weight handles not found: {}. Is vLLM running with the export hook?", handles.display()); } eprintln!("Dumping live weights from GPU..."); let live_weights = dump_live_weights(&handles, &model_dir)?; eprintln!(" {} tensors dumped", live_weights.len()); // Map parameter names to safetensors files let weight_map = read_safetensors_index(&model_dir)?; // Group tensors by safetensors file let mut by_file: HashMap)>> = HashMap::new(); for (name, data) in live_weights { let file = weight_map .get(&name) .cloned() .unwrap_or_else(|| "model.safetensors".to_string()); by_file.entry(file).or_default().push((name, data)); } let mut total_compared = 0usize; let mut total_changed = 0usize; for (filename, tensors) in &by_file { let file_path = model_dir.join(filename); if !file_path.exists() { eprintln!(" Warning: {} not found, skipping", filename); continue; } let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?; total_compared += compared; total_changed += changed; if changed > 0 { eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6); } } if total_changed == 0 { eprintln!("No changes — model files are up to date"); } else { eprintln!( "Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)", total_changed as f64 / 1e6, total_compared as f64 / 1e9, total_changed as f64 / total_compared as f64 * 100.0, ); } Ok(()) } fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { Cmd::Sync { handles, model_dir, block_size } => { cmd_sync(handles, model_dir, block_size) } } }