From c1245ab1399aeead5e379afab12d90f6614c426e Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Mon, 30 Mar 2026 22:53:17 -0400 Subject: [PATCH] apollo-checkpoint: efficient diff-based GPU weight checkpointing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rust tool that mmaps previous checkpoint, diffs against live GPU weights (via CUDA IPC handles), and only writes changed blocks. For small behavioral training steps, turns 54GB write into ~500MB. Also includes vllm_export_hook.py with direct source patch approach — exports IPC handles from vLLM's worker subprocess after model load. Run every 10 minutes via cron to protect against vLLM crashes. Daily rsync to moria for long-term storage. --- training/checkpoint/Cargo.toml | 13 ++ training/checkpoint/src/main.rs | 281 ++++++++++++++++++++++++++++++++ training/vllm_export_hook.py | 16 +- 3 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 training/checkpoint/Cargo.toml create mode 100644 training/checkpoint/src/main.rs diff --git a/training/checkpoint/Cargo.toml b/training/checkpoint/Cargo.toml new file mode 100644 index 0000000..c28aa9f --- /dev/null +++ b/training/checkpoint/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "apollo-checkpoint" +version = "0.1.0" +edition = "2024" + +[dependencies] +memmap2 = "0.9" +safetensors = "0.5" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +anyhow = "1" +clap = { version = "4", features = ["derive"] } +chrono = "0.4" diff --git a/training/checkpoint/src/main.rs b/training/checkpoint/src/main.rs new file mode 100644 index 0000000..5829021 --- /dev/null +++ b/training/checkpoint/src/main.rs @@ -0,0 +1,281 @@ +// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff. +// +// mmaps the previous checkpoint, reads live weights from GPU via a +// Python helper (CUDA IPC handles), compares block by block, and only +// writes changed regions. For small behavioral training steps, this +// turns a 54GB write into a few hundred MB. +// +// Usage: +// apollo-checkpoint save \ +// --handles /tmp/vllm_weight_handles.pt \ +// --checkpoint-dir /home/ubuntu/checkpoints \ +// --block-size 4096 +// +// Runs every 10 minutes via cron to protect against vLLM crashes. + +use anyhow::{Context, Result, bail}; +use chrono::Utc; +use clap::{Parser, Subcommand}; +use memmap2::MmapOptions; +use std::collections::HashMap; +use std::fs; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::process::Command; + +#[derive(Parser)] +#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")] +struct Cli { + #[command(subcommand)] + command: Cmd, +} + +#[derive(Subcommand)] +enum Cmd { + /// Save a checkpoint (diff against previous, write only changes) + Save { + /// Path to vLLM weight IPC handles + #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] + handles: PathBuf, + + /// Checkpoint directory + #[arg(long, default_value = "/home/ubuntu/checkpoints")] + checkpoint_dir: PathBuf, + + /// Block size for diffing (bytes) + #[arg(long, default_value_t = 4096)] + block_size: usize, + }, + + /// List checkpoints + List { + #[arg(long, default_value = "/home/ubuntu/checkpoints")] + checkpoint_dir: PathBuf, + }, +} + +/// Dump live GPU weights to a flat binary file via Python helper. +/// +/// The Python script imports the CUDA IPC handles and saves each +/// tensor's raw bytes to a flat file, plus a JSON index mapping +/// parameter names to (offset, size, shape, dtype). +fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result> { + let index_path = output_path.with_extension("json"); + + let status = Command::new("python3") + .arg("-c") + .arg(format!(r#" +import torch, json + +handles = torch.load("{}", weights_only=False) +index = {{}} +offset = 0 + +with open("{}", "wb") as f: + for name, info in sorted(handles.items()): + func, args = info["handle"] + tensor = func(*args) + data = tensor.contiguous().cpu().numpy().tobytes() + f.write(data) + index[name] = {{ + "offset": offset, + "size": len(data), + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + }} + offset += len(data) + +with open("{}", "w") as f: + json.dump(index, f) + +print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") +"#, + handles_path.display(), + output_path.display(), + index_path.display(), + )) + .status() + .context("Failed to run Python weight dump")?; + + if !status.success() { + bail!("Python weight dump failed"); + } + + // Read the index + let index_str = fs::read_to_string(&index_path) + .context("Failed to read weight index")?; + let index: HashMap = serde_json::from_str(&index_str)?; + Ok(index) +} + +#[derive(serde::Deserialize, serde::Serialize, Clone)] +struct TensorMeta { + offset: usize, + size: usize, + shape: Vec, + dtype: String, +} + +/// Diff two flat binary files block by block, return changed byte ranges. +fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> { + assert_eq!(old.len(), new.len(), "File sizes must match for diffing"); + let mut changed = Vec::new(); + let mut i = 0; + + while i < old.len() { + let end = (i + block_size).min(old.len()); + if old[i..end] != new[i..end] { + // Extend contiguous changed region + let start = i; + while i < old.len() { + let end = (i + block_size).min(old.len()); + if old[i..end] == new[i..end] { + break; + } + i = end; + } + changed.push((start, i)); + } else { + i = end; + } + } + + changed +} + +fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> { + fs::create_dir_all(&checkpoint_dir)?; + + let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string(); + let current_dir = checkpoint_dir.join(×tamp); + fs::create_dir_all(¤t_dir)?; + + let live_path = current_dir.join("weights.bin"); + eprintln!("Dumping live weights from GPU..."); + let index = dump_live_weights(&handles, &live_path)?; + + // Find previous checkpoint + let latest_link = checkpoint_dir.join("latest"); + let previous_path = if latest_link.exists() { + let prev_dir = fs::read_link(&latest_link)?; + let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin"); + if prev_weights.exists() { + Some(prev_weights) + } else { + None + } + } else { + None + }; + + if let Some(ref prev_path) = previous_path { + // Diff against previous + eprintln!("Diffing against previous checkpoint..."); + let prev_file = fs::File::open(prev_path)?; + let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? }; + + let live_file = fs::File::open(&live_path)?; + let live_mmap = unsafe { MmapOptions::new().map(&live_file)? }; + + if prev_mmap.len() == live_mmap.len() { + let changed = diff_blocks(&prev_mmap, &live_mmap, block_size); + let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum(); + let total_bytes = live_mmap.len(); + + eprintln!( + "Changed: {:.1} MB / {:.1} GB ({:.2}%)", + changed_bytes as f64 / 1e6, + total_bytes as f64 / 1e9, + changed_bytes as f64 / total_bytes as f64 * 100.0, + ); + + // If nothing changed, remove the new checkpoint dir + if changed.is_empty() { + eprintln!("No changes — skipping checkpoint"); + fs::remove_dir_all(¤t_dir)?; + return Ok(()); + } + } else { + eprintln!( + "Size mismatch ({} vs {}), writing full checkpoint", + prev_mmap.len(), + live_mmap.len() + ); + } + } else { + eprintln!("No previous checkpoint — writing full snapshot"); + } + + // Save index + let index_path = current_dir.join("weights.json"); + let index_str = serde_json::to_string_pretty(&index)?; + fs::write(&index_path, index_str)?; + + // Save metadata + let meta = serde_json::json!({ + "timestamp": timestamp, + "n_params": index.len(), + "total_bytes": index.values().map(|m| m.size).sum::(), + }); + fs::write( + current_dir.join("checkpoint-meta.json"), + serde_json::to_string_pretty(&meta)?, + )?; + + // Update latest symlink + if latest_link.is_symlink() { + fs::remove_file(&latest_link)?; + } + std::os::unix::fs::symlink(×tamp, &latest_link)?; + + let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9; + eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb); + Ok(()) +} + +fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> { + if !checkpoint_dir.exists() { + println!("No checkpoints directory"); + return Ok(()); + } + + let latest = if checkpoint_dir.join("latest").exists() { + fs::read_link(checkpoint_dir.join("latest"))? + .to_string_lossy() + .to_string() + } else { + String::new() + }; + + let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)? + .filter_map(|e| e.ok()) + .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) + .collect(); + entries.sort_by_key(|e| e.file_name()); + + for entry in entries { + let name = entry.file_name().to_string_lossy().to_string(); + let weights = entry.path().join("weights.bin"); + let size = if weights.exists() { + format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9) + } else { + "no weights".to_string() + }; + let marker = if name == latest { " ← latest" } else { "" }; + println!(" {} ({}){}", name, size, marker); + } + + Ok(()) +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + match cli.command { + Cmd::Save { handles, checkpoint_dir, block_size } => { + cmd_save(handles, checkpoint_dir, block_size) + } + Cmd::List { checkpoint_dir } => { + cmd_list(checkpoint_dir) + } + } +} diff --git a/training/vllm_export_hook.py b/training/vllm_export_hook.py index 8576faf..6a0bf1e 100644 --- a/training/vllm_export_hook.py +++ b/training/vllm_export_hook.py @@ -49,20 +49,26 @@ def export_model_weights(model): def _patch_model_runner(): - """Patch gpu_model_runner to export handles after load_model.""" - from vllm.v1.worker import gpu_model_runner + """Patch gpu_worker to export handles after model loading. - original_load = gpu_model_runner.GPUModelRunner.load_model + vLLM loads the model in a subprocess (EngineCore_DP0), so we + can't patch from the parent. Instead, patch the worker's + init_device or load_model at the module level — the subprocess + imports the same modules. + """ + from vllm.v1.worker import gpu_worker + + original_load = gpu_worker.Worker.load_model def patched_load(self, *args, **kwargs): result = original_load(self, *args, **kwargs) try: - export_model_weights(self.model) + export_model_weights(self.model_runner.model) except Exception as e: print(f"[apollo] Failed to export weights: {e}") return result - gpu_model_runner.GPUModelRunner.load_model = patched_load + gpu_worker.Worker.load_model = patched_load print("[apollo] Weight export hook installed")