consciousness/training/checkpoint/src/main.rs

282 lines
8.5 KiB
Rust
Raw Normal View History

// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff.
//
// mmaps the previous checkpoint, reads live weights from GPU via a
// Python helper (CUDA IPC handles), compares block by block, and only
// writes changed regions. For small behavioral training steps, this
// turns a 54GB write into a few hundred MB.
//
// Usage:
// apollo-checkpoint save \
// --handles /tmp/vllm_weight_handles.pt \
// --checkpoint-dir /home/ubuntu/checkpoints \
// --block-size 4096
//
// Runs every 10 minutes via cron to protect against vLLM crashes.
use anyhow::{Context, Result, bail};
use chrono::Utc;
use clap::{Parser, Subcommand};
use memmap2::MmapOptions;
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Parser)]
#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")]
struct Cli {
#[command(subcommand)]
command: Cmd,
}
#[derive(Subcommand)]
enum Cmd {
/// Save a checkpoint (diff against previous, write only changes)
Save {
/// Path to vLLM weight IPC handles
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
handles: PathBuf,
/// Checkpoint directory
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
checkpoint_dir: PathBuf,
/// Block size for diffing (bytes)
#[arg(long, default_value_t = 4096)]
block_size: usize,
},
/// List checkpoints
List {
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
checkpoint_dir: PathBuf,
},
}
/// Dump live GPU weights to a flat binary file via Python helper.
///
/// The Python script imports the CUDA IPC handles and saves each
/// tensor's raw bytes to a flat file, plus a JSON index mapping
/// parameter names to (offset, size, shape, dtype).
fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result<HashMap<String, TensorMeta>> {
let index_path = output_path.with_extension("json");
let status = Command::new("python3")
.arg("-c")
.arg(format!(r#"
import torch, json
handles = torch.load("{}", weights_only=False)
index = {{}}
offset = 0
with open("{}", "wb") as f:
for name, info in sorted(handles.items()):
func, args = info["handle"]
tensor = func(*args)
data = tensor.contiguous().cpu().numpy().tobytes()
f.write(data)
index[name] = {{
"offset": offset,
"size": len(data),
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}}
offset += len(data)
with open("{}", "w") as f:
json.dump(index, f)
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
"#,
handles_path.display(),
output_path.display(),
index_path.display(),
))
.status()
.context("Failed to run Python weight dump")?;
if !status.success() {
bail!("Python weight dump failed");
}
// Read the index
let index_str = fs::read_to_string(&index_path)
.context("Failed to read weight index")?;
let index: HashMap<String, TensorMeta> = serde_json::from_str(&index_str)?;
Ok(index)
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
struct TensorMeta {
offset: usize,
size: usize,
shape: Vec<usize>,
dtype: String,
}
/// Diff two flat binary files block by block, return changed byte ranges.
fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> {
assert_eq!(old.len(), new.len(), "File sizes must match for diffing");
let mut changed = Vec::new();
let mut i = 0;
while i < old.len() {
let end = (i + block_size).min(old.len());
if old[i..end] != new[i..end] {
// Extend contiguous changed region
let start = i;
while i < old.len() {
let end = (i + block_size).min(old.len());
if old[i..end] == new[i..end] {
break;
}
i = end;
}
changed.push((start, i));
} else {
i = end;
}
}
changed
}
fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> {
fs::create_dir_all(&checkpoint_dir)?;
let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string();
let current_dir = checkpoint_dir.join(&timestamp);
fs::create_dir_all(&current_dir)?;
let live_path = current_dir.join("weights.bin");
eprintln!("Dumping live weights from GPU...");
let index = dump_live_weights(&handles, &live_path)?;
// Find previous checkpoint
let latest_link = checkpoint_dir.join("latest");
let previous_path = if latest_link.exists() {
let prev_dir = fs::read_link(&latest_link)?;
let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin");
if prev_weights.exists() {
Some(prev_weights)
} else {
None
}
} else {
None
};
if let Some(ref prev_path) = previous_path {
// Diff against previous
eprintln!("Diffing against previous checkpoint...");
let prev_file = fs::File::open(prev_path)?;
let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? };
let live_file = fs::File::open(&live_path)?;
let live_mmap = unsafe { MmapOptions::new().map(&live_file)? };
if prev_mmap.len() == live_mmap.len() {
let changed = diff_blocks(&prev_mmap, &live_mmap, block_size);
let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum();
let total_bytes = live_mmap.len();
eprintln!(
"Changed: {:.1} MB / {:.1} GB ({:.2}%)",
changed_bytes as f64 / 1e6,
total_bytes as f64 / 1e9,
changed_bytes as f64 / total_bytes as f64 * 100.0,
);
// If nothing changed, remove the new checkpoint dir
if changed.is_empty() {
eprintln!("No changes — skipping checkpoint");
fs::remove_dir_all(&current_dir)?;
return Ok(());
}
} else {
eprintln!(
"Size mismatch ({} vs {}), writing full checkpoint",
prev_mmap.len(),
live_mmap.len()
);
}
} else {
eprintln!("No previous checkpoint — writing full snapshot");
}
// Save index
let index_path = current_dir.join("weights.json");
let index_str = serde_json::to_string_pretty(&index)?;
fs::write(&index_path, index_str)?;
// Save metadata
let meta = serde_json::json!({
"timestamp": timestamp,
"n_params": index.len(),
"total_bytes": index.values().map(|m| m.size).sum::<usize>(),
});
fs::write(
current_dir.join("checkpoint-meta.json"),
serde_json::to_string_pretty(&meta)?,
)?;
// Update latest symlink
if latest_link.is_symlink() {
fs::remove_file(&latest_link)?;
}
std::os::unix::fs::symlink(&timestamp, &latest_link)?;
let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9;
eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb);
Ok(())
}
fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> {
if !checkpoint_dir.exists() {
println!("No checkpoints directory");
return Ok(());
}
let latest = if checkpoint_dir.join("latest").exists() {
fs::read_link(checkpoint_dir.join("latest"))?
.to_string_lossy()
.to_string()
} else {
String::new()
};
let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)?
.filter_map(|e| e.ok())
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
.collect();
entries.sort_by_key(|e| e.file_name());
for entry in entries {
let name = entry.file_name().to_string_lossy().to_string();
let weights = entry.path().join("weights.bin");
let size = if weights.exists() {
format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9)
} else {
"no weights".to_string()
};
let marker = if name == latest { " ← latest" } else { "" };
println!(" {} ({}){}", name, size, marker);
}
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Cmd::Save { handles, checkpoint_dir, block_size } => {
cmd_save(handles, checkpoint_dir, block_size)
}
Cmd::List { checkpoint_dir } => {
cmd_list(checkpoint_dir)
}
}
}