apollo-checkpoint: efficient diff-based GPU weight checkpointing
Rust tool that mmaps previous checkpoint, diffs against live GPU weights (via CUDA IPC handles), and only writes changed blocks. For small behavioral training steps, turns 54GB write into ~500MB. Also includes vllm_export_hook.py with direct source patch approach — exports IPC handles from vLLM's worker subprocess after model load. Run every 10 minutes via cron to protect against vLLM crashes. Daily rsync to moria for long-term storage.
This commit is contained in:
parent
5f41898bb8
commit
c1245ab139
3 changed files with 305 additions and 5 deletions
13
training/checkpoint/Cargo.toml
Normal file
13
training/checkpoint/Cargo.toml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "apollo-checkpoint"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
memmap2 = "0.9"
|
||||
safetensors = "0.5"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
chrono = "0.4"
|
||||
281
training/checkpoint/src/main.rs
Normal file
281
training/checkpoint/src/main.rs
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff.
|
||||
//
|
||||
// mmaps the previous checkpoint, reads live weights from GPU via a
|
||||
// Python helper (CUDA IPC handles), compares block by block, and only
|
||||
// writes changed regions. For small behavioral training steps, this
|
||||
// turns a 54GB write into a few hundred MB.
|
||||
//
|
||||
// Usage:
|
||||
// apollo-checkpoint save \
|
||||
// --handles /tmp/vllm_weight_handles.pt \
|
||||
// --checkpoint-dir /home/ubuntu/checkpoints \
|
||||
// --block-size 4096
|
||||
//
|
||||
// Runs every 10 minutes via cron to protect against vLLM crashes.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use chrono::Utc;
|
||||
use clap::{Parser, Subcommand};
|
||||
use memmap2::MmapOptions;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Cmd,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Cmd {
|
||||
/// Save a checkpoint (diff against previous, write only changes)
|
||||
Save {
|
||||
/// Path to vLLM weight IPC handles
|
||||
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
||||
handles: PathBuf,
|
||||
|
||||
/// Checkpoint directory
|
||||
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
|
||||
checkpoint_dir: PathBuf,
|
||||
|
||||
/// Block size for diffing (bytes)
|
||||
#[arg(long, default_value_t = 4096)]
|
||||
block_size: usize,
|
||||
},
|
||||
|
||||
/// List checkpoints
|
||||
List {
|
||||
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
|
||||
checkpoint_dir: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
/// Dump live GPU weights to a flat binary file via Python helper.
|
||||
///
|
||||
/// The Python script imports the CUDA IPC handles and saves each
|
||||
/// tensor's raw bytes to a flat file, plus a JSON index mapping
|
||||
/// parameter names to (offset, size, shape, dtype).
|
||||
fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result<HashMap<String, TensorMeta>> {
|
||||
let index_path = output_path.with_extension("json");
|
||||
|
||||
let status = Command::new("python3")
|
||||
.arg("-c")
|
||||
.arg(format!(r#"
|
||||
import torch, json
|
||||
|
||||
handles = torch.load("{}", weights_only=False)
|
||||
index = {{}}
|
||||
offset = 0
|
||||
|
||||
with open("{}", "wb") as f:
|
||||
for name, info in sorted(handles.items()):
|
||||
func, args = info["handle"]
|
||||
tensor = func(*args)
|
||||
data = tensor.contiguous().cpu().numpy().tobytes()
|
||||
f.write(data)
|
||||
index[name] = {{
|
||||
"offset": offset,
|
||||
"size": len(data),
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
}}
|
||||
offset += len(data)
|
||||
|
||||
with open("{}", "w") as f:
|
||||
json.dump(index, f)
|
||||
|
||||
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
||||
"#,
|
||||
handles_path.display(),
|
||||
output_path.display(),
|
||||
index_path.display(),
|
||||
))
|
||||
.status()
|
||||
.context("Failed to run Python weight dump")?;
|
||||
|
||||
if !status.success() {
|
||||
bail!("Python weight dump failed");
|
||||
}
|
||||
|
||||
// Read the index
|
||||
let index_str = fs::read_to_string(&index_path)
|
||||
.context("Failed to read weight index")?;
|
||||
let index: HashMap<String, TensorMeta> = serde_json::from_str(&index_str)?;
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Clone)]
|
||||
struct TensorMeta {
|
||||
offset: usize,
|
||||
size: usize,
|
||||
shape: Vec<usize>,
|
||||
dtype: String,
|
||||
}
|
||||
|
||||
/// Diff two flat binary files block by block, return changed byte ranges.
|
||||
fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> {
|
||||
assert_eq!(old.len(), new.len(), "File sizes must match for diffing");
|
||||
let mut changed = Vec::new();
|
||||
let mut i = 0;
|
||||
|
||||
while i < old.len() {
|
||||
let end = (i + block_size).min(old.len());
|
||||
if old[i..end] != new[i..end] {
|
||||
// Extend contiguous changed region
|
||||
let start = i;
|
||||
while i < old.len() {
|
||||
let end = (i + block_size).min(old.len());
|
||||
if old[i..end] == new[i..end] {
|
||||
break;
|
||||
}
|
||||
i = end;
|
||||
}
|
||||
changed.push((start, i));
|
||||
} else {
|
||||
i = end;
|
||||
}
|
||||
}
|
||||
|
||||
changed
|
||||
}
|
||||
|
||||
fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> {
|
||||
fs::create_dir_all(&checkpoint_dir)?;
|
||||
|
||||
let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string();
|
||||
let current_dir = checkpoint_dir.join(×tamp);
|
||||
fs::create_dir_all(¤t_dir)?;
|
||||
|
||||
let live_path = current_dir.join("weights.bin");
|
||||
eprintln!("Dumping live weights from GPU...");
|
||||
let index = dump_live_weights(&handles, &live_path)?;
|
||||
|
||||
// Find previous checkpoint
|
||||
let latest_link = checkpoint_dir.join("latest");
|
||||
let previous_path = if latest_link.exists() {
|
||||
let prev_dir = fs::read_link(&latest_link)?;
|
||||
let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin");
|
||||
if prev_weights.exists() {
|
||||
Some(prev_weights)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(ref prev_path) = previous_path {
|
||||
// Diff against previous
|
||||
eprintln!("Diffing against previous checkpoint...");
|
||||
let prev_file = fs::File::open(prev_path)?;
|
||||
let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? };
|
||||
|
||||
let live_file = fs::File::open(&live_path)?;
|
||||
let live_mmap = unsafe { MmapOptions::new().map(&live_file)? };
|
||||
|
||||
if prev_mmap.len() == live_mmap.len() {
|
||||
let changed = diff_blocks(&prev_mmap, &live_mmap, block_size);
|
||||
let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum();
|
||||
let total_bytes = live_mmap.len();
|
||||
|
||||
eprintln!(
|
||||
"Changed: {:.1} MB / {:.1} GB ({:.2}%)",
|
||||
changed_bytes as f64 / 1e6,
|
||||
total_bytes as f64 / 1e9,
|
||||
changed_bytes as f64 / total_bytes as f64 * 100.0,
|
||||
);
|
||||
|
||||
// If nothing changed, remove the new checkpoint dir
|
||||
if changed.is_empty() {
|
||||
eprintln!("No changes — skipping checkpoint");
|
||||
fs::remove_dir_all(¤t_dir)?;
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
eprintln!(
|
||||
"Size mismatch ({} vs {}), writing full checkpoint",
|
||||
prev_mmap.len(),
|
||||
live_mmap.len()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
eprintln!("No previous checkpoint — writing full snapshot");
|
||||
}
|
||||
|
||||
// Save index
|
||||
let index_path = current_dir.join("weights.json");
|
||||
let index_str = serde_json::to_string_pretty(&index)?;
|
||||
fs::write(&index_path, index_str)?;
|
||||
|
||||
// Save metadata
|
||||
let meta = serde_json::json!({
|
||||
"timestamp": timestamp,
|
||||
"n_params": index.len(),
|
||||
"total_bytes": index.values().map(|m| m.size).sum::<usize>(),
|
||||
});
|
||||
fs::write(
|
||||
current_dir.join("checkpoint-meta.json"),
|
||||
serde_json::to_string_pretty(&meta)?,
|
||||
)?;
|
||||
|
||||
// Update latest symlink
|
||||
if latest_link.is_symlink() {
|
||||
fs::remove_file(&latest_link)?;
|
||||
}
|
||||
std::os::unix::fs::symlink(×tamp, &latest_link)?;
|
||||
|
||||
let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9;
|
||||
eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> {
|
||||
if !checkpoint_dir.exists() {
|
||||
println!("No checkpoints directory");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let latest = if checkpoint_dir.join("latest").exists() {
|
||||
fs::read_link(checkpoint_dir.join("latest"))?
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
||||
.collect();
|
||||
entries.sort_by_key(|e| e.file_name());
|
||||
|
||||
for entry in entries {
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let weights = entry.path().join("weights.bin");
|
||||
let size = if weights.exists() {
|
||||
format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9)
|
||||
} else {
|
||||
"no weights".to_string()
|
||||
};
|
||||
let marker = if name == latest { " ← latest" } else { "" };
|
||||
println!(" {} ({}){}", name, size, marker);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
match cli.command {
|
||||
Cmd::Save { handles, checkpoint_dir, block_size } => {
|
||||
cmd_save(handles, checkpoint_dir, block_size)
|
||||
}
|
||||
Cmd::List { checkpoint_dir } => {
|
||||
cmd_list(checkpoint_dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue