// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff. // // mmaps the previous checkpoint, reads live weights from GPU via a // Python helper (CUDA IPC handles), compares block by block, and only // writes changed regions. For small behavioral training steps, this // turns a 54GB write into a few hundred MB. // // Usage: // apollo-checkpoint save \ // --handles /tmp/vllm_weight_handles.pt \ // --checkpoint-dir /home/ubuntu/checkpoints \ // --block-size 4096 // // Runs every 10 minutes via cron to protect against vLLM crashes. use anyhow::{Context, Result, bail}; use chrono::Utc; use clap::{Parser, Subcommand}; use memmap2::MmapOptions; use std::collections::HashMap; use std::fs; use std::io::Write; use std::path::{Path, PathBuf}; use std::process::Command; #[derive(Parser)] #[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")] struct Cli { #[command(subcommand)] command: Cmd, } #[derive(Subcommand)] enum Cmd { /// Save a checkpoint (diff against previous, write only changes) Save { /// Path to vLLM weight IPC handles #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] handles: PathBuf, /// Checkpoint directory #[arg(long, default_value = "/home/ubuntu/checkpoints")] checkpoint_dir: PathBuf, /// Block size for diffing (bytes) #[arg(long, default_value_t = 4096)] block_size: usize, }, /// List checkpoints List { #[arg(long, default_value = "/home/ubuntu/checkpoints")] checkpoint_dir: PathBuf, }, } /// Dump live GPU weights to a flat binary file via Python helper. /// /// The Python script imports the CUDA IPC handles and saves each /// tensor's raw bytes to a flat file, plus a JSON index mapping /// parameter names to (offset, size, shape, dtype). fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result> { let index_path = output_path.with_extension("json"); let status = Command::new("python3") .arg("-c") .arg(format!(r#" import torch, json handles = torch.load("{}", weights_only=False) index = {{}} offset = 0 with open("{}", "wb") as f: for name, info in sorted(handles.items()): func, args = info["handle"] tensor = func(*args) data = tensor.contiguous().cpu().numpy().tobytes() f.write(data) index[name] = {{ "offset": offset, "size": len(data), "shape": list(tensor.shape), "dtype": str(tensor.dtype), }} offset += len(data) with open("{}", "w") as f: json.dump(index, f) print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") "#, handles_path.display(), output_path.display(), index_path.display(), )) .status() .context("Failed to run Python weight dump")?; if !status.success() { bail!("Python weight dump failed"); } // Read the index let index_str = fs::read_to_string(&index_path) .context("Failed to read weight index")?; let index: HashMap = serde_json::from_str(&index_str)?; Ok(index) } #[derive(serde::Deserialize, serde::Serialize, Clone)] struct TensorMeta { offset: usize, size: usize, shape: Vec, dtype: String, } /// Diff two flat binary files block by block, return changed byte ranges. fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> { assert_eq!(old.len(), new.len(), "File sizes must match for diffing"); let mut changed = Vec::new(); let mut i = 0; while i < old.len() { let end = (i + block_size).min(old.len()); if old[i..end] != new[i..end] { // Extend contiguous changed region let start = i; while i < old.len() { let end = (i + block_size).min(old.len()); if old[i..end] == new[i..end] { break; } i = end; } changed.push((start, i)); } else { i = end; } } changed } fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> { fs::create_dir_all(&checkpoint_dir)?; let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string(); let current_dir = checkpoint_dir.join(×tamp); fs::create_dir_all(¤t_dir)?; let live_path = current_dir.join("weights.bin"); eprintln!("Dumping live weights from GPU..."); let index = dump_live_weights(&handles, &live_path)?; // Find previous checkpoint let latest_link = checkpoint_dir.join("latest"); let previous_path = if latest_link.exists() { let prev_dir = fs::read_link(&latest_link)?; let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin"); if prev_weights.exists() { Some(prev_weights) } else { None } } else { None }; if let Some(ref prev_path) = previous_path { // Diff against previous eprintln!("Diffing against previous checkpoint..."); let prev_file = fs::File::open(prev_path)?; let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? }; let live_file = fs::File::open(&live_path)?; let live_mmap = unsafe { MmapOptions::new().map(&live_file)? }; if prev_mmap.len() == live_mmap.len() { let changed = diff_blocks(&prev_mmap, &live_mmap, block_size); let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum(); let total_bytes = live_mmap.len(); eprintln!( "Changed: {:.1} MB / {:.1} GB ({:.2}%)", changed_bytes as f64 / 1e6, total_bytes as f64 / 1e9, changed_bytes as f64 / total_bytes as f64 * 100.0, ); // If nothing changed, remove the new checkpoint dir if changed.is_empty() { eprintln!("No changes — skipping checkpoint"); fs::remove_dir_all(¤t_dir)?; return Ok(()); } } else { eprintln!( "Size mismatch ({} vs {}), writing full checkpoint", prev_mmap.len(), live_mmap.len() ); } } else { eprintln!("No previous checkpoint — writing full snapshot"); } // Save index let index_path = current_dir.join("weights.json"); let index_str = serde_json::to_string_pretty(&index)?; fs::write(&index_path, index_str)?; // Save metadata let meta = serde_json::json!({ "timestamp": timestamp, "n_params": index.len(), "total_bytes": index.values().map(|m| m.size).sum::(), }); fs::write( current_dir.join("checkpoint-meta.json"), serde_json::to_string_pretty(&meta)?, )?; // Update latest symlink if latest_link.is_symlink() { fs::remove_file(&latest_link)?; } std::os::unix::fs::symlink(×tamp, &latest_link)?; let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9; eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb); Ok(()) } fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> { if !checkpoint_dir.exists() { println!("No checkpoints directory"); return Ok(()); } let latest = if checkpoint_dir.join("latest").exists() { fs::read_link(checkpoint_dir.join("latest"))? .to_string_lossy() .to_string() } else { String::new() }; let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)? .filter_map(|e| e.ok()) .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) .collect(); entries.sort_by_key(|e| e.file_name()); for entry in entries { let name = entry.file_name().to_string_lossy().to_string(); let weights = entry.path().join("weights.bin"); let size = if weights.exists() { format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9) } else { "no weights".to_string() }; let marker = if name == latest { " ← latest" } else { "" }; println!(" {} ({}){}", name, size, marker); } Ok(()) } fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { Cmd::Save { handles, checkpoint_dir, block_size } => { cmd_save(handles, checkpoint_dir, block_size) } Cmd::List { checkpoint_dir } => { cmd_list(checkpoint_dir) } } }