mmap each safetensors file, diff block-by-block against live GPU weights, memcpy only changed blocks. No separate checkpoint files — the model directory IS the checkpoint. Every 10 min via cron.
265 lines
8.3 KiB
Rust
265 lines
8.3 KiB
Rust
// apollo-checkpoint — Sync live GPU weights back to model files on disk.
|
|
//
|
|
// mmaps the model's safetensors files, reads live weights from GPU via
|
|
// Python helper (CUDA IPC handles), compares block by block, and memcpys
|
|
// only changed regions back into the mmap. For small behavioral training
|
|
// steps, this turns a 54GB write into a few hundred MB.
|
|
//
|
|
// The model files on disk are the checkpoint. No separate checkpoint
|
|
// directory — just keep the model up to date.
|
|
//
|
|
// Usage:
|
|
// apollo-checkpoint sync \
|
|
// --handles /tmp/vllm_weight_handles.pt \
|
|
// --model-dir /path/to/Qwen3.5-27B
|
|
//
|
|
// Runs every 10 minutes via cron. Daily rsync to moria.
|
|
|
|
use anyhow::{Context, Result, bail};
|
|
use clap::{Parser, Subcommand};
|
|
use memmap2::MmapMut;
|
|
use std::collections::HashMap;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
use std::process::Command;
|
|
|
|
#[derive(Parser)]
|
|
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
|
|
struct Cli {
|
|
#[command(subcommand)]
|
|
command: Cmd,
|
|
}
|
|
|
|
#[derive(Subcommand)]
|
|
enum Cmd {
|
|
/// Sync live GPU weights back to model safetensors files
|
|
Sync {
|
|
/// Path to vLLM weight IPC handles
|
|
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
|
handles: PathBuf,
|
|
|
|
/// Model directory containing safetensors files
|
|
#[arg(long)]
|
|
model_dir: PathBuf,
|
|
|
|
/// Block size for diffing (bytes)
|
|
#[arg(long, default_value_t = 4096)]
|
|
block_size: usize,
|
|
},
|
|
}
|
|
|
|
/// Dump live GPU weights to a flat binary file, ordered by safetensors
|
|
/// file and offset to match the on-disk layout.
|
|
///
|
|
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
|
|
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
|
|
let dump_path = output_dir.join(".live_dump.bin");
|
|
let index_path = output_dir.join(".live_dump.json");
|
|
|
|
let status = Command::new("python3")
|
|
.arg("-c")
|
|
.arg(format!(r#"
|
|
import torch, json
|
|
|
|
handles = torch.load("{handles}", weights_only=False)
|
|
index = {{}}
|
|
offset = 0
|
|
|
|
with open("{dump}", "wb") as f:
|
|
for name in sorted(handles.keys()):
|
|
info = handles[name]
|
|
func, args = info["handle"]
|
|
tensor = func(*args)
|
|
data = tensor.contiguous().cpu().numpy().tobytes()
|
|
f.write(data)
|
|
index[name] = {{"offset": offset, "size": len(data)}}
|
|
offset += len(data)
|
|
|
|
with open("{index}", "w") as f:
|
|
json.dump(index, f)
|
|
|
|
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
|
"#,
|
|
handles = handles_path.display(),
|
|
dump = dump_path.display(),
|
|
index = index_path.display(),
|
|
))
|
|
.status()
|
|
.context("Failed to run Python weight dump")?;
|
|
|
|
if !status.success() {
|
|
bail!("Python weight dump failed");
|
|
}
|
|
|
|
let index_str = fs::read_to_string(&index_path)?;
|
|
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
|
|
let dump_data = fs::read(&dump_path)?;
|
|
|
|
let mut result = HashMap::new();
|
|
for (name, entry) in &index {
|
|
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
|
|
}
|
|
|
|
// Clean up temp files
|
|
let _ = fs::remove_file(&dump_path);
|
|
let _ = fs::remove_file(&index_path);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct DumpEntry {
|
|
offset: usize,
|
|
size: usize,
|
|
}
|
|
|
|
/// Read the safetensors index to map parameter names to files.
|
|
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
|
|
let index_path = model_dir.join("model.safetensors.index.json");
|
|
if !index_path.exists() {
|
|
// Single file model
|
|
return Ok(HashMap::new());
|
|
}
|
|
|
|
let index_str = fs::read_to_string(&index_path)?;
|
|
let index: serde_json::Value = serde_json::from_str(&index_str)?;
|
|
let weight_map = index["weight_map"]
|
|
.as_object()
|
|
.context("No weight_map in index")?;
|
|
|
|
let mut result = HashMap::new();
|
|
for (name, file) in weight_map {
|
|
result.insert(name.clone(), file.as_str().unwrap().to_string());
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
/// Sync changed blocks from live weights into a mmap'd safetensors file.
|
|
/// Returns (total_bytes_compared, bytes_changed).
|
|
fn sync_tensors_to_file(
|
|
file_path: &Path,
|
|
tensors: &[(String, Vec<u8>)],
|
|
block_size: usize,
|
|
) -> Result<(usize, usize)> {
|
|
use safetensors::SafeTensors;
|
|
|
|
let file = fs::OpenOptions::new()
|
|
.read(true)
|
|
.write(true)
|
|
.open(file_path)
|
|
.with_context(|| format!("Failed to open {}", file_path.display()))?;
|
|
|
|
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
|
|
|
// Parse safetensors header to find tensor offsets
|
|
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
|
|
let header_json: serde_json::Value =
|
|
serde_json::from_slice(&mmap[8..8 + header_size])?;
|
|
let data_start = 8 + header_size;
|
|
|
|
let mut total_compared = 0usize;
|
|
let mut total_changed = 0usize;
|
|
|
|
for (name, live_data) in tensors {
|
|
let meta = match header_json.get(name) {
|
|
Some(m) => m,
|
|
None => {
|
|
eprintln!(" Warning: {} not found in {}", name, file_path.display());
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let offsets = meta["data_offsets"].as_array().unwrap();
|
|
let start = data_start + offsets[0].as_u64().unwrap() as usize;
|
|
let end = data_start + offsets[1].as_u64().unwrap() as usize;
|
|
let disk_data = &mmap[start..end];
|
|
|
|
if disk_data.len() != live_data.len() {
|
|
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
|
|
name, disk_data.len(), live_data.len());
|
|
continue;
|
|
}
|
|
|
|
// Diff block by block, memcpy only changed blocks
|
|
let mut offset = 0;
|
|
while offset < disk_data.len() {
|
|
let block_end = (offset + block_size).min(disk_data.len());
|
|
total_compared += block_end - offset;
|
|
|
|
if disk_data[offset..block_end] != live_data[offset..block_end] {
|
|
mmap[start + offset..start + block_end]
|
|
.copy_from_slice(&live_data[offset..block_end]);
|
|
total_changed += block_end - offset;
|
|
}
|
|
offset = block_end;
|
|
}
|
|
}
|
|
|
|
mmap.flush()?;
|
|
Ok((total_compared, total_changed))
|
|
}
|
|
|
|
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
|
|
if !handles.exists() {
|
|
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
|
|
handles.display());
|
|
}
|
|
|
|
eprintln!("Dumping live weights from GPU...");
|
|
let live_weights = dump_live_weights(&handles, &model_dir)?;
|
|
eprintln!(" {} tensors dumped", live_weights.len());
|
|
|
|
// Map parameter names to safetensors files
|
|
let weight_map = read_safetensors_index(&model_dir)?;
|
|
|
|
// Group tensors by safetensors file
|
|
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
|
|
for (name, data) in live_weights {
|
|
let file = weight_map
|
|
.get(&name)
|
|
.cloned()
|
|
.unwrap_or_else(|| "model.safetensors".to_string());
|
|
by_file.entry(file).or_default().push((name, data));
|
|
}
|
|
|
|
let mut total_compared = 0usize;
|
|
let mut total_changed = 0usize;
|
|
|
|
for (filename, tensors) in &by_file {
|
|
let file_path = model_dir.join(filename);
|
|
if !file_path.exists() {
|
|
eprintln!(" Warning: {} not found, skipping", filename);
|
|
continue;
|
|
}
|
|
|
|
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
|
|
total_compared += compared;
|
|
total_changed += changed;
|
|
|
|
if changed > 0 {
|
|
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
|
|
}
|
|
}
|
|
|
|
if total_changed == 0 {
|
|
eprintln!("No changes — model files are up to date");
|
|
} else {
|
|
eprintln!(
|
|
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
|
|
total_changed as f64 / 1e6,
|
|
total_compared as f64 / 1e9,
|
|
total_changed as f64 / total_compared as f64 * 100.0,
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn main() -> Result<()> {
|
|
let cli = Cli::parse();
|
|
match cli.command {
|
|
Cmd::Sync { handles, model_dir, block_size } => {
|
|
cmd_sync(handles, model_dir, block_size)
|
|
}
|
|
}
|
|
}
|