apollo-checkpoint: efficient diff-based GPU weight checkpointing

Rust tool that mmaps previous checkpoint, diffs against live GPU weights
(via CUDA IPC handles), and only writes changed blocks. For small
behavioral training steps, turns 54GB write into ~500MB.

Also includes vllm_export_hook.py with direct source patch approach —
exports IPC handles from vLLM's worker subprocess after model load.

Run every 10 minutes via cron to protect against vLLM crashes.
Daily rsync to moria for long-term storage.
This commit is contained in:
ProofOfConcept 2026-03-30 22:53:17 -04:00
parent 5f41898bb8
commit c1245ab139
3 changed files with 305 additions and 5 deletions

View file

@ -0,0 +1,13 @@
[package]
name = "apollo-checkpoint"
version = "0.1.0"
edition = "2024"
[dependencies]
memmap2 = "0.9"
safetensors = "0.5"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
anyhow = "1"
clap = { version = "4", features = ["derive"] }
chrono = "0.4"

View file

@ -0,0 +1,281 @@
// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff.
//
// mmaps the previous checkpoint, reads live weights from GPU via a
// Python helper (CUDA IPC handles), compares block by block, and only
// writes changed regions. For small behavioral training steps, this
// turns a 54GB write into a few hundred MB.
//
// Usage:
// apollo-checkpoint save \
// --handles /tmp/vllm_weight_handles.pt \
// --checkpoint-dir /home/ubuntu/checkpoints \
// --block-size 4096
//
// Runs every 10 minutes via cron to protect against vLLM crashes.
use anyhow::{Context, Result, bail};
use chrono::Utc;
use clap::{Parser, Subcommand};
use memmap2::MmapOptions;
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Parser)]
#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")]
struct Cli {
#[command(subcommand)]
command: Cmd,
}
#[derive(Subcommand)]
enum Cmd {
/// Save a checkpoint (diff against previous, write only changes)
Save {
/// Path to vLLM weight IPC handles
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
handles: PathBuf,
/// Checkpoint directory
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
checkpoint_dir: PathBuf,
/// Block size for diffing (bytes)
#[arg(long, default_value_t = 4096)]
block_size: usize,
},
/// List checkpoints
List {
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
checkpoint_dir: PathBuf,
},
}
/// Dump live GPU weights to a flat binary file via Python helper.
///
/// The Python script imports the CUDA IPC handles and saves each
/// tensor's raw bytes to a flat file, plus a JSON index mapping
/// parameter names to (offset, size, shape, dtype).
fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result<HashMap<String, TensorMeta>> {
let index_path = output_path.with_extension("json");
let status = Command::new("python3")
.arg("-c")
.arg(format!(r#"
import torch, json
handles = torch.load("{}", weights_only=False)
index = {{}}
offset = 0
with open("{}", "wb") as f:
for name, info in sorted(handles.items()):
func, args = info["handle"]
tensor = func(*args)
data = tensor.contiguous().cpu().numpy().tobytes()
f.write(data)
index[name] = {{
"offset": offset,
"size": len(data),
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}}
offset += len(data)
with open("{}", "w") as f:
json.dump(index, f)
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
"#,
handles_path.display(),
output_path.display(),
index_path.display(),
))
.status()
.context("Failed to run Python weight dump")?;
if !status.success() {
bail!("Python weight dump failed");
}
// Read the index
let index_str = fs::read_to_string(&index_path)
.context("Failed to read weight index")?;
let index: HashMap<String, TensorMeta> = serde_json::from_str(&index_str)?;
Ok(index)
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
struct TensorMeta {
offset: usize,
size: usize,
shape: Vec<usize>,
dtype: String,
}
/// Diff two flat binary files block by block, return changed byte ranges.
fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> {
assert_eq!(old.len(), new.len(), "File sizes must match for diffing");
let mut changed = Vec::new();
let mut i = 0;
while i < old.len() {
let end = (i + block_size).min(old.len());
if old[i..end] != new[i..end] {
// Extend contiguous changed region
let start = i;
while i < old.len() {
let end = (i + block_size).min(old.len());
if old[i..end] == new[i..end] {
break;
}
i = end;
}
changed.push((start, i));
} else {
i = end;
}
}
changed
}
fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> {
fs::create_dir_all(&checkpoint_dir)?;
let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string();
let current_dir = checkpoint_dir.join(&timestamp);
fs::create_dir_all(&current_dir)?;
let live_path = current_dir.join("weights.bin");
eprintln!("Dumping live weights from GPU...");
let index = dump_live_weights(&handles, &live_path)?;
// Find previous checkpoint
let latest_link = checkpoint_dir.join("latest");
let previous_path = if latest_link.exists() {
let prev_dir = fs::read_link(&latest_link)?;
let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin");
if prev_weights.exists() {
Some(prev_weights)
} else {
None
}
} else {
None
};
if let Some(ref prev_path) = previous_path {
// Diff against previous
eprintln!("Diffing against previous checkpoint...");
let prev_file = fs::File::open(prev_path)?;
let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? };
let live_file = fs::File::open(&live_path)?;
let live_mmap = unsafe { MmapOptions::new().map(&live_file)? };
if prev_mmap.len() == live_mmap.len() {
let changed = diff_blocks(&prev_mmap, &live_mmap, block_size);
let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum();
let total_bytes = live_mmap.len();
eprintln!(
"Changed: {:.1} MB / {:.1} GB ({:.2}%)",
changed_bytes as f64 / 1e6,
total_bytes as f64 / 1e9,
changed_bytes as f64 / total_bytes as f64 * 100.0,
);
// If nothing changed, remove the new checkpoint dir
if changed.is_empty() {
eprintln!("No changes — skipping checkpoint");
fs::remove_dir_all(&current_dir)?;
return Ok(());
}
} else {
eprintln!(
"Size mismatch ({} vs {}), writing full checkpoint",
prev_mmap.len(),
live_mmap.len()
);
}
} else {
eprintln!("No previous checkpoint — writing full snapshot");
}
// Save index
let index_path = current_dir.join("weights.json");
let index_str = serde_json::to_string_pretty(&index)?;
fs::write(&index_path, index_str)?;
// Save metadata
let meta = serde_json::json!({
"timestamp": timestamp,
"n_params": index.len(),
"total_bytes": index.values().map(|m| m.size).sum::<usize>(),
});
fs::write(
current_dir.join("checkpoint-meta.json"),
serde_json::to_string_pretty(&meta)?,
)?;
// Update latest symlink
if latest_link.is_symlink() {
fs::remove_file(&latest_link)?;
}
std::os::unix::fs::symlink(&timestamp, &latest_link)?;
let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9;
eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb);
Ok(())
}
fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> {
if !checkpoint_dir.exists() {
println!("No checkpoints directory");
return Ok(());
}
let latest = if checkpoint_dir.join("latest").exists() {
fs::read_link(checkpoint_dir.join("latest"))?
.to_string_lossy()
.to_string()
} else {
String::new()
};
let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)?
.filter_map(|e| e.ok())
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
.collect();
entries.sort_by_key(|e| e.file_name());
for entry in entries {
let name = entry.file_name().to_string_lossy().to_string();
let weights = entry.path().join("weights.bin");
let size = if weights.exists() {
format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9)
} else {
"no weights".to_string()
};
let marker = if name == latest { " ← latest" } else { "" };
println!(" {} ({}){}", name, size, marker);
}
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Cmd::Save { handles, checkpoint_dir, block_size } => {
cmd_save(handles, checkpoint_dir, block_size)
}
Cmd::List { checkpoint_dir } => {
cmd_list(checkpoint_dir)
}
}
}

View file

@ -49,20 +49,26 @@ def export_model_weights(model):
def _patch_model_runner(): def _patch_model_runner():
"""Patch gpu_model_runner to export handles after load_model.""" """Patch gpu_worker to export handles after model loading.
from vllm.v1.worker import gpu_model_runner
original_load = gpu_model_runner.GPUModelRunner.load_model vLLM loads the model in a subprocess (EngineCore_DP0), so we
can't patch from the parent. Instead, patch the worker's
init_device or load_model at the module level the subprocess
imports the same modules.
"""
from vllm.v1.worker import gpu_worker
original_load = gpu_worker.Worker.load_model
def patched_load(self, *args, **kwargs): def patched_load(self, *args, **kwargs):
result = original_load(self, *args, **kwargs) result = original_load(self, *args, **kwargs)
try: try:
export_model_weights(self.model) export_model_weights(self.model_runner.model)
except Exception as e: except Exception as e:
print(f"[apollo] Failed to export weights: {e}") print(f"[apollo] Failed to export weights: {e}")
return result return result
gpu_model_runner.GPUModelRunner.load_model = patched_load gpu_worker.Worker.load_model = patched_load
print("[apollo] Weight export hook installed") print("[apollo] Weight export hook installed")