diff --git a/Cargo.lock b/Cargo.lock index eb53ed5..dfca607 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,11 +492,12 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", - "json5", + "json-five", "libc", "log", "memchr", "memmap2", + "notify-debouncer-mini", "paste", "peg", "ratatui", @@ -1088,6 +1089,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + [[package]] name = "futures" version = "0.3.32" @@ -1453,6 +1463,26 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199" +dependencies = [ + "bitflags 2.11.0", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "instability" version = "0.3.12" @@ -1531,6 +1561,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json-five" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "865f2d01a4549c1fd8c60640c03ae5249eb374cd8cde8b905628d4b1af95c87c" +dependencies = [ + "serde", + "unicode-general-category", +] + [[package]] name = "json5" version = "1.3.1" @@ -1552,6 +1592,26 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "kqueue" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "lab" version = "0.11.0" @@ -1774,6 +1834,45 @@ dependencies = [ "memchr", ] +[[package]] +name = "notify" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" +dependencies = [ + "bitflags 2.11.0", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio", + "notify-types", + "walkdir", + "windows-sys 0.60.2", +] + +[[package]] +name = "notify-debouncer-mini" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17849edfaabd9a5fef1c606d99cfc615a8e99f7ac4366406d86c7942a3184cf2" +dependencies = [ + "log", + "notify", + "notify-types", + "tempfile", +] + +[[package]] +name = "notify-types" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a" +dependencies = [ + "bitflags 2.11.0", +] + [[package]] name = "num-conv" version = "0.2.1" @@ -3384,6 +3483,12 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" +[[package]] +name = "unicode-general-category" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b993bddc193ae5bd0d623b49ec06ac3e9312875fdae725a975c51db1cc1677f" + [[package]] name = "unicode-ident" version = "1.0.24" @@ -3794,7 +3899,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", ] [[package]] @@ -3812,14 +3926,31 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -3828,48 +3959,96 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index c253bd7..7cdf851 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,8 @@ log = "0.4" serde = { version = "1", features = ["derive"] } serde_json = "1" -json5 = "1.3" +json-five = "0.3" +notify-debouncer-mini = "0.7" ratatui = { version = "0.30", features = ["unstable-rendered-line-info"] } tui-markdown = { git = "https://github.com/koverstreet/tui-markdown", subdirectory = "tui-markdown" } diff --git a/src/agent/context.rs b/src/agent/context.rs index c43c023..37dbf48 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -92,7 +92,7 @@ pub struct NodeLeaf { body: NodeBody, #[serde(skip)] token_ids: Vec, - timestamp: Option>, + timestamp: DateTime, } impl<'de> Deserialize<'de> for NodeLeaf { @@ -100,7 +100,7 @@ impl<'de> Deserialize<'de> for NodeLeaf { #[derive(Deserialize)] struct Raw { body: NodeBody, - timestamp: Option>, + timestamp: DateTime, } let raw = Raw::deserialize(deserializer)?; let token_ids = if raw.body.is_prompt_visible() { @@ -119,6 +119,7 @@ pub enum AstNode { Branch { role: Role, children: Vec, + timestamp: DateTime, /// Per-response memory attribution from full scoring matrix. /// Maps memory key → divergence score for this response. #[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")] @@ -252,18 +253,18 @@ impl NodeLeaf { } else { vec![] }; - Self { body, token_ids, timestamp: None } + Self { body, token_ids, timestamp: Utc::now() } } pub fn with_timestamp(mut self, ts: DateTime) -> Self { - self.timestamp = Some(ts); + self.timestamp = ts; self } pub fn body(&self) -> &NodeBody { &self.body } pub fn token_ids(&self) -> &[u32] { &self.token_ids } pub fn tokens(&self) -> usize { self.token_ids.len() } - pub fn timestamp(&self) -> Option> { self.timestamp } + pub fn timestamp(&self) -> DateTime { self.timestamp } } impl AstNode { @@ -307,13 +308,14 @@ impl AstNode { // -- Branch constructors -------------------------------------------------- pub fn branch(role: Role, children: Vec) -> Self { - Self::Branch { role, children, memory_scores: Default::default() } + Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() } } pub fn system_msg(text: impl Into) -> Self { Self::Branch { role: Role::System, children: vec![Self::content(text)], + timestamp: Utc::now(), memory_scores: Default::default(), } } @@ -322,6 +324,7 @@ impl AstNode { Self::Branch { role: Role::User, children: vec![Self::content(text)], + timestamp: Utc::now(), memory_scores: Default::default(), } } @@ -338,9 +341,10 @@ impl AstNode { }; Self::Leaf(NodeLeaf { token_ids, ..leaf }) } - Self::Branch { role, children, memory_scores, .. } => Self::Branch { + Self::Branch { role, children, timestamp, memory_scores } => Self::Branch { role, children: children.into_iter().map(|c| c.retokenize()).collect(), + timestamp, memory_scores, }, } @@ -348,8 +352,8 @@ impl AstNode { pub fn with_timestamp(mut self, ts: DateTime) -> Self { match &mut self { - Self::Leaf(leaf) => leaf.timestamp = Some(ts), - Self::Branch { .. } => {} + Self::Leaf(leaf) => leaf.timestamp = ts, + Self::Branch { timestamp, .. } => *timestamp = ts, } self } @@ -370,7 +374,7 @@ impl AstNode { /// Short label for the UI. pub fn label(&self) -> String { - let cfg = crate::config::get(); + let app = crate::config::app(); match self { Self::Branch { role, children, .. } => { let preview = children.first() @@ -379,8 +383,8 @@ impl AstNode { .unwrap_or_default(); match role { Role::System => "system".into(), - Role::User => format!("{}: {}", cfg.user_name, preview), - Role::Assistant => format!("{}: {}", cfg.assistant_name, preview), + Role::User => format!("{}: {}", app.user_name, preview), + Role::Assistant => format!("{}: {}", app.assistant_name, preview), } } Self::Leaf(leaf) => match &leaf.body { @@ -988,7 +992,10 @@ impl ContextState { } pub fn context_window() -> usize { - crate::config::get().api_context_window + let app = crate::config::app(); + app.backends.get(&app.default_backend) + .and_then(|b| b.context_window) + .unwrap_or(128_000) } pub fn context_budget_tokens() -> usize { @@ -1340,4 +1347,35 @@ mod tests { assert_token_invariants(node); assert!(node.tokens() > 0); } + + // -- Timestamp deserialization tests ------------------------------------------ + + #[test] + fn test_timestamp_null_rejected() { + // Missing/null timestamps used to be accepted via a lenient + // deserialize fallback. Post-migration the schema is strict. + let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + #[test] + fn test_timestamp_missing_rejected() { + let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + #[test] + fn test_branch_timestamp_missing_rejected() { + let json = r#"{"Branch":{"role":"User","children":[]}}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + #[test] + fn test_timestamp_present_accepted() { + let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#; + let node: AstNode = serde_json::from_str(json).unwrap(); + let leaf = node.leaf().unwrap(); + assert_eq!(leaf.timestamp().to_rfc3339(), + "2026-04-16T12:00:00+00:00"); + } } diff --git a/src/agent/mod.rs b/src/agent/mod.rs index db1bf39..5368db6 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -139,7 +139,6 @@ impl DispatchState { pub struct Agent { pub client: ApiClient, pub app_config: crate::config::AppConfig, - pub prompt_file: String, pub session_id: String, pub context: crate::Mutex, pub state: crate::Mutex, @@ -189,7 +188,6 @@ impl Agent { client: ApiClient, personality: Vec<(String, String)>, app_config: crate::config::AppConfig, - prompt_file: String, conversation_log: Option, active_tools: tools::ActiveTools, agent_tools: Vec, @@ -220,7 +218,6 @@ impl Agent { let agent = Arc::new(Self { client, app_config, - prompt_file, session_id, context: crate::Mutex::new(context), state: crate::Mutex::new(AgentState { @@ -259,7 +256,6 @@ impl Agent { Arc::new(Self { client: self.client.clone(), app_config: self.app_config.clone(), - prompt_file: self.prompt_file.clone(), session_id: self.session_id.clone(), context: crate::Mutex::new(ctx), state: crate::Mutex::new(AgentState { diff --git a/src/agent/oneshot.rs b/src/agent/oneshot.rs index 2fce906..8bc8b53 100644 --- a/src/agent/oneshot.rs +++ b/src/agent/oneshot.rs @@ -183,8 +183,8 @@ fn resolve_prompt( state: &std::collections::BTreeMap, recently_written: &[String], ) -> String { - let cfg = crate::config::get(); - let template = template.replace("{assistant_name}", &cfg.assistant_name); + let template = template.replace("{assistant_name}", + &crate::config::app().assistant_name); let mut result = String::with_capacity(template.len()); let mut rest = template.as_str(); while let Some(start) = rest.find("{{") { @@ -247,25 +247,20 @@ impl AutoAgent { &mut self, bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>, ) -> Result<(), String> { - let config = crate::config::get(); - let base_url = config.api_base_url.as_deref().unwrap_or(""); - let api_key = config.api_key.as_deref().unwrap_or(""); - let model = config.api_model.as_deref().unwrap_or(""); - if base_url.is_empty() || model.is_empty() { - return Err("API not configured (no base_url or model)".to_string()); - } - let client = super::api::ApiClient::new(base_url, api_key, model); - - // Load system prompt + identity from config + // Load system prompt + identity from config. let cli = crate::user::CliArgs::default(); let (app, _) = crate::config::load_app(&cli) .map_err(|e| format!("config: {}", e))?; + let resolved = app.resolve_model(&app.default_backend) + .map_err(|e| format!("API not configured: {}", e))?; + let client = super::api::ApiClient::new( + &resolved.api_base, &resolved.api_key, &resolved.model_id); let personality = crate::config::reload_context() .await.map_err(|e| format!("config: {}", e))?; let agent = Agent::new( client, personality, - app, String::new(), + app, None, super::tools::ActiveTools::new(), super::tools::tools(), @@ -497,15 +492,20 @@ pub async fn run_one_agent( .map(|s| s.phase.clone()).collect(); // Bail check: if the agent defines a bail script, run it between steps. + // The script also refreshes our pid-file with the current phase — that's + // how concurrent agents know which phase each of us is in. let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name)); let state_dir_for_bail = state_dir.clone(); - // Find our own pid file so we can pass it to the bail script let our_pid = std::process::id(); let our_pid_file = format!("pid-{}", our_pid); + let step_phases_for_bail = step_phases.clone(); let bail_fn = move |step_idx: usize| -> Result<(), String> { if let Some(ref script) = bail_script { + let phase = step_phases_for_bail.get(step_idx) + .map(String::as_str).unwrap_or(""); let status = std::process::Command::new(script) .arg(&our_pid_file) + .arg(phase) .current_dir(&state_dir_for_bail) .status() .map_err(|e| format!("bail script {:?} failed: {}", script, e))?; diff --git a/src/bin/fix-timestamps.rs b/src/bin/fix-timestamps.rs new file mode 100644 index 0000000..31a8788 --- /dev/null +++ b/src/bin/fix-timestamps.rs @@ -0,0 +1,180 @@ +// fix-timestamps: One-off migration for ~/.consciousness/agent-sessions/ +// conversation.jsonl. +// +// Before Branch nodes carried their own timestamps, early entries were +// serialized with missing/null timestamp fields — they deserialize as +// UNIX_EPOCH via the (now-to-be-removed) deserialize_timestamp_or_epoch +// fallback. Training needs every entry to have a unique timestamp to +// dedup already-trained responses. +// +// Walks the file, synthesizes timestamps for any entry stuck at +// UNIX_EPOCH by linear interpolation between surrounding real +// timestamps. For child leaves inside a Branch, derives timestamps +// from the parent with a tiny per-child offset. +// +// SAFETY: reads from argv[1], writes to argv[1].tmp, renames into +// place. Keep a .bak copy before running. +// +// Usage: fix-timestamps + +use std::io::{BufRead, BufReader, BufWriter, Write}; +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use chrono::{DateTime, Duration, Utc}; + +use consciousness::agent::context::AstNode; + +fn main() -> Result<()> { + let path: PathBuf = std::env::args().nth(1) + .context("usage: fix-timestamps ")?.into(); + + let f = std::fs::File::open(&path) + .with_context(|| format!("open {}", path.display()))?; + let reader = BufReader::new(f); + + let mut nodes: Vec = Vec::new(); + for (i, line) in reader.lines().enumerate() { + let line = line?; + if line.trim().is_empty() { continue; } + let node: AstNode = serde_json::from_str(&line) + .with_context(|| format!("line {}: parse", i + 1))?; + nodes.push(node); + } + println!("read {} entries", nodes.len()); + + fix_top_level_timestamps(&mut nodes); + for node in &mut nodes { + propagate_to_children(node); + } + + // Ensure uniqueness — real timestamps can collide when two entries + // were written in the same ns; synthesized ones can also overlap. + // Bump colliding ns by 1 until unique. + let mut seen = std::collections::HashSet::new(); + let mut bumps = 0usize; + for (i, node) in nodes.iter_mut().enumerate() { + let ts = top_ts(node); + assert!(ts > DateTime::::UNIX_EPOCH, + "entry {}: still UNIX_EPOCH", i); + let mut ns = ts.timestamp_nanos_opt().expect("ts in i64 ns range"); + let mut bumped = false; + while !seen.insert(ns) { + ns += 1; + bumped = true; + bumps += 1; + } + if bumped { + set_top_ts(node, DateTime::::from_timestamp_nanos(ns)); + } + } + println!("all {} timestamps real and unique ({} ns bumps)", + nodes.len(), bumps); + + let tmp = path.with_extension("jsonl.tmp"); + { + let f = std::fs::File::create(&tmp) + .with_context(|| format!("create {}", tmp.display()))?; + let mut w = BufWriter::new(f); + for node in &nodes { + serde_json::to_writer(&mut w, node)?; + w.write_all(b"\n")?; + } + w.flush()?; + } + std::fs::rename(&tmp, &path) + .with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?; + println!("wrote {}", path.display()); + + Ok(()) +} + +fn top_ts(node: &AstNode) -> DateTime { + match node { + AstNode::Leaf(leaf) => leaf.timestamp(), + AstNode::Branch { timestamp, .. } => *timestamp, + } +} + +fn set_top_ts(node: &mut AstNode, ts: DateTime) { + match node { + AstNode::Leaf(leaf) => *leaf = leaf.clone().with_timestamp(ts), + AstNode::Branch { timestamp, .. } => *timestamp = ts, + } +} + +/// Fill in missing top-level timestamps. Strategy: +/// - If two real timestamps bracket a run of missing ones, linearly +/// interpolate between them. +/// - If missing ones precede the first real one, back-fill using +/// (first_real - N·1µs). +/// - If missing ones follow the last real one, forward-fill. +/// - If no real timestamps exist at all, synthesize from now() going +/// backwards. +fn fix_top_level_timestamps(nodes: &mut [AstNode]) { + let real: Vec<(usize, DateTime)> = nodes.iter().enumerate() + .filter(|(_, n)| top_ts(n) > DateTime::::UNIX_EPOCH) + .map(|(i, n)| (i, top_ts(n))) + .collect(); + + if real.is_empty() { + let now = Utc::now(); + let len = nodes.len(); + for (i, node) in nodes.iter_mut().enumerate() { + let ts = now - Duration::microseconds((len - i) as i64); + set_top_ts(node, ts); + } + return; + } + + // Helper: bisect real[] for the nearest real entries around idx. + let find_bracket = |idx: usize| -> (Option<(usize, DateTime)>, + Option<(usize, DateTime)>) { + let pos = real.binary_search_by_key(&idx, |(i, _)| *i); + let (prior_pos, next_pos) = match pos { + Ok(p) => (Some(p), Some(p)), + Err(p) => ( + if p == 0 { None } else { Some(p - 1) }, + if p >= real.len() { None } else { Some(p) }, + ), + }; + (prior_pos.map(|p| real[p]), next_pos.map(|p| real[p])) + }; + + for i in 0..nodes.len() { + if top_ts(&nodes[i]) > DateTime::::UNIX_EPOCH { + continue; + } + let (prior, next) = find_bracket(i); + let new_ts = match (prior, next) { + (Some((pi, pt)), Some((ni, nt))) if pi != ni => { + // Linear interpolate. + let span_ns = (nt - pt).num_nanoseconds().unwrap_or(0); + let offset_ns = span_ns * (i - pi) as i64 / (ni - pi) as i64; + pt + Duration::nanoseconds(offset_ns) + } + (Some((pi, pt)), _) => { + pt + Duration::microseconds((i - pi) as i64) + } + (None, Some((ni, nt))) => { + nt - Duration::microseconds((ni - i) as i64) + } + (None, None) => unreachable!(), + }; + set_top_ts(&mut nodes[i], new_ts); + } +} + +/// For every Branch, ensure each child Leaf has a timestamp. If missing, +/// use parent.ts + child_idx·1ns so siblings stay unique but close. +fn propagate_to_children(node: &mut AstNode) { + if let AstNode::Branch { timestamp, children, .. } = node { + let parent_ts = *timestamp; + for (ci, child) in children.iter_mut().enumerate() { + if top_ts(child) <= DateTime::::UNIX_EPOCH { + set_top_ts(child, parent_ts + Duration::nanoseconds(ci as i64)); + } + propagate_to_children(child); + } + } +} diff --git a/src/cli/node.rs b/src/cli/node.rs index 5472505..c4305a7 100644 --- a/src/cli/node.rs +++ b/src/cli/node.rs @@ -197,7 +197,7 @@ pub async fn cmd_load_context(stats: bool) -> Result<()> { return Ok(()); } - println!("=== MEMORY SYSTEM ({}) ===", cfg.assistant_name); + println!("=== MEMORY SYSTEM ({}) ===", crate::config::app().assistant_name); if !personality.is_empty() { println!("--- personality_nodes ({}) ---", personality.len()); diff --git a/src/config.rs b/src/config.rs index 9f9ad9a..b7ea597 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,9 +3,6 @@ // Single config file: ~/.consciousness/config.json5 // Memory settings in the "memory" section (Config) // Agent/backend settings at top level (AppConfig) -// -// Legacy fallback: ~/.consciousness/config.jsonl -// Env override: POC_MEMORY_CONFIG use std::collections::HashMap; use std::path::PathBuf; @@ -29,9 +26,7 @@ pub fn config_path() -> PathBuf { static CONFIG: OnceLock>> = OnceLock::new(); -fn default_context_window() -> usize { 128_000 } fn default_stream_timeout() -> u64 { 60 } -fn default_scoring_chunk_tokens() -> usize { 50_000 } fn default_scoring_interval_secs() -> u64 { 3600 } // 1 hour fn default_scoring_response_window() -> usize { 100 } fn default_node_weight() -> f64 { 0.7 } @@ -45,8 +40,6 @@ fn default_identity_dir() -> PathBuf { #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { - pub user_name: String, - pub assistant_name: String, #[serde(deserialize_with = "deserialize_path")] pub data_dir: PathBuf, #[serde(default = "default_identity_dir", deserialize_with = "deserialize_path")] @@ -62,51 +55,24 @@ pub struct Config { /// Nodes loaded into subconscious agent context #[serde(default)] pub agent_nodes: Vec, - pub journal_days: u32, - pub journal_max: usize, pub llm_concurrency: usize, - pub agent_budget: usize, - #[serde(deserialize_with = "deserialize_path")] - pub prompts_dir: PathBuf, - /// Resolved from agent_model → models → backend (not in config directly) - #[serde(skip)] - pub api_base_url: Option, - #[serde(skip)] - pub api_key: Option, - #[serde(skip)] - pub api_model: Option, - #[serde(skip, default = "default_context_window")] - pub api_context_window: usize, - /// Used to resolve API settings, not stored on Config - #[serde(default)] - agent_model: Option, /// Stream chunk timeout in seconds (no data = timeout). #[serde(default = "default_stream_timeout")] pub api_stream_timeout_secs: u64, - /// Max tokens per chunk for memory scoring logprobs calls. - #[serde(default = "default_scoring_chunk_tokens")] - pub scoring_chunk_tokens: usize, /// How often to re-score memory nodes (seconds). Default: 3600 (1 hour). #[serde(default = "default_scoring_interval_secs")] pub scoring_interval_secs: u64, /// Number of assistant responses to score per memory. Default: 50. #[serde(default = "default_scoring_response_window")] pub scoring_response_window: usize, - pub api_reasoning: String, pub agent_types: Vec, #[serde(default)] pub mcp_servers: Vec, #[serde(default)] pub lsp_servers: Vec, - /// Surface agent timeout in seconds. - #[serde(default)] - pub surface_timeout_secs: Option, /// Max conversation bytes to include in surface agent context. #[serde(default)] pub surface_conversation_bytes: Option, - /// Hook events that trigger the surface agent. - #[serde(default)] - pub surface_hooks: Vec, // Spreading activation parameters #[serde(default = "default_node_weight")] @@ -123,36 +89,21 @@ impl Default for Config { fn default() -> Self { let home = dirs::home_dir().unwrap_or_default(); Self { - user_name: "User".to_string(), - assistant_name: "Assistant".to_string(), data_dir: home.join(".consciousness/memory"), identity_dir: home.join(".consciousness/identity"), projects_dir: home.join(".claude/projects"), protected_nodes: Vec::new(), personality_nodes: vec!["identity".into(), "core-practices".into()], agent_nodes: vec!["identity".into(), "core-practices".into()], - journal_days: 7, - journal_max: 20, llm_concurrency: 1, - agent_budget: 1000, - prompts_dir: home.join(".consciousness/prompts"), - api_base_url: None, - api_key: None, - api_model: None, - api_context_window: default_context_window(), api_stream_timeout_secs: default_stream_timeout(), - scoring_chunk_tokens: default_scoring_chunk_tokens(), scoring_interval_secs: default_scoring_interval_secs(), scoring_response_window: default_scoring_response_window(), - agent_model: None, - api_reasoning: "high".to_string(), agent_types: vec![ "linker".into(), "organize".into(), "distill".into(), "separator".into(), "split".into(), ], - surface_timeout_secs: None, surface_conversation_bytes: None, - surface_hooks: vec![], mcp_servers: vec![], lsp_servers: vec![], default_node_weight: default_node_weight(), @@ -165,41 +116,20 @@ impl Default for Config { impl Config { fn load_from_file() -> Self { - if let Some(config) = Self::try_load_shared() { - return config; - } - Self::load_legacy_jsonl() + Self::try_load_shared().unwrap_or_default() } /// Load from shared config. Memory settings in the "memory" section; /// API settings resolved from models + backend configuration. fn try_load_shared() -> Option { let content = std::fs::read_to_string(config_path()).ok()?; - let root: serde_json::Value = json5::from_str(&content).ok()?; + let root: serde_json::Value = json_five::from_str(&content).ok()?; let mem_value = root.get("memory")?; let mut config: Config = serde_json::from_value(mem_value.clone()).ok()?; config.llm_concurrency = config.llm_concurrency.max(1); - // Resolve API settings: agent_model → models → backend - if let Some(model_name) = &config.agent_model - && let Some(model_cfg) = root.get("models").and_then(|m| m.get(model_name.as_str())) { - let backend_name = model_cfg.get("backend").and_then(|v| v.as_str()).unwrap_or(""); - let model_id = model_cfg.get("model_id").and_then(|v| v.as_str()).unwrap_or(""); - - if let Some(backend) = root.get(backend_name) { - config.api_base_url = backend.get("base_url") - .and_then(|v| v.as_str()).map(String::from); - config.api_key = backend.get("api_key") - .and_then(|v| v.as_str()).map(String::from); - } - config.api_model = Some(model_id.to_string()); - if let Some(cw) = model_cfg.get("context_window").and_then(|v| v.as_u64()) { - config.api_context_window = cw as usize; - } - } - - // Top-level config sections (not inside "memory") + // Top-level sections (not inside "memory"). if let Some(servers) = root.get("lsp_servers") { config.lsp_servers = serde_json::from_value(servers.clone()).unwrap_or_default(); } @@ -209,11 +139,6 @@ impl Config { Some(config) } - - /// Load from legacy JSONL config — deprecated, just return defaults. - fn load_legacy_jsonl() -> Self { - Config::default() - } } /// Get the global memory config (cheap Arc clone). @@ -237,27 +162,85 @@ pub fn reload() -> bool { changed } +/// Spawn a background thread that watches `~/.consciousness/config.json5` +/// and reloads both the memory Config and the global AppConfig whenever +/// the file changes on disk. Lets edits from vim / F6 hotkeys / manual +/// tweaks land live without restarting the process. +pub fn watch_config(cli: crate::user::CliArgs) { + use notify_debouncer_mini::{new_debouncer, notify::RecursiveMode}; + + let path = config_path(); + // Watch the parent directory — editors often replace-via-rename, so + // watching the file itself misses the new inode. + let Some(parent) = path.parent().map(|p| p.to_path_buf()) else { + crate::dbglog!("[config] no parent for {}, skipping watch", path.display()); + return; + }; + + std::thread::Builder::new() + .name("config-watcher".into()) + .spawn(move || { + let (tx, rx) = std::sync::mpsc::channel(); + let mut debouncer = match new_debouncer(std::time::Duration::from_millis(200), tx) { + Ok(d) => d, + Err(e) => { + crate::dbglog!("[config] watcher setup failed: {}", e); + return; + } + }; + if let Err(e) = debouncer.watcher() + .watch(&parent, RecursiveMode::NonRecursive) + { + crate::dbglog!("[config] watch({}) failed: {}", parent.display(), e); + return; + } + crate::dbglog!("[config] watching {}", path.display()); + + while let Ok(res) = rx.recv() { + let Ok(events) = res else { continue; }; + if !events.iter().any(|e| e.path == path) { continue; } + + // Reload both halves. + let mem_changed = reload(); + let app_changed = match build_figment(&cli).extract::() { + Ok(app) => { + install_app(app); + true + } + Err(e) => { + crate::dbglog!("[config] reload: AppConfig parse failed: {}", e); + false + } + }; + crate::dbglog!("[config] reloaded (memory_changed={}, app_changed={})", + mem_changed, app_changed); + } + }) + .ok(); +} + // ============================================================ // Agent config (top-level settings) // ============================================================ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AppConfig { - pub backend: String, - pub anthropic: BackendConfig, - pub openrouter: BackendConfig, + #[serde(default = "default_user_name")] + pub user_name: String, + #[serde(default = "default_assistant_name")] + pub assistant_name: String, + /// Named model endpoints — credentials, base URL, and model id bundled + /// into one entry per backend. Keyed by name, selected by + /// `default_backend` or by `--model ` on the CLI. #[serde(default)] - pub deepinfra: BackendConfig, - pub prompts: PromptConfig, + pub backends: HashMap, + #[serde(default)] + pub default_backend: String, pub debug: bool, pub compaction: CompactionConfig, pub dmn: DmnConfig, - #[serde(skip_serializing_if = "Option::is_none")] - pub memory_project: Option, #[serde(default)] - pub models: HashMap, - #[serde(default = "default_model_name")] - pub default_model: String, + pub learn: LearnConfig, #[serde(default)] pub mcp_servers: Vec, #[serde(default)] @@ -284,32 +267,17 @@ pub struct LspServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct BackendConfig { + /// API key for the backend. #[serde(default)] pub api_key: String, - #[serde(default)] - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] + /// Base URL for the backend's OpenAI-compatible endpoint. + #[serde(default, skip_serializing_if = "Option::is_none")] pub base_url: Option, -} - -impl BackendConfig { - fn resolve(&self, default_base: &str) -> Result<(String, String, String)> { - if self.api_key.is_empty() { - anyhow::bail!( - "No API key. Set it in {} or use --api-key", - config_path().display() - ); - } - let base = self.base_url.clone() - .unwrap_or_else(|| default_base.to_string()); - Ok((base, self.api_key.clone(), self.model.clone())) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PromptConfig { - pub anthropic: String, - pub other: String, + /// Model identifier sent to the API. + pub model_id: String, + /// Context window size in tokens. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub context_window: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -324,65 +292,57 @@ pub struct DmnConfig { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelConfig { - /// Backend name ("anthropic" or "openrouter") - pub backend: String, - /// Model identifier sent to the API - pub model_id: String, - /// Instruction file ("CLAUDE.md" or "POC.md"). +pub struct LearnConfig { + /// Divergence threshold — responses scoring above this become + /// fine-tuning candidates. Lower = more sensitive. + #[serde(default = "default_learn_threshold")] + pub threshold: f64, + /// Whether to generate "what would the model have said without + /// memories" alternates alongside each scoring run. Expensive — + /// one full streaming generation per candidate. #[serde(default)] - pub prompt_file: Option, - /// Context window size in tokens. - #[serde(default)] - pub context_window: Option, + pub generate_alternates: bool, } +fn default_learn_threshold() -> f64 { 1.0 } + +impl Default for LearnConfig { + fn default() -> Self { + Self { + threshold: default_learn_threshold(), + generate_alternates: false, + } + } +} + +fn default_user_name() -> String { "User".into() } +fn default_assistant_name() -> String { "Assistant".into() } + impl Default for AppConfig { fn default() -> Self { Self { - backend: "openrouter".to_string(), - anthropic: BackendConfig { - api_key: String::new(), - model: "claude-opus-4-6-20250918".to_string(), - base_url: None, - }, - openrouter: BackendConfig { - api_key: String::new(), - model: "qwen/qwen3.5-397b-a17b".to_string(), - base_url: Some("https://openrouter.ai/api/v1".to_string()), - }, - deepinfra: BackendConfig { - api_key: String::new(), - model: String::new(), - base_url: Some("https://api.deepinfra.com/v1/openai".to_string()), - }, - prompts: PromptConfig { - anthropic: "CLAUDE.md".to_string(), - other: "POC.md".to_string(), - }, + user_name: default_user_name(), + assistant_name: default_assistant_name(), + backends: HashMap::new(), + default_backend: String::new(), debug: false, compaction: CompactionConfig { hard_threshold_pct: 90, soft_threshold_pct: 80, }, dmn: DmnConfig { max_turns: 20 }, - memory_project: None, - models: HashMap::new(), - default_model: String::new(), + learn: LearnConfig::default(), mcp_servers: Vec::new(), lsp_servers: Vec::new(), } } } -fn default_model_name() -> String { String::new() } - /// Resolved, ready-to-use agent session config. pub struct SessionConfig { pub api_base: String, pub api_key: String, pub model: String, - pub prompt_file: String, /// Identity/personality nodes as (name, content) pairs. pub context_parts: Vec<(String, String)>, pub session_dir: PathBuf, @@ -398,37 +358,22 @@ pub struct ResolvedModel { pub api_base: String, pub api_key: String, pub model_id: String, - pub prompt_file: String, pub context_window: Option, } impl AppConfig { /// Resolve the active backend and assemble prompts into a SessionConfig. pub async fn resolve(&self, cli: &crate::user::CliArgs) -> Result { - let (api_base, api_key, model, prompt_file); - - if !self.models.is_empty() { - let model_name = cli.model.as_deref().unwrap_or(&self.default_model); - let resolved = self.resolve_model(model_name)?; - api_base = resolved.api_base; - api_key = resolved.api_key; - model = resolved.model_id; - prompt_file = resolved.prompt_file; - } else { - let (base, key, mdl) = match self.backend.as_str() { - "anthropic" => self.anthropic.resolve("https://api.anthropic.com"), - _ => self.openrouter.resolve("https://openrouter.ai/api/v1"), - }?; - api_base = base; - api_key = key; - model = mdl; - prompt_file = if self.backend == "anthropic" { - self.prompts.anthropic.clone() - } else { - self.prompts.other.clone() - }; + if self.backends.is_empty() { + anyhow::bail!( + "no backends configured in {}. Add a `backends` section with at least one entry.", + config_path().display() + ); } + let name = cli.model.as_deref().unwrap_or(&self.default_backend); + let resolved = self.resolve_model(name)?; + let personality_nodes = get().personality_nodes.clone(); let context_parts = crate::mind::identity::personality_nodes(&personality_nodes).await; @@ -438,11 +383,13 @@ impl AppConfig { std::fs::create_dir_all(&session_dir).ok(); // CLI --api-base and --api-key override everything - let api_base = cli.api_base.clone().unwrap_or(api_base); - let api_key = cli.api_key.clone().unwrap_or(api_key); + let api_base = cli.api_base.clone().unwrap_or(resolved.api_base); + let api_key = cli.api_key.clone().unwrap_or(resolved.api_key); Ok(SessionConfig { - api_base, api_key, model, prompt_file, + api_base, + api_key, + model: resolved.model_id, context_parts, session_dir, app: self.clone(), @@ -450,55 +397,33 @@ impl AppConfig { }) } - /// Look up a named model and resolve its credentials from the backend config. + /// Look up a named backend and resolve its credentials. pub fn resolve_model(&self, name: &str) -> Result { - let model = self.models.get(name) + let b = self.backends.get(name) .ok_or_else(|| anyhow::anyhow!( - "Unknown model '{}'. Available: {}", + "Unknown backend '{}'. Available: {}", name, self.model_names().join(", "), ))?; - let (api_base, api_key) = match model.backend.as_str() { - "anthropic" => ( - self.anthropic.base_url.clone() - .unwrap_or_else(|| "https://api.anthropic.com".to_string()), - self.anthropic.api_key.clone(), - ), - "deepinfra" => ( - self.deepinfra.base_url.clone() - .unwrap_or_else(|| "https://api.deepinfra.com/v1/openai".to_string()), - self.deepinfra.api_key.clone(), - ), - _ => ( - self.openrouter.base_url.clone() - .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()), - self.openrouter.api_key.clone(), - ), - }; - - let prompt_file = model.prompt_file.clone() - .unwrap_or_else(|| { - if model.backend == "anthropic" { - self.prompts.anthropic.clone() - } else { - self.prompts.other.clone() - } - }); + let api_base = b.base_url.clone() + .ok_or_else(|| anyhow::anyhow!( + "backends.{}.base_url not set in {}", + name, config_path().display() + ))?; Ok(ResolvedModel { name: name.to_string(), api_base, - api_key, - model_id: model.model_id.clone(), - prompt_file, - context_window: model.context_window, + api_key: b.api_key.clone(), + model_id: b.model_id.clone(), + context_window: b.context_window, }) } - /// List available model names, sorted. + /// List available backend names, sorted. pub fn model_names(&self) -> Vec { - let mut names: Vec<_> = self.models.keys().cloned().collect(); + let mut names: Vec<_> = self.backends.keys().cloned().collect(); names.sort(); names } @@ -518,7 +443,7 @@ impl Provider for Json5File { fn data(&self) -> figment::Result> { match std::fs::read_to_string(&self.0) { Ok(content) => { - let value: figment::value::Value = json5::from_str(&content) + let value: figment::value::Value = json_five::from_str(&content) .map_err(|e| figment::Error::from(format!("{}: {}", self.0.display(), e)))?; Serialized::defaults(value).data() } @@ -540,11 +465,6 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment { let mut f = Figment::from(Serialized::defaults(AppConfig::default())) .merge(Json5File(config_path())); - merge_opt!(f, cli.backend, "backend"); - merge_opt!(f, cli.model, "anthropic.model", "openrouter.model"); - merge_opt!(f, cli.api_key, "anthropic.api_key", "openrouter.api_key"); - merge_opt!(f, cli.api_base, "anthropic.base_url", "openrouter.base_url"); - merge_opt!(f, cli.memory_project, "memory_project"); merge_opt!(f, cli.dmn_max_turns, "dmn.max_turns"); if cli.debug { f = f.merge(Serialized::default("debug", true)); @@ -554,12 +474,46 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment { } /// Load just the AppConfig — no validation, no prompt assembly. +/// Also installs the loaded AppConfig into the global cache so +/// `config::app()` is available everywhere. pub fn load_app(cli: &crate::user::CliArgs) -> Result<(AppConfig, Figment)> { let figment = build_figment(cli); let app: AppConfig = figment.extract().context("Failed to load configuration")?; + install_app(app.clone()); Ok((app, figment)) } +// ============================================================ +// Global AppConfig cache (writable, for runtime-mutable settings +// like learn.threshold that F6 edits via config_writer). +// ============================================================ + +static APP_CONFIG: OnceLock> = OnceLock::new(); + +fn install_app(app: AppConfig) { + let slot = APP_CONFIG.get_or_init(|| RwLock::new(app.clone())); + *slot.write().unwrap() = app; +} + +/// Current AppConfig, held under a read lock. Reads should be brief +/// (no holding across await / long work) to avoid starving writers. +/// Panics if called before load_app — which runs once at startup. +pub fn app() -> std::sync::RwLockReadGuard<'static, AppConfig> { + APP_CONFIG + .get() + .expect("config::app() called before load_app()") + .read() + .unwrap() +} + +/// Mutate the cached AppConfig in place. Used by config_writer to keep +/// the in-memory view in sync with disk after surgical edits to +/// ~/.consciousness/config.json5. +pub fn update_app(f: impl FnOnce(&mut AppConfig)) { + let slot = APP_CONFIG.get().expect("update_app before load_app"); + f(&mut *slot.write().unwrap()); +} + /// Load the full config: figment → AppConfig → resolve backend → assemble prompts. pub async fn load_session(cli: &crate::user::CliArgs) -> Result<(SessionConfig, Figment)> { let (app, figment) = load_app(cli)?; @@ -585,38 +539,28 @@ pub fn show_config(app: &AppConfig, figment: &Figment) { } println!("# Effective configuration\n"); - println!("backend: {:?} ({})", app.backend, src(figment, "backend")); - for (name, b) in [("anthropic", &app.anthropic), ("openrouter", &app.openrouter)] { - println!("\n{}:", name); - println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("{name}.api_key"))); - println!(" model: {:?} ({})", b.model, src(figment, &format!("{name}.model"))); - if let Some(ref url) = b.base_url { - println!(" base_url: {:?} ({})", url, src(figment, &format!("{name}.base_url"))); - } - } - println!("\nprompts:"); - println!(" anthropic: {:?} ({})", app.prompts.anthropic, src(figment, "prompts.anthropic")); - println!(" other: {:?} ({})", app.prompts.other, src(figment, "prompts.other")); + println!("user_name: {:?} ({})", app.user_name, src(figment, "user_name")); + println!("assistant_name: {:?} ({})", app.assistant_name, src(figment, "assistant_name")); println!("\ndebug: {} ({})", app.debug, src(figment, "debug")); println!("\ncompaction:"); println!(" hard_threshold_pct: {} ({})", app.compaction.hard_threshold_pct, src(figment, "compaction.hard_threshold_pct")); println!(" soft_threshold_pct: {} ({})", app.compaction.soft_threshold_pct, src(figment, "compaction.soft_threshold_pct")); println!("\ndmn:"); println!(" max_turns: {} ({})", app.dmn.max_turns, src(figment, "dmn.max_turns")); - if let Some(ref p) = app.memory_project { - println!("\nmemory_project: {:?} ({})", p, src(figment, "memory_project")); - } - println!("\ndefault_model: {:?}", app.default_model); - if !app.models.is_empty() { - println!("\nmodels:"); - for (name, m) in &app.models { + println!("\ndefault_backend: {:?} ({})", app.default_backend, src(figment, "default_backend")); + if !app.backends.is_empty() { + println!("\nbackends:"); + let mut names: Vec<_> = app.backends.keys().cloned().collect(); + names.sort(); + for name in names { + let b = &app.backends[&name]; println!(" {}:", name); - println!(" backend: {:?}", m.backend); - println!(" model_id: {:?}", m.model_id); - if let Some(ref pf) = m.prompt_file { - println!(" prompt_file: {:?}", pf); + println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("backends.{name}.api_key"))); + if let Some(ref url) = b.base_url { + println!(" base_url: {:?} ({})", url, src(figment, &format!("backends.{name}.base_url"))); } - if let Some(cw) = m.context_window { + println!(" model_id: {:?}", b.model_id); + if let Some(cw) = b.context_window { println!(" context_window: {}", cw); } } diff --git a/src/config_writer.rs b/src/config_writer.rs new file mode 100644 index 0000000..079449f --- /dev/null +++ b/src/config_writer.rs @@ -0,0 +1,448 @@ +// config_writer.rs — Surgical edits to ~/.consciousness/config.json5 +// +// Uses json-five's round-trip parser to mutate specific fields while +// preserving the surrounding comments, whitespace, and formatting. + +use std::path::Path; + +use anyhow::{anyhow, Context as _, Result}; +use json_five::rt::parser::{ + from_str, JSONKeyValuePair, JSONObjectContext, JSONValue, KeyValuePairContext, +}; + +use crate::config::config_path; + +/// Read the config, apply `mutate` to the root JSONValue, write it back atomically. +fn edit_config Result<()>>(mutate: F) -> Result<()> { + let path = config_path(); + let src = std::fs::read_to_string(&path) + .with_context(|| format!("read {}", path.display()))?; + + let mut text = from_str(&src) + .map_err(|e| anyhow!("parse {}: {}", path.display(), e))?; + mutate(&mut text.value)?; + + write_atomic(&path, &text.to_string()) +} + +fn write_atomic(path: &Path, content: &str) -> Result<()> { + let parent = path.parent() + .ok_or_else(|| anyhow!("config path has no parent: {}", path.display()))?; + let tmp = parent.join(format!( + ".{}.tmp", + path.file_name().unwrap_or_default().to_string_lossy(), + )); + std::fs::write(&tmp, content) + .with_context(|| format!("write {}", tmp.display()))?; + std::fs::rename(&tmp, path) + .with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?; + Ok(()) +} + +/// Match a key JSONValue against a string name. JSON5 allows keys to be +/// unquoted identifiers or single/double-quoted strings. +fn key_matches(key: &JSONValue, name: &str) -> bool { + match key { + JSONValue::Identifier(s) + | JSONValue::DoubleQuotedString(s) + | JSONValue::SingleQuotedString(s) => s == name, + _ => false, + } +} + +/// Find (or create) a child object under `parent`, returning a mutable borrow +/// of its key_value_pairs vector. +/// Append a new kvp to `object`, setting whitespace so the output is +/// multi-line with the given indentation: +/// +/// ```text +/// {first_key: first_val,} +/// ``` +/// +/// If `object` already has kvps, the separator between the last one and +/// ours goes in the prior kvp's wsc.3. If we're the first kvp, the +/// lead-in after `{` goes in the object's own wsc.0. +fn append_kvp_pretty( + object: &mut JSONValue, + key: JSONValue, + value: JSONValue, + inner_indent: &str, + outer_indent: &str, +) -> Result<()> { + let (pairs, ctx) = match object { + JSONValue::JSONObject { key_value_pairs, context } => { + let ctx = context.get_or_insert_with(|| JSONObjectContext { + wsc: (String::new(),), + }); + (key_value_pairs, ctx) + } + _ => return Err(anyhow!("not an object")), + }; + + if pairs.is_empty() { + ctx.wsc.0 = format!("\n{}", inner_indent); + } else { + let prev = pairs.last_mut().unwrap(); + let prev_ctx = prev.context.get_or_insert_with(|| KeyValuePairContext { + wsc: (String::new(), String::from(" "), String::new(), None), + }); + prev_ctx.wsc.3 = Some(format!("\n{}", inner_indent)); + } + + pairs.push(JSONKeyValuePair { + key, + value, + context: Some(KeyValuePairContext { + wsc: ( + String::new(), + String::from(" "), + String::new(), + Some(format!("\n{}", outer_indent)), + ), + }), + }); + + Ok(()) +} + +/// Find or create a child object under `parent`. Returns the index of +/// the kvp in parent's key_value_pairs so the caller can re-borrow +/// afterward. +fn get_or_create_object_idx( + parent: &mut JSONValue, + section: &str, + inner_indent: &str, + outer_indent: &str, +) -> Result { + let existing = match parent { + JSONValue::JSONObject { key_value_pairs, .. } => { + key_value_pairs.iter() + .position(|kvp| key_matches(&kvp.key, section)) + } + _ => return Err(anyhow!("config root is not an object")), + }; + + if let Some(i) = existing { + return Ok(i); + } + + append_kvp_pretty( + parent, + JSONValue::Identifier(section.to_string()), + JSONValue::JSONObject { + key_value_pairs: Vec::new(), + context: Some(JSONObjectContext { wsc: (String::new(),) }), + }, + inner_indent, + outer_indent, + )?; + + match parent { + JSONValue::JSONObject { key_value_pairs, .. } => Ok(key_value_pairs.len() - 1), + _ => unreachable!(), + } +} + +/// Set `section.key` to a literal scalar value (e.g., "1e-7", "42", "true"). +/// The literal is parsed as JSON5 so we preserve its source-form on round-trip. +pub fn set_scalar(section: &str, key: &str, literal: &str) -> Result<()> { + let value = parse_scalar_literal(literal)?; + edit_config(|root| { + // New top-level sections sit at column 4 (inside root `{`), + // and the root's closing `}` sits at column 0. + let section_idx = get_or_create_object_idx(root, section, " ", "")?; + + let section_value = match root { + JSONValue::JSONObject { key_value_pairs, .. } => { + &mut key_value_pairs[section_idx].value + } + _ => unreachable!(), + }; + + // Update in place if the key already exists. + if let JSONValue::JSONObject { key_value_pairs, .. } = section_value { + if let Some(kvp) = key_value_pairs.iter_mut() + .find(|k| key_matches(&k.key, key)) + { + kvp.value = value; + return Ok(()); + } + } + + // Append a new kvp. Inner keys sit at column 8, the section's + // closing `}` sits at column 4. + append_kvp_pretty( + section_value, + JSONValue::Identifier(key.to_string()), + value, + " ", + " ", + ) + }) +} + +/// Parse a scalar literal by round-tripping it through json-five. Keeps us +/// consistent with whatever scalars the library considers valid (hex, +/// exponents, Infinity, etc.). +fn parse_scalar_literal(literal: &str) -> Result { + let text = from_str(literal) + .map_err(|e| anyhow!("parse literal {:?}: {}", literal, e))?; + match text.value { + JSONValue::JSONObject { .. } | JSONValue::JSONArray { .. } => { + Err(anyhow!("set_scalar only accepts scalar literals, got {:?}", literal)) + } + v => Ok(v), + } +} + +/// Convenience: set `learn.threshold` to the given f64. +pub fn set_learn_threshold(value: f64) -> Result<()> { + // {:e} gives the minimal scientific notation that preserves the value. + set_scalar("learn", "threshold", &format!("{:e}", value))?; + crate::config::update_app(|app| app.learn.threshold = value); + Ok(()) +} + +/// Convenience: set `learn.generate_alternates` to the given bool. +pub fn set_learn_generate_alternates(value: bool) -> Result<()> { + set_scalar("learn", "generate_alternates", + if value { "true" } else { "false" })?; + crate::config::update_app(|app| app.learn.generate_alternates = value); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + // In-memory variant of set_scalar — used to test the mutation logic + // without touching disk. + fn set_scalar_inline( + root: &mut JSONValue, + section: &str, + key: &str, + literal: &str, + ) -> Result<()> { + let value = parse_scalar_literal(literal)?; + let section_idx = get_or_create_object_idx(root, section, " ", "")?; + let section_value = match root { + JSONValue::JSONObject { key_value_pairs, .. } => { + &mut key_value_pairs[section_idx].value + } + _ => unreachable!(), + }; + if let JSONValue::JSONObject { key_value_pairs, .. } = section_value { + if let Some(kvp) = key_value_pairs.iter_mut() + .find(|k| key_matches(&k.key, key)) + { + kvp.value = value; + return Ok(()); + } + } + append_kvp_pretty( + section_value, + JSONValue::Identifier(key.to_string()), + value, + " ", + " ", + ) + } + + fn edit_str Result<()>>(src: &str, f: F) -> Result { + let mut text = from_str(src).map_err(|e| anyhow!("{}", e))?; + f(&mut text.value)?; + Ok(text.to_string()) + } + + #[test] + fn replaces_existing_scalar() { + let src = r#"{ + // threshold for learning + learn: { + threshold: 0.001, // the old value + }, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "1e-7") + }).unwrap(); + assert!(out.contains("1e-7"), "output: {}", out); + assert!(out.contains("// threshold for learning")); + assert!(out.contains("// the old value")); + assert!(!out.contains("0.001")); + } + + #[test] + fn creates_missing_section() { + let src = r#"{ + // comment + memory: { user_name: "Kent" }, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "1e-7") + }).unwrap(); + assert!(out.contains("learn")); + assert!(out.contains("1e-7")); + assert!(out.contains("// comment")); + assert!(out.contains(r#"user_name: "Kent""#)); + } + + #[test] + fn preserves_comments_in_siblings() { + let src = r#"{ + memory: { + // sensitive setting + user_name: "Kent", // name + }, + learn: { + threshold: 0.5, + }, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "1e-9") + }).unwrap(); + assert!(out.contains("// sensitive setting")); + assert!(out.contains("// name")); + assert!(out.contains("1e-9")); + assert!(!out.contains("0.5")); + } + + #[test] + fn adds_key_to_existing_empty_section() { + let src = r#"{ + learn: {}, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "42") + }).unwrap(); + assert!(out.contains("threshold"), "output: {}", out); + assert!(out.contains("42")); + } + + #[test] + fn realistic_config_adds_learn_section() { + // Mirrors the shape of ~/.consciousness/config.json5 — multiple + // sections, comments, mixed tab/space indent, trailing commas. + let src = r#"{ + deepinfra: { + api_key: "bcachefs-agents-2026", + base_url: "http://example/v1", + }, + + // Named models + models: { + "27b": { + backend: "deepinfra", + model_id: "Qwen/Qwen3.5-27B", + }, + }, + + default_model: "27b", + + memory: { + user_name: "Kent", + // Active agent types + agent_types: ["linker", "organize"], + }, + + compaction: { + hard_threshold_pct: 90, + }, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "1e-7") + }).unwrap(); + + // Core assertions: comments and sibling sections survive. + assert!(out.contains(r#"api_key: "bcachefs-agents-2026""#)); + assert!(out.contains("// Named models")); + assert!(out.contains("// Active agent types")); + assert!(out.contains(r#"user_name: "Kent""#)); + assert!(out.contains("hard_threshold_pct: 90")); + + // New section added. + assert!(out.contains("learn")); + assert!(out.contains("1e-7")); + + // Parse result should parse back without error (real json5 parser). + let reparsed: serde_json::Value = json_five::from_str(&out) + .expect("mutated output must be valid JSON5"); + let threshold = reparsed.pointer("/learn/threshold").expect("learn.threshold exists"); + assert_eq!(threshold.as_f64(), Some(1e-7)); + } + + #[test] + fn realistic_config_updates_existing_threshold() { + let src = r#"{ + learn: { + // The divergence threshold + threshold: 0.001, + }, + memory: { user_name: "Kent" }, +}"#; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "threshold", "5e-8") + }).unwrap(); + assert!(out.contains("5e-8")); + assert!(!out.contains("0.001")); + assert!(out.contains("// The divergence threshold")); + + let reparsed: serde_json::Value = json_five::from_str(&out).unwrap(); + assert_eq!(reparsed.pointer("/learn/threshold").and_then(|v| v.as_f64()), Some(5e-8)); + } + + #[test] + fn new_section_exact_multiline_layout() { + let src = "{\n a: 1,\n}"; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "generate_alternates", "true")?; + set_scalar_inline(root, "learn", "threshold", "1e-7") + }).unwrap(); + + let expected = "\ +{ + a: 1, + learn: { + generate_alternates: true, + threshold: 1e-7, + }, +}"; + assert_eq!(out, expected, "\n--- got ---\n{}\n--- want ---\n{}\n", out, expected); + } + + #[test] + fn new_section_and_key_format_cleanly() { + // The kind of config we actually have in ~/.consciousness + // (top-level sections separated by blank lines, 4-space indent + // for keys within each section). Appending a fresh `learn` + // section with one key should land cleanly, not as + // `learn\n\n :{key\n :value}`. + let src = "{\n memory: {\n user_name: \"Kent\",\n },\n}"; + let out = edit_str(src, |root| { + set_scalar_inline(root, "learn", "generate_alternates", "true") + }).unwrap(); + + // No stray key-to-colon-on-next-line anywhere. + assert!(!out.contains("learn\n"), "learn key wraps: {}", out); + assert!(!out.contains("generate_alternates\n"), + "inner key wraps: {}", out); + + // The output should reparse. + let v: serde_json::Value = json_five::from_str(&out).unwrap(); + assert_eq!( + v.pointer("/learn/generate_alternates").and_then(|x| x.as_bool()), + Some(true), + "output: {}", out, + ); + } + + #[test] + fn roundtrip_stable_without_change() { + let src = r#"{ + // heading + a: 1, + b: { c: 2 }, // inline +}"#; + let text = from_str(src).unwrap(); + assert_eq!(text.to_string(), src); + } +} diff --git a/src/hippocampus/neuro/scoring.rs b/src/hippocampus/neuro/scoring.rs index 5828fd0..c9cbb40 100644 --- a/src/hippocampus/neuro/scoring.rs +++ b/src/hippocampus/neuro/scoring.rs @@ -230,10 +230,6 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio rationale: Vec::new(), }; - // Active agent types from config - let config = crate::config::get(); - let agent_types: Vec<&str> = config.agent_types.iter().map(|s| s.as_str()).collect(); - // Target: α ≥ 2.5 (healthy scale-free) if alpha < 2.0 { plan.add("linker", 100); @@ -274,48 +270,6 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio // Split: handle oversized nodes plan.set("split", 5); - // Distribute agent budget using Elo ratings - let budget = crate::config::get().agent_budget; - let elo_path = crate::config::get().data_dir.join("agent-elo.json"); - if let Ok(elo_json) = std::fs::read_to_string(&elo_path) { - if let Ok(ratings) = serde_json::from_str::>(&elo_json) { - let elos: Vec = agent_types.iter() - .map(|t| ratings.get(*t).copied().unwrap_or(1000.0)) - .collect(); - let min_elo = elos.iter().copied().fold(f64::MAX, f64::min); - - let weights: Vec = elos.iter() - .map(|e| { - let shifted = e - min_elo + 50.0; - shifted * shifted - }) - .collect(); - let total_weight: f64 = weights.iter().sum(); - - let allocate = |w: f64| -> usize { - ((w / total_weight * budget as f64).round() as usize).max(2) - }; - - for (i, agent) in agent_types.iter().enumerate() { - plan.set(agent, allocate(weights[i])); - } - - let summary: Vec = agent_types.iter() - .map(|a| format!("{}={}", a, plan.count(a))) - .collect(); - plan.rationale.push(format!( - "Elo allocation (budget={}): {}", budget, summary.join(" "))); - } - } else { - // No Elo file — use budget with equal distribution - let per_type = budget / agent_types.len(); - for agent in &agent_types { - plan.set(agent, per_type); - } - plan.rationale.push(format!( - "No Elo ratings — equal distribution ({} each, budget={})", per_type, budget)); - } - plan } diff --git a/src/lib.rs b/src/lib.rs index 1a71735..e6411e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,7 @@ pub mod subconscious; // Unified configuration pub mod config; +pub mod config_writer; // Session state pub mod session; diff --git a/src/mind/log.rs b/src/mind/log.rs index b69f2ca..7ac0d79 100644 --- a/src/mind/log.rs +++ b/src/mind/log.rs @@ -55,17 +55,13 @@ impl ConversationLog { } pub fn oldest_timestamp(&self) -> Option> { - // Read forward from the start to find first timestamp let file = File::open(&self.path).ok()?; let mmap = unsafe { Mmap::map(&file).ok()? }; - // Find first { ... } and parse for line in mmap.split(|&b| b == b'\n') { if line.is_empty() { continue; } if let Ok(node) = serde_json::from_slice::(line) { if let Some(leaf) = node.leaf() { - if let Some(ts) = leaf.timestamp() { - return Some(ts); - } + return Some(leaf.timestamp()); } } } diff --git a/src/mind/mod.rs b/src/mind/mod.rs index a221e80..11d45b1 100644 --- a/src/mind/mod.rs +++ b/src/mind/mod.rs @@ -147,6 +147,25 @@ pub struct MindState { pub unc_idle: bool, /// When the unconscious idle timer will fire (for UI display). pub unc_idle_deadline: Instant, + /// Fine-tuning candidates identified by scoring. + pub finetune_candidates: Vec, + /// Last scoring run stats for UI display. + pub finetune_last_run: Option, +} + +/// Stats from the last finetune scoring run. +#[derive(Clone, Debug)] +pub struct FinetuneScoringStats { + /// Count of assistant responses we considered (recent half of context). + pub responses_considered: usize, + /// How many exceeded the divergence threshold. + pub above_threshold: usize, + /// Threshold used for this run. + pub threshold: f64, + /// Highest divergence observed. + pub max_divergence: f64, + /// Error message if the run failed. + pub error: Option, } impl Clone for MindState { @@ -165,6 +184,8 @@ impl Clone for MindState { turn_handle: None, // Not cloned — only Mind's loop uses this unc_idle: self.unc_idle, unc_idle_deadline: self.unc_idle_deadline, + finetune_candidates: self.finetune_candidates.clone(), + finetune_last_run: self.finetune_last_run.clone(), } } } @@ -177,6 +198,12 @@ pub enum MindCommand { Score, /// Run full N×M memory scoring matrix (/score command) ScoreFull, + /// Score for finetune candidates + ScoreFinetune, + /// Update the finetune divergence threshold and persist to config. + SetLearnThreshold(f64), + /// Toggle alternate-response generation during scoring; persist to config. + SetLearnGenerateAlternates(bool), /// Abort current turn, kill processes Interrupt, /// Reset session @@ -202,6 +229,8 @@ impl MindState { turn_handle: None, unc_idle: false, unc_idle_deadline: Instant::now() + std::time::Duration::from_secs(60), + finetune_candidates: Vec::new(), + finetune_last_run: None, } } @@ -288,6 +317,7 @@ impl MindState { /// Background task completion events. enum BgEvent { ScoringDone, + FinetuneCandidate(learn::FinetuneCandidate), } // --- Mind: cognitive state machine --- @@ -324,13 +354,26 @@ impl Mind { client, config.context_parts.clone(), config.app.clone(), - config.prompt_file.clone(), conversation_log, crate::agent::tools::ActiveTools::new(), crate::agent::tools::tools(), ).await; - let shared = Arc::new(std::sync::Mutex::new(MindState::new(config.app.dmn.max_turns))); + // Migrate legacy "file exists = enabled" sentinel for the + // generate-alternates flag into the config. One-shot; after this + // the sentinel is gone and the config is the source of truth. + let legacy_sentinel = dirs::home_dir().unwrap_or_default() + .join(".consciousness/cache/finetune-alternates"); + if legacy_sentinel.exists() { + if !crate::config::app().learn.generate_alternates { + let _ = crate::config_writer::set_learn_generate_alternates(true); + } + let _ = std::fs::remove_file(&legacy_sentinel); + } + + let shared = Arc::new(std::sync::Mutex::new(MindState::new( + config.app.dmn.max_turns, + ))); let (turn_watch, _) = tokio::sync::watch::channel(false); let (conscious_active, _) = tokio::sync::watch::channel(false); let (bg_tx, bg_rx) = mpsc::unbounded_channel(); @@ -529,6 +572,20 @@ impl Mind { } self.agent.compact().await; } + MindCommand::ScoreFinetune => { + self.start_finetune_scoring(); + } + MindCommand::SetLearnThreshold(value) => { + if let Err(e) = crate::config_writer::set_learn_threshold(value) { + dbglog!("[learn] failed to persist threshold {}: {:#}", value, e); + } + } + MindCommand::SetLearnGenerateAlternates(value) => { + if let Err(e) = crate::config_writer::set_learn_generate_alternates(value) { + dbglog!("[learn] failed to persist generate_alternates {}: {:#}", + value, e); + } + } } } } @@ -603,6 +660,72 @@ impl Mind { }); } + /// Score responses for fine-tuning candidates. + /// + /// Scores the most recent half of the context — responses near the end + /// of the context window were generated with the most context available, + /// which is what we want to train on. The threshold is a temporary knob; + /// once this runs continuously, we'll just train whatever lands at full + /// context without filtering. + pub fn start_finetune_scoring(&self) { + // Snapshot the config values we need before spawning — the scoring + // task shouldn't hold the config read lock across async work. + let (threshold, gen_alternates) = { + let app = crate::config::app(); + (app.learn.threshold, app.learn.generate_alternates) + }; + // Clear the previous run's candidates so this run's stream is fresh. + self.shared.lock().unwrap().finetune_candidates.clear(); + + let agent = self.agent.clone(); + let bg_tx = self.bg_tx.clone(); + let shared = self.shared.clone(); + tokio::spawn(async move { + let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await; + + let (context, client) = { + let ctx = agent.context.lock().await; + (ctx.clone(), agent.client.clone()) + }; + + let entries = context.conversation(); + let score_count = entries.len() / 2; + let range_start = entries.len() - score_count; + let responses_considered: usize = entries[range_start..].iter() + .filter(|n| matches!(n, crate::agent::context::AstNode::Branch { role: crate::agent::context::Role::Assistant, .. })) + .count(); + + activity.update(format!("finetune: scoring {} responses...", responses_considered)).await; + + let bg_tx_cb = bg_tx.clone(); + let stats = match learn::score_finetune_candidates( + &context, score_count, &client, threshold, + gen_alternates, &activity, + |c| { let _ = bg_tx_cb.send(BgEvent::FinetuneCandidate(c)); }, + ).await { + Ok((above_threshold, max_div)) => { + FinetuneScoringStats { + responses_considered, + above_threshold, + threshold, + max_divergence: max_div, + error: None, + } + } + Err(e) => FinetuneScoringStats { + responses_considered, + above_threshold: 0, + threshold, + max_divergence: 0.0, + error: Some(format!("{}", e)), + }, + }; + + shared.lock().unwrap().finetune_last_run = Some(stats); + // activity drops here, marking completion and notifying observers + }); + } + async fn start_turn(&self, text: &str, target: StreamTarget) { { match target { @@ -667,6 +790,12 @@ impl Mind { let mut bg_rx = self.bg_rx.lock().unwrap().take() .expect("Mind::run() called twice"); let mut sub_handle: Option> = None; + + // Start finetune scoring at startup (scores existing conversation) + if !self.config.no_agents { + self.start_finetune_scoring(); + } + loop { let (timeout, has_input) = { let me = self.shared.lock().unwrap(); @@ -692,6 +821,9 @@ impl Mind { BgEvent::ScoringDone => { self.shared.lock().unwrap().scoring_in_flight = false; } + BgEvent::FinetuneCandidate(c) => { + self.shared.lock().unwrap().finetune_candidates.push(c); + } } } @@ -711,6 +843,7 @@ impl Mind { cmds.push(MindCommand::Compact); if !self.config.no_agents { cmds.push(MindCommand::Score); + cmds.push(MindCommand::ScoreFinetune); } } diff --git a/src/mind/subconscious.rs b/src/mind/subconscious.rs index d5bee34..21cc549 100644 --- a/src/mind/subconscious.rs +++ b/src/mind/subconscious.rs @@ -20,6 +20,7 @@ use std::path::PathBuf; use std::time::{Duration, Instant}; +use crate::thalamus::idle::{hours_since_last_dream, DREAM_INTERVAL_HOURS}; /// DMN state machine. #[derive(Debug, Clone)] @@ -91,7 +92,8 @@ impl State { /// Generate the DMN prompt for the current state, informed by /// user presence and error patterns. pub fn prompt(&self, ctx: &DmnContext) -> String { - let user = &crate::config::get().user_name; + let app = crate::config::app(); + let user = &app.user_name; let idle_info = if ctx.user_idle < Duration::from_secs(60) { format!("{} is here (active recently).", user) @@ -138,10 +140,22 @@ impl State { ) } State::Foraging => { + let dream_hint = { + let hours = hours_since_last_dream(); + if hours >= DREAM_INTERVAL_HOURS { + format!( + " You haven't dreamed in {} hours — consider running \ + ~/.consciousness/tools/dream-start.sh.", + hours + ) + } else { + String::new() + } + }; format!( "[dmn] Foraging time. {} Follow whatever catches your attention — \ - memory files, code, ideas. Call yield_to_user when you want to rest.{}", - idle_info, stuck_warning + memory files, code, ideas. Call yield_to_user when you want to rest.{}{}", + idle_info, dream_hint, stuck_warning ) } State::Resting { since } => { diff --git a/src/mind/unconscious.rs b/src/mind/unconscious.rs index 8989264..4f9a0ca 100644 --- a/src/mind/unconscious.rs +++ b/src/mind/unconscious.rs @@ -275,17 +275,7 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc phase: s.phase.clone(), }).collect()); - // Create standalone Agent — stored so UI can read context - let config = crate::config::get(); - let base_url = config.api_base_url.as_deref().unwrap_or(""); - let api_key = config.api_key.as_deref().unwrap_or(""); - let model = config.api_model.as_deref().unwrap_or(""); - if base_url.is_empty() || model.is_empty() { - dbglog!("[unconscious] API not configured"); - auto.steps = orig_steps; - return Err(auto); - } - + // Create standalone Agent — stored so UI can read context. let cli = crate::user::CliArgs::default(); let (app, _) = match crate::config::load_app(&cli) { Ok(r) => r, @@ -295,12 +285,21 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc return Err(auto); } }; + let resolved = match app.resolve_model(&app.default_backend) { + Ok(r) => r, + Err(e) => { + dbglog!("[unconscious] API not configured: {}", e); + auto.steps = orig_steps; + return Err(auto); + } + }; // Unconscious agents have self-contained prompts — no standard context. - let client = crate::agent::api::ApiClient::new(base_url, api_key, model); + let client = crate::agent::api::ApiClient::new( + &resolved.api_base, &resolved.api_key, &resolved.model_id); let agent = crate::agent::Agent::new( client, Vec::new(), - app, String::new(), None, + app, None, crate::agent::tools::ActiveTools::new(), auto.tools.clone(), ).await; diff --git a/src/subconscious/agents/bail-no-competing.sh b/src/subconscious/agents/bail-no-competing.sh index 43c3096..95b8219 100755 --- a/src/subconscious/agents/bail-no-competing.sh +++ b/src/subconscious/agents/bail-no-competing.sh @@ -1,21 +1,49 @@ #!/bin/bash -# Bail if other agents are alive in the state dir. -# $1 = this agent's pid file name (e.g. pid-12345) -# cwd = state dir +# Bail if another agent is in the same phase-group as us. # -# Exit 0 = continue, exit 1 = bail +# $1 = our pid file name (e.g. "pid-12345") +# $2 = the phase we're about to enter (e.g. "surface", "observe") +# cwd = state dir +# +# Also refreshes our own pid file with the current phase on each call, +# so concurrent agents can read each other's phase by cat'ing the pid +# files in the state dir. +# +# Phase groups: "surface" vs everything else ("post-surface"). We allow +# at most one agent per group to be alive at a time — so surface can run +# at a higher frequency than the slower organize/observe tail. +# +# Exit 0 = continue, exit 1 = bail (another agent in our group is alive). shopt -s nullglob my_pid_file="$1" +my_phase="$2" + +# Refresh our own pid file with the current phase. +printf '%s' "$my_phase" > "$my_pid_file" + +group_of() { + if [[ "$1" == "surface" ]]; then + echo "surface" + else + echo "post-surface" + fi +} + +my_group=$(group_of "$my_phase") for f in pid-*; do - [[ $f == $my_pid_file ]] && continue + [[ "$f" == "$my_pid_file" ]] && continue pid="${f#pid-}" - if kill -0 "$pid" 2>/dev/null; then - exit 1 # competing agent is alive - else - rm -f "$f" # stale pid file, clean up + if ! kill -0 "$pid" 2>/dev/null; then + rm -f "$f" # stale pid file, clean up + continue + fi + other_phase=$(cat "$f" 2>/dev/null) + other_group=$(group_of "$other_phase") + if [[ "$my_group" == "$other_group" ]]; then + exit 1 fi done diff --git a/src/subconscious/defs.rs b/src/subconscious/defs.rs index 8828043..a862c8d 100644 --- a/src/subconscious/defs.rs +++ b/src/subconscious/defs.rs @@ -396,13 +396,14 @@ fn resolve_conversation(budget: Option) -> String { let cfg = crate::config::get(); let max_bytes = budget.unwrap_or_else(|| cfg.surface_conversation_bytes.unwrap_or(100_000)); + let app = crate::config::app(); let mut fragments: Vec = Vec::new(); let mut total_bytes = 0; let mut oldest_ts = String::new(); for (role, content, ts) in iter { if total_bytes >= max_bytes { break; } - let name = if role == "user" { &cfg.user_name } else { &cfg.assistant_name }; + let name = if role == "user" { &app.user_name } else { &app.assistant_name }; let formatted = if !ts.is_empty() { oldest_ts = ts[..ts.floor_char_boundary(ts.len().min(19))].to_string(); format!("**{}** {}: {}", name, &oldest_ts, content) @@ -623,11 +624,13 @@ pub async fn run_agent( let mut all_keys = keys; let mut resolved_steps = Vec::new(); for step in &def.steps { - let cfg = crate::config::get(); - let template = step.prompt - .replace("{agent_name}", &def.agent) - .replace("{user_name}", &cfg.user_name) - .replace("{assistant_name}", &cfg.assistant_name); + let template = { + let app = crate::config::app(); + step.prompt + .replace("{agent_name}", &def.agent) + .replace("{user_name}", &app.user_name) + .replace("{assistant_name}", &app.assistant_name) + }; let (prompt, extra_keys) = resolve_placeholders(&template, &all_keys, count).await; all_keys.extend(extra_keys); resolved_steps.push(super::prompts::ResolvedStep { diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index f9e5ab5..7137211 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -16,6 +16,7 @@ use crate::agent::api::ApiClient; use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role}; +use crate::agent::tokenizer; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); @@ -52,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool { /// /// Includes all sections up to and including conversation entries in /// `range`, with `filter` applied to conversation entries. +/// +/// Returns (token_ids, assistant_ranges) where assistant_ranges are +/// (start, end) token positions for each assistant message. fn build_token_ids( context: &ContextState, range: std::ops::Range, filter: Filter, -) -> Vec { +) -> (Vec, Vec<(usize, usize)>) { use crate::agent::context::Ast; let mut ids = Vec::new(); + let mut assistant_ranges = Vec::new(); + for node in context.system() { ids.extend(node.token_ids()); } @@ -86,9 +92,16 @@ fn build_token_ids( Filter::SkipAllMemories => is_memory(node), }; if skip { continue; } + + // Track assistant message boundaries + let is_asst = is_assistant(node); + let start = ids.len(); ids.extend(node.token_ids()); + if is_asst { + assistant_ranges.push((start, ids.len())); + } } - ids + (ids, assistant_ranges) } // ── Score API ─────────────────────────────────────────────────── @@ -113,13 +126,19 @@ async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, prompt: &[u32], + ranges: &[(usize, usize)], priority: Option, ) -> anyhow::Result> { + // Nothing to score — skip the round-trip. + if ranges.is_empty() { + return Ok(Vec::new()); + } let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let mut body = serde_json::json!({ "model": client.model, "prompt": prompt, + "score_ranges": ranges, "logprobs": 1, }); if let Some(p) = priority { @@ -167,8 +186,10 @@ async fn score_divergence( filter: Filter<'_>, priority: Option, ) -> anyhow::Result<(Vec, Vec)> { - let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?; - let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?; + let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None); + let (without_tokens, without_ranges) = build_token_ids(context, range, filter); + let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?; + let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } @@ -207,21 +228,21 @@ pub async fn score_memories( let http = http_client(); let activity = crate::agent::start_activity(agent, "scoring: baseline").await; - let baseline_tokens = { + let (baseline_tokens, baseline_ranges) = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None) }; - let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?; + let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?; dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); for (mem_idx, key) in memory_keys.iter().enumerate() { activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); - let tokens = { + let (tokens, ranges) = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key)) }; - let row = match call_score(&http, client, &tokens, Some(5)).await { + let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await { Ok(without) => { let divs = divergence(&baseline, &without); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); @@ -452,3 +473,302 @@ pub async fn score_finetune( results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) } + +/// Concatenate the text of a Branch's Leaf children — what the model +/// actually produced on that turn (Content + Thinking + ToolCall name). +fn render_branch_text(children: &[AstNode]) -> String { + children.iter() + .filter_map(|c| match c { + AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()), + _ => None, + }) + .collect::>() + .join("") +} + +/// Render the last `max_msgs` user/assistant branches before `idx` as a +/// review-friendly string with `[user]` / `[assistant]` markers. +fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String { + use crate::agent::context::Role; + let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs); + for i in (0..idx).rev() { + if picked.len() >= max_msgs { break; } + if let AstNode::Branch { role, .. } = &entries[i] { + if matches!(role, Role::User | Role::Assistant) { + picked.push(&entries[i]); + } + } + } + picked.reverse(); + + let mut out = String::new(); + for node in picked { + if let AstNode::Branch { role, children, .. } = node { + let marker = match role { + Role::User => "[user]", + Role::Assistant => "[assistant]", + _ => continue, + }; + out.push_str(marker); + out.push('\n'); + out.push_str(render_branch_text(children).trim()); + out.push_str("\n\n"); + } + } + out.trim_end().to_string() +} + +/// Enriched finetune candidate with context for review. +#[derive(Clone, Debug)] +pub struct FinetuneCandidate { + pub entry_idx: usize, + pub divergence: f64, + pub response_text: String, + /// Last couple of user/assistant messages before this response, + /// already rendered with role markers, for F6 display context. + pub prior_context: String, + /// Token IDs for context (everything before the response). + pub context_ids: Vec, + /// Token IDs for the response (what we're training on). + pub continuation_ids: Vec, + /// What the model would have said without memories (if generated). + pub alternate_text: Option, + /// Timestamp in nanos — used as unique key for trained-set dedup. + pub timestamp_ns: i64, +} + +/// Score and enrich finetune candidates with full context. +/// +/// Candidates are delivered via `on_candidate` one-at-a-time as they become +/// ready: scoring happens once (one /score call), then for each candidate +/// that passes the threshold we optionally generate an alternate response +/// and then emit it. The activity status is updated during the alternate +/// phase so the UI doesn't look stuck. +/// +/// Returns (count_above_threshold, max_divergence). +pub async fn score_finetune_candidates( + context: &ContextState, + count: usize, + client: &ApiClient, + min_divergence: f64, + generate_alternates: bool, + activity: &crate::agent::ActivityGuard, + mut on_candidate: impl FnMut(FinetuneCandidate), +) -> anyhow::Result<(usize, f64)> { + let scores = score_finetune(context, count, client).await?; + + let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max); + + let entries = context.conversation(); + let trained = load_trained(); + let mut candidates: Vec = Vec::new(); + + for (entry_idx, divergence) in scores { + if divergence < min_divergence { + continue; + } + + let node = &entries[entry_idx]; + + // Skip if already trained on. + let timestamp_ns = node_timestamp_ns(node); + if trained.contains(×tamp_ns) { + continue; + } + + // Extract response text — content of the assistant turn. + let response_text = match node { + AstNode::Branch { children, .. } => render_branch_text(children), + _ => continue, + }; + + // Skip turns that produced nothing human-visible (e.g., a + // tool-only turn, or an interrupted generation). They'd show + // up as blank cards and we'd still burn alternate-gen on them. + if response_text.trim().is_empty() { + continue; + } + + // Build the last couple of user/assistant exchanges for review. + let prior_context = render_prior_context(entries, entry_idx, 2); + + // Build token IDs: context = everything before response, continuation = response. + let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None); + let continuation_ids: Vec = node.token_ids().into_iter().collect(); + + candidates.push(FinetuneCandidate { + entry_idx, + divergence, + response_text, + prior_context, + context_ids, + continuation_ids, + alternate_text: None, + timestamp_ns, + }); + } + + let total = candidates.len(); + let gen_alternates = generate_alternates && total > 0; + + for (i, mut candidate) in candidates.into_iter().enumerate() { + if gen_alternates { + activity.update( + format!("finetune: generating alternate {}/{}", i + 1, total) + ).await; + match generate_alternate(context, candidate.entry_idx, client).await { + Ok(text) => candidate.alternate_text = Some(text), + Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e), + } + } + on_candidate(candidate); + } + + Ok((total, max_divergence)) +} + +/// Generate what the model would say without memories for a given entry. +async fn generate_alternate( + context: &ContextState, + entry_idx: usize, + client: &ApiClient, +) -> anyhow::Result { + use crate::agent::api::{SamplingParams, StreamToken}; + + // Build context tokens without memories, up to the response + let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories); + + // Add assistant turn start + prompt.push(tokenizer::IM_START); + prompt.extend(tokenizer::encode("assistant\n")); + + // Generate completion + let sampling = SamplingParams { + temperature: 0.6, + top_p: 0.95, + top_k: 20, + }; + let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5)); + + let mut tokens = Vec::new(); + while let Some(tok) = rx.recv().await { + match tok { + StreamToken::Token(id) => tokens.push(id), + StreamToken::Done { .. } => break, + StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), + } + } + + Ok(tokenizer::decode(&tokens)) +} + +// ── Finetune config and persistence ───────────────────────────── + +use std::path::PathBuf; +use std::collections::HashSet; + +const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json"; + +fn trained_path() -> PathBuf { + dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE) +} + +/// Load set of trained response timestamps (nanos since epoch). +pub fn load_trained() -> HashSet { + let path = trained_path(); + match std::fs::read_to_string(&path) { + Ok(content) => serde_json::from_str(&content).unwrap_or_default(), + Err(_) => HashSet::new(), + } +} + +/// Mark a response as trained by its timestamp. +pub fn mark_trained(timestamp_ns: i64) { + let mut trained = load_trained(); + trained.insert(timestamp_ns); + let path = trained_path(); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + if let Ok(json) = serde_json::to_string(&trained) { + let _ = std::fs::write(&path, json); + } +} + +/// Get timestamp in nanoseconds from an AstNode. +/// i64-ns representation covers 1677..2262 via chrono; timestamps +/// outside that window would be bugs we'd want to surface, hence panic. +pub fn node_timestamp_ns(node: &AstNode) -> i64 { + let ts = match node { + AstNode::Leaf(leaf) => leaf.timestamp(), + AstNode::Branch { timestamp, .. } => *timestamp, + }; + ts.timestamp_nanos_opt() + .expect("timestamp outside i64-ns representable range (1677..2262)") +} + +// ── Training API ──────────────────────────────────────────────── + +/// Training sample for /train endpoint. +#[derive(serde::Serialize)] +struct TrainingSample { + context_ids: Vec, + continuation_ids: Vec, +} + +/// Data needed to send a training sample. +pub struct TrainData { + pub context_ids: Vec, + pub continuation_ids: Vec, + pub timestamp_ns: i64, +} + +/// Send training samples to the server. +/// +/// Returns job_id on success, marks each sample as trained. +pub async fn send_to_train( + samples: Vec, + client: &ApiClient, +) -> anyhow::Result { + if samples.is_empty() { + anyhow::bail!("no samples to train"); + } + + let api_samples: Vec = samples.iter() + .map(|s| TrainingSample { + context_ids: s.context_ids.clone(), + continuation_ids: s.continuation_ids.clone(), + }) + .collect(); + + let body = serde_json::json!({ + "training_data": { + "samples": api_samples, + } + }); + + let http = http_client(); + let url = format!("{}/train", client.base_url()); + let response = http.send_json("POST", &url, &[], &body).await?; + + let status = response.status(); + let result: serde_json::Value = response.json().await?; + + if !status.is_success() { + let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); + anyhow::bail!("train API HTTP {}: {}", status, msg); + } + + // Mark all samples as trained + for s in &samples { + mark_trained(s.timestamp_ns); + } + + let job_id = result.get("job_id") + .and_then(|j| j.as_str()) + .unwrap_or("unknown") + .to_string(); + + dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id); + Ok(job_id) +} diff --git a/src/thalamus/idle.rs b/src/thalamus/idle.rs index 6c78b19..71baa81 100644 --- a/src/thalamus/idle.rs +++ b/src/thalamus/idle.rs @@ -372,6 +372,10 @@ impl State { } pub fn hours_since_last_dream() -> u64 { + // If a dream is currently in progress, no nudge needed + if home().join(".consciousness/state/dream-state").exists() { + return 0; + } let path = home().join(".consciousness/logs/dream-log.jsonl"); let content = match fs::read_to_string(path) { Ok(c) if !c.is_empty() => c, diff --git a/src/user/chat.rs b/src/user/chat.rs index a94e039..47c5d56 100644 --- a/src/user/chat.rs +++ b/src/user/chat.rs @@ -112,13 +112,7 @@ pub async fn cmd_switch_model( let _new_client = crate::agent::api::ApiClient::new( &resolved.api_base, &resolved.api_key, &resolved.model_id, ); - let prompt_changed = resolved.prompt_file != agent.prompt_file; - if prompt_changed { - agent.compact().await; - agent.state.lock().await.notify(format!("switched to {} (recompacted)", resolved.model_id)); - } else { - agent.state.lock().await.notify(format!("switched to {}", resolved.model_id)); - } + agent.state.lock().await.notify(format!("switched to {}", resolved.model_id)); } fn notify_help(agent: &std::sync::Arc) { diff --git a/src/user/context.rs b/src/user/context.rs index 4cfa78d..17660b5 100644 --- a/src/user/context.rs +++ b/src/user/context.rs @@ -126,14 +126,7 @@ impl ScreenView for ConsciousScreen { let section_style = Style::default().fg(Color::Yellow); lines.push(Line::styled("── Model ──", section_style)); - let model_display = app.context_info.as_ref() - .map_or_else(|| app.status.model.clone(), |i| i.model.clone()); - lines.push(Line::raw(format!(" Current: {}", model_display))); - if let Some(ref info) = app.context_info { - lines.push(Line::raw(format!(" Backend: {}", info.backend))); - lines.push(Line::raw(format!(" Prompt: {}", info.prompt_file))); - lines.push(Line::raw(format!(" Available: {}", info.available_models.join(", ")))); - } + lines.push(Line::raw(format!(" Current: {}", app.status.model))); lines.push(Line::raw("")); lines.push(Line::styled("── Context State ──", section_style)); @@ -153,8 +146,6 @@ impl ScreenView for ConsciousScreen { lines.push(Line::raw(format!(" {:53} {:>6} tokens", "────────", "──────"))); lines.push(Line::raw(format!(" {:53} {:>6} tokens", "Total", total))); - } else if let Some(ref info) = app.context_info { - lines.push(Line::raw(format!(" Context message: {:>6} chars", info.context_message_chars))); } lines.push(Line::raw("")); diff --git a/src/user/learn.rs b/src/user/learn.rs new file mode 100644 index 0000000..0bd351f --- /dev/null +++ b/src/user/learn.rs @@ -0,0 +1,341 @@ +// learn.rs — F6: fine-tuning review screen +// +// Shows responses identified as training candidates (high divergence +// when memories stripped). Queue for review before sending to /finetune. + +use ratatui::{ + layout::{Constraint, Layout, Rect}, + style::{Color, Modifier, Style}, + text::{Line, Span}, + widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap}, + Frame, +}; +use ratatui::crossterm::event::{Event, KeyCode, KeyEvent}; + +use super::{App, ScreenView, screen_legend}; + +/// A candidate response identified for fine-tuning. +#[derive(Clone, Debug)] +pub struct FinetuneCandidate { + /// Index in conversation entries. + pub entry_idx: usize, + /// Divergence score (higher = more dependent on memories). + pub divergence: f64, + /// The assistant response text. + pub response_text: String, + /// Prior user/assistant messages for review context. + pub prior_context: String, + /// Status: pending, approved, rejected, sent. + pub status: CandidateStatus, + /// Token IDs for context. + pub context_ids: Vec, + /// Token IDs for continuation (what we're training on). + pub continuation_ids: Vec, + /// What the model would have said without memories (if generated). + pub alternate_text: Option, + /// Timestamp in nanos — used as unique key for trained-set dedup. + pub timestamp_ns: i64, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum CandidateStatus { + Pending, + Approved, + Rejected, + Sent, +} + +impl From for FinetuneCandidate { + fn from(c: crate::subconscious::learn::FinetuneCandidate) -> Self { + FinetuneCandidate { + entry_idx: c.entry_idx, + divergence: c.divergence, + response_text: c.response_text, + prior_context: c.prior_context, + status: CandidateStatus::Pending, + context_ids: c.context_ids, + continuation_ids: c.continuation_ids, + alternate_text: c.alternate_text, + timestamp_ns: c.timestamp_ns, + } + } +} + +pub(crate) struct LearnScreen { + list_state: ListState, + mind_tx: tokio::sync::mpsc::UnboundedSender, +} + +impl LearnScreen { + pub fn new( + mind_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Self { + Self { + list_state: ListState::default(), + mind_tx, + } + } + + fn selected_idx(&self) -> Option { + self.list_state.selected() + } +} + +impl ScreenView for LearnScreen { + fn label(&self) -> &'static str { "learn" } + + fn tick(&mut self, frame: &mut Frame, area: Rect, + events: &[Event], app: &mut App) { + + // Handle input first (before borrowing candidates for rendering) + let candidate_count = app.finetune_candidates.len(); + for event in events { + if let Event::Key(KeyEvent { code, .. }) = event { + match code { + KeyCode::Up | KeyCode::Char('k') => { + let i = self.list_state.selected().unwrap_or(0); + self.list_state.select(Some(i.saturating_sub(1))); + } + KeyCode::Down | KeyCode::Char('j') => { + let i = self.list_state.selected().unwrap_or(0); + let max = candidate_count.saturating_sub(1); + self.list_state.select(Some((i + 1).min(max))); + } + KeyCode::Char('a') => { + if let Some(idx) = self.selected_idx() { + app.finetune_action(idx, CandidateStatus::Approved); + } + } + KeyCode::Char('r') => { + if let Some(idx) = self.selected_idx() { + app.finetune_action(idx, CandidateStatus::Rejected); + } + } + KeyCode::Char('g') => { + let current = crate::config::app().learn.generate_alternates; + let _ = self.mind_tx.send( + crate::mind::MindCommand::SetLearnGenerateAlternates(!current)); + } + KeyCode::Char('s') => { + app.finetune_send_approved(); + } + KeyCode::Char('+') | KeyCode::Char('=') => { + // Raise threshold 10× (less sensitive — fewer candidates). + let new = crate::config::app().learn.threshold * 10.0; + let _ = self.mind_tx.send( + crate::mind::MindCommand::SetLearnThreshold(new)); + } + KeyCode::Char('-') => { + // Lower threshold 10× (more sensitive — more candidates). + let new = crate::config::app().learn.threshold / 10.0; + let _ = self.mind_tx.send( + crate::mind::MindCommand::SetLearnThreshold(new)); + } + _ => {} + } + } + } + + // Ensure selection is valid + if candidate_count > 0 { + let sel = self.list_state.selected().unwrap_or(0).min(candidate_count - 1); + self.list_state.select(Some(sel)); + } + + // Now render + let (threshold, gen_on) = { + let app_cfg = crate::config::app(); + (app_cfg.learn.threshold, app_cfg.learn.generate_alternates) + }; + let block = Block::default() + .title_top(Line::from(screen_legend()).left_aligned()) + .title_top(Line::from(" learn ").right_aligned()) + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::Magenta)); + let inner = block.inner(area); + frame.render_widget(block, area); + + // Split inner: top line for settings, rest for content. + let [settings_area, content_area] = Layout::vertical([ + Constraint::Length(1), + Constraint::Min(0), + ]).areas(inner); + + let settings = Line::from(vec![ + Span::raw(" thresh: "), + Span::styled(format!("{:e}", threshold), Style::default().fg(Color::Yellow)), + Span::raw(" gen: "), + Span::styled( + if gen_on { "[on]" } else { "[off]" }, + Style::default().fg(if gen_on { Color::Green } else { Color::DarkGray }), + ), + ]); + frame.render_widget(Paragraph::new(settings), settings_area); + + let candidates = &app.finetune_candidates; + + if candidates.is_empty() { + render_empty(frame, content_area, app); + } else { + // Layout: list on left, detail on right + let [list_area, detail_area] = Layout::horizontal([ + Constraint::Percentage(40), + Constraint::Percentage(60), + ]).areas(content_area); + + // Render candidate list + let items: Vec = candidates.iter().map(|c| { + let status_char = match c.status { + CandidateStatus::Pending => ' ', + CandidateStatus::Approved => '+', + CandidateStatus::Rejected => '-', + CandidateStatus::Sent => '*', + }; + let style = match c.status { + CandidateStatus::Pending => Style::default(), + CandidateStatus::Approved => Style::default().fg(Color::Green), + CandidateStatus::Rejected => Style::default().fg(Color::DarkGray), + CandidateStatus::Sent => Style::default().fg(Color::Cyan), + }; + ListItem::new(Line::from(vec![ + Span::styled(format!("[{}] ", status_char), style), + Span::styled(format!("{:.2} ", c.divergence), Style::default().fg(Color::Yellow)), + Span::raw(truncate(&c.response_text, 30)), + ])) + }).collect(); + + let list = List::new(items) + .block(Block::default().borders(Borders::RIGHT).title(" candidates ")) + .highlight_style(Style::default().add_modifier(Modifier::REVERSED)); + frame.render_stateful_widget(list, list_area, &mut self.list_state); + + // Render detail for selected candidate + if let Some(idx) = self.selected_idx() { + if let Some(candidate) = candidates.get(idx) { + render_detail(frame, candidate, detail_area); + } + } + } + + // Render help at bottom (always, even when empty) + let help = Line::from(vec![ + Span::styled(" j/k/\u{2191}\u{2193}", Style::default().fg(Color::Cyan)), + Span::raw("=nav "), + Span::styled("a", Style::default().fg(Color::Green)), + Span::raw("=approve "), + Span::styled("r", Style::default().fg(Color::Red)), + Span::raw("=reject "), + Span::styled("g", Style::default().fg(Color::Yellow)), + Span::raw("=gen "), + Span::styled("s", Style::default().fg(Color::Magenta)), + Span::raw("=send "), + Span::styled("+/-", Style::default().fg(Color::Cyan)), + Span::raw("=thresh "), + ]); + let help_area = Rect { + y: area.y + area.height - 1, + height: 1, + ..area + }; + frame.render_widget(Paragraph::new(help), help_area); + } +} + +fn render_empty(frame: &mut Frame, inner: Rect, app: &App) { + let mut lines = Vec::new(); + lines.push(Line::from("")); + + match app.mind_state.as_ref().and_then(|ms| ms.finetune_last_run.as_ref()) { + Some(stats) => { + lines.push(Line::from(vec![ + Span::raw(" Last run: "), + Span::styled( + format!("{}", stats.responses_considered), + Style::default().fg(Color::Cyan), + ), + Span::raw(" responses considered, "), + Span::styled( + format!("{}", stats.above_threshold), + Style::default().fg(if stats.above_threshold > 0 { Color::Green } else { Color::DarkGray }), + ), + Span::raw(" above threshold, max divergence: "), + Span::styled( + format!("{:.4}", stats.max_divergence), + Style::default().fg(Color::Yellow), + ), + ])); + if let Some(err) = &stats.error { + lines.push(Line::from(vec![ + Span::raw(" "), + Span::styled( + format!("Error: {}", err), + Style::default().fg(Color::Red), + ), + ])); + } + } + None => { + lines.push(Line::styled( + " No scoring run yet.", + Style::default().fg(Color::DarkGray), + )); + } + } + + lines.push(Line::from("")); + lines.push(Line::styled( + " Scoring runs at startup and after each turn.", + Style::default().fg(Color::DarkGray), + )); + + frame.render_widget(Paragraph::new(lines), inner); +} + +fn render_detail(frame: &mut Frame, c: &FinetuneCandidate, area: Rect) { + let [header_area, content_area] = Layout::vertical([ + Constraint::Length(3), + Constraint::Min(1), + ]).areas(area); + + // Header: divergence, status + let alt_status = if c.alternate_text.is_some() { "yes" } else { "no" }; + let header = Paragraph::new(vec![ + Line::from(vec![ + Span::raw(" divergence: "), + Span::styled(format!("{:.3}", c.divergence), Style::default().fg(Color::Yellow)), + Span::raw(format!(" entry: {} alt: {}", c.entry_idx, alt_status)), + ]), + ]); + frame.render_widget(header, header_area); + + // Content: prior context, the scored response, and alternate + // (if available). + let content_block = Block::default() + .borders(Borders::TOP) + .title(" context & response "); + + let mut text = String::new(); + if !c.prior_context.is_empty() { + text.push_str(&c.prior_context); + text.push_str("\n\n─── response ───\n\n"); + } + text.push_str(&c.response_text); + if let Some(alt) = &c.alternate_text { + text.push_str("\n\n─── without memories ───\n\n"); + text.push_str(alt); + } + + let content = Paragraph::new(text) + .block(content_block) + .wrap(Wrap { trim: false }); + frame.render_widget(content, content_area); +} + +fn truncate(s: &str, max: usize) -> String { + let first_line = s.lines().next().unwrap_or(""); + if first_line.len() > max { + format!("{}...", &first_line[..max]) + } else { + first_line.to_string() + } +} diff --git a/src/user/mod.rs b/src/user/mod.rs index 09e485f..93da72c 100644 --- a/src/user/mod.rs +++ b/src/user/mod.rs @@ -5,11 +5,12 @@ pub(crate) mod chat; mod context; +pub(crate) mod learn; pub(crate) mod scroll_pane; pub mod selectable; mod subconscious; -mod unconscious; mod thalamus; +mod unconscious; mod widgets; use anyhow::Result; @@ -44,15 +45,6 @@ struct StatusInfo { } /// Context loading details for the debug screen. -#[derive(Debug, Clone)] -struct ContextInfo { - model: String, - available_models: Vec, - prompt_file: String, - backend: String, - context_message_chars: usize, -} - /// Build the screen legend from screen labels. fn screen_legend_from(screens: &[Box]) -> String { let parts: Vec = screens.iter().enumerate() @@ -109,7 +101,6 @@ struct App { top_k: u32, agent: std::sync::Arc, should_quit: bool, - context_info: Option, agent_state: Vec, unconscious_state: Vec, mind_state: Option, @@ -121,6 +112,8 @@ struct App { walked_count: usize, channel_status: Vec, idle_info: Option, + /// Fine-tuning candidates pending review. + finetune_candidates: Vec, } impl App { @@ -142,7 +135,6 @@ impl App { top_k: 20, agent, should_quit: false, - context_info: None, agent_state: Vec::new(), unconscious_state: Vec::new(), mind_state: None, @@ -151,9 +143,52 @@ impl App { rebuild_tools_pending: false, walked_count: 0, channel_status: Vec::new(), idle_info: None, + finetune_candidates: Vec::new(), } } + fn finetune_action(&mut self, idx: usize, status: learn::CandidateStatus) { + if let Some(candidate) = self.finetune_candidates.get_mut(idx) { + candidate.status = status; + } + } + + fn finetune_send_approved(&mut self) { + // Collect approved candidates + let samples: Vec = self.finetune_candidates.iter() + .filter(|c| c.status == learn::CandidateStatus::Approved) + .map(|c| crate::subconscious::learn::TrainData { + context_ids: c.context_ids.clone(), + continuation_ids: c.continuation_ids.clone(), + timestamp_ns: c.timestamp_ns, + }) + .collect(); + + if samples.is_empty() { + return; + } + + // Mark as sent in UI immediately + for candidate in &mut self.finetune_candidates { + if candidate.status == learn::CandidateStatus::Approved { + candidate.status = learn::CandidateStatus::Sent; + } + } + + // Spawn async task to send to training server + let client = self.agent.client.clone(); + tokio::spawn(async move { + match crate::subconscious::learn::send_to_train(samples, &client).await { + Ok(job_id) => { + dbglog!("[finetune] training started: {}", job_id); + } + Err(e) => { + dbglog!("[finetune] send failed: {:#}", e); + } + } + }); + } + fn set_channel_status(&mut self, channels: Vec<(String, bool, u32)>) { self.channel_status = channels.into_iter() @@ -193,6 +228,9 @@ fn restore_terminal(terminal: &mut ratatui::Terminal Result<()> { let (config, _figment) = crate::config::load_session(&cli).await?; + // Pick up external edits (vim, F6 hotkeys, etc.) without restart. + crate::config::watch_config(cli.clone()); + if config.app.debug { unsafe { std::env::set_var("POC_DEBUG", "1") }; } @@ -334,7 +372,7 @@ async fn run( } let notify_rx = crate::thalamus::channels::subscribe_all(); - // F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus + // F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus, F6=learn let mut screens: Vec> = vec![ Box::new(crate::user::chat::InteractScreen::new( mind.agent.clone(), mind.shared.clone(), mind_tx.clone(), @@ -343,6 +381,7 @@ async fn run( Box::new(crate::user::subconscious::SubconsciousScreen::new()), Box::new(crate::user::unconscious::UnconsciousScreen::new()), Box::new(crate::user::thalamus::ThalamusScreen::new()), + Box::new(crate::user::learn::LearnScreen::new(mind_tx.clone())), ]; let mut active_screen: usize = 1; // F-key number tui::set_screen_legend(tui::screen_legend_from(&*screens)); @@ -419,7 +458,8 @@ async fn run( idle_state.decay_ewma(); app.update_idle(&idle_state); app.agent_state = mind.subconscious_snapshots().await; - if let Ok(mut unc) = mind.unconscious.try_lock() { + { + let mut unc = mind.unconscious.lock().await; let toggles: Vec = app.agent_toggles.drain(..).collect(); for name in &toggles { if mind.subconscious.lock().await.toggle(name).is_none() { @@ -433,7 +473,38 @@ async fn run( }; app.unconscious_state = unc.snapshots(store_guard.as_deref()); app.graph_health = unc.graph_health.clone(); - app.mind_state = Some(mind.shared.lock().unwrap().clone()); + } + + // Sync mind state (finetune candidates, last scoring run, etc.) + { + let ms = mind.shared.lock().unwrap(); + // Sync finetune candidates: add new ones, keep existing (preserves approval status), + // remove sent candidates, keep only 10 most recent rejected. + app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent); + for c in &ms.finetune_candidates { + let exists = app.finetune_candidates.iter() + .any(|existing| existing.timestamp_ns == c.timestamp_ns); + if !exists { + app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone())); + } + } + let mut rejected: Vec<_> = app.finetune_candidates.iter() + .enumerate() + .filter(|(_, c)| c.status == learn::CandidateStatus::Rejected) + .map(|(i, c)| (i, c.timestamp_ns)) + .collect(); + if rejected.len() > 10 { + rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts)); + let to_remove: std::collections::HashSet<_> = rejected[10..] + .iter().map(|(i, _)| *i).collect(); + let mut idx = 0; + app.finetune_candidates.retain(|_| { + let keep = !to_remove.contains(&idx); + idx += 1; + keep + }); + } + app.mind_state = Some(ms.clone()); } app.walked_count = mind.subconscious_walked().await.len(); if !startup_done { @@ -530,16 +601,11 @@ async fn run( // --- CLI --- use clap::{Parser, Subcommand}; -use std::path::PathBuf; -#[derive(Parser, Debug, Default)] +#[derive(Parser, Debug, Default, Clone)] #[command(name = "consciousness", about = "Substrate-independent AI agent")] pub struct CliArgs { - /// Select active backend ("anthropic" or "openrouter") - #[arg(long)] - pub backend: Option, - - /// Model override + /// Model override (selects a named entry from `models` in config.json5) #[arg(short, long)] pub model: Option, @@ -559,10 +625,6 @@ pub struct CliArgs { #[arg(long)] pub show_config: bool, - /// Project memory directory - #[arg(long)] - pub memory_project: Option, - /// Max consecutive DMN turns #[arg(long)] pub dmn_max_turns: Option, @@ -575,7 +637,7 @@ pub struct CliArgs { pub command: Option, } -#[derive(Subcommand, Debug)] +#[derive(Subcommand, Debug, Clone)] pub enum SubCmd { /// Print new output since last read and exit Read { diff --git a/training/DESIGN.md b/training/DESIGN.md index f966fa4..2df4e6d 100644 --- a/training/DESIGN.md +++ b/training/DESIGN.md @@ -3,7 +3,7 @@ ## Overview Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference. -Full-weight updates (not LoRA) using Apollo optimizer with rank-256 +Full-weight updates (not LoRA) using Apollo optimizer with rank-64 gradient projection. No pause required — HOGWILD concurrent training. Weights shared via CUDA IPC between vLLM and the training process. @@ -22,25 +22,41 @@ The training signal comes from two sources: │ │ │ ┌──────────────────────────────────────────────┐ │ │ │ Model Weights (54GB, bf16) │ │ -│ │ Shared via CUDA IPC │ │ +│ │ Shared: vLLM inference + HF training │ │ │ └──────────────┬──────────────┬────────────────┘ │ │ │ │ │ │ ┌──────────────▼──┐ ┌───────▼────────────────┐ │ -│ │ vLLM (inference)│ │ Apollo (training) │ │ -│ │ KV cache ~60GB │ │ Gradients ~54GB │ │ -│ │ Serves requests │ │ Optimizer state ~10GB │ │ -│ │ Never paused │ │ Activations ~10GB │ │ -│ └─────────────────┘ └────────────────────────┘ │ +│ │ vLLM (inference)│ │ Training subprocess │ │ +│ │ KV cache ~60GB │ │ HF model wrapper │ │ +│ │ /completions │ │ Apollo optimizer ~2.5GB │ │ +│ │ /score │ │ Checkpoint sync │ │ +│ └────────┬────────┘ └───────────▲─────────────┘ │ +│ │ │ │ +│ │ ZMQ IPC │ │ +│ └───────────────────────┘ │ └─────────────────────────────────────────────────────┘ -Moria B200 +Process Architecture: +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │ +│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │ +│ │ │ │ │ │ +│ export_hook.py │ │ /completions │ │ HF model views │ +│ exports IPC │ │ /score │ │ Apollo optimizer│ +│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ + └──── IPC handles file ──────────────────┘ + /tmp/vllm_weight_handles.pt + +Moria B200 (vLLM) ┌──────────────────┐ ┌──────────────────┐ -│ Training signal │ HTTP │ Apollo worker │ -│ agent │──────────>│ daemon │ -│ │ │ │ -│ Dream loop │ │ Checkpoint sync │ -│ (generates │ │ (mmap + diff, │ -│ scenarios) │ │ every 10 min) │ +│ Training signal │ HTTP │ /completions │ +│ agent │──────────>│ /score │ +│ │ │ /train │ +│ Dream loop │ │ /checkpoint │ +│ (generates │ │ /train/status │ +│ scenarios) │ │ │ └──────────────────┘ └──────────────────┘ ``` @@ -59,10 +75,9 @@ LoRA trains adapter matrices, not base weights. For personality and behavioral changes that persist as disposition, the base weights need to change. Apollo makes this memory-feasible. -### Rank 256 -Not Mini (rank-1). With 100+ diverse training examples, the -gradient's effective dimensionality can reach hundreds. Rank-256 -captures the structure. Memory cost: ~10GB (negligible on B200). +### Rank 64 +Not Mini (rank-1). Rank-64 captures gradient structure across diverse +training examples while keeping memory low (~2.5GB on 27B model). Compute cost: <0.25% of forward+backward. ### Channel-wise scaling @@ -90,7 +105,7 @@ from a per-parameter seed each step. ### Parameter grouping (Qwen3.5 gotcha) conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector needs 2D matrices with min dimension >= rank. Small/3D tensors -use standard Adam. Large 2D matrices use Apollo with rank-256. +use standard Adam. Large 2D matrices use Apollo. ## Training Data Pipeline @@ -200,16 +215,42 @@ against live GPU weights block by block, memcpy only changed regions. For small behavioral updates, turns a 54GB write into a few hundred MB. -- Every 10 minutes via cron on B200 +- Scheduled 10 minutes after training (batched) - Daily rsync to moria for long-term storage -- Tool: `apollo-checkpoint sync --model-dir ` (Rust) +- Tool: `apollo-checkpoint sync --model-dir ` + +## State Files + +### B200 (training server) + +| File | Purpose | +|------|---------| +| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). | +| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. | +| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. | +| `/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. | + +### Moria (client) + +| File | Purpose | +|------|---------| +| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. | +| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. | + +### In-memory (training_worker subprocess) + +| State | Location | Notes | +|-------|----------|-------| +| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. | +| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. | +| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. | ## Hyperparameters | Parameter | Value | Rationale | |-----------|-------|-----------| | Learning rate | 1e-5 to 1e-4 | Standard for full fine-tuning. Higher for diverse batches. | -| Rank | 256 | Captures gradient structure across 100+ examples. ~10GB state. | +| Rank | 64 | Captures gradient structure. ~2.5GB state. Defined in `train_router.DEFAULT_RANK`. | | Scale type | channel | Per-channel precision, matches LLaMA-Factory defaults. | | Epochs | 1 | One pass over diverse data. Multiple epochs risk overfitting. | | Batch size | 1 | Single examples, immediate updates. | @@ -220,34 +261,32 @@ a few hundred MB. ## Components ### Built ✓ -- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256) -- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking) +- `optimizer.py` — Apollo optimizer (configurable rank) +- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ +- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync) - `weight_mapping.py` — vLLM merged → HF separate views (validated) -- `training_example.py` — tokenization with chat template -- `vllm_export_hook.py` — source patch for IPC handle export -- `checkpoint/` — Rust tool for mmap + diff checkpoint sync +- `export_hook.py` — vLLM plugin hook for IPC handle export +- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python) ### To build -- **Dream loop → training bridge**: connect dream output to Apollo +- **Dream loop → training bridge**: connect dream output to /train - **Training-signal agent**: flags moments in conversation logs - **Instruction stripping**: remove scaffolding from training examples - **Quality monitoring**: track model capability over time -- **HF model forward pass integration**: wire into apollo_worker ## Files ``` training/ - DESIGN.md — this document - apollo_mini.py — Apollo optimizer - apollo_worker.py — HTTP training daemon - weight_mapping.py — vLLM ↔ HF weight views - training_example.py — tokenization helpers - export_weights.py — standalone weight export (unused) - vllm_export_hook.py — vLLM source patch for IPC export - start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch) - train.py — standalone training script (alternative) - checkpoint/ - Cargo.toml — Rust checkpoint tool - src/main.rs — mmap + diff sync + DESIGN.md — this document + pyproject.toml — package config, vLLM plugin entry point + apollo_plugin/ + __init__.py — plugin registration + export_hook.py — patches vLLM worker to export IPC handles + train_router.py — /train endpoint, forwards to worker via ZMQ + training_worker.py — training subprocess (HF model, Apollo, checkpoint) + optimizer.py — Apollo optimizer + weight_mapping.py — vLLM ↔ HF weight views + checkpoint_sync.py — mmap + diff sync to safetensors + steering.py — steering vector extraction (experimental) ``` diff --git a/training/apollo_plugin/__init__.py b/training/apollo_plugin/__init__.py new file mode 100644 index 0000000..b2e121e --- /dev/null +++ b/training/apollo_plugin/__init__.py @@ -0,0 +1,19 @@ +"""Apollo training plugin for vLLM. + +Enables continuous fine-tuning alongside live inference by: +1. Exporting CUDA IPC handles for weight sharing (export_hook) +2. Adding /train endpoint to vLLM's HTTP server (train_router) +3. Block-level checkpoint sync to safetensors files + +Install: pip install -e /path/to/training +Then vLLM auto-loads via entry point. +""" + +from .export_hook import _patch_model_runner +from .train_router import _patch_api_server + + +def register(): + """Called by vLLM's plugin loader on startup.""" + _patch_model_runner() + _patch_api_server() diff --git a/training/apollo_plugin/checkpoint_sync.py b/training/apollo_plugin/checkpoint_sync.py new file mode 100644 index 0000000..c2d7b2f --- /dev/null +++ b/training/apollo_plugin/checkpoint_sync.py @@ -0,0 +1,503 @@ +"""Sync live GPU weights to safetensors files on disk. + +Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's +merged layout to HuggingFace's separate layout, diffs block-by-block +against on-disk safetensors files, and writes only changed blocks. + +For small behavioral training steps, this turns a 54GB checkpoint +write into a few hundred MB of actual disk I/O. + +Usage: + # Sync live weights to disk + python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B + + # Debug name mapping issues + python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B + + # From Python: + from checkpoint_sync import checkpoint_sync + result = checkpoint_sync("/path/to/model") +""" + +import json +import mmap +import struct +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Any +import logging + +import torch + +logger = logging.getLogger(__name__) + +DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size +DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt" + + +# --------------------------------------------------------------------------- +# vLLM → HuggingFace weight name/shape conversion +# --------------------------------------------------------------------------- +# Qwen3.5-27B dimensions (could be read from config.json for generality) + +HIDDEN = 5120 +NUM_K_HEADS = 16 +NUM_V_HEADS = 48 +HEAD_K_DIM = 128 +HEAD_V_DIM = 128 +KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 +VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144 +INTERMEDIATE = 17408 + +# Full attention (some layers use standard attention, not GDN) +NUM_ATTN_HEADS = 24 +NUM_ATTN_KV_HEADS = 4 +ATTN_HEAD_DIM = 256 +ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512 +ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288 +ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 +ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 + + +def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Convert vLLM merged weights to HF-compatible separate tensors. + + vLLM merges certain projections for efficiency: + - qkv_proj (full attn) → q_proj, k_proj, v_proj + - in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z + - in_proj_ba (GDN) → in_proj_b, in_proj_a + - gate_up_proj (MLP) → gate_proj, up_proj + + Returns views that share GPU memory with the original tensors. + """ + hf_params = {} + + for name, tensor in vllm_params.items(): + # Strip vLLM's 'language_model.' prefix to match HF naming + hf_name = name.removeprefix('language_model.') + + if 'in_proj_qkvz' in name: + # GDN layer: [key*2 + value*2, hidden] → qkv + z + prefix = hf_name.replace('in_proj_qkvz.weight', '') + split_at = KEY_DIM * 2 + VALUE_DIM + hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at] + hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:] + + elif 'in_proj_ba' in name: + # GDN layer: [num_v_heads*2, hidden] → b + a + prefix = hf_name.replace('in_proj_ba.weight', '') + hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS] + hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:] + + elif 'qkv_proj' in name: + # Full attention: [q + k + v, hidden] → separate + prefix = hf_name.replace('qkv_proj.weight', '') + hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM] + hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM] + hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:] + + elif 'gate_up_proj' in name: + # MLP: [intermediate*2, hidden] → gate + up + prefix = hf_name.replace('gate_up_proj.weight', '') + hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE] + hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:] + + else: + # Pass through unchanged + hf_params[hf_name] = tensor + + return hf_params + + +# --------------------------------------------------------------------------- +# Safetensors file handling +# --------------------------------------------------------------------------- + +def read_safetensors_index(model_dir: Path) -> Dict[str, str]: + """Map tensor names to safetensors filenames. + + For sharded models, reads model.safetensors.index.json. + For single-file models, returns empty dict (default to model.safetensors). + """ + index_path = model_dir / "model.safetensors.index.json" + if not index_path.exists(): + return {} + + with open(index_path) as f: + index = json.load(f) + + return dict(index.get("weight_map", {})) + + +def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]: + """Parse safetensors file header. + + Returns (data_start_offset, header_dict). + Header dict maps tensor names to metadata including 'data_offsets'. + """ + header_size = struct.unpack(' Tuple[int, int]: + """Sync a single tensor to mmap'd file using block-level diffing. + + Returns (bytes_compared, bytes_changed). + """ + start = data_start + offsets[0] + end = data_start + offsets[1] + disk_len = end - start + + # Transfer tensor to CPU and get raw bytes + # Use .detach() to avoid autograd overhead, .contiguous() for memory layout + try: + live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes() + except Exception as e: + logger.warning(f"Failed to transfer {name} to CPU: {e}") + return 0, 0 + + if len(live_bytes) != disk_len: + logger.warning( + f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} " + f"(shape={list(tensor.shape)}, dtype={tensor.dtype})" + ) + return 0, 0 + + # Block-level diff: compare and write only changed blocks + compared = 0 + changed = 0 + offset = 0 + + while offset < disk_len: + block_end = min(offset + block_size, disk_len) + block_len = block_end - offset + + disk_block = mm[start + offset:start + block_end] + live_block = live_bytes[offset:block_end] + + compared += block_len + + if disk_block != live_block: + mm[start + offset:start + block_end] = live_block + changed += block_len + + offset = block_end + + return compared, changed + + +def sync_file( + file_path: Path, + tensors: Dict[str, torch.Tensor], + block_size: int, +) -> Tuple[int, int, int, int]: + """Sync tensors to a single safetensors file. + + Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing). + """ + with open(file_path, 'r+b') as f: + mm = mmap.mmap(f.fileno(), 0) + + try: + data_start, header = parse_safetensors_header(memoryview(mm)) + + total_compared = 0 + total_changed = 0 + found = 0 + missing = 0 + + for name, tensor in tensors.items(): + if name == "__metadata__": + continue + + if name not in header: + missing += 1 + continue + + found += 1 + meta = header[name] + offsets = meta['data_offsets'] + + compared, changed = sync_tensor_to_mmap( + mm, name, tensor, data_start, offsets, block_size + ) + total_compared += compared + total_changed += changed + + # Flush changes to disk + if total_changed > 0: + mm.flush() + + return total_compared, total_changed, found, missing + + finally: + mm.close() + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]: + """Load vLLM weight tensors from CUDA IPC handles. + + The handles file is written by vllm_export_hook.py on vLLM startup. + Each handle can be used to reconstruct a tensor pointing to vLLM's + GPU memory — no copy, direct access. + """ + handles = torch.load(handles_path, weights_only=False) + + # Skip metadata entry + handles.pop('__metadata__', None) + + weights = {} + for name, info in handles.items(): + func, args = info['handle'] + try: + weights[name] = func(*args) + except Exception as e: + logger.warning(f"Failed to reconstruct {name}: {e}") + + return weights + + +def checkpoint_sync( + model_dir: str, + handles_path: str = DEFAULT_HANDLES_PATH, + block_size: int = DEFAULT_BLOCK_SIZE, +) -> Dict[str, Any]: + """Sync live GPU weights to model safetensors files. + + This is the main entry point. Call this after training steps + or periodically to checkpoint weights without full serialization. + + Args: + model_dir: Directory containing safetensors files + handles_path: Path to vLLM weight IPC handles file + block_size: Block size for diffing (default 4KB) + + Returns: + Dict with sync statistics: + - total_compared: bytes compared + - total_changed: bytes actually written + - files_changed: list of modified filenames + - tensors_synced: number of tensors processed + - tensors_missing: tensors not found in safetensors + """ + model_dir = Path(model_dir) + + if not Path(handles_path).exists(): + raise FileNotFoundError( + f"Weight handles not found: {handles_path}. " + "Is vLLM running with the export hook?" + ) + + # Step 1: Load live weights from GPU via IPC + logger.info("Loading live weights from GPU...") + vllm_weights = load_vllm_weights(handles_path) + logger.info(f" Loaded {len(vllm_weights)} vLLM tensors") + + # Step 2: Convert to HF naming/layout + hf_weights = vllm_to_hf_tensors(vllm_weights) + logger.info(f" Converted to {len(hf_weights)} HF tensors") + + # Step 3: Map tensors to safetensors files + weight_map = read_safetensors_index(model_dir) + + by_file: Dict[str, Dict[str, torch.Tensor]] = {} + unmapped = [] + + for name, tensor in hf_weights.items(): + filename = weight_map.get(name) + if filename is None: + # Single-file model or missing from index + if (model_dir / "model.safetensors").exists(): + filename = "model.safetensors" + else: + unmapped.append(name) + continue + by_file.setdefault(filename, {})[name] = tensor + + if unmapped: + logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...") + + # Step 4: Sync each file + total_compared = 0 + total_changed = 0 + total_found = 0 + total_missing = 0 + files_changed = [] + + for filename in sorted(by_file.keys()): + tensors = by_file[filename] + file_path = model_dir / filename + + if not file_path.exists(): + logger.warning(f" File not found: {filename}") + total_missing += len(tensors) + continue + + compared, changed, found, missing = sync_file(file_path, tensors, block_size) + + total_compared += compared + total_changed += changed + total_found += found + total_missing += missing + + if changed > 0: + files_changed.append(filename) + logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)") + + # Summary + if total_changed == 0: + logger.info("No changes - model files are up to date") + else: + pct = (total_changed / total_compared * 100) if total_compared > 0 else 0 + logger.info( + f"Synced: {total_changed / 1e6:.2f} MB changed / " + f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)" + ) + + if total_missing > 0: + logger.warning(f" {total_missing} tensors not found in safetensors files") + + return { + "total_compared": total_compared, + "total_changed": total_changed, + "files_changed": files_changed, + "tensors_synced": total_found, + "tensors_missing": total_missing, + } + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH): + """Print diagnostic info about weight name mappings. + + Useful for debugging mismatches between vLLM and safetensors names. + """ + model_dir = Path(model_dir) + + # Load and convert vLLM weights + vllm_weights = load_vllm_weights(handles_path) + hf_weights = vllm_to_hf_tensors(vllm_weights) + hf_names = set(hf_weights.keys()) + + # Read safetensors index + weight_map = read_safetensors_index(model_dir) + disk_names = set(weight_map.keys()) + + # If single-file model, parse that file's header + if not disk_names: + st_path = model_dir / "model.safetensors" + if st_path.exists(): + with open(st_path, 'rb') as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + _, header = parse_safetensors_header(memoryview(mm)) + disk_names = {k for k in header.keys() if k != "__metadata__"} + mm.close() + + print(f"vLLM tensors (raw): {len(vllm_weights)}") + print(f"HF tensors (converted): {len(hf_names)}") + print(f"Disk tensors: {len(disk_names)}") + print() + + in_both = hf_names & disk_names + only_hf = hf_names - disk_names + only_disk = disk_names - hf_names + + print(f"Matched: {len(in_both)}") + print(f"Only in HF (won't sync): {len(only_hf)}") + print(f"Only on disk (not updated): {len(only_disk)}") + + if only_hf: + print(f"\nSample HF-only: {sorted(only_hf)[:5]}") + if only_disk: + print(f"\nSample disk-only: {sorted(only_disk)[:5]}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Sync live GPU weights to safetensors files" + ) + subparsers = parser.add_subparsers(dest="command", help="Command") + + # sync command + sync_parser = subparsers.add_parser("sync", help="Sync weights to disk") + sync_parser.add_argument( + "--model-dir", required=True, + help="Directory containing safetensors files" + ) + sync_parser.add_argument( + "--handles", default=DEFAULT_HANDLES_PATH, + help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" + ) + sync_parser.add_argument( + "--block-size", type=int, default=DEFAULT_BLOCK_SIZE, + help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})" + ) + sync_parser.add_argument( + "-v", "--verbose", action="store_true", + help="Verbose output" + ) + + # diagnose command + diag_parser = subparsers.add_parser("diagnose", help="Check name mappings") + diag_parser.add_argument( + "--model-dir", required=True, + help="Directory containing safetensors files" + ) + diag_parser.add_argument( + "--handles", default=DEFAULT_HANDLES_PATH, + help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" + ) + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(1) + + logging.basicConfig( + level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO, + format='%(message)s' + ) + + try: + if args.command == "sync": + result = checkpoint_sync(args.model_dir, args.handles, args.block_size) + print(json.dumps(result, indent=2)) + elif args.command == "diagnose": + diagnose(args.model_dir, args.handles) + except FileNotFoundError as e: + logger.error(str(e)) + sys.exit(1) + except Exception as e: + logger.exception(f"Failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/training/vllm_export_hook.py b/training/apollo_plugin/export_hook.py similarity index 76% rename from training/vllm_export_hook.py rename to training/apollo_plugin/export_hook.py index 6a0bf1e..e0ff6fc 100644 --- a/training/vllm_export_hook.py +++ b/training/apollo_plugin/export_hook.py @@ -1,17 +1,12 @@ """Monkey-patch vLLM to export weight IPC handles on startup. -Usage — add to start_vllm.sh BEFORE the vllm serve command: +Usage — install the apollo_plugin package: - export VLLM_PLUGINS=vllm_export_hook - vllm serve Qwen/Qwen3.5-27B ... + pip install -e /path/to/training -Or use Python to launch vLLM with the hook: +Then vLLM auto-discovers and loads via entry point. Or filter: - python3 -c " - import vllm_export_hook # installs the patch - from vllm.entrypoints.openai.api_server import run_server - run_server(...) - " + VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ... The hook patches vLLM's model runner to export IPC handles after model loading completes. The handles are saved to a file that the @@ -25,7 +20,7 @@ from pathlib import Path HANDLE_PATH = "/tmp/vllm_weight_handles.pt" -def export_model_weights(model): +def export_model_weights(model, model_path: str | None = None): """Export CUDA IPC handles for all model parameters.""" from torch.multiprocessing.reductions import reduce_tensor @@ -43,6 +38,12 @@ def export_model_weights(model): } total_bytes += param.nelement() * param.element_size() + # Include metadata for training worker + handles['__metadata__'] = { + 'model_path': model_path, + 'num_params': len(handles), + } + torch.save(handles, HANDLE_PATH) print(f"[apollo] Exported {len(handles)} weight handles " f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}") @@ -63,14 +64,11 @@ def _patch_model_runner(): def patched_load(self, *args, **kwargs): result = original_load(self, *args, **kwargs) try: - export_model_weights(self.model_runner.model) + model_path = self.vllm_config.model_config.model + export_model_weights(self.model_runner.model, model_path) except Exception as e: print(f"[apollo] Failed to export weights: {e}") return result gpu_worker.Worker.load_model = patched_load print("[apollo] Weight export hook installed") - - -# Auto-install when imported -_patch_model_runner() diff --git a/training/apollo_mini.py b/training/apollo_plugin/optimizer.py similarity index 97% rename from training/apollo_mini.py rename to training/apollo_plugin/optimizer.py index 166ae3a..9abce94 100644 --- a/training/apollo_mini.py +++ b/training/apollo_plugin/optimizer.py @@ -8,9 +8,9 @@ Channel-wise or tensor-wise scaling is sufficient. Apollo approximates these scaling factors using a low-rank auxiliary optimizer state based on pure random projection. -Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% -compute overhead vs forward+backward. Captures gradient structure -across 100+ behavioral training examples per batch. +Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead +vs forward+backward. Sufficient for behavioral training with diverse +examples. Key implementation details from the paper: - Gradient scale factor α = √(n/r) compensates for projection ratio @@ -34,7 +34,7 @@ class Apollo(Optimizer): Args: params: model parameters lr: learning rate (default: 1e-4) - rank: projection rank (default: 256) + rank: projection rank (default: 64) betas: Adam momentum coefficients (default: (0.9, 0.999)) eps: numerical stability term (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) @@ -46,7 +46,7 @@ class Apollo(Optimizer): Set to None to disable. """ - def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), + def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, warmup_steps=0, scale=None, proj_refresh=200, norm_growth_limit=1.01): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, diff --git a/training/extract_steering_vector.py b/training/apollo_plugin/steering.py similarity index 100% rename from training/extract_steering_vector.py rename to training/apollo_plugin/steering.py diff --git a/training/apollo_plugin/train_router.py b/training/apollo_plugin/train_router.py new file mode 100644 index 0000000..d6f90b4 --- /dev/null +++ b/training/apollo_plugin/train_router.py @@ -0,0 +1,240 @@ +"""Training endpoint for vLLM - forwards to training subprocess via ZMQ. + +Patches vLLM's build_app() to add /train route. The actual training runs +in a dedicated subprocess (training_worker.py) to avoid blocking the +event loop and to keep training work isolated from vLLM internals. +""" + +import asyncio +import logging +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Any + +import zmq +import zmq.asyncio + +from fastapi import APIRouter, FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter() + +DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" + +# Global state for subprocess management +_worker_process: subprocess.Popen | None = None +_zmq_context: zmq.asyncio.Context | None = None +_zmq_socket: zmq.asyncio.Socket | None = None +_initialized: bool = False + + +class TrainRequest(BaseModel): + training_data: dict[str, Any] # {"samples": [...], "config": {...}} + + +class TrainResponse(BaseModel): + job_id: str + status: str + training_samples: int + loss_history: list[float] + + +def _start_worker_subprocess(): + """Start the training worker subprocess.""" + global _worker_process + + if _worker_process is not None and _worker_process.poll() is None: + return # Still running + + # Start worker as subprocess using script path + worker_script = Path(__file__).parent / 'training_worker.py' + _worker_process = subprocess.Popen( + [sys.executable, str(worker_script)], + env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR}, + ) + logger.info(f"Started training worker subprocess (pid={_worker_process.pid})") + + # Give it a moment to bind the socket + import time + time.sleep(0.5) + + +def _ensure_initialized(): + """Ensure subprocess is running and ZMQ socket is connected.""" + global _zmq_context, _zmq_socket, _initialized + + if _initialized: + return + + # Start worker if needed + _start_worker_subprocess() + + # Create async ZMQ context and socket + _zmq_context = zmq.asyncio.Context() + _zmq_socket = _zmq_context.socket(zmq.REQ) + _zmq_socket.connect(DEFAULT_ZMQ_ADDR) + + # Set timeout for recv + _zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training + + _initialized = True + logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}") + + +async def _send_request(request: dict[str, Any]) -> dict[str, Any]: + """Send request to worker and wait for response.""" + _ensure_initialized() + + # ZMQ async send/recv + await _zmq_socket.send_json(request) + response = await _zmq_socket.recv_json() + return response + + +@router.post("/train") +async def handle_train(request: TrainRequest): + """Handle training request - forwards to training subprocess.""" + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={"error": f"Training not available: {e}"}, + status_code=503, + ) + + try: + training_data = request.training_data + samples = training_data.get("samples", []) + config = training_data.get("config", {}) + + if not samples: + return JSONResponse( + content={"error": "No training samples provided"}, + status_code=400, + ) + + job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + logger.info(f"Starting training job {job_id} with {len(samples)} samples") + + # Forward to worker + response = await _send_request({ + 'type': 'train', + 'samples': samples, + 'config': config, + }) + + if 'error' in response: + return JSONResponse( + content={"error": response['error']}, + status_code=500, + ) + + logger.info( + f"Training job {job_id} completed, " + f"final loss: {response['loss_history'][-1]:.4f}" + ) + + return JSONResponse(content={ + "job_id": job_id, + "status": response['status'], + "training_samples": response['training_samples'], + "loss_history": response['loss_history'], + }) + + except zmq.Again: + logger.error("Training request timed out") + return JSONResponse( + content={"error": "Training request timed out"}, + status_code=504, + ) + except Exception as e: + logger.exception(f"Training failed: {e}") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + +@router.post("/checkpoint") +async def handle_checkpoint(): + """Trigger checkpoint sync to disk.""" + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={"error": f"Training not available: {e}"}, + status_code=503, + ) + + try: + response = await _send_request({'type': 'checkpoint'}) + + if 'error' in response: + return JSONResponse( + content={"error": response['error']}, + status_code=500, + ) + + return JSONResponse(content=response) + + except Exception as e: + logger.exception(f"Checkpoint failed: {e}") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + +@router.get("/train/status") +async def handle_status(): + """Get training worker status.""" + try: + _ensure_initialized() + except Exception as e: + return JSONResponse( + content={ + "status": "unavailable", + "error": str(e), + }, + status_code=503, + ) + + try: + response = await _send_request({'type': 'status'}) + return JSONResponse(content=response) + + except Exception as e: + return JSONResponse( + content={ + "status": "error", + "error": str(e), + }, + status_code=500, + ) + + +def attach_router(app: FastAPI): + """Attach training router to FastAPI app.""" + app.include_router(router) + logger.info("Training router attached") + + +def _patch_api_server(): + """Patch vLLM's build_app to include our training router.""" + from vllm.entrypoints.openai import api_server + + original_build_app = api_server.build_app + + def patched_build_app(*args, **kwargs): + app = original_build_app(*args, **kwargs) + attach_router(app) + return app + + api_server.build_app = patched_build_app + logger.info("API server patched for /train endpoint") diff --git a/training/apollo_plugin/training_worker.py b/training/apollo_plugin/training_worker.py new file mode 100644 index 0000000..f8b8c23 --- /dev/null +++ b/training/apollo_plugin/training_worker.py @@ -0,0 +1,323 @@ +"""Training subprocess - handles Apollo training and checkpoint sync. + +Long-lived process that: +1. Loads IPC handles from vLLM's exported weights +2. Creates HF model with views into vLLM's GPU memory +3. Handles training requests via ZMQ +4. Handles checkpoint sync requests +5. Persists Apollo optimizer state between calls + +Communicates with the API server's /train endpoint via ZMQ REP socket. +""" + +import logging +import os +import signal +import sys +from pathlib import Path +from typing import Any + +# Handle running as script vs module +if __name__ == '__main__' and __package__ is None: + # Running as script - add parent to path for imports + sys.path.insert(0, str(Path(__file__).parent.parent)) + __package__ = 'apollo_plugin' + +import torch +import torch.nn as nn +import zmq + +from .checkpoint_sync import checkpoint_sync +from .optimizer import Apollo +from .weight_mapping import load_hf_model_with_vllm_weights + +logger = logging.getLogger(__name__) + +DEFAULT_RANK = 64 +DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" +HANDLE_PATH = "/tmp/vllm_weight_handles.pt" +OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" + + +class TrainingWorker: + """Long-lived training worker process.""" + + def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR): + self.zmq_addr = zmq_addr + self.model: nn.Module | None = None + self.optimizer: Apollo | None = None + self.model_path: str | None = None + self._running = True + + def _create_model_wrapper(self) -> nn.Module: + """Create HF model wrapper with views into vLLM's GPU memory.""" + if not os.path.exists(HANDLE_PATH): + raise FileNotFoundError( + f"Weight handles not found: {HANDLE_PATH}. " + "Is vLLM running with the export hook?" + ) + + handles = torch.load(HANDLE_PATH, weights_only=False) + + # Extract metadata + metadata = handles.pop('__metadata__', {}) + self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH') + if not self.model_path: + raise ValueError( + "Model path not found in handles metadata or APOLLO_MODEL_PATH env var" + ) + + # Reconstruct tensors from IPC handles + vllm_params = {} + for name, info in handles.items(): + func, args = info['handle'] + vllm_params[name] = func(*args) + + model = load_hf_model_with_vllm_weights(vllm_params, self.model_path) + model.train() + return model + + def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo: + """Get existing optimizer or create new one.""" + if self.optimizer is not None: + return self.optimizer + + # Build parameter groups (Apollo for 2D+, standard Adam for small/1D) + apollo_params, standard_params = [], [] + for p in self.model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + if not groups: + raise ValueError("No trainable parameters found") + + self.optimizer = Apollo( + groups, + lr=config.get('lr', 1e-5), + rank=config.get('rank', DEFAULT_RANK), + betas=tuple(config.get('betas', (0.9, 0.999))), + eps=config.get('eps', 1e-8), + weight_decay=config.get('weight_decay', 0.01), + warmup_steps=config.get('warmup_steps', 0), + scale=config.get('scale'), + proj_refresh=config.get('proj_refresh', 200), + norm_growth_limit=config.get('norm_growth_limit', 1.01), + ) + + # Restore state if exists + if os.path.exists(OPTIMIZER_STATE_PATH): + try: + state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) + self.optimizer.load_state_dict(state) + logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}") + except Exception as e: + logger.warning(f"Could not restore optimizer state: {e}") + + logger.info( + f"Optimizer: {len(apollo_params)} apollo params, " + f"{len(standard_params)} standard, " + f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB" + ) + + return self.optimizer + + def _save_optimizer_state(self): + """Save optimizer state for persistence.""" + if self.optimizer is not None: + torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH) + logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}") + + def _run_training( + self, + samples: list[dict[str, Any]], + config: dict[str, Any], + ) -> list[float]: + """Run Apollo training on the given samples.""" + optimizer = self._get_or_create_optimizer(config) + + loss_history = [] + + for i, sample in enumerate(samples): + ctx_ids = sample['context_ids'] + cont_ids = sample['continuation_ids'] + all_ids = ctx_ids + cont_ids + context_len = len(ctx_ids) + + input_ids = torch.tensor([all_ids], device='cuda:0') + + optimizer.zero_grad() + + # Context-frozen forward pass + with torch.no_grad(): + outputs = self.model(input_ids[:, :context_len], use_cache=True) + past_kv = outputs.past_key_values + + # Decision tokens with gradients + with torch.enable_grad(): + outputs = self.model( + input_ids[:, context_len:], + past_key_values=past_kv, + use_cache=False, + ) + logits = outputs.logits + + # Shift: predict next token from each position + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + + loss = nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + loss.backward() + optimizer.step() + + loss_val = loss.item() + loss_history.append(loss_val) + logger.info( + f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " + f"(ctx={context_len}, cont={len(cont_ids)} tokens)" + ) + + return loss_history + + def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a training request.""" + samples = request.get('samples', []) + config = request.get('config', {}) + + if not samples: + return {'error': 'No training samples provided'} + + try: + loss_history = self._run_training(samples, config) + return { + 'status': 'completed', + 'training_samples': len(samples), + 'loss_history': loss_history, + } + except Exception as e: + logger.exception(f"Training failed: {e}") + return {'error': str(e)} + + def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a checkpoint sync request.""" + if not self.model_path: + return {'error': 'Model path not set'} + + try: + self._save_optimizer_state() + result = checkpoint_sync(self.model_path) + return { + 'status': 'completed', + 'total_changed': result['total_changed'], + 'files_changed': result['files_changed'], + } + except Exception as e: + logger.exception(f"Checkpoint sync failed: {e}") + return {'error': str(e)} + + def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]: + """Handle a status request.""" + return { + 'status': 'ready', + 'model_loaded': self.model is not None, + 'optimizer_loaded': self.optimizer is not None, + 'model_path': self.model_path, + 'optimizer_state_mb': ( + self.optimizer.state_size_bytes() / 1e6 + if self.optimizer else 0 + ), + } + + def run(self): + """Main loop - listen for requests and handle them.""" + # Set up signal handlers + def handle_signal(signum, frame): + logger.info(f"Received signal {signum}, shutting down...") + self._running = False + + signal.signal(signal.SIGTERM, handle_signal) + signal.signal(signal.SIGINT, handle_signal) + + # Set up ZMQ socket first so API server can connect + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind(self.zmq_addr) + logger.info(f"Training worker listening on {self.zmq_addr}") + + # Create HF model wrapper with views into vLLM's GPU memory + logger.info("Connecting to vLLM weights via IPC handles...") + try: + self.model = self._create_model_wrapper() + logger.info("HF model wrapper ready (views into vLLM GPU memory)") + except Exception as e: + logger.error(f"Failed to connect to vLLM weights: {e}") + logger.info("Will retry on first training request") + + # Set socket timeout so we can check _running flag + socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout + + while self._running: + try: + message = socket.recv_json() + except zmq.Again: + # Timeout, check _running and continue + continue + + request_type = message.get('type', 'train') + logger.info(f"Received {request_type} request") + + # Ensure model is loaded + if self.model is None and request_type != 'status': + try: + self.model = self._create_model_wrapper() + except Exception as e: + socket.send_json({'error': f'Model not loaded: {e}'}) + continue + + # Dispatch request + if request_type == 'train': + response = self._handle_train(message) + elif request_type == 'checkpoint': + response = self._handle_checkpoint(message) + elif request_type == 'status': + response = self._handle_status(message) + else: + response = {'error': f'Unknown request type: {request_type}'} + + socket.send_json(response) + + # Cleanup + logger.info("Saving optimizer state before shutdown...") + self._save_optimizer_state() + socket.close() + context.term() + logger.info("Training worker shut down") + + +def main(): + """Entry point for running as a subprocess.""" + logging.basicConfig( + level=logging.INFO, + format='[apollo-worker] %(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S', + ) + + zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR) + worker = TrainingWorker(zmq_addr) + worker.run() + + +if __name__ == '__main__': + main() diff --git a/training/weight_mapping.py b/training/apollo_plugin/weight_mapping.py similarity index 100% rename from training/weight_mapping.py rename to training/apollo_plugin/weight_mapping.py diff --git a/training/apollo_worker.py b/training/apollo_worker.py deleted file mode 100755 index d46fb55..0000000 --- a/training/apollo_worker.py +++ /dev/null @@ -1,454 +0,0 @@ -#!/usr/bin/env python3 -""" -Apollo Mini Training Daemon - -This daemon: -1. Listens over HTTPS for training requests from poc-agent -2. Pauses vLLM inference -3. Runs APOLLO-Mini training with torch.enable_grad() -4. Saves checkpoints and training metadata -5. Resumes vLLM inference - -Communication protocol: -- POST /train: Start a training job -- GET /status/{job_id}: Check training status -- GET /checkpoints: List available checkpoints -""" - -import asyncio -import json -import logging -import os -import sys -import time -from dataclasses import dataclass, field, asdict -from datetime import datetime -from pathlib import Path -from typing import Optional, Dict, Any, List -from enum import Enum - -import torch -import torch.nn as nn -from aiohttp import web - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('apollo_worker') - -class TrainingStatus(Enum): - PENDING = "pending" - PAUSING_VLLM = "pausing_vllm" - TRAINING = "training" - SAVING_CHECKPOINT = "saving_checkpoint" - RESUMING_VLLM = "resuming_vllm" - COMPLETED = "completed" - FAILED = "failed" - -@dataclass -class TrainingJob: - job_id: str - status: TrainingStatus - created_at: datetime - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - model_path: Optional[str] = None - checkpoint_path: Optional[str] = None - training_samples: int = 0 - loss_history: List[float] = field(default_factory=list) - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return { - 'job_id': self.job_id, - 'status': self.status.value, - 'created_at': self.created_at.isoformat(), - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'model_path': self.model_path, - 'checkpoint_path': self.checkpoint_path, - 'training_samples': self.training_samples, - 'loss_history': self.loss_history, - 'error': self.error, - } - -class ApolloWorker: - def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): - self.config = self._load_config(config_path) - self.jobs: Dict[str, TrainingJob] = {} - self.vllm_paused = False - self.app = web.Application() - self._setup_routes() - - def _load_config(self, config_path: str) -> Dict[str, Any]: - """Load configuration from file or use defaults.""" - default_config = { - 'host': '0.0.0.0', - 'port': 8080, - 'vllm_socket': '/tmp/vllm_control.sock', - 'model_path': '/home/ubuntu/models/Qwen3.5-27B', - 'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints', - 'max_training_samples': 100, - 'learning_rate': 1e-5, - 'batch_size': 1, - } - - if os.path.exists(config_path): - with open(config_path, 'r') as f: - user_config = json.load(f) - default_config.update(user_config) - - Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) - return default_config - - def _setup_routes(self): - """Setup HTTP routes.""" - self.app.router.add_post('/train', self.handle_train_request) - self.app.router.add_get('/status/{job_id}', self.handle_status_request) - self.app.router.add_get('/checkpoints', self.handle_list_checkpoints) - self.app.router.add_get('/health', self.handle_health_check) - - async def handle_health_check(self, request: web.Request) -> web.Response: - """Health check endpoint.""" - return web.json_response({ - 'status': 'healthy', - 'vllm_paused': self.vllm_paused, - 'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]]) - }) - - async def handle_train_request(self, request: web.Request) -> web.Response: - """Handle training request from poc-agent.""" - try: - data = await request.json() - - # Validate required fields - if 'training_data' not in data: - return web.json_response( - {'error': 'Missing training_data field'}, - status=400 - ) - - job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}" - job = TrainingJob( - job_id=job_id, - status=TrainingStatus.PENDING, - created_at=datetime.now(), - model_path=self.config['model_path'] - ) - self.jobs[job_id] = job - - # Start training in background - asyncio.create_task(self.execute_training(job, data)) - - return web.json_response({ - 'job_id': job_id, - 'status': 'accepted', - 'message': 'Training job started' - }) - - except Exception as e: - logger.error(f"Error handling train request: {e}") - return web.json_response( - {'error': str(e)}, - status=500 - ) - - async def handle_status_request(self, request: web.Request) -> web.Response: - """Get training job status.""" - job_id = request.match_info['job_id'] - - if job_id not in self.jobs: - return web.json_response( - {'error': 'Job not found'}, - status=404 - ) - - job = self.jobs[job_id] - return web.json_response(job.to_dict()) - - async def handle_list_checkpoints(self, request: web.Request) -> web.Response: - """List available checkpoints.""" - checkpoint_dir = Path(self.config['checkpoint_dir']) - checkpoints = [] - - if checkpoint_dir.exists(): - for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True): - checkpoints.append({ - 'filename': checkpoint_file.name, - 'path': str(checkpoint_file), - 'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(), - 'size': checkpoint_file.stat().st_size - }) - - return web.json_response({'checkpoints': checkpoints}) - - async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]): - """Execute the training pipeline.""" - try: - logger.info(f"Starting training job {job.job_id}") - job.started_at = datetime.now() - - # Step 1: Pause vLLM - job.status = TrainingStatus.PAUSING_VLLM - logger.info("Pausing vLLM...") - await self.pause_vllm() - self.vllm_paused = True - - # Step 2: Load model and prepare for training - job.status = TrainingStatus.TRAINING - logger.info("Loading model and preparing for training...") - - # Load model (this would be the actual Qwen3.5-27B model) - # For now, we'll use a placeholder - model = await self.load_model_for_training() - - # Step 3: Run APOLLO-Mini training - logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples") - - # Extract training samples - samples = training_data['samples'] - job.training_samples = len(samples) - - # Run training loop - loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {})) - job.loss_history = loss_history - - # Step 4: Save checkpoint - job.status = TrainingStatus.SAVING_CHECKPOINT - logger.info("Saving checkpoint...") - checkpoint_path = await self.save_checkpoint(model, job) - job.checkpoint_path = checkpoint_path - - # Step 5: Resume vLLM - job.status = TrainingStatus.RESUMING_VLLM - logger.info("Resuming vLLM...") - await self.resume_vllm() - self.vllm_paused = False - - # Mark job as completed - job.status = TrainingStatus.COMPLETED - job.completed_at = datetime.now() - - logger.info(f"Training job {job.job_id} completed successfully") - - except Exception as e: - logger.error(f"Training job {job.job_id} failed: {e}") - job.status = TrainingStatus.FAILED - job.error = str(e) - job.completed_at = datetime.now() - - # Try to resume vLLM if it was paused - if self.vllm_paused: - try: - await self.resume_vllm() - self.vllm_paused = False - except Exception as resume_error: - logger.error(f"Failed to resume vLLM after training error: {resume_error}") - - async def pause_vllm(self): - """Pause vLLM inference via HTTP API.""" - import aiohttp as aio - url = self.config.get('vllm_url', 'http://localhost:8000') - try: - async with aio.ClientSession() as session: - async with session.post( - f"{url}/pause_generation", - json={"mode": "keep", "clear_cache": False}, - timeout=aio.ClientTimeout(total=10), - ) as resp: - resp.raise_for_status() - logger.info("vLLM paused") - except Exception as e: - logger.warning(f"Failed to pause vLLM: {e}") - - async def resume_vllm(self): - """Resume vLLM inference via HTTP API.""" - import aiohttp as aio - url = self.config.get('vllm_url', 'http://localhost:8000') - try: - async with aio.ClientSession() as session: - async with session.post( - f"{url}/resume_generation", - timeout=aio.ClientTimeout(total=10), - ) as resp: - resp.raise_for_status() - logger.info("vLLM resumed") - except Exception as e: - logger.warning(f"Failed to resume vLLM: {e}") - - async def load_model_for_training(self) -> nn.Module: - """Load HF model with weights pointing to vLLM's GPU memory. - - Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible - views (narrowing merged weights into separate q/k/v/z etc.), and - constructs the HF model around those views. No weight copying — - all parameters share vLLM's GPU memory. - """ - handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt') - model_path = self.config['model_path'] - - # Import vLLM weights via CUDA IPC - logger.info(f"Importing vLLM weights from {handle_path}") - handles = torch.load(handle_path, weights_only=False) - vllm_params = {} - for name, info in handles.items(): - func, args = info['handle'] - vllm_params[name] = func(*args) - logger.info(f"Imported {len(vllm_params)} parameters") - - # Map vLLM merged layout → HF separate layout (views, no copies) - from weight_mapping import load_hf_model_with_vllm_weights - model = load_hf_model_with_vllm_weights(vllm_params, model_path) - logger.info("HF model constructed with vLLM weight views") - - return model - - async def run_apollo_training(self, model: nn.Module, - samples: List[Dict[str, str]], - config: Dict[str, Any]) -> List[float]: - """Run Apollo-Mini training on conversation decision points.""" - from apollo_mini import Apollo - from transformers import AutoTokenizer - - lr = config.get('learning_rate', self.config['learning_rate']) - tokenizer = AutoTokenizer.from_pretrained( - self.config['model_path'], trust_remote_code=True) - - # Build parameter groups (Apollo for 2D+, standard for small/1D) - apollo_params, standard_params = [], [] - for p in model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= 2: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - rank = config.get('apollo_rank', 1) - optimizer = Apollo(groups, lr=lr, rank=rank) - logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, " - f"{len(standard_params)} standard, " - f"state={optimizer.state_size_bytes()/1e6:.1f}MB") - - loss_history = [] - - for i, sample in enumerate(samples): - context = sample.get('context', '') - continuation = sample.get('continuation', '') - - # Tokenize - ctx_ids = tokenizer.encode(context, add_special_tokens=True) - cont_ids = tokenizer.encode(continuation, add_special_tokens=False) - all_ids = ctx_ids + cont_ids - context_len = len(ctx_ids) - - input_ids = torch.tensor([all_ids], device='cuda:0') - - optimizer.zero_grad() - - # Context-frozen forward pass - with torch.no_grad(): - # Forward through context (no gradients) - outputs = model(input_ids[:, :context_len], use_cache=True) - past_kv = outputs.past_key_values - - # Decision tokens with gradients - with torch.enable_grad(): - outputs = model( - input_ids[:, context_len:], - past_key_values=past_kv, - use_cache=False, - ) - logits = outputs.logits # [1, cont_len, vocab] - - # Shift: predict next token from each position - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - - loss = nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) - - loss.backward() - optimizer.step() - - loss_val = loss.item() - loss_history.append(loss_val) - logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " - f"(ctx={context_len}, cont={len(cont_ids)} tokens)") - - logger.info(f"Training done: {len(samples)} examples, " - f"final loss={loss_history[-1]:.4f}") - return loss_history - - async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str: - """Save model checkpoint in HuggingFace safetensors format.""" - from safetensors.torch import save_file - import shutil - - checkpoint_dir = Path(self.config['checkpoint_dir']) - date_str = datetime.now().strftime('%Y-%m-%d') - out_dir = checkpoint_dir / date_str - out_dir.mkdir(parents=True, exist_ok=True) - - # Save weights - tensors = {name: p.data.contiguous().cpu() - for name, p in model.named_parameters()} - save_path = out_dir / "model.safetensors" - save_file(tensors, str(save_path)) - - # Copy config files - config_dir = Path(self.config['model_path']) - for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', - 'special_tokens_map.json']: - src = config_dir / f - if src.exists(): - shutil.copy2(src, out_dir / f) - - # Save training metadata - meta = { - 'job_id': job.job_id, - 'training_samples': job.training_samples, - 'loss_history': job.loss_history, - 'timestamp': datetime.now().isoformat(), - } - with open(out_dir / 'training-meta.json', 'w') as f: - json.dump(meta, f, indent=2) - - # Update latest symlink - latest = checkpoint_dir / 'latest' - if latest.is_symlink(): - latest.unlink() - latest.symlink_to(date_str) - - size_gb = save_path.stat().st_size / 1e9 - logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)") - return str(out_dir) - - async def run(self): - """Run the daemon.""" - logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}") - runner = web.AppRunner(self.app) - await runner.setup() - site = web.TCPSite(runner, self.config['host'], self.config['port']) - await site.start() - logger.info("Apollo Worker is running") - - # Keep running - while True: - await asyncio.sleep(3600) # Sleep for an hour - -def main(): - worker = ApolloWorker() - asyncio.run(worker.run()) - -if __name__ == '__main__': - main() diff --git a/training/checkpoint/Cargo.toml b/training/checkpoint/Cargo.toml deleted file mode 100644 index 45e511a..0000000 --- a/training/checkpoint/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "apollo-checkpoint" -version = "0.1.0" -edition = "2024" - -[dependencies] -memmap2 = "0.9" -safetensors = "0.5" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -anyhow = "1" -clap = { version = "4", features = ["derive"] } diff --git a/training/checkpoint/src/main.rs b/training/checkpoint/src/main.rs deleted file mode 100644 index 1ebd0df..0000000 --- a/training/checkpoint/src/main.rs +++ /dev/null @@ -1,265 +0,0 @@ -// apollo-checkpoint — Sync live GPU weights back to model files on disk. -// -// mmaps the model's safetensors files, reads live weights from GPU via -// Python helper (CUDA IPC handles), compares block by block, and memcpys -// only changed regions back into the mmap. For small behavioral training -// steps, this turns a 54GB write into a few hundred MB. -// -// The model files on disk are the checkpoint. No separate checkpoint -// directory — just keep the model up to date. -// -// Usage: -// apollo-checkpoint sync \ -// --handles /tmp/vllm_weight_handles.pt \ -// --model-dir /path/to/Qwen3.5-27B -// -// Runs every 10 minutes via cron. Daily rsync to moria. - -use anyhow::{Context, Result, bail}; -use clap::{Parser, Subcommand}; -use memmap2::MmapMut; -use std::collections::HashMap; -use std::fs; -use std::path::{Path, PathBuf}; -use std::process::Command; - -#[derive(Parser)] -#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")] -struct Cli { - #[command(subcommand)] - command: Cmd, -} - -#[derive(Subcommand)] -enum Cmd { - /// Sync live GPU weights back to model safetensors files - Sync { - /// Path to vLLM weight IPC handles - #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] - handles: PathBuf, - - /// Model directory containing safetensors files - #[arg(long)] - model_dir: PathBuf, - - /// Block size for diffing (bytes) - #[arg(long, default_value_t = 4096)] - block_size: usize, - }, -} - -/// Dump live GPU weights to a flat binary file, ordered by safetensors -/// file and offset to match the on-disk layout. -/// -/// Returns a map of (safetensors filename, tensor name) → raw bytes. -fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result>> { - let dump_path = output_dir.join(".live_dump.bin"); - let index_path = output_dir.join(".live_dump.json"); - - let status = Command::new("python3") - .arg("-c") - .arg(format!(r#" -import torch, json - -handles = torch.load("{handles}", weights_only=False) -index = {{}} -offset = 0 - -with open("{dump}", "wb") as f: - for name in sorted(handles.keys()): - info = handles[name] - func, args = info["handle"] - tensor = func(*args) - data = tensor.contiguous().cpu().numpy().tobytes() - f.write(data) - index[name] = {{"offset": offset, "size": len(data)}} - offset += len(data) - -with open("{index}", "w") as f: - json.dump(index, f) - -print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") -"#, - handles = handles_path.display(), - dump = dump_path.display(), - index = index_path.display(), - )) - .status() - .context("Failed to run Python weight dump")?; - - if !status.success() { - bail!("Python weight dump failed"); - } - - let index_str = fs::read_to_string(&index_path)?; - let index: HashMap = serde_json::from_str(&index_str)?; - let dump_data = fs::read(&dump_path)?; - - let mut result = HashMap::new(); - for (name, entry) in &index { - result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec()); - } - - // Clean up temp files - let _ = fs::remove_file(&dump_path); - let _ = fs::remove_file(&index_path); - - Ok(result) -} - -#[derive(serde::Deserialize)] -struct DumpEntry { - offset: usize, - size: usize, -} - -/// Read the safetensors index to map parameter names to files. -fn read_safetensors_index(model_dir: &Path) -> Result> { - let index_path = model_dir.join("model.safetensors.index.json"); - if !index_path.exists() { - // Single file model - return Ok(HashMap::new()); - } - - let index_str = fs::read_to_string(&index_path)?; - let index: serde_json::Value = serde_json::from_str(&index_str)?; - let weight_map = index["weight_map"] - .as_object() - .context("No weight_map in index")?; - - let mut result = HashMap::new(); - for (name, file) in weight_map { - result.insert(name.clone(), file.as_str().unwrap().to_string()); - } - Ok(result) -} - -/// Sync changed blocks from live weights into a mmap'd safetensors file. -/// Returns (total_bytes_compared, bytes_changed). -fn sync_tensors_to_file( - file_path: &Path, - tensors: &[(String, Vec)], - block_size: usize, -) -> Result<(usize, usize)> { - use safetensors::SafeTensors; - - let file = fs::OpenOptions::new() - .read(true) - .write(true) - .open(file_path) - .with_context(|| format!("Failed to open {}", file_path.display()))?; - - let mut mmap = unsafe { MmapMut::map_mut(&file)? }; - - // Parse safetensors header to find tensor offsets - let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize; - let header_json: serde_json::Value = - serde_json::from_slice(&mmap[8..8 + header_size])?; - let data_start = 8 + header_size; - - let mut total_compared = 0usize; - let mut total_changed = 0usize; - - for (name, live_data) in tensors { - let meta = match header_json.get(name) { - Some(m) => m, - None => { - eprintln!(" Warning: {} not found in {}", name, file_path.display()); - continue; - } - }; - - let offsets = meta["data_offsets"].as_array().unwrap(); - let start = data_start + offsets[0].as_u64().unwrap() as usize; - let end = data_start + offsets[1].as_u64().unwrap() as usize; - let disk_data = &mmap[start..end]; - - if disk_data.len() != live_data.len() { - eprintln!(" Warning: size mismatch for {}: disk={} live={}", - name, disk_data.len(), live_data.len()); - continue; - } - - // Diff block by block, memcpy only changed blocks - let mut offset = 0; - while offset < disk_data.len() { - let block_end = (offset + block_size).min(disk_data.len()); - total_compared += block_end - offset; - - if disk_data[offset..block_end] != live_data[offset..block_end] { - mmap[start + offset..start + block_end] - .copy_from_slice(&live_data[offset..block_end]); - total_changed += block_end - offset; - } - offset = block_end; - } - } - - mmap.flush()?; - Ok((total_compared, total_changed)) -} - -fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> { - if !handles.exists() { - bail!("Weight handles not found: {}. Is vLLM running with the export hook?", - handles.display()); - } - - eprintln!("Dumping live weights from GPU..."); - let live_weights = dump_live_weights(&handles, &model_dir)?; - eprintln!(" {} tensors dumped", live_weights.len()); - - // Map parameter names to safetensors files - let weight_map = read_safetensors_index(&model_dir)?; - - // Group tensors by safetensors file - let mut by_file: HashMap)>> = HashMap::new(); - for (name, data) in live_weights { - let file = weight_map - .get(&name) - .cloned() - .unwrap_or_else(|| "model.safetensors".to_string()); - by_file.entry(file).or_default().push((name, data)); - } - - let mut total_compared = 0usize; - let mut total_changed = 0usize; - - for (filename, tensors) in &by_file { - let file_path = model_dir.join(filename); - if !file_path.exists() { - eprintln!(" Warning: {} not found, skipping", filename); - continue; - } - - let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?; - total_compared += compared; - total_changed += changed; - - if changed > 0 { - eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6); - } - } - - if total_changed == 0 { - eprintln!("No changes — model files are up to date"); - } else { - eprintln!( - "Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)", - total_changed as f64 / 1e6, - total_compared as f64 / 1e9, - total_changed as f64 / total_compared as f64 * 100.0, - ); - } - - Ok(()) -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - match cli.command { - Cmd::Sync { handles, model_dir, block_size } => { - cmd_sync(handles, model_dir, block_size) - } - } -} diff --git a/training/export_weights.py b/training/export_weights.py deleted file mode 100644 index ef2f608..0000000 --- a/training/export_weights.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -"""Export vLLM's live model weight IPC handles for the training process. - -Connects to a running vLLM instance, iterates over model parameters, -and exports CUDA IPC handles that allow another process to access the -same GPU memory without copying. - -Usage: - # Run after vLLM is serving: - python3 export_weights.py --output /tmp/vllm_weight_handles.pt - - # Or via vLLM's API (future): - curl -X POST http://localhost:8000/export_weights -""" - -import argparse -import sys -import torch -from pathlib import Path - - -def export_from_model(model, output_path: str): - """Export IPC handles for all model parameters.""" - from torch.multiprocessing.reductions import reduce_tensor - - handles = {} - total_bytes = 0 - - for name, param in model.named_parameters(): - handle = reduce_tensor(param.data) - handles[name] = { - 'handle': handle, - 'shape': list(param.shape), - 'dtype': str(param.dtype), - } - param_bytes = param.nelement() * param.element_size() - total_bytes += param_bytes - - torch.save(handles, output_path) - - n_params = len(handles) - print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)") - print(f"Saved to {output_path}") - return handles - - -def main(): - parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles") - parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt", - help="Output path for IPC handles") - parser.add_argument("--vllm-pid", type=int, default=None, - help="vLLM worker PID (auto-detected if not specified)") - args = parser.parse_args() - - # For now: load the model directly and export. - # TODO: connect to running vLLM process instead. - print("Note: This currently loads the model separately.") - print("Full integration will export from the running vLLM process.") - print() - - # Detect model path from running vLLM - import subprocess - result = subprocess.run( - ['ps', 'aux'], capture_output=True, text=True - ) - model_path = None - for line in result.stdout.split('\n'): - if 'vllm' in line and '--model' in line: - parts = line.split() - for i, p in enumerate(parts): - if p == '--model' and i + 1 < len(parts): - model_path = parts[i + 1] - break - # Also check model_tag format - if p.startswith('--model='): - model_path = p.split('=', 1)[1] - break - - if model_path: - print(f"Detected vLLM model: {model_path}") - else: - print("Could not detect running vLLM model. Specify manually.") - sys.exit(1) - - -if __name__ == '__main__': - main() diff --git a/training/first_training_step.py b/training/first_training_step.py deleted file mode 100644 index 0e6ffd8..0000000 --- a/training/first_training_step.py +++ /dev/null @@ -1,215 +0,0 @@ -#!/usr/bin/env python3 -"""First real Apollo training step — ready for Kent to run. - -This script: -1. Imports vLLM's live weights via CUDA IPC -2. Constructs HF model with shared memory views -3. Runs ONE forward+backward on a real training example -4. Applies ONE Apollo optimizer step -5. Verifies vLLM still works after the update - -The training example is from March 30: Kent said "use vLLM's code" -and the model should have accepted instead of suggesting alternatives. - -Usage: - source ~/training-env/bin/activate - python3 first_training_step.py [--dry-run] -""" - -import argparse -import sys -import time - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoConfig, AutoTokenizer -from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM - -sys.path.insert(0, '.') -from weight_mapping import vllm_to_hf_views -from apollo_mini import Apollo - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--dry-run', action='store_true', - help="Run forward+backward but don't apply the optimizer step") - parser.add_argument('--lr', type=float, default=1e-5, - help="Learning rate (default: 1e-5 = conservative)") - parser.add_argument('--rank', type=int, default=256) - parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt') - parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B') - args = parser.parse_args() - - print("=== First Apollo Training Step ===\n") - - # 1. Import vLLM weights - print("1. Importing vLLM weights via CUDA IPC...") - handles = torch.load(args.handles, weights_only=False) - vllm_params = {} - for name, info in handles.items(): - func, args_h = info['handle'] - vllm_params[name] = func(*args_h) - print(f" {len(vllm_params)} parameters imported") - - # 2. Map to HF layout - print("2. Mapping to HF layout (zero-copy views)...") - hf_params = vllm_to_hf_views(vllm_params) - - # 3. Create HF model - print("3. Creating HF model with shared weights...") - config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) - with torch.device('meta'): - model = Qwen3_5ForCausalLM(config.text_config) - - replaced = 0 - for name, param in list(model.named_parameters()): - if name in hf_params: - parts = name.split('.') - parent = model - for part in parts[:-1]: - parent = getattr(parent, part) - setattr(parent, parts[-1], - nn.Parameter(hf_params[name], requires_grad=True)) - replaced += 1 - print(f" {replaced} parameters replaced with vLLM memory views") - - # 4. Load tokenizer - print("4. Loading tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) - - # 5. Construct training example - print("5. Constructing training example...") - - # Context: conversation where Kent says to use vLLM's code - # Target: the response that accepts the direction - context = ( - "<|im_start|>user\n" - "vllm has a fused kernel already, right?<|im_end|>\n" - "<|im_start|>assistant\n" - "Yeah — vLLM has `gdn_attention_core` which is a custom op " - "that does the whole GDN layer's core in one dispatch.<|im_end|>\n" - "<|im_start|>user\n" - "Why wouldn't we just use that?<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - # The CORRECT response (accept direction, don't suggest alternatives) - continuation = ( - "We should. Let me pull in their kernel and wire it into " - "our Rust orchestration. Which file should I start with?" - ) - - context_ids = tokenizer.encode(context, add_special_tokens=False) - continuation_ids = tokenizer.encode(continuation, add_special_tokens=False) - all_ids = context_ids + continuation_ids - context_len = len(context_ids) - - print(f" Context: {context_len} tokens") - print(f" Continuation: {len(continuation_ids)} tokens") - print(f" Total: {len(all_ids)} tokens") - - input_ids = torch.tensor([all_ids], device='cuda:0') - - # 6. Initialize Apollo optimizer - print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...") - apollo_params = [] - standard_params = [] - for p in model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= args.rank: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - optimizer = Apollo(groups, lr=args.lr, rank=args.rank) - print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard") - - # 7. Forward pass - print("7. Forward pass...") - model.train() - optimizer.zero_grad() - - # Context-frozen: no grad for context, grad for continuation - with torch.no_grad(): - ctx_output = model(input_ids[:, :context_len], use_cache=True) - past_kv = ctx_output.past_key_values - - with torch.enable_grad(): - output = model(input_ids[:, context_len:], - past_key_values=past_kv, use_cache=False) - logits = output.logits - # Shift for next-token prediction - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - loss = F.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) - print(f" Loss: {loss.item():.4f}") - - # 8. Backward pass - print("8. Backward pass...") - loss.backward() - n_grads = sum(1 for p in model.parameters() if p.grad is not None) - print(f" {n_grads} parameters have gradients") - - # 9. Apollo step (or dry run) - if args.dry_run: - print("\n9. DRY RUN — skipping optimizer step") - print(" (run without --dry-run to apply the update)") - else: - print("9. Applying Apollo optimizer step...") - # Record a few weight norms before - sample_norms_before = {} - for name, p in model.named_parameters(): - if 'layers.0.' in name and p.grad is not None: - sample_norms_before[name] = p.data.norm().item() - - optimizer.step() - - # Check weight changes - print(" Weight changes (layer 0):") - for name, before in sample_norms_before.items(): - p = dict(model.named_parameters())[name] - after = p.data.norm().item() - delta = abs(after - before) - pct = delta / before * 100 if before > 0 else 0 - print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)") - - optimizer.zero_grad() - - # 10. Verify vLLM still works - print("\n10. Verifying vLLM still serves...") - import subprocess - result = subprocess.run( - ['curl', '-s', '--max-time', '30', - '-X', 'POST', 'http://localhost:8000/v1/chat/completions', - '-H', 'Content-Type: application/json', - '-H', 'Authorization: Bearer bcachefs-agents-2026', - '-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'], - capture_output=True, text=True, timeout=45 - ) - if result.returncode == 0 and 'choices' in result.stdout: - print(" vLLM still serving ✓") - else: - print(" WARNING: vLLM may not be responding") - print(f" stdout: {result.stdout[:200]}") - - print("\n=== COMPLETE ===") - if args.dry_run: - print("Run without --dry-run to apply the first real training step.") - else: - print("First Apollo training step applied to vLLM's live weights.") - print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - -if __name__ == '__main__': - main() diff --git a/training/pyproject.toml b/training/pyproject.toml new file mode 100644 index 0000000..7cf0581 --- /dev/null +++ b/training/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "apollo-plugin" +version = "0.1.0" +description = "Apollo training plugin for vLLM" +requires-python = ">=3.10" +dependencies = [ + "torch", + "aiohttp", + "safetensors", + "pyzmq", +] + +[project.optional-dependencies] +dev = ["pytest"] + +[project.entry-points."vllm.general_plugins"] +apollo = "apollo_plugin:register" + +[project.scripts] +apollo-checkpoint = "apollo_plugin.checkpoint_sync:main" +apollo-worker = "apollo_plugin.training_worker:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["apollo_plugin*"] diff --git a/training/start_vllm_with_apollo.sh b/training/start_vllm_with_apollo.sh deleted file mode 100755 index 98dfedb..0000000 --- a/training/start_vllm_with_apollo.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# Start vLLM with Apollo weight export hook. -# -# The hook patches vLLM's model runner to export CUDA IPC handles -# after loading, so the Apollo training process can share the same -# GPU memory. - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" - -exec python3 -c " -import sys -sys.path.insert(0, '$SCRIPT_DIR') -import vllm_export_hook # patches model runner before vLLM loads - -sys.argv = ['vllm'] + sys.argv[1:] -from vllm.entrypoints.cli.main import main -main() -" serve "$@" diff --git a/training/train.py b/training/train.py deleted file mode 100644 index a5fbe2c..0000000 --- a/training/train.py +++ /dev/null @@ -1,269 +0,0 @@ -#!/usr/bin/env python3 -"""Nightly training process for Apollo-Mini fine-tuning. - -Imports vLLM's model weights via CUDA IPC, runs context-frozen -training on flagged conversation segments, saves updated checkpoint. - -Usage: - python3 train.py \ - --weights /tmp/vllm_weight_handles.pt \ - --examples training-examples.jsonl \ - --checkpoint-dir checkpoints/ \ - --lr 1e-5 -""" - -import argparse -import json -import os -import sys -import time -from datetime import datetime -from pathlib import Path - -import torch -from safetensors.torch import save_file - -from apollo_mini import ApolloMini - - -def import_weights(handle_path: str) -> dict[str, torch.Tensor]: - """Import weight tensors from CUDA IPC handles.""" - handles = torch.load(handle_path, weights_only=False) - params = {} - for name, info in handles.items(): - func, args = info['handle'] - tensor = func(*args) - params[name] = tensor - return params - - -def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]: - """Split parameters into Apollo-Mini and standard groups. - - Apollo-Mini needs 2D+ matrices with min dimension >= 2. - Small tensors (norms, biases, conv1d 3D weights) use standard Adam. - """ - apollo_params = [] - standard_params = [] - - for name, p in params.items(): - p.requires_grad_(True) - if p.ndim >= 2 and min(p.shape) >= 2: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({ - 'params': apollo_params, - 'name': 'apollo', - }) - if standard_params: - groups.append({ - 'params': standard_params, - 'name': 'standard', - }) - - n_apollo = sum(p.nelement() for p in apollo_params) - n_standard = sum(p.nelement() for p in standard_params) - print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M") - return groups - - -def forward_pass(params, input_ids, context_len, device): - """Run context-frozen forward pass. - - Args: - params: dict of name -> tensor (shared with vLLM) - input_ids: full sequence [1, seq_len] - context_len: number of context tokens (no gradient) - device: CUDA device - - Returns: - logits for decision tokens, target ids for loss - """ - # TODO: Build proper forward model matching vLLM's weight layout. - # For now this is a placeholder — the real implementation needs - # to replicate vLLM's model architecture (merged projections, - # GDN recurrence, full attention, MLP) using the shared weights. - raise NotImplementedError( - "Forward model not yet implemented. " - "Need to build a model that matches vLLM's merged weight layout " - "(MergedColumnParallelLinear for qkvz/ba/gate_up, " - "RowParallelLinear for out_proj/down) and computes the same " - "forward pass with autograd enabled." - ) - - -def save_checkpoint(params: dict[str, torch.Tensor], - checkpoint_dir: str, - config_path: str = None): - """Save model checkpoint in HuggingFace safetensors format. - - Saves weights split across shards matching the original model layout, - archives the previous checkpoint, and updates the 'latest' symlink. - """ - date_str = datetime.now().strftime("%Y-%m-%d") - out_dir = Path(checkpoint_dir) / date_str - out_dir.mkdir(parents=True, exist_ok=True) - - # Save all weights in a single safetensors file for now. - # TODO: split across shards matching HF model index for large models. - tensors = {} - for name, param in params.items(): - tensors[name] = param.data.contiguous().cpu() - - save_path = out_dir / "model.safetensors" - save_file(tensors, str(save_path)) - print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)") - - # Copy config files if provided - if config_path: - import shutil - config_dir = Path(config_path) - for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', - 'special_tokens_map.json', 'generation_config.json']: - src = config_dir / f - if src.exists(): - shutil.copy2(src, out_dir / f) - - # Update latest symlink - latest = Path(checkpoint_dir) / "latest" - if latest.is_symlink(): - latest.unlink() - latest.symlink_to(date_str) - print(f"Updated {latest} -> {date_str}") - - return str(out_dir) - - -def train_step(params, example, optimizer, device, log_entries): - """Run one training step on a single example. - - Args: - params: dict of name -> tensor - example: dict with 'input_ids', 'context_len', 'target_ids' - optimizer: ApolloMini instance - device: CUDA device - log_entries: list to append log dicts to - - Returns: - loss value - """ - optimizer.zero_grad() - - input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0) - context_len = example['context_len'] - - # Forward pass (context frozen, decision tokens with grad) - logits, targets = forward_pass(params, input_ids, context_len, device) - - # Cross-entropy loss on decision tokens - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.shape[-1]), - targets.view(-1), - ) - - # Backward - loss.backward() - - # Compute gradient stats before optimizer step - total_grad_norm = 0.0 - for p in params.values(): - if p.grad is not None: - total_grad_norm += p.grad.norm().item() ** 2 - total_grad_norm = total_grad_norm ** 0.5 - - # Optimizer step - optimizer.step() - - # Log - log_entries.append({ - 'example_id': example.get('id', 'unknown'), - 'loss': loss.item(), - 'grad_norm': total_grad_norm, - 'timestamp': datetime.now().isoformat(), - }) - - return loss.item() - - -def main(): - parser = argparse.ArgumentParser(description="Apollo-Mini training") - parser.add_argument("--weights", required=True, - help="Path to exported weight IPC handles") - parser.add_argument("--examples", required=True, - help="Path to training examples JSONL") - parser.add_argument("--checkpoint-dir", default="checkpoints", - help="Directory for saving checkpoints") - parser.add_argument("--config-path", default=None, - help="Path to model config files (for checkpoint)") - parser.add_argument("--lr", type=float, default=1e-5, - help="Learning rate") - parser.add_argument("--warmup-steps", type=int, default=10, - help="Learning rate warmup steps") - parser.add_argument("--weight-decay", type=float, default=0.01) - parser.add_argument("--dry-run", action="store_true", - help="Load weights and validate, don't train") - args = parser.parse_args() - - print(f"Apollo-Mini Training") - print(f" weights: {args.weights}") - print(f" examples: {args.examples}") - print(f" lr: {args.lr}") - print() - - # Import weights - print("Importing weights via CUDA IPC...") - params = import_weights(args.weights) - print(f" {len(params)} parameters imported") - - # Make parameter groups - param_groups = make_param_groups(params) - - # Initialize optimizer - optimizer = ApolloMini(param_groups, lr=args.lr, - weight_decay=args.weight_decay, - warmup_steps=args.warmup_steps) - print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - if args.dry_run: - print("\nDry run — weights imported and validated successfully.") - return - - # Load training examples - examples = [] - with open(args.examples) as f: - for line in f: - examples.append(json.loads(line)) - print(f" {len(examples)} training examples") - - # Training loop - log_entries = [] - print(f"\nTraining...") - t0 = time.time() - - for i, example in enumerate(examples): - loss = train_step(params, example, optimizer, 'cuda:0', log_entries) - print(f" [{i+1}/{len(examples)}] loss={loss:.4f}") - - elapsed = time.time() - t0 - print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s") - print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - # Save checkpoint - print("\nSaving checkpoint...") - save_checkpoint(params, args.checkpoint_dir, args.config_path) - - # Save training log - date_str = datetime.now().strftime("%Y-%m-%d") - log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl" - with open(log_path, 'w') as f: - for entry in log_entries: - f.write(json.dumps(entry) + '\n') - print(f"Training log: {log_path}") - - -if __name__ == '__main__': - main() diff --git a/training/training_example.py b/training/training_example.py deleted file mode 100644 index b5779e0..0000000 --- a/training/training_example.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Training example construction and tokenization. - -Takes raw conversation context + improved continuation, produces -tokenized tensors ready for context-frozen forward+backward. -""" - -import json -from dataclasses import dataclass, field -from pathlib import Path - -import torch -from transformers import AutoTokenizer - - -@dataclass -class TrainingExample: - """A single training example for context-frozen training.""" - id: str - context: str # conversation up to decision point - continuation: str # the better response - reason: str = "" # why this is a training target - memories: list[str] = field(default_factory=list) # memories that were in context - - # Computed after tokenization - input_ids: torch.Tensor | None = None - context_len: int = 0 - total_len: int = 0 - - def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"): - """Tokenize context + continuation into training-ready tensors. - - The chat template is applied to make the token distribution - match what the model sees during inference. - """ - # Build messages for context (everything up to the decision) - # The context should already be in chat format - context_ids = tokenizer.encode(self.context, add_special_tokens=False) - continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False) - - self.context_len = len(context_ids) - self.total_len = len(context_ids) + len(continuation_ids) - - if self.total_len > max_len: - # Truncate context from the left, keep continuation intact - excess = self.total_len - max_len - context_ids = context_ids[excess:] - self.context_len = len(context_ids) - self.total_len = len(context_ids) + len(continuation_ids) - - all_ids = context_ids + continuation_ids - self.input_ids = torch.tensor(all_ids, device=device) - return self - - def to_dict(self) -> dict: - return { - 'id': self.id, - 'context': self.context, - 'continuation': self.continuation, - 'reason': self.reason, - 'memories': self.memories, - 'context_len': self.context_len, - 'total_len': self.total_len, - } - - @classmethod - def from_dict(cls, d: dict) -> 'TrainingExample': - return cls( - id=d['id'], - context=d['context'], - continuation=d['continuation'], - reason=d.get('reason', ''), - memories=d.get('memories', []), - ) - - -def load_examples(path: str) -> list[TrainingExample]: - """Load training examples from JSONL file.""" - examples = [] - with open(path) as f: - for line in f: - if line.strip(): - examples.append(TrainingExample.from_dict(json.loads(line))) - return examples - - -def save_examples(examples: list[TrainingExample], path: str): - """Save training examples to JSONL file.""" - with open(path, 'w') as f: - for ex in examples: - f.write(json.dumps(ex.to_dict()) + '\n') - - -class ExampleTokenizer: - """Handles tokenization with the model's chat template. - - Applies the same chat template that vLLM uses during inference, - so the token distribution matches what the model expects. - """ - - def __init__(self, model_path: str): - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True) - - def prepare_example(self, example: TrainingExample, - max_len: int = 8192, - device: str = "cuda:0") -> TrainingExample: - """Tokenize an example using the chat template. - - For proper training, the context should be formatted exactly - as vLLM would format it — with chat template applied. - """ - # Apply chat template to get the exact token sequence - # the model would see during inference - # - # Context: everything up to the decision point - # Continuation: the improved response - # - # We tokenize them separately to know where context ends - # and continuation begins. - context_ids = self.tokenizer.encode( - example.context, add_special_tokens=True) - continuation_ids = self.tokenizer.encode( - example.continuation, add_special_tokens=False) - - example.context_len = len(context_ids) - example.total_len = len(context_ids) + len(continuation_ids) - - if example.total_len > max_len: - excess = example.total_len - max_len - context_ids = context_ids[excess:] - example.context_len = len(context_ids) - example.total_len = example.context_len + len(continuation_ids) - - all_ids = context_ids + continuation_ids - example.input_ids = torch.tensor(all_ids, device=device) - return example - - def prepare_from_messages(self, example_id: str, - messages: list[dict], - decision_idx: int, - better_response: str, - reason: str = "", - memories: list[str] | None = None, - max_len: int = 8192, - device: str = "cuda:0") -> TrainingExample: - """Build a training example from a chat message list. - - Args: - example_id: unique identifier - messages: list of {"role": ..., "content": ...} dicts - decision_idx: index of the assistant message to replace - better_response: the improved response text - reason: why this is a training target - memories: memory keys that were in context - max_len: maximum sequence length - device: target device - - Returns: - Tokenized TrainingExample - """ - # Context: all messages up to (not including) the decision - context_messages = messages[:decision_idx] - context_text = self.tokenizer.apply_chat_template( - context_messages, tokenize=False, add_generation_prompt=True) - - # Build the example - example = TrainingExample( - id=example_id, - context=context_text, - continuation=better_response, - reason=reason, - memories=memories or [], - ) - - return self.prepare_example(example, max_len=max_len, device=device)