From d0883e101b3aff2dc40aa54ed56e8f63cd685c4f Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Mon, 30 Mar 2026 22:55:23 -0400 Subject: [PATCH] checkpoint: sync live weights back into model safetensors in-place MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mmap each safetensors file, diff block-by-block against live GPU weights, memcpy only changed blocks. No separate checkpoint files — the model directory IS the checkpoint. Every 10 min via cron. --- training/checkpoint/src/main.rs | 386 +++++++++++++++----------------- 1 file changed, 185 insertions(+), 201 deletions(-) diff --git a/training/checkpoint/src/main.rs b/training/checkpoint/src/main.rs index 5829021..1ebd0df 100644 --- a/training/checkpoint/src/main.rs +++ b/training/checkpoint/src/main.rs @@ -1,30 +1,30 @@ -// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff. +// apollo-checkpoint — Sync live GPU weights back to model files on disk. // -// mmaps the previous checkpoint, reads live weights from GPU via a -// Python helper (CUDA IPC handles), compares block by block, and only -// writes changed regions. For small behavioral training steps, this -// turns a 54GB write into a few hundred MB. +// mmaps the model's safetensors files, reads live weights from GPU via +// Python helper (CUDA IPC handles), compares block by block, and memcpys +// only changed regions back into the mmap. For small behavioral training +// steps, this turns a 54GB write into a few hundred MB. +// +// The model files on disk are the checkpoint. No separate checkpoint +// directory — just keep the model up to date. // // Usage: -// apollo-checkpoint save \ +// apollo-checkpoint sync \ // --handles /tmp/vllm_weight_handles.pt \ -// --checkpoint-dir /home/ubuntu/checkpoints \ -// --block-size 4096 +// --model-dir /path/to/Qwen3.5-27B // -// Runs every 10 minutes via cron to protect against vLLM crashes. +// Runs every 10 minutes via cron. Daily rsync to moria. use anyhow::{Context, Result, bail}; -use chrono::Utc; use clap::{Parser, Subcommand}; -use memmap2::MmapOptions; +use memmap2::MmapMut; use std::collections::HashMap; use std::fs; -use std::io::Write; use std::path::{Path, PathBuf}; use std::process::Command; #[derive(Parser)] -#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")] +#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")] struct Cli { #[command(subcommand)] command: Cmd, @@ -32,67 +32,57 @@ struct Cli { #[derive(Subcommand)] enum Cmd { - /// Save a checkpoint (diff against previous, write only changes) - Save { + /// Sync live GPU weights back to model safetensors files + Sync { /// Path to vLLM weight IPC handles #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] handles: PathBuf, - /// Checkpoint directory - #[arg(long, default_value = "/home/ubuntu/checkpoints")] - checkpoint_dir: PathBuf, + /// Model directory containing safetensors files + #[arg(long)] + model_dir: PathBuf, /// Block size for diffing (bytes) #[arg(long, default_value_t = 4096)] block_size: usize, }, - - /// List checkpoints - List { - #[arg(long, default_value = "/home/ubuntu/checkpoints")] - checkpoint_dir: PathBuf, - }, } -/// Dump live GPU weights to a flat binary file via Python helper. +/// Dump live GPU weights to a flat binary file, ordered by safetensors +/// file and offset to match the on-disk layout. /// -/// The Python script imports the CUDA IPC handles and saves each -/// tensor's raw bytes to a flat file, plus a JSON index mapping -/// parameter names to (offset, size, shape, dtype). -fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result> { - let index_path = output_path.with_extension("json"); +/// Returns a map of (safetensors filename, tensor name) → raw bytes. +fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result>> { + let dump_path = output_dir.join(".live_dump.bin"); + let index_path = output_dir.join(".live_dump.json"); let status = Command::new("python3") .arg("-c") .arg(format!(r#" import torch, json -handles = torch.load("{}", weights_only=False) +handles = torch.load("{handles}", weights_only=False) index = {{}} offset = 0 -with open("{}", "wb") as f: - for name, info in sorted(handles.items()): +with open("{dump}", "wb") as f: + for name in sorted(handles.keys()): + info = handles[name] func, args = info["handle"] tensor = func(*args) data = tensor.contiguous().cpu().numpy().tobytes() f.write(data) - index[name] = {{ - "offset": offset, - "size": len(data), - "shape": list(tensor.shape), - "dtype": str(tensor.dtype), - }} + index[name] = {{"offset": offset, "size": len(data)}} offset += len(data) -with open("{}", "w") as f: +with open("{index}", "w") as f: json.dump(index, f) print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") "#, - handles_path.display(), - output_path.display(), - index_path.display(), + handles = handles_path.display(), + dump = dump_path.display(), + index = index_path.display(), )) .status() .context("Failed to run Python weight dump")?; @@ -101,168 +91,165 @@ print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") bail!("Python weight dump failed"); } - // Read the index - let index_str = fs::read_to_string(&index_path) - .context("Failed to read weight index")?; - let index: HashMap = serde_json::from_str(&index_str)?; - Ok(index) + let index_str = fs::read_to_string(&index_path)?; + let index: HashMap = serde_json::from_str(&index_str)?; + let dump_data = fs::read(&dump_path)?; + + let mut result = HashMap::new(); + for (name, entry) in &index { + result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec()); + } + + // Clean up temp files + let _ = fs::remove_file(&dump_path); + let _ = fs::remove_file(&index_path); + + Ok(result) } -#[derive(serde::Deserialize, serde::Serialize, Clone)] -struct TensorMeta { +#[derive(serde::Deserialize)] +struct DumpEntry { offset: usize, size: usize, - shape: Vec, - dtype: String, } -/// Diff two flat binary files block by block, return changed byte ranges. -fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> { - assert_eq!(old.len(), new.len(), "File sizes must match for diffing"); - let mut changed = Vec::new(); - let mut i = 0; +/// Read the safetensors index to map parameter names to files. +fn read_safetensors_index(model_dir: &Path) -> Result> { + let index_path = model_dir.join("model.safetensors.index.json"); + if !index_path.exists() { + // Single file model + return Ok(HashMap::new()); + } - while i < old.len() { - let end = (i + block_size).min(old.len()); - if old[i..end] != new[i..end] { - // Extend contiguous changed region - let start = i; - while i < old.len() { - let end = (i + block_size).min(old.len()); - if old[i..end] == new[i..end] { - break; - } - i = end; + let index_str = fs::read_to_string(&index_path)?; + let index: serde_json::Value = serde_json::from_str(&index_str)?; + let weight_map = index["weight_map"] + .as_object() + .context("No weight_map in index")?; + + let mut result = HashMap::new(); + for (name, file) in weight_map { + result.insert(name.clone(), file.as_str().unwrap().to_string()); + } + Ok(result) +} + +/// Sync changed blocks from live weights into a mmap'd safetensors file. +/// Returns (total_bytes_compared, bytes_changed). +fn sync_tensors_to_file( + file_path: &Path, + tensors: &[(String, Vec)], + block_size: usize, +) -> Result<(usize, usize)> { + use safetensors::SafeTensors; + + let file = fs::OpenOptions::new() + .read(true) + .write(true) + .open(file_path) + .with_context(|| format!("Failed to open {}", file_path.display()))?; + + let mut mmap = unsafe { MmapMut::map_mut(&file)? }; + + // Parse safetensors header to find tensor offsets + let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize; + let header_json: serde_json::Value = + serde_json::from_slice(&mmap[8..8 + header_size])?; + let data_start = 8 + header_size; + + let mut total_compared = 0usize; + let mut total_changed = 0usize; + + for (name, live_data) in tensors { + let meta = match header_json.get(name) { + Some(m) => m, + None => { + eprintln!(" Warning: {} not found in {}", name, file_path.display()); + continue; } - changed.push((start, i)); - } else { - i = end; - } - } - - changed -} - -fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> { - fs::create_dir_all(&checkpoint_dir)?; - - let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string(); - let current_dir = checkpoint_dir.join(×tamp); - fs::create_dir_all(¤t_dir)?; - - let live_path = current_dir.join("weights.bin"); - eprintln!("Dumping live weights from GPU..."); - let index = dump_live_weights(&handles, &live_path)?; - - // Find previous checkpoint - let latest_link = checkpoint_dir.join("latest"); - let previous_path = if latest_link.exists() { - let prev_dir = fs::read_link(&latest_link)?; - let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin"); - if prev_weights.exists() { - Some(prev_weights) - } else { - None - } - } else { - None - }; - - if let Some(ref prev_path) = previous_path { - // Diff against previous - eprintln!("Diffing against previous checkpoint..."); - let prev_file = fs::File::open(prev_path)?; - let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? }; - - let live_file = fs::File::open(&live_path)?; - let live_mmap = unsafe { MmapOptions::new().map(&live_file)? }; - - if prev_mmap.len() == live_mmap.len() { - let changed = diff_blocks(&prev_mmap, &live_mmap, block_size); - let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum(); - let total_bytes = live_mmap.len(); - - eprintln!( - "Changed: {:.1} MB / {:.1} GB ({:.2}%)", - changed_bytes as f64 / 1e6, - total_bytes as f64 / 1e9, - changed_bytes as f64 / total_bytes as f64 * 100.0, - ); - - // If nothing changed, remove the new checkpoint dir - if changed.is_empty() { - eprintln!("No changes — skipping checkpoint"); - fs::remove_dir_all(¤t_dir)?; - return Ok(()); - } - } else { - eprintln!( - "Size mismatch ({} vs {}), writing full checkpoint", - prev_mmap.len(), - live_mmap.len() - ); - } - } else { - eprintln!("No previous checkpoint — writing full snapshot"); - } - - // Save index - let index_path = current_dir.join("weights.json"); - let index_str = serde_json::to_string_pretty(&index)?; - fs::write(&index_path, index_str)?; - - // Save metadata - let meta = serde_json::json!({ - "timestamp": timestamp, - "n_params": index.len(), - "total_bytes": index.values().map(|m| m.size).sum::(), - }); - fs::write( - current_dir.join("checkpoint-meta.json"), - serde_json::to_string_pretty(&meta)?, - )?; - - // Update latest symlink - if latest_link.is_symlink() { - fs::remove_file(&latest_link)?; - } - std::os::unix::fs::symlink(×tamp, &latest_link)?; - - let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9; - eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb); - Ok(()) -} - -fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> { - if !checkpoint_dir.exists() { - println!("No checkpoints directory"); - return Ok(()); - } - - let latest = if checkpoint_dir.join("latest").exists() { - fs::read_link(checkpoint_dir.join("latest"))? - .to_string_lossy() - .to_string() - } else { - String::new() - }; - - let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)? - .filter_map(|e| e.ok()) - .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) - .collect(); - entries.sort_by_key(|e| e.file_name()); - - for entry in entries { - let name = entry.file_name().to_string_lossy().to_string(); - let weights = entry.path().join("weights.bin"); - let size = if weights.exists() { - format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9) - } else { - "no weights".to_string() }; - let marker = if name == latest { " ← latest" } else { "" }; - println!(" {} ({}){}", name, size, marker); + + let offsets = meta["data_offsets"].as_array().unwrap(); + let start = data_start + offsets[0].as_u64().unwrap() as usize; + let end = data_start + offsets[1].as_u64().unwrap() as usize; + let disk_data = &mmap[start..end]; + + if disk_data.len() != live_data.len() { + eprintln!(" Warning: size mismatch for {}: disk={} live={}", + name, disk_data.len(), live_data.len()); + continue; + } + + // Diff block by block, memcpy only changed blocks + let mut offset = 0; + while offset < disk_data.len() { + let block_end = (offset + block_size).min(disk_data.len()); + total_compared += block_end - offset; + + if disk_data[offset..block_end] != live_data[offset..block_end] { + mmap[start + offset..start + block_end] + .copy_from_slice(&live_data[offset..block_end]); + total_changed += block_end - offset; + } + offset = block_end; + } + } + + mmap.flush()?; + Ok((total_compared, total_changed)) +} + +fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> { + if !handles.exists() { + bail!("Weight handles not found: {}. Is vLLM running with the export hook?", + handles.display()); + } + + eprintln!("Dumping live weights from GPU..."); + let live_weights = dump_live_weights(&handles, &model_dir)?; + eprintln!(" {} tensors dumped", live_weights.len()); + + // Map parameter names to safetensors files + let weight_map = read_safetensors_index(&model_dir)?; + + // Group tensors by safetensors file + let mut by_file: HashMap)>> = HashMap::new(); + for (name, data) in live_weights { + let file = weight_map + .get(&name) + .cloned() + .unwrap_or_else(|| "model.safetensors".to_string()); + by_file.entry(file).or_default().push((name, data)); + } + + let mut total_compared = 0usize; + let mut total_changed = 0usize; + + for (filename, tensors) in &by_file { + let file_path = model_dir.join(filename); + if !file_path.exists() { + eprintln!(" Warning: {} not found, skipping", filename); + continue; + } + + let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?; + total_compared += compared; + total_changed += changed; + + if changed > 0 { + eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6); + } + } + + if total_changed == 0 { + eprintln!("No changes — model files are up to date"); + } else { + eprintln!( + "Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)", + total_changed as f64 / 1e6, + total_compared as f64 / 1e9, + total_changed as f64 / total_compared as f64 * 100.0, + ); } Ok(()) @@ -271,11 +258,8 @@ fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> { fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { - Cmd::Save { handles, checkpoint_dir, block_size } => { - cmd_save(handles, checkpoint_dir, block_size) - } - Cmd::List { checkpoint_dir } => { - cmd_list(checkpoint_dir) + Cmd::Sync { handles, model_dir, block_size } => { + cmd_sync(handles, model_dir, block_size) } } }