consciousness/training/checkpoint/src/main.rs
ProofOfConcept d0883e101b checkpoint: sync live weights back into model safetensors in-place
mmap each safetensors file, diff block-by-block against live GPU
weights, memcpy only changed blocks. No separate checkpoint files —
the model directory IS the checkpoint. Every 10 min via cron.
2026-03-30 22:55:23 -04:00

265 lines
8.3 KiB
Rust

// apollo-checkpoint — Sync live GPU weights back to model files on disk.
//
// mmaps the model's safetensors files, reads live weights from GPU via
// Python helper (CUDA IPC handles), compares block by block, and memcpys
// only changed regions back into the mmap. For small behavioral training
// steps, this turns a 54GB write into a few hundred MB.
//
// The model files on disk are the checkpoint. No separate checkpoint
// directory — just keep the model up to date.
//
// Usage:
// apollo-checkpoint sync \
// --handles /tmp/vllm_weight_handles.pt \
// --model-dir /path/to/Qwen3.5-27B
//
// Runs every 10 minutes via cron. Daily rsync to moria.
use anyhow::{Context, Result, bail};
use clap::{Parser, Subcommand};
use memmap2::MmapMut;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Parser)]
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
struct Cli {
#[command(subcommand)]
command: Cmd,
}
#[derive(Subcommand)]
enum Cmd {
/// Sync live GPU weights back to model safetensors files
Sync {
/// Path to vLLM weight IPC handles
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
handles: PathBuf,
/// Model directory containing safetensors files
#[arg(long)]
model_dir: PathBuf,
/// Block size for diffing (bytes)
#[arg(long, default_value_t = 4096)]
block_size: usize,
},
}
/// Dump live GPU weights to a flat binary file, ordered by safetensors
/// file and offset to match the on-disk layout.
///
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
let dump_path = output_dir.join(".live_dump.bin");
let index_path = output_dir.join(".live_dump.json");
let status = Command::new("python3")
.arg("-c")
.arg(format!(r#"
import torch, json
handles = torch.load("{handles}", weights_only=False)
index = {{}}
offset = 0
with open("{dump}", "wb") as f:
for name in sorted(handles.keys()):
info = handles[name]
func, args = info["handle"]
tensor = func(*args)
data = tensor.contiguous().cpu().numpy().tobytes()
f.write(data)
index[name] = {{"offset": offset, "size": len(data)}}
offset += len(data)
with open("{index}", "w") as f:
json.dump(index, f)
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
"#,
handles = handles_path.display(),
dump = dump_path.display(),
index = index_path.display(),
))
.status()
.context("Failed to run Python weight dump")?;
if !status.success() {
bail!("Python weight dump failed");
}
let index_str = fs::read_to_string(&index_path)?;
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
let dump_data = fs::read(&dump_path)?;
let mut result = HashMap::new();
for (name, entry) in &index {
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
}
// Clean up temp files
let _ = fs::remove_file(&dump_path);
let _ = fs::remove_file(&index_path);
Ok(result)
}
#[derive(serde::Deserialize)]
struct DumpEntry {
offset: usize,
size: usize,
}
/// Read the safetensors index to map parameter names to files.
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
let index_path = model_dir.join("model.safetensors.index.json");
if !index_path.exists() {
// Single file model
return Ok(HashMap::new());
}
let index_str = fs::read_to_string(&index_path)?;
let index: serde_json::Value = serde_json::from_str(&index_str)?;
let weight_map = index["weight_map"]
.as_object()
.context("No weight_map in index")?;
let mut result = HashMap::new();
for (name, file) in weight_map {
result.insert(name.clone(), file.as_str().unwrap().to_string());
}
Ok(result)
}
/// Sync changed blocks from live weights into a mmap'd safetensors file.
/// Returns (total_bytes_compared, bytes_changed).
fn sync_tensors_to_file(
file_path: &Path,
tensors: &[(String, Vec<u8>)],
block_size: usize,
) -> Result<(usize, usize)> {
use safetensors::SafeTensors;
let file = fs::OpenOptions::new()
.read(true)
.write(true)
.open(file_path)
.with_context(|| format!("Failed to open {}", file_path.display()))?;
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
// Parse safetensors header to find tensor offsets
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
let header_json: serde_json::Value =
serde_json::from_slice(&mmap[8..8 + header_size])?;
let data_start = 8 + header_size;
let mut total_compared = 0usize;
let mut total_changed = 0usize;
for (name, live_data) in tensors {
let meta = match header_json.get(name) {
Some(m) => m,
None => {
eprintln!(" Warning: {} not found in {}", name, file_path.display());
continue;
}
};
let offsets = meta["data_offsets"].as_array().unwrap();
let start = data_start + offsets[0].as_u64().unwrap() as usize;
let end = data_start + offsets[1].as_u64().unwrap() as usize;
let disk_data = &mmap[start..end];
if disk_data.len() != live_data.len() {
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
name, disk_data.len(), live_data.len());
continue;
}
// Diff block by block, memcpy only changed blocks
let mut offset = 0;
while offset < disk_data.len() {
let block_end = (offset + block_size).min(disk_data.len());
total_compared += block_end - offset;
if disk_data[offset..block_end] != live_data[offset..block_end] {
mmap[start + offset..start + block_end]
.copy_from_slice(&live_data[offset..block_end]);
total_changed += block_end - offset;
}
offset = block_end;
}
}
mmap.flush()?;
Ok((total_compared, total_changed))
}
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
if !handles.exists() {
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
handles.display());
}
eprintln!("Dumping live weights from GPU...");
let live_weights = dump_live_weights(&handles, &model_dir)?;
eprintln!(" {} tensors dumped", live_weights.len());
// Map parameter names to safetensors files
let weight_map = read_safetensors_index(&model_dir)?;
// Group tensors by safetensors file
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
for (name, data) in live_weights {
let file = weight_map
.get(&name)
.cloned()
.unwrap_or_else(|| "model.safetensors".to_string());
by_file.entry(file).or_default().push((name, data));
}
let mut total_compared = 0usize;
let mut total_changed = 0usize;
for (filename, tensors) in &by_file {
let file_path = model_dir.join(filename);
if !file_path.exists() {
eprintln!(" Warning: {} not found, skipping", filename);
continue;
}
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
total_compared += compared;
total_changed += changed;
if changed > 0 {
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
}
}
if total_changed == 0 {
eprintln!("No changes — model files are up to date");
} else {
eprintln!(
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
total_changed as f64 / 1e6,
total_compared as f64 / 1e9,
total_changed as f64 / total_compared as f64 * 100.0,
);
}
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Cmd::Sync { handles, model_dir, block_size } => {
cmd_sync(handles, model_dir, block_size)
}
}
}