checkpoint: sync live weights back into model safetensors in-place
mmap each safetensors file, diff block-by-block against live GPU weights, memcpy only changed blocks. No separate checkpoint files — the model directory IS the checkpoint. Every 10 min via cron.
This commit is contained in:
parent
c1245ab139
commit
d0883e101b
1 changed files with 185 additions and 201 deletions
|
|
@ -1,30 +1,30 @@
|
||||||
// apollo-checkpoint — Efficient GPU weight checkpointing via mmap + diff.
|
// apollo-checkpoint — Sync live GPU weights back to model files on disk.
|
||||||
//
|
//
|
||||||
// mmaps the previous checkpoint, reads live weights from GPU via a
|
// mmaps the model's safetensors files, reads live weights from GPU via
|
||||||
// Python helper (CUDA IPC handles), compares block by block, and only
|
// Python helper (CUDA IPC handles), compares block by block, and memcpys
|
||||||
// writes changed regions. For small behavioral training steps, this
|
// only changed regions back into the mmap. For small behavioral training
|
||||||
// turns a 54GB write into a few hundred MB.
|
// steps, this turns a 54GB write into a few hundred MB.
|
||||||
|
//
|
||||||
|
// The model files on disk are the checkpoint. No separate checkpoint
|
||||||
|
// directory — just keep the model up to date.
|
||||||
//
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
// apollo-checkpoint save \
|
// apollo-checkpoint sync \
|
||||||
// --handles /tmp/vllm_weight_handles.pt \
|
// --handles /tmp/vllm_weight_handles.pt \
|
||||||
// --checkpoint-dir /home/ubuntu/checkpoints \
|
// --model-dir /path/to/Qwen3.5-27B
|
||||||
// --block-size 4096
|
|
||||||
//
|
//
|
||||||
// Runs every 10 minutes via cron to protect against vLLM crashes.
|
// Runs every 10 minutes via cron. Daily rsync to moria.
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
use anyhow::{Context, Result, bail};
|
||||||
use chrono::Utc;
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use memmap2::MmapOptions;
|
use memmap2::MmapMut;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "apollo-checkpoint", about = "Efficient GPU weight checkpointing")]
|
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Cmd,
|
command: Cmd,
|
||||||
|
|
@ -32,67 +32,57 @@ struct Cli {
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
enum Cmd {
|
enum Cmd {
|
||||||
/// Save a checkpoint (diff against previous, write only changes)
|
/// Sync live GPU weights back to model safetensors files
|
||||||
Save {
|
Sync {
|
||||||
/// Path to vLLM weight IPC handles
|
/// Path to vLLM weight IPC handles
|
||||||
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
||||||
handles: PathBuf,
|
handles: PathBuf,
|
||||||
|
|
||||||
/// Checkpoint directory
|
/// Model directory containing safetensors files
|
||||||
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
|
#[arg(long)]
|
||||||
checkpoint_dir: PathBuf,
|
model_dir: PathBuf,
|
||||||
|
|
||||||
/// Block size for diffing (bytes)
|
/// Block size for diffing (bytes)
|
||||||
#[arg(long, default_value_t = 4096)]
|
#[arg(long, default_value_t = 4096)]
|
||||||
block_size: usize,
|
block_size: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// List checkpoints
|
|
||||||
List {
|
|
||||||
#[arg(long, default_value = "/home/ubuntu/checkpoints")]
|
|
||||||
checkpoint_dir: PathBuf,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dump live GPU weights to a flat binary file via Python helper.
|
/// Dump live GPU weights to a flat binary file, ordered by safetensors
|
||||||
|
/// file and offset to match the on-disk layout.
|
||||||
///
|
///
|
||||||
/// The Python script imports the CUDA IPC handles and saves each
|
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
|
||||||
/// tensor's raw bytes to a flat file, plus a JSON index mapping
|
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
|
||||||
/// parameter names to (offset, size, shape, dtype).
|
let dump_path = output_dir.join(".live_dump.bin");
|
||||||
fn dump_live_weights(handles_path: &Path, output_path: &Path) -> Result<HashMap<String, TensorMeta>> {
|
let index_path = output_dir.join(".live_dump.json");
|
||||||
let index_path = output_path.with_extension("json");
|
|
||||||
|
|
||||||
let status = Command::new("python3")
|
let status = Command::new("python3")
|
||||||
.arg("-c")
|
.arg("-c")
|
||||||
.arg(format!(r#"
|
.arg(format!(r#"
|
||||||
import torch, json
|
import torch, json
|
||||||
|
|
||||||
handles = torch.load("{}", weights_only=False)
|
handles = torch.load("{handles}", weights_only=False)
|
||||||
index = {{}}
|
index = {{}}
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
with open("{}", "wb") as f:
|
with open("{dump}", "wb") as f:
|
||||||
for name, info in sorted(handles.items()):
|
for name in sorted(handles.keys()):
|
||||||
|
info = handles[name]
|
||||||
func, args = info["handle"]
|
func, args = info["handle"]
|
||||||
tensor = func(*args)
|
tensor = func(*args)
|
||||||
data = tensor.contiguous().cpu().numpy().tobytes()
|
data = tensor.contiguous().cpu().numpy().tobytes()
|
||||||
f.write(data)
|
f.write(data)
|
||||||
index[name] = {{
|
index[name] = {{"offset": offset, "size": len(data)}}
|
||||||
"offset": offset,
|
|
||||||
"size": len(data),
|
|
||||||
"shape": list(tensor.shape),
|
|
||||||
"dtype": str(tensor.dtype),
|
|
||||||
}}
|
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
||||||
with open("{}", "w") as f:
|
with open("{index}", "w") as f:
|
||||||
json.dump(index, f)
|
json.dump(index, f)
|
||||||
|
|
||||||
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
||||||
"#,
|
"#,
|
||||||
handles_path.display(),
|
handles = handles_path.display(),
|
||||||
output_path.display(),
|
dump = dump_path.display(),
|
||||||
index_path.display(),
|
index = index_path.display(),
|
||||||
))
|
))
|
||||||
.status()
|
.status()
|
||||||
.context("Failed to run Python weight dump")?;
|
.context("Failed to run Python weight dump")?;
|
||||||
|
|
@ -101,168 +91,165 @@ print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
||||||
bail!("Python weight dump failed");
|
bail!("Python weight dump failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the index
|
let index_str = fs::read_to_string(&index_path)?;
|
||||||
let index_str = fs::read_to_string(&index_path)
|
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
|
||||||
.context("Failed to read weight index")?;
|
let dump_data = fs::read(&dump_path)?;
|
||||||
let index: HashMap<String, TensorMeta> = serde_json::from_str(&index_str)?;
|
|
||||||
Ok(index)
|
let mut result = HashMap::new();
|
||||||
|
for (name, entry) in &index {
|
||||||
|
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up temp files
|
||||||
|
let _ = fs::remove_file(&dump_path);
|
||||||
|
let _ = fs::remove_file(&index_path);
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize, serde::Serialize, Clone)]
|
#[derive(serde::Deserialize)]
|
||||||
struct TensorMeta {
|
struct DumpEntry {
|
||||||
offset: usize,
|
offset: usize,
|
||||||
size: usize,
|
size: usize,
|
||||||
shape: Vec<usize>,
|
|
||||||
dtype: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Diff two flat binary files block by block, return changed byte ranges.
|
/// Read the safetensors index to map parameter names to files.
|
||||||
fn diff_blocks(old: &[u8], new: &[u8], block_size: usize) -> Vec<(usize, usize)> {
|
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
|
||||||
assert_eq!(old.len(), new.len(), "File sizes must match for diffing");
|
let index_path = model_dir.join("model.safetensors.index.json");
|
||||||
let mut changed = Vec::new();
|
if !index_path.exists() {
|
||||||
let mut i = 0;
|
// Single file model
|
||||||
|
return Ok(HashMap::new());
|
||||||
|
}
|
||||||
|
|
||||||
while i < old.len() {
|
let index_str = fs::read_to_string(&index_path)?;
|
||||||
let end = (i + block_size).min(old.len());
|
let index: serde_json::Value = serde_json::from_str(&index_str)?;
|
||||||
if old[i..end] != new[i..end] {
|
let weight_map = index["weight_map"]
|
||||||
// Extend contiguous changed region
|
.as_object()
|
||||||
let start = i;
|
.context("No weight_map in index")?;
|
||||||
while i < old.len() {
|
|
||||||
let end = (i + block_size).min(old.len());
|
let mut result = HashMap::new();
|
||||||
if old[i..end] == new[i..end] {
|
for (name, file) in weight_map {
|
||||||
break;
|
result.insert(name.clone(), file.as_str().unwrap().to_string());
|
||||||
}
|
}
|
||||||
i = end;
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sync changed blocks from live weights into a mmap'd safetensors file.
|
||||||
|
/// Returns (total_bytes_compared, bytes_changed).
|
||||||
|
fn sync_tensors_to_file(
|
||||||
|
file_path: &Path,
|
||||||
|
tensors: &[(String, Vec<u8>)],
|
||||||
|
block_size: usize,
|
||||||
|
) -> Result<(usize, usize)> {
|
||||||
|
use safetensors::SafeTensors;
|
||||||
|
|
||||||
|
let file = fs::OpenOptions::new()
|
||||||
|
.read(true)
|
||||||
|
.write(true)
|
||||||
|
.open(file_path)
|
||||||
|
.with_context(|| format!("Failed to open {}", file_path.display()))?;
|
||||||
|
|
||||||
|
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
||||||
|
|
||||||
|
// Parse safetensors header to find tensor offsets
|
||||||
|
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
|
||||||
|
let header_json: serde_json::Value =
|
||||||
|
serde_json::from_slice(&mmap[8..8 + header_size])?;
|
||||||
|
let data_start = 8 + header_size;
|
||||||
|
|
||||||
|
let mut total_compared = 0usize;
|
||||||
|
let mut total_changed = 0usize;
|
||||||
|
|
||||||
|
for (name, live_data) in tensors {
|
||||||
|
let meta = match header_json.get(name) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => {
|
||||||
|
eprintln!(" Warning: {} not found in {}", name, file_path.display());
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
changed.push((start, i));
|
|
||||||
} else {
|
|
||||||
i = end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
changed
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cmd_save(handles: PathBuf, checkpoint_dir: PathBuf, block_size: usize) -> Result<()> {
|
|
||||||
fs::create_dir_all(&checkpoint_dir)?;
|
|
||||||
|
|
||||||
let timestamp = Utc::now().format("%Y-%m-%d-%H%M").to_string();
|
|
||||||
let current_dir = checkpoint_dir.join(×tamp);
|
|
||||||
fs::create_dir_all(¤t_dir)?;
|
|
||||||
|
|
||||||
let live_path = current_dir.join("weights.bin");
|
|
||||||
eprintln!("Dumping live weights from GPU...");
|
|
||||||
let index = dump_live_weights(&handles, &live_path)?;
|
|
||||||
|
|
||||||
// Find previous checkpoint
|
|
||||||
let latest_link = checkpoint_dir.join("latest");
|
|
||||||
let previous_path = if latest_link.exists() {
|
|
||||||
let prev_dir = fs::read_link(&latest_link)?;
|
|
||||||
let prev_weights = checkpoint_dir.join(&prev_dir).join("weights.bin");
|
|
||||||
if prev_weights.exists() {
|
|
||||||
Some(prev_weights)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref prev_path) = previous_path {
|
|
||||||
// Diff against previous
|
|
||||||
eprintln!("Diffing against previous checkpoint...");
|
|
||||||
let prev_file = fs::File::open(prev_path)?;
|
|
||||||
let prev_mmap = unsafe { MmapOptions::new().map(&prev_file)? };
|
|
||||||
|
|
||||||
let live_file = fs::File::open(&live_path)?;
|
|
||||||
let live_mmap = unsafe { MmapOptions::new().map(&live_file)? };
|
|
||||||
|
|
||||||
if prev_mmap.len() == live_mmap.len() {
|
|
||||||
let changed = diff_blocks(&prev_mmap, &live_mmap, block_size);
|
|
||||||
let changed_bytes: usize = changed.iter().map(|(s, e)| e - s).sum();
|
|
||||||
let total_bytes = live_mmap.len();
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
"Changed: {:.1} MB / {:.1} GB ({:.2}%)",
|
|
||||||
changed_bytes as f64 / 1e6,
|
|
||||||
total_bytes as f64 / 1e9,
|
|
||||||
changed_bytes as f64 / total_bytes as f64 * 100.0,
|
|
||||||
);
|
|
||||||
|
|
||||||
// If nothing changed, remove the new checkpoint dir
|
|
||||||
if changed.is_empty() {
|
|
||||||
eprintln!("No changes — skipping checkpoint");
|
|
||||||
fs::remove_dir_all(¤t_dir)?;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Size mismatch ({} vs {}), writing full checkpoint",
|
|
||||||
prev_mmap.len(),
|
|
||||||
live_mmap.len()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eprintln!("No previous checkpoint — writing full snapshot");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save index
|
|
||||||
let index_path = current_dir.join("weights.json");
|
|
||||||
let index_str = serde_json::to_string_pretty(&index)?;
|
|
||||||
fs::write(&index_path, index_str)?;
|
|
||||||
|
|
||||||
// Save metadata
|
|
||||||
let meta = serde_json::json!({
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"n_params": index.len(),
|
|
||||||
"total_bytes": index.values().map(|m| m.size).sum::<usize>(),
|
|
||||||
});
|
|
||||||
fs::write(
|
|
||||||
current_dir.join("checkpoint-meta.json"),
|
|
||||||
serde_json::to_string_pretty(&meta)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Update latest symlink
|
|
||||||
if latest_link.is_symlink() {
|
|
||||||
fs::remove_file(&latest_link)?;
|
|
||||||
}
|
|
||||||
std::os::unix::fs::symlink(×tamp, &latest_link)?;
|
|
||||||
|
|
||||||
let size_gb = fs::metadata(&live_path)?.len() as f64 / 1e9;
|
|
||||||
eprintln!("Checkpoint saved: {} ({:.1} GB)", current_dir.display(), size_gb);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> {
|
|
||||||
if !checkpoint_dir.exists() {
|
|
||||||
println!("No checkpoints directory");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let latest = if checkpoint_dir.join("latest").exists() {
|
|
||||||
fs::read_link(checkpoint_dir.join("latest"))?
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string()
|
|
||||||
} else {
|
|
||||||
String::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut entries: Vec<_> = fs::read_dir(&checkpoint_dir)?
|
|
||||||
.filter_map(|e| e.ok())
|
|
||||||
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
|
||||||
.collect();
|
|
||||||
entries.sort_by_key(|e| e.file_name());
|
|
||||||
|
|
||||||
for entry in entries {
|
|
||||||
let name = entry.file_name().to_string_lossy().to_string();
|
|
||||||
let weights = entry.path().join("weights.bin");
|
|
||||||
let size = if weights.exists() {
|
|
||||||
format!("{:.1} GB", fs::metadata(&weights)?.len() as f64 / 1e9)
|
|
||||||
} else {
|
|
||||||
"no weights".to_string()
|
|
||||||
};
|
};
|
||||||
let marker = if name == latest { " ← latest" } else { "" };
|
|
||||||
println!(" {} ({}){}", name, size, marker);
|
let offsets = meta["data_offsets"].as_array().unwrap();
|
||||||
|
let start = data_start + offsets[0].as_u64().unwrap() as usize;
|
||||||
|
let end = data_start + offsets[1].as_u64().unwrap() as usize;
|
||||||
|
let disk_data = &mmap[start..end];
|
||||||
|
|
||||||
|
if disk_data.len() != live_data.len() {
|
||||||
|
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
|
||||||
|
name, disk_data.len(), live_data.len());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diff block by block, memcpy only changed blocks
|
||||||
|
let mut offset = 0;
|
||||||
|
while offset < disk_data.len() {
|
||||||
|
let block_end = (offset + block_size).min(disk_data.len());
|
||||||
|
total_compared += block_end - offset;
|
||||||
|
|
||||||
|
if disk_data[offset..block_end] != live_data[offset..block_end] {
|
||||||
|
mmap[start + offset..start + block_end]
|
||||||
|
.copy_from_slice(&live_data[offset..block_end]);
|
||||||
|
total_changed += block_end - offset;
|
||||||
|
}
|
||||||
|
offset = block_end;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mmap.flush()?;
|
||||||
|
Ok((total_compared, total_changed))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
|
||||||
|
if !handles.exists() {
|
||||||
|
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
|
||||||
|
handles.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!("Dumping live weights from GPU...");
|
||||||
|
let live_weights = dump_live_weights(&handles, &model_dir)?;
|
||||||
|
eprintln!(" {} tensors dumped", live_weights.len());
|
||||||
|
|
||||||
|
// Map parameter names to safetensors files
|
||||||
|
let weight_map = read_safetensors_index(&model_dir)?;
|
||||||
|
|
||||||
|
// Group tensors by safetensors file
|
||||||
|
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
|
||||||
|
for (name, data) in live_weights {
|
||||||
|
let file = weight_map
|
||||||
|
.get(&name)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "model.safetensors".to_string());
|
||||||
|
by_file.entry(file).or_default().push((name, data));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut total_compared = 0usize;
|
||||||
|
let mut total_changed = 0usize;
|
||||||
|
|
||||||
|
for (filename, tensors) in &by_file {
|
||||||
|
let file_path = model_dir.join(filename);
|
||||||
|
if !file_path.exists() {
|
||||||
|
eprintln!(" Warning: {} not found, skipping", filename);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
|
||||||
|
total_compared += compared;
|
||||||
|
total_changed += changed;
|
||||||
|
|
||||||
|
if changed > 0 {
|
||||||
|
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_changed == 0 {
|
||||||
|
eprintln!("No changes — model files are up to date");
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
|
||||||
|
total_changed as f64 / 1e6,
|
||||||
|
total_compared as f64 / 1e9,
|
||||||
|
total_changed as f64 / total_compared as f64 * 100.0,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -271,11 +258,8 @@ fn cmd_list(checkpoint_dir: PathBuf) -> Result<()> {
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
match cli.command {
|
match cli.command {
|
||||||
Cmd::Save { handles, checkpoint_dir, block_size } => {
|
Cmd::Sync { handles, model_dir, block_size } => {
|
||||||
cmd_save(handles, checkpoint_dir, block_size)
|
cmd_sync(handles, model_dir, block_size)
|
||||||
}
|
|
||||||
Cmd::List { checkpoint_dir } => {
|
|
||||||
cmd_list(checkpoint_dir)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue