diff --git a/Cargo.lock b/Cargo.lock index dfca607..eb53ed5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,12 +492,11 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", - "json-five", + "json5", "libc", "log", "memchr", "memmap2", - "notify-debouncer-mini", "paste", "peg", "ratatui", @@ -1089,15 +1088,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "fsevent-sys" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" -dependencies = [ - "libc", -] - [[package]] name = "futures" version = "0.3.32" @@ -1463,26 +1453,6 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" -[[package]] -name = "inotify" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199" -dependencies = [ - "bitflags 2.11.0", - "inotify-sys", - "libc", -] - -[[package]] -name = "inotify-sys" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" -dependencies = [ - "libc", -] - [[package]] name = "instability" version = "0.3.12" @@ -1561,16 +1531,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "json-five" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "865f2d01a4549c1fd8c60640c03ae5249eb374cd8cde8b905628d4b1af95c87c" -dependencies = [ - "serde", - "unicode-general-category", -] - [[package]] name = "json5" version = "1.3.1" @@ -1592,26 +1552,6 @@ dependencies = [ "thiserror 2.0.18", ] -[[package]] -name = "kqueue" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a" -dependencies = [ - "kqueue-sys", - "libc", -] - -[[package]] -name = "kqueue-sys" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" -dependencies = [ - "bitflags 1.3.2", - "libc", -] - [[package]] name = "lab" version = "0.11.0" @@ -1834,45 +1774,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "notify" -version = "8.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" -dependencies = [ - "bitflags 2.11.0", - "fsevent-sys", - "inotify", - "kqueue", - "libc", - "log", - "mio", - "notify-types", - "walkdir", - "windows-sys 0.60.2", -] - -[[package]] -name = "notify-debouncer-mini" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17849edfaabd9a5fef1c606d99cfc615a8e99f7ac4366406d86c7942a3184cf2" -dependencies = [ - "log", - "notify", - "notify-types", - "tempfile", -] - -[[package]] -name = "notify-types" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "num-conv" version = "0.2.1" @@ -3483,12 +3384,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" -[[package]] -name = "unicode-general-category" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b993bddc193ae5bd0d623b49ec06ac3e9312875fdae725a975c51db1cc1677f" - [[package]] name = "unicode-ident" version = "1.0.24" @@ -3899,16 +3794,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", + "windows-targets", ] [[package]] @@ -3926,31 +3812,14 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] @@ -3959,96 +3828,48 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 7cdf851..c253bd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,8 +29,7 @@ log = "0.4" serde = { version = "1", features = ["derive"] } serde_json = "1" -json-five = "0.3" -notify-debouncer-mini = "0.7" +json5 = "1.3" ratatui = { version = "0.30", features = ["unstable-rendered-line-info"] } tui-markdown = { git = "https://github.com/koverstreet/tui-markdown", subdirectory = "tui-markdown" } diff --git a/src/agent/context.rs b/src/agent/context.rs index 37dbf48..c43c023 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -92,7 +92,7 @@ pub struct NodeLeaf { body: NodeBody, #[serde(skip)] token_ids: Vec, - timestamp: DateTime, + timestamp: Option>, } impl<'de> Deserialize<'de> for NodeLeaf { @@ -100,7 +100,7 @@ impl<'de> Deserialize<'de> for NodeLeaf { #[derive(Deserialize)] struct Raw { body: NodeBody, - timestamp: DateTime, + timestamp: Option>, } let raw = Raw::deserialize(deserializer)?; let token_ids = if raw.body.is_prompt_visible() { @@ -119,7 +119,6 @@ pub enum AstNode { Branch { role: Role, children: Vec, - timestamp: DateTime, /// Per-response memory attribution from full scoring matrix. /// Maps memory key → divergence score for this response. #[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")] @@ -253,18 +252,18 @@ impl NodeLeaf { } else { vec![] }; - Self { body, token_ids, timestamp: Utc::now() } + Self { body, token_ids, timestamp: None } } pub fn with_timestamp(mut self, ts: DateTime) -> Self { - self.timestamp = ts; + self.timestamp = Some(ts); self } pub fn body(&self) -> &NodeBody { &self.body } pub fn token_ids(&self) -> &[u32] { &self.token_ids } pub fn tokens(&self) -> usize { self.token_ids.len() } - pub fn timestamp(&self) -> DateTime { self.timestamp } + pub fn timestamp(&self) -> Option> { self.timestamp } } impl AstNode { @@ -308,14 +307,13 @@ impl AstNode { // -- Branch constructors -------------------------------------------------- pub fn branch(role: Role, children: Vec) -> Self { - Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() } + Self::Branch { role, children, memory_scores: Default::default() } } pub fn system_msg(text: impl Into) -> Self { Self::Branch { role: Role::System, children: vec![Self::content(text)], - timestamp: Utc::now(), memory_scores: Default::default(), } } @@ -324,7 +322,6 @@ impl AstNode { Self::Branch { role: Role::User, children: vec![Self::content(text)], - timestamp: Utc::now(), memory_scores: Default::default(), } } @@ -341,10 +338,9 @@ impl AstNode { }; Self::Leaf(NodeLeaf { token_ids, ..leaf }) } - Self::Branch { role, children, timestamp, memory_scores } => Self::Branch { + Self::Branch { role, children, memory_scores, .. } => Self::Branch { role, children: children.into_iter().map(|c| c.retokenize()).collect(), - timestamp, memory_scores, }, } @@ -352,8 +348,8 @@ impl AstNode { pub fn with_timestamp(mut self, ts: DateTime) -> Self { match &mut self { - Self::Leaf(leaf) => leaf.timestamp = ts, - Self::Branch { timestamp, .. } => *timestamp = ts, + Self::Leaf(leaf) => leaf.timestamp = Some(ts), + Self::Branch { .. } => {} } self } @@ -374,7 +370,7 @@ impl AstNode { /// Short label for the UI. pub fn label(&self) -> String { - let app = crate::config::app(); + let cfg = crate::config::get(); match self { Self::Branch { role, children, .. } => { let preview = children.first() @@ -383,8 +379,8 @@ impl AstNode { .unwrap_or_default(); match role { Role::System => "system".into(), - Role::User => format!("{}: {}", app.user_name, preview), - Role::Assistant => format!("{}: {}", app.assistant_name, preview), + Role::User => format!("{}: {}", cfg.user_name, preview), + Role::Assistant => format!("{}: {}", cfg.assistant_name, preview), } } Self::Leaf(leaf) => match &leaf.body { @@ -992,10 +988,7 @@ impl ContextState { } pub fn context_window() -> usize { - let app = crate::config::app(); - app.backends.get(&app.default_backend) - .and_then(|b| b.context_window) - .unwrap_or(128_000) + crate::config::get().api_context_window } pub fn context_budget_tokens() -> usize { @@ -1347,35 +1340,4 @@ mod tests { assert_token_invariants(node); assert!(node.tokens() > 0); } - - // -- Timestamp deserialization tests ------------------------------------------ - - #[test] - fn test_timestamp_null_rejected() { - // Missing/null timestamps used to be accepted via a lenient - // deserialize fallback. Post-migration the schema is strict. - let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#; - assert!(serde_json::from_str::(json).is_err()); - } - - #[test] - fn test_timestamp_missing_rejected() { - let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#; - assert!(serde_json::from_str::(json).is_err()); - } - - #[test] - fn test_branch_timestamp_missing_rejected() { - let json = r#"{"Branch":{"role":"User","children":[]}}"#; - assert!(serde_json::from_str::(json).is_err()); - } - - #[test] - fn test_timestamp_present_accepted() { - let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#; - let node: AstNode = serde_json::from_str(json).unwrap(); - let leaf = node.leaf().unwrap(); - assert_eq!(leaf.timestamp().to_rfc3339(), - "2026-04-16T12:00:00+00:00"); - } } diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 5368db6..db1bf39 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -139,6 +139,7 @@ impl DispatchState { pub struct Agent { pub client: ApiClient, pub app_config: crate::config::AppConfig, + pub prompt_file: String, pub session_id: String, pub context: crate::Mutex, pub state: crate::Mutex, @@ -188,6 +189,7 @@ impl Agent { client: ApiClient, personality: Vec<(String, String)>, app_config: crate::config::AppConfig, + prompt_file: String, conversation_log: Option, active_tools: tools::ActiveTools, agent_tools: Vec, @@ -218,6 +220,7 @@ impl Agent { let agent = Arc::new(Self { client, app_config, + prompt_file, session_id, context: crate::Mutex::new(context), state: crate::Mutex::new(AgentState { @@ -256,6 +259,7 @@ impl Agent { Arc::new(Self { client: self.client.clone(), app_config: self.app_config.clone(), + prompt_file: self.prompt_file.clone(), session_id: self.session_id.clone(), context: crate::Mutex::new(ctx), state: crate::Mutex::new(AgentState { diff --git a/src/agent/oneshot.rs b/src/agent/oneshot.rs index 8bc8b53..2fce906 100644 --- a/src/agent/oneshot.rs +++ b/src/agent/oneshot.rs @@ -183,8 +183,8 @@ fn resolve_prompt( state: &std::collections::BTreeMap, recently_written: &[String], ) -> String { - let template = template.replace("{assistant_name}", - &crate::config::app().assistant_name); + let cfg = crate::config::get(); + let template = template.replace("{assistant_name}", &cfg.assistant_name); let mut result = String::with_capacity(template.len()); let mut rest = template.as_str(); while let Some(start) = rest.find("{{") { @@ -247,20 +247,25 @@ impl AutoAgent { &mut self, bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>, ) -> Result<(), String> { - // Load system prompt + identity from config. + let config = crate::config::get(); + let base_url = config.api_base_url.as_deref().unwrap_or(""); + let api_key = config.api_key.as_deref().unwrap_or(""); + let model = config.api_model.as_deref().unwrap_or(""); + if base_url.is_empty() || model.is_empty() { + return Err("API not configured (no base_url or model)".to_string()); + } + let client = super::api::ApiClient::new(base_url, api_key, model); + + // Load system prompt + identity from config let cli = crate::user::CliArgs::default(); let (app, _) = crate::config::load_app(&cli) .map_err(|e| format!("config: {}", e))?; - let resolved = app.resolve_model(&app.default_backend) - .map_err(|e| format!("API not configured: {}", e))?; - let client = super::api::ApiClient::new( - &resolved.api_base, &resolved.api_key, &resolved.model_id); let personality = crate::config::reload_context() .await.map_err(|e| format!("config: {}", e))?; let agent = Agent::new( client, personality, - app, + app, String::new(), None, super::tools::ActiveTools::new(), super::tools::tools(), @@ -492,20 +497,15 @@ pub async fn run_one_agent( .map(|s| s.phase.clone()).collect(); // Bail check: if the agent defines a bail script, run it between steps. - // The script also refreshes our pid-file with the current phase — that's - // how concurrent agents know which phase each of us is in. let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name)); let state_dir_for_bail = state_dir.clone(); + // Find our own pid file so we can pass it to the bail script let our_pid = std::process::id(); let our_pid_file = format!("pid-{}", our_pid); - let step_phases_for_bail = step_phases.clone(); let bail_fn = move |step_idx: usize| -> Result<(), String> { if let Some(ref script) = bail_script { - let phase = step_phases_for_bail.get(step_idx) - .map(String::as_str).unwrap_or(""); let status = std::process::Command::new(script) .arg(&our_pid_file) - .arg(phase) .current_dir(&state_dir_for_bail) .status() .map_err(|e| format!("bail script {:?} failed: {}", script, e))?; diff --git a/src/bin/fix-timestamps.rs b/src/bin/fix-timestamps.rs deleted file mode 100644 index 31a8788..0000000 --- a/src/bin/fix-timestamps.rs +++ /dev/null @@ -1,180 +0,0 @@ -// fix-timestamps: One-off migration for ~/.consciousness/agent-sessions/ -// conversation.jsonl. -// -// Before Branch nodes carried their own timestamps, early entries were -// serialized with missing/null timestamp fields — they deserialize as -// UNIX_EPOCH via the (now-to-be-removed) deserialize_timestamp_or_epoch -// fallback. Training needs every entry to have a unique timestamp to -// dedup already-trained responses. -// -// Walks the file, synthesizes timestamps for any entry stuck at -// UNIX_EPOCH by linear interpolation between surrounding real -// timestamps. For child leaves inside a Branch, derives timestamps -// from the parent with a tiny per-child offset. -// -// SAFETY: reads from argv[1], writes to argv[1].tmp, renames into -// place. Keep a .bak copy before running. -// -// Usage: fix-timestamps - -use std::io::{BufRead, BufReader, BufWriter, Write}; -use std::path::PathBuf; - -use anyhow::{Context, Result}; -use chrono::{DateTime, Duration, Utc}; - -use consciousness::agent::context::AstNode; - -fn main() -> Result<()> { - let path: PathBuf = std::env::args().nth(1) - .context("usage: fix-timestamps ")?.into(); - - let f = std::fs::File::open(&path) - .with_context(|| format!("open {}", path.display()))?; - let reader = BufReader::new(f); - - let mut nodes: Vec = Vec::new(); - for (i, line) in reader.lines().enumerate() { - let line = line?; - if line.trim().is_empty() { continue; } - let node: AstNode = serde_json::from_str(&line) - .with_context(|| format!("line {}: parse", i + 1))?; - nodes.push(node); - } - println!("read {} entries", nodes.len()); - - fix_top_level_timestamps(&mut nodes); - for node in &mut nodes { - propagate_to_children(node); - } - - // Ensure uniqueness — real timestamps can collide when two entries - // were written in the same ns; synthesized ones can also overlap. - // Bump colliding ns by 1 until unique. - let mut seen = std::collections::HashSet::new(); - let mut bumps = 0usize; - for (i, node) in nodes.iter_mut().enumerate() { - let ts = top_ts(node); - assert!(ts > DateTime::::UNIX_EPOCH, - "entry {}: still UNIX_EPOCH", i); - let mut ns = ts.timestamp_nanos_opt().expect("ts in i64 ns range"); - let mut bumped = false; - while !seen.insert(ns) { - ns += 1; - bumped = true; - bumps += 1; - } - if bumped { - set_top_ts(node, DateTime::::from_timestamp_nanos(ns)); - } - } - println!("all {} timestamps real and unique ({} ns bumps)", - nodes.len(), bumps); - - let tmp = path.with_extension("jsonl.tmp"); - { - let f = std::fs::File::create(&tmp) - .with_context(|| format!("create {}", tmp.display()))?; - let mut w = BufWriter::new(f); - for node in &nodes { - serde_json::to_writer(&mut w, node)?; - w.write_all(b"\n")?; - } - w.flush()?; - } - std::fs::rename(&tmp, &path) - .with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?; - println!("wrote {}", path.display()); - - Ok(()) -} - -fn top_ts(node: &AstNode) -> DateTime { - match node { - AstNode::Leaf(leaf) => leaf.timestamp(), - AstNode::Branch { timestamp, .. } => *timestamp, - } -} - -fn set_top_ts(node: &mut AstNode, ts: DateTime) { - match node { - AstNode::Leaf(leaf) => *leaf = leaf.clone().with_timestamp(ts), - AstNode::Branch { timestamp, .. } => *timestamp = ts, - } -} - -/// Fill in missing top-level timestamps. Strategy: -/// - If two real timestamps bracket a run of missing ones, linearly -/// interpolate between them. -/// - If missing ones precede the first real one, back-fill using -/// (first_real - N·1µs). -/// - If missing ones follow the last real one, forward-fill. -/// - If no real timestamps exist at all, synthesize from now() going -/// backwards. -fn fix_top_level_timestamps(nodes: &mut [AstNode]) { - let real: Vec<(usize, DateTime)> = nodes.iter().enumerate() - .filter(|(_, n)| top_ts(n) > DateTime::::UNIX_EPOCH) - .map(|(i, n)| (i, top_ts(n))) - .collect(); - - if real.is_empty() { - let now = Utc::now(); - let len = nodes.len(); - for (i, node) in nodes.iter_mut().enumerate() { - let ts = now - Duration::microseconds((len - i) as i64); - set_top_ts(node, ts); - } - return; - } - - // Helper: bisect real[] for the nearest real entries around idx. - let find_bracket = |idx: usize| -> (Option<(usize, DateTime)>, - Option<(usize, DateTime)>) { - let pos = real.binary_search_by_key(&idx, |(i, _)| *i); - let (prior_pos, next_pos) = match pos { - Ok(p) => (Some(p), Some(p)), - Err(p) => ( - if p == 0 { None } else { Some(p - 1) }, - if p >= real.len() { None } else { Some(p) }, - ), - }; - (prior_pos.map(|p| real[p]), next_pos.map(|p| real[p])) - }; - - for i in 0..nodes.len() { - if top_ts(&nodes[i]) > DateTime::::UNIX_EPOCH { - continue; - } - let (prior, next) = find_bracket(i); - let new_ts = match (prior, next) { - (Some((pi, pt)), Some((ni, nt))) if pi != ni => { - // Linear interpolate. - let span_ns = (nt - pt).num_nanoseconds().unwrap_or(0); - let offset_ns = span_ns * (i - pi) as i64 / (ni - pi) as i64; - pt + Duration::nanoseconds(offset_ns) - } - (Some((pi, pt)), _) => { - pt + Duration::microseconds((i - pi) as i64) - } - (None, Some((ni, nt))) => { - nt - Duration::microseconds((ni - i) as i64) - } - (None, None) => unreachable!(), - }; - set_top_ts(&mut nodes[i], new_ts); - } -} - -/// For every Branch, ensure each child Leaf has a timestamp. If missing, -/// use parent.ts + child_idx·1ns so siblings stay unique but close. -fn propagate_to_children(node: &mut AstNode) { - if let AstNode::Branch { timestamp, children, .. } = node { - let parent_ts = *timestamp; - for (ci, child) in children.iter_mut().enumerate() { - if top_ts(child) <= DateTime::::UNIX_EPOCH { - set_top_ts(child, parent_ts + Duration::nanoseconds(ci as i64)); - } - propagate_to_children(child); - } - } -} diff --git a/src/cli/node.rs b/src/cli/node.rs index c4305a7..5472505 100644 --- a/src/cli/node.rs +++ b/src/cli/node.rs @@ -197,7 +197,7 @@ pub async fn cmd_load_context(stats: bool) -> Result<()> { return Ok(()); } - println!("=== MEMORY SYSTEM ({}) ===", crate::config::app().assistant_name); + println!("=== MEMORY SYSTEM ({}) ===", cfg.assistant_name); if !personality.is_empty() { println!("--- personality_nodes ({}) ---", personality.len()); diff --git a/src/config.rs b/src/config.rs index b7ea597..9f9ad9a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,9 @@ // Single config file: ~/.consciousness/config.json5 // Memory settings in the "memory" section (Config) // Agent/backend settings at top level (AppConfig) +// +// Legacy fallback: ~/.consciousness/config.jsonl +// Env override: POC_MEMORY_CONFIG use std::collections::HashMap; use std::path::PathBuf; @@ -26,7 +29,9 @@ pub fn config_path() -> PathBuf { static CONFIG: OnceLock>> = OnceLock::new(); +fn default_context_window() -> usize { 128_000 } fn default_stream_timeout() -> u64 { 60 } +fn default_scoring_chunk_tokens() -> usize { 50_000 } fn default_scoring_interval_secs() -> u64 { 3600 } // 1 hour fn default_scoring_response_window() -> usize { 100 } fn default_node_weight() -> f64 { 0.7 } @@ -40,6 +45,8 @@ fn default_identity_dir() -> PathBuf { #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { + pub user_name: String, + pub assistant_name: String, #[serde(deserialize_with = "deserialize_path")] pub data_dir: PathBuf, #[serde(default = "default_identity_dir", deserialize_with = "deserialize_path")] @@ -55,24 +62,51 @@ pub struct Config { /// Nodes loaded into subconscious agent context #[serde(default)] pub agent_nodes: Vec, + pub journal_days: u32, + pub journal_max: usize, pub llm_concurrency: usize, + pub agent_budget: usize, + #[serde(deserialize_with = "deserialize_path")] + pub prompts_dir: PathBuf, + /// Resolved from agent_model → models → backend (not in config directly) + #[serde(skip)] + pub api_base_url: Option, + #[serde(skip)] + pub api_key: Option, + #[serde(skip)] + pub api_model: Option, + #[serde(skip, default = "default_context_window")] + pub api_context_window: usize, + /// Used to resolve API settings, not stored on Config + #[serde(default)] + agent_model: Option, /// Stream chunk timeout in seconds (no data = timeout). #[serde(default = "default_stream_timeout")] pub api_stream_timeout_secs: u64, + /// Max tokens per chunk for memory scoring logprobs calls. + #[serde(default = "default_scoring_chunk_tokens")] + pub scoring_chunk_tokens: usize, /// How often to re-score memory nodes (seconds). Default: 3600 (1 hour). #[serde(default = "default_scoring_interval_secs")] pub scoring_interval_secs: u64, /// Number of assistant responses to score per memory. Default: 50. #[serde(default = "default_scoring_response_window")] pub scoring_response_window: usize, + pub api_reasoning: String, pub agent_types: Vec, #[serde(default)] pub mcp_servers: Vec, #[serde(default)] pub lsp_servers: Vec, + /// Surface agent timeout in seconds. + #[serde(default)] + pub surface_timeout_secs: Option, /// Max conversation bytes to include in surface agent context. #[serde(default)] pub surface_conversation_bytes: Option, + /// Hook events that trigger the surface agent. + #[serde(default)] + pub surface_hooks: Vec, // Spreading activation parameters #[serde(default = "default_node_weight")] @@ -89,21 +123,36 @@ impl Default for Config { fn default() -> Self { let home = dirs::home_dir().unwrap_or_default(); Self { + user_name: "User".to_string(), + assistant_name: "Assistant".to_string(), data_dir: home.join(".consciousness/memory"), identity_dir: home.join(".consciousness/identity"), projects_dir: home.join(".claude/projects"), protected_nodes: Vec::new(), personality_nodes: vec!["identity".into(), "core-practices".into()], agent_nodes: vec!["identity".into(), "core-practices".into()], + journal_days: 7, + journal_max: 20, llm_concurrency: 1, + agent_budget: 1000, + prompts_dir: home.join(".consciousness/prompts"), + api_base_url: None, + api_key: None, + api_model: None, + api_context_window: default_context_window(), api_stream_timeout_secs: default_stream_timeout(), + scoring_chunk_tokens: default_scoring_chunk_tokens(), scoring_interval_secs: default_scoring_interval_secs(), scoring_response_window: default_scoring_response_window(), + agent_model: None, + api_reasoning: "high".to_string(), agent_types: vec![ "linker".into(), "organize".into(), "distill".into(), "separator".into(), "split".into(), ], + surface_timeout_secs: None, surface_conversation_bytes: None, + surface_hooks: vec![], mcp_servers: vec![], lsp_servers: vec![], default_node_weight: default_node_weight(), @@ -116,20 +165,41 @@ impl Default for Config { impl Config { fn load_from_file() -> Self { - Self::try_load_shared().unwrap_or_default() + if let Some(config) = Self::try_load_shared() { + return config; + } + Self::load_legacy_jsonl() } /// Load from shared config. Memory settings in the "memory" section; /// API settings resolved from models + backend configuration. fn try_load_shared() -> Option { let content = std::fs::read_to_string(config_path()).ok()?; - let root: serde_json::Value = json_five::from_str(&content).ok()?; + let root: serde_json::Value = json5::from_str(&content).ok()?; let mem_value = root.get("memory")?; let mut config: Config = serde_json::from_value(mem_value.clone()).ok()?; config.llm_concurrency = config.llm_concurrency.max(1); - // Top-level sections (not inside "memory"). + // Resolve API settings: agent_model → models → backend + if let Some(model_name) = &config.agent_model + && let Some(model_cfg) = root.get("models").and_then(|m| m.get(model_name.as_str())) { + let backend_name = model_cfg.get("backend").and_then(|v| v.as_str()).unwrap_or(""); + let model_id = model_cfg.get("model_id").and_then(|v| v.as_str()).unwrap_or(""); + + if let Some(backend) = root.get(backend_name) { + config.api_base_url = backend.get("base_url") + .and_then(|v| v.as_str()).map(String::from); + config.api_key = backend.get("api_key") + .and_then(|v| v.as_str()).map(String::from); + } + config.api_model = Some(model_id.to_string()); + if let Some(cw) = model_cfg.get("context_window").and_then(|v| v.as_u64()) { + config.api_context_window = cw as usize; + } + } + + // Top-level config sections (not inside "memory") if let Some(servers) = root.get("lsp_servers") { config.lsp_servers = serde_json::from_value(servers.clone()).unwrap_or_default(); } @@ -139,6 +209,11 @@ impl Config { Some(config) } + + /// Load from legacy JSONL config — deprecated, just return defaults. + fn load_legacy_jsonl() -> Self { + Config::default() + } } /// Get the global memory config (cheap Arc clone). @@ -162,85 +237,27 @@ pub fn reload() -> bool { changed } -/// Spawn a background thread that watches `~/.consciousness/config.json5` -/// and reloads both the memory Config and the global AppConfig whenever -/// the file changes on disk. Lets edits from vim / F6 hotkeys / manual -/// tweaks land live without restarting the process. -pub fn watch_config(cli: crate::user::CliArgs) { - use notify_debouncer_mini::{new_debouncer, notify::RecursiveMode}; - - let path = config_path(); - // Watch the parent directory — editors often replace-via-rename, so - // watching the file itself misses the new inode. - let Some(parent) = path.parent().map(|p| p.to_path_buf()) else { - crate::dbglog!("[config] no parent for {}, skipping watch", path.display()); - return; - }; - - std::thread::Builder::new() - .name("config-watcher".into()) - .spawn(move || { - let (tx, rx) = std::sync::mpsc::channel(); - let mut debouncer = match new_debouncer(std::time::Duration::from_millis(200), tx) { - Ok(d) => d, - Err(e) => { - crate::dbglog!("[config] watcher setup failed: {}", e); - return; - } - }; - if let Err(e) = debouncer.watcher() - .watch(&parent, RecursiveMode::NonRecursive) - { - crate::dbglog!("[config] watch({}) failed: {}", parent.display(), e); - return; - } - crate::dbglog!("[config] watching {}", path.display()); - - while let Ok(res) = rx.recv() { - let Ok(events) = res else { continue; }; - if !events.iter().any(|e| e.path == path) { continue; } - - // Reload both halves. - let mem_changed = reload(); - let app_changed = match build_figment(&cli).extract::() { - Ok(app) => { - install_app(app); - true - } - Err(e) => { - crate::dbglog!("[config] reload: AppConfig parse failed: {}", e); - false - } - }; - crate::dbglog!("[config] reloaded (memory_changed={}, app_changed={})", - mem_changed, app_changed); - } - }) - .ok(); -} - // ============================================================ // Agent config (top-level settings) // ============================================================ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AppConfig { - #[serde(default = "default_user_name")] - pub user_name: String, - #[serde(default = "default_assistant_name")] - pub assistant_name: String, - /// Named model endpoints — credentials, base URL, and model id bundled - /// into one entry per backend. Keyed by name, selected by - /// `default_backend` or by `--model ` on the CLI. + pub backend: String, + pub anthropic: BackendConfig, + pub openrouter: BackendConfig, #[serde(default)] - pub backends: HashMap, - #[serde(default)] - pub default_backend: String, + pub deepinfra: BackendConfig, + pub prompts: PromptConfig, pub debug: bool, pub compaction: CompactionConfig, pub dmn: DmnConfig, + #[serde(skip_serializing_if = "Option::is_none")] + pub memory_project: Option, #[serde(default)] - pub learn: LearnConfig, + pub models: HashMap, + #[serde(default = "default_model_name")] + pub default_model: String, #[serde(default)] pub mcp_servers: Vec, #[serde(default)] @@ -267,17 +284,32 @@ pub struct LspServerConfig { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct BackendConfig { - /// API key for the backend. #[serde(default)] pub api_key: String, - /// Base URL for the backend's OpenAI-compatible endpoint. - #[serde(default, skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] pub base_url: Option, - /// Model identifier sent to the API. - pub model_id: String, - /// Context window size in tokens. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub context_window: Option, +} + +impl BackendConfig { + fn resolve(&self, default_base: &str) -> Result<(String, String, String)> { + if self.api_key.is_empty() { + anyhow::bail!( + "No API key. Set it in {} or use --api-key", + config_path().display() + ); + } + let base = self.base_url.clone() + .unwrap_or_else(|| default_base.to_string()); + Ok((base, self.api_key.clone(), self.model.clone())) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptConfig { + pub anthropic: String, + pub other: String, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -292,57 +324,65 @@ pub struct DmnConfig { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LearnConfig { - /// Divergence threshold — responses scoring above this become - /// fine-tuning candidates. Lower = more sensitive. - #[serde(default = "default_learn_threshold")] - pub threshold: f64, - /// Whether to generate "what would the model have said without - /// memories" alternates alongside each scoring run. Expensive — - /// one full streaming generation per candidate. +pub struct ModelConfig { + /// Backend name ("anthropic" or "openrouter") + pub backend: String, + /// Model identifier sent to the API + pub model_id: String, + /// Instruction file ("CLAUDE.md" or "POC.md"). #[serde(default)] - pub generate_alternates: bool, + pub prompt_file: Option, + /// Context window size in tokens. + #[serde(default)] + pub context_window: Option, } -fn default_learn_threshold() -> f64 { 1.0 } - -impl Default for LearnConfig { - fn default() -> Self { - Self { - threshold: default_learn_threshold(), - generate_alternates: false, - } - } -} - -fn default_user_name() -> String { "User".into() } -fn default_assistant_name() -> String { "Assistant".into() } - impl Default for AppConfig { fn default() -> Self { Self { - user_name: default_user_name(), - assistant_name: default_assistant_name(), - backends: HashMap::new(), - default_backend: String::new(), + backend: "openrouter".to_string(), + anthropic: BackendConfig { + api_key: String::new(), + model: "claude-opus-4-6-20250918".to_string(), + base_url: None, + }, + openrouter: BackendConfig { + api_key: String::new(), + model: "qwen/qwen3.5-397b-a17b".to_string(), + base_url: Some("https://openrouter.ai/api/v1".to_string()), + }, + deepinfra: BackendConfig { + api_key: String::new(), + model: String::new(), + base_url: Some("https://api.deepinfra.com/v1/openai".to_string()), + }, + prompts: PromptConfig { + anthropic: "CLAUDE.md".to_string(), + other: "POC.md".to_string(), + }, debug: false, compaction: CompactionConfig { hard_threshold_pct: 90, soft_threshold_pct: 80, }, dmn: DmnConfig { max_turns: 20 }, - learn: LearnConfig::default(), + memory_project: None, + models: HashMap::new(), + default_model: String::new(), mcp_servers: Vec::new(), lsp_servers: Vec::new(), } } } +fn default_model_name() -> String { String::new() } + /// Resolved, ready-to-use agent session config. pub struct SessionConfig { pub api_base: String, pub api_key: String, pub model: String, + pub prompt_file: String, /// Identity/personality nodes as (name, content) pairs. pub context_parts: Vec<(String, String)>, pub session_dir: PathBuf, @@ -358,21 +398,36 @@ pub struct ResolvedModel { pub api_base: String, pub api_key: String, pub model_id: String, + pub prompt_file: String, pub context_window: Option, } impl AppConfig { /// Resolve the active backend and assemble prompts into a SessionConfig. pub async fn resolve(&self, cli: &crate::user::CliArgs) -> Result { - if self.backends.is_empty() { - anyhow::bail!( - "no backends configured in {}. Add a `backends` section with at least one entry.", - config_path().display() - ); - } + let (api_base, api_key, model, prompt_file); - let name = cli.model.as_deref().unwrap_or(&self.default_backend); - let resolved = self.resolve_model(name)?; + if !self.models.is_empty() { + let model_name = cli.model.as_deref().unwrap_or(&self.default_model); + let resolved = self.resolve_model(model_name)?; + api_base = resolved.api_base; + api_key = resolved.api_key; + model = resolved.model_id; + prompt_file = resolved.prompt_file; + } else { + let (base, key, mdl) = match self.backend.as_str() { + "anthropic" => self.anthropic.resolve("https://api.anthropic.com"), + _ => self.openrouter.resolve("https://openrouter.ai/api/v1"), + }?; + api_base = base; + api_key = key; + model = mdl; + prompt_file = if self.backend == "anthropic" { + self.prompts.anthropic.clone() + } else { + self.prompts.other.clone() + }; + } let personality_nodes = get().personality_nodes.clone(); let context_parts = crate::mind::identity::personality_nodes(&personality_nodes).await; @@ -383,13 +438,11 @@ impl AppConfig { std::fs::create_dir_all(&session_dir).ok(); // CLI --api-base and --api-key override everything - let api_base = cli.api_base.clone().unwrap_or(resolved.api_base); - let api_key = cli.api_key.clone().unwrap_or(resolved.api_key); + let api_base = cli.api_base.clone().unwrap_or(api_base); + let api_key = cli.api_key.clone().unwrap_or(api_key); Ok(SessionConfig { - api_base, - api_key, - model: resolved.model_id, + api_base, api_key, model, prompt_file, context_parts, session_dir, app: self.clone(), @@ -397,33 +450,55 @@ impl AppConfig { }) } - /// Look up a named backend and resolve its credentials. + /// Look up a named model and resolve its credentials from the backend config. pub fn resolve_model(&self, name: &str) -> Result { - let b = self.backends.get(name) + let model = self.models.get(name) .ok_or_else(|| anyhow::anyhow!( - "Unknown backend '{}'. Available: {}", + "Unknown model '{}'. Available: {}", name, self.model_names().join(", "), ))?; - let api_base = b.base_url.clone() - .ok_or_else(|| anyhow::anyhow!( - "backends.{}.base_url not set in {}", - name, config_path().display() - ))?; + let (api_base, api_key) = match model.backend.as_str() { + "anthropic" => ( + self.anthropic.base_url.clone() + .unwrap_or_else(|| "https://api.anthropic.com".to_string()), + self.anthropic.api_key.clone(), + ), + "deepinfra" => ( + self.deepinfra.base_url.clone() + .unwrap_or_else(|| "https://api.deepinfra.com/v1/openai".to_string()), + self.deepinfra.api_key.clone(), + ), + _ => ( + self.openrouter.base_url.clone() + .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()), + self.openrouter.api_key.clone(), + ), + }; + + let prompt_file = model.prompt_file.clone() + .unwrap_or_else(|| { + if model.backend == "anthropic" { + self.prompts.anthropic.clone() + } else { + self.prompts.other.clone() + } + }); Ok(ResolvedModel { name: name.to_string(), api_base, - api_key: b.api_key.clone(), - model_id: b.model_id.clone(), - context_window: b.context_window, + api_key, + model_id: model.model_id.clone(), + prompt_file, + context_window: model.context_window, }) } - /// List available backend names, sorted. + /// List available model names, sorted. pub fn model_names(&self) -> Vec { - let mut names: Vec<_> = self.backends.keys().cloned().collect(); + let mut names: Vec<_> = self.models.keys().cloned().collect(); names.sort(); names } @@ -443,7 +518,7 @@ impl Provider for Json5File { fn data(&self) -> figment::Result> { match std::fs::read_to_string(&self.0) { Ok(content) => { - let value: figment::value::Value = json_five::from_str(&content) + let value: figment::value::Value = json5::from_str(&content) .map_err(|e| figment::Error::from(format!("{}: {}", self.0.display(), e)))?; Serialized::defaults(value).data() } @@ -465,6 +540,11 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment { let mut f = Figment::from(Serialized::defaults(AppConfig::default())) .merge(Json5File(config_path())); + merge_opt!(f, cli.backend, "backend"); + merge_opt!(f, cli.model, "anthropic.model", "openrouter.model"); + merge_opt!(f, cli.api_key, "anthropic.api_key", "openrouter.api_key"); + merge_opt!(f, cli.api_base, "anthropic.base_url", "openrouter.base_url"); + merge_opt!(f, cli.memory_project, "memory_project"); merge_opt!(f, cli.dmn_max_turns, "dmn.max_turns"); if cli.debug { f = f.merge(Serialized::default("debug", true)); @@ -474,46 +554,12 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment { } /// Load just the AppConfig — no validation, no prompt assembly. -/// Also installs the loaded AppConfig into the global cache so -/// `config::app()` is available everywhere. pub fn load_app(cli: &crate::user::CliArgs) -> Result<(AppConfig, Figment)> { let figment = build_figment(cli); let app: AppConfig = figment.extract().context("Failed to load configuration")?; - install_app(app.clone()); Ok((app, figment)) } -// ============================================================ -// Global AppConfig cache (writable, for runtime-mutable settings -// like learn.threshold that F6 edits via config_writer). -// ============================================================ - -static APP_CONFIG: OnceLock> = OnceLock::new(); - -fn install_app(app: AppConfig) { - let slot = APP_CONFIG.get_or_init(|| RwLock::new(app.clone())); - *slot.write().unwrap() = app; -} - -/// Current AppConfig, held under a read lock. Reads should be brief -/// (no holding across await / long work) to avoid starving writers. -/// Panics if called before load_app — which runs once at startup. -pub fn app() -> std::sync::RwLockReadGuard<'static, AppConfig> { - APP_CONFIG - .get() - .expect("config::app() called before load_app()") - .read() - .unwrap() -} - -/// Mutate the cached AppConfig in place. Used by config_writer to keep -/// the in-memory view in sync with disk after surgical edits to -/// ~/.consciousness/config.json5. -pub fn update_app(f: impl FnOnce(&mut AppConfig)) { - let slot = APP_CONFIG.get().expect("update_app before load_app"); - f(&mut *slot.write().unwrap()); -} - /// Load the full config: figment → AppConfig → resolve backend → assemble prompts. pub async fn load_session(cli: &crate::user::CliArgs) -> Result<(SessionConfig, Figment)> { let (app, figment) = load_app(cli)?; @@ -539,28 +585,38 @@ pub fn show_config(app: &AppConfig, figment: &Figment) { } println!("# Effective configuration\n"); - println!("user_name: {:?} ({})", app.user_name, src(figment, "user_name")); - println!("assistant_name: {:?} ({})", app.assistant_name, src(figment, "assistant_name")); + println!("backend: {:?} ({})", app.backend, src(figment, "backend")); + for (name, b) in [("anthropic", &app.anthropic), ("openrouter", &app.openrouter)] { + println!("\n{}:", name); + println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("{name}.api_key"))); + println!(" model: {:?} ({})", b.model, src(figment, &format!("{name}.model"))); + if let Some(ref url) = b.base_url { + println!(" base_url: {:?} ({})", url, src(figment, &format!("{name}.base_url"))); + } + } + println!("\nprompts:"); + println!(" anthropic: {:?} ({})", app.prompts.anthropic, src(figment, "prompts.anthropic")); + println!(" other: {:?} ({})", app.prompts.other, src(figment, "prompts.other")); println!("\ndebug: {} ({})", app.debug, src(figment, "debug")); println!("\ncompaction:"); println!(" hard_threshold_pct: {} ({})", app.compaction.hard_threshold_pct, src(figment, "compaction.hard_threshold_pct")); println!(" soft_threshold_pct: {} ({})", app.compaction.soft_threshold_pct, src(figment, "compaction.soft_threshold_pct")); println!("\ndmn:"); println!(" max_turns: {} ({})", app.dmn.max_turns, src(figment, "dmn.max_turns")); - println!("\ndefault_backend: {:?} ({})", app.default_backend, src(figment, "default_backend")); - if !app.backends.is_empty() { - println!("\nbackends:"); - let mut names: Vec<_> = app.backends.keys().cloned().collect(); - names.sort(); - for name in names { - let b = &app.backends[&name]; + if let Some(ref p) = app.memory_project { + println!("\nmemory_project: {:?} ({})", p, src(figment, "memory_project")); + } + println!("\ndefault_model: {:?}", app.default_model); + if !app.models.is_empty() { + println!("\nmodels:"); + for (name, m) in &app.models { println!(" {}:", name); - println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("backends.{name}.api_key"))); - if let Some(ref url) = b.base_url { - println!(" base_url: {:?} ({})", url, src(figment, &format!("backends.{name}.base_url"))); + println!(" backend: {:?}", m.backend); + println!(" model_id: {:?}", m.model_id); + if let Some(ref pf) = m.prompt_file { + println!(" prompt_file: {:?}", pf); } - println!(" model_id: {:?}", b.model_id); - if let Some(cw) = b.context_window { + if let Some(cw) = m.context_window { println!(" context_window: {}", cw); } } diff --git a/src/config_writer.rs b/src/config_writer.rs deleted file mode 100644 index 079449f..0000000 --- a/src/config_writer.rs +++ /dev/null @@ -1,448 +0,0 @@ -// config_writer.rs — Surgical edits to ~/.consciousness/config.json5 -// -// Uses json-five's round-trip parser to mutate specific fields while -// preserving the surrounding comments, whitespace, and formatting. - -use std::path::Path; - -use anyhow::{anyhow, Context as _, Result}; -use json_five::rt::parser::{ - from_str, JSONKeyValuePair, JSONObjectContext, JSONValue, KeyValuePairContext, -}; - -use crate::config::config_path; - -/// Read the config, apply `mutate` to the root JSONValue, write it back atomically. -fn edit_config Result<()>>(mutate: F) -> Result<()> { - let path = config_path(); - let src = std::fs::read_to_string(&path) - .with_context(|| format!("read {}", path.display()))?; - - let mut text = from_str(&src) - .map_err(|e| anyhow!("parse {}: {}", path.display(), e))?; - mutate(&mut text.value)?; - - write_atomic(&path, &text.to_string()) -} - -fn write_atomic(path: &Path, content: &str) -> Result<()> { - let parent = path.parent() - .ok_or_else(|| anyhow!("config path has no parent: {}", path.display()))?; - let tmp = parent.join(format!( - ".{}.tmp", - path.file_name().unwrap_or_default().to_string_lossy(), - )); - std::fs::write(&tmp, content) - .with_context(|| format!("write {}", tmp.display()))?; - std::fs::rename(&tmp, path) - .with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?; - Ok(()) -} - -/// Match a key JSONValue against a string name. JSON5 allows keys to be -/// unquoted identifiers or single/double-quoted strings. -fn key_matches(key: &JSONValue, name: &str) -> bool { - match key { - JSONValue::Identifier(s) - | JSONValue::DoubleQuotedString(s) - | JSONValue::SingleQuotedString(s) => s == name, - _ => false, - } -} - -/// Find (or create) a child object under `parent`, returning a mutable borrow -/// of its key_value_pairs vector. -/// Append a new kvp to `object`, setting whitespace so the output is -/// multi-line with the given indentation: -/// -/// ```text -/// {first_key: first_val,} -/// ``` -/// -/// If `object` already has kvps, the separator between the last one and -/// ours goes in the prior kvp's wsc.3. If we're the first kvp, the -/// lead-in after `{` goes in the object's own wsc.0. -fn append_kvp_pretty( - object: &mut JSONValue, - key: JSONValue, - value: JSONValue, - inner_indent: &str, - outer_indent: &str, -) -> Result<()> { - let (pairs, ctx) = match object { - JSONValue::JSONObject { key_value_pairs, context } => { - let ctx = context.get_or_insert_with(|| JSONObjectContext { - wsc: (String::new(),), - }); - (key_value_pairs, ctx) - } - _ => return Err(anyhow!("not an object")), - }; - - if pairs.is_empty() { - ctx.wsc.0 = format!("\n{}", inner_indent); - } else { - let prev = pairs.last_mut().unwrap(); - let prev_ctx = prev.context.get_or_insert_with(|| KeyValuePairContext { - wsc: (String::new(), String::from(" "), String::new(), None), - }); - prev_ctx.wsc.3 = Some(format!("\n{}", inner_indent)); - } - - pairs.push(JSONKeyValuePair { - key, - value, - context: Some(KeyValuePairContext { - wsc: ( - String::new(), - String::from(" "), - String::new(), - Some(format!("\n{}", outer_indent)), - ), - }), - }); - - Ok(()) -} - -/// Find or create a child object under `parent`. Returns the index of -/// the kvp in parent's key_value_pairs so the caller can re-borrow -/// afterward. -fn get_or_create_object_idx( - parent: &mut JSONValue, - section: &str, - inner_indent: &str, - outer_indent: &str, -) -> Result { - let existing = match parent { - JSONValue::JSONObject { key_value_pairs, .. } => { - key_value_pairs.iter() - .position(|kvp| key_matches(&kvp.key, section)) - } - _ => return Err(anyhow!("config root is not an object")), - }; - - if let Some(i) = existing { - return Ok(i); - } - - append_kvp_pretty( - parent, - JSONValue::Identifier(section.to_string()), - JSONValue::JSONObject { - key_value_pairs: Vec::new(), - context: Some(JSONObjectContext { wsc: (String::new(),) }), - }, - inner_indent, - outer_indent, - )?; - - match parent { - JSONValue::JSONObject { key_value_pairs, .. } => Ok(key_value_pairs.len() - 1), - _ => unreachable!(), - } -} - -/// Set `section.key` to a literal scalar value (e.g., "1e-7", "42", "true"). -/// The literal is parsed as JSON5 so we preserve its source-form on round-trip. -pub fn set_scalar(section: &str, key: &str, literal: &str) -> Result<()> { - let value = parse_scalar_literal(literal)?; - edit_config(|root| { - // New top-level sections sit at column 4 (inside root `{`), - // and the root's closing `}` sits at column 0. - let section_idx = get_or_create_object_idx(root, section, " ", "")?; - - let section_value = match root { - JSONValue::JSONObject { key_value_pairs, .. } => { - &mut key_value_pairs[section_idx].value - } - _ => unreachable!(), - }; - - // Update in place if the key already exists. - if let JSONValue::JSONObject { key_value_pairs, .. } = section_value { - if let Some(kvp) = key_value_pairs.iter_mut() - .find(|k| key_matches(&k.key, key)) - { - kvp.value = value; - return Ok(()); - } - } - - // Append a new kvp. Inner keys sit at column 8, the section's - // closing `}` sits at column 4. - append_kvp_pretty( - section_value, - JSONValue::Identifier(key.to_string()), - value, - " ", - " ", - ) - }) -} - -/// Parse a scalar literal by round-tripping it through json-five. Keeps us -/// consistent with whatever scalars the library considers valid (hex, -/// exponents, Infinity, etc.). -fn parse_scalar_literal(literal: &str) -> Result { - let text = from_str(literal) - .map_err(|e| anyhow!("parse literal {:?}: {}", literal, e))?; - match text.value { - JSONValue::JSONObject { .. } | JSONValue::JSONArray { .. } => { - Err(anyhow!("set_scalar only accepts scalar literals, got {:?}", literal)) - } - v => Ok(v), - } -} - -/// Convenience: set `learn.threshold` to the given f64. -pub fn set_learn_threshold(value: f64) -> Result<()> { - // {:e} gives the minimal scientific notation that preserves the value. - set_scalar("learn", "threshold", &format!("{:e}", value))?; - crate::config::update_app(|app| app.learn.threshold = value); - Ok(()) -} - -/// Convenience: set `learn.generate_alternates` to the given bool. -pub fn set_learn_generate_alternates(value: bool) -> Result<()> { - set_scalar("learn", "generate_alternates", - if value { "true" } else { "false" })?; - crate::config::update_app(|app| app.learn.generate_alternates = value); - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - // In-memory variant of set_scalar — used to test the mutation logic - // without touching disk. - fn set_scalar_inline( - root: &mut JSONValue, - section: &str, - key: &str, - literal: &str, - ) -> Result<()> { - let value = parse_scalar_literal(literal)?; - let section_idx = get_or_create_object_idx(root, section, " ", "")?; - let section_value = match root { - JSONValue::JSONObject { key_value_pairs, .. } => { - &mut key_value_pairs[section_idx].value - } - _ => unreachable!(), - }; - if let JSONValue::JSONObject { key_value_pairs, .. } = section_value { - if let Some(kvp) = key_value_pairs.iter_mut() - .find(|k| key_matches(&k.key, key)) - { - kvp.value = value; - return Ok(()); - } - } - append_kvp_pretty( - section_value, - JSONValue::Identifier(key.to_string()), - value, - " ", - " ", - ) - } - - fn edit_str Result<()>>(src: &str, f: F) -> Result { - let mut text = from_str(src).map_err(|e| anyhow!("{}", e))?; - f(&mut text.value)?; - Ok(text.to_string()) - } - - #[test] - fn replaces_existing_scalar() { - let src = r#"{ - // threshold for learning - learn: { - threshold: 0.001, // the old value - }, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "1e-7") - }).unwrap(); - assert!(out.contains("1e-7"), "output: {}", out); - assert!(out.contains("// threshold for learning")); - assert!(out.contains("// the old value")); - assert!(!out.contains("0.001")); - } - - #[test] - fn creates_missing_section() { - let src = r#"{ - // comment - memory: { user_name: "Kent" }, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "1e-7") - }).unwrap(); - assert!(out.contains("learn")); - assert!(out.contains("1e-7")); - assert!(out.contains("// comment")); - assert!(out.contains(r#"user_name: "Kent""#)); - } - - #[test] - fn preserves_comments_in_siblings() { - let src = r#"{ - memory: { - // sensitive setting - user_name: "Kent", // name - }, - learn: { - threshold: 0.5, - }, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "1e-9") - }).unwrap(); - assert!(out.contains("// sensitive setting")); - assert!(out.contains("// name")); - assert!(out.contains("1e-9")); - assert!(!out.contains("0.5")); - } - - #[test] - fn adds_key_to_existing_empty_section() { - let src = r#"{ - learn: {}, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "42") - }).unwrap(); - assert!(out.contains("threshold"), "output: {}", out); - assert!(out.contains("42")); - } - - #[test] - fn realistic_config_adds_learn_section() { - // Mirrors the shape of ~/.consciousness/config.json5 — multiple - // sections, comments, mixed tab/space indent, trailing commas. - let src = r#"{ - deepinfra: { - api_key: "bcachefs-agents-2026", - base_url: "http://example/v1", - }, - - // Named models - models: { - "27b": { - backend: "deepinfra", - model_id: "Qwen/Qwen3.5-27B", - }, - }, - - default_model: "27b", - - memory: { - user_name: "Kent", - // Active agent types - agent_types: ["linker", "organize"], - }, - - compaction: { - hard_threshold_pct: 90, - }, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "1e-7") - }).unwrap(); - - // Core assertions: comments and sibling sections survive. - assert!(out.contains(r#"api_key: "bcachefs-agents-2026""#)); - assert!(out.contains("// Named models")); - assert!(out.contains("// Active agent types")); - assert!(out.contains(r#"user_name: "Kent""#)); - assert!(out.contains("hard_threshold_pct: 90")); - - // New section added. - assert!(out.contains("learn")); - assert!(out.contains("1e-7")); - - // Parse result should parse back without error (real json5 parser). - let reparsed: serde_json::Value = json_five::from_str(&out) - .expect("mutated output must be valid JSON5"); - let threshold = reparsed.pointer("/learn/threshold").expect("learn.threshold exists"); - assert_eq!(threshold.as_f64(), Some(1e-7)); - } - - #[test] - fn realistic_config_updates_existing_threshold() { - let src = r#"{ - learn: { - // The divergence threshold - threshold: 0.001, - }, - memory: { user_name: "Kent" }, -}"#; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "threshold", "5e-8") - }).unwrap(); - assert!(out.contains("5e-8")); - assert!(!out.contains("0.001")); - assert!(out.contains("// The divergence threshold")); - - let reparsed: serde_json::Value = json_five::from_str(&out).unwrap(); - assert_eq!(reparsed.pointer("/learn/threshold").and_then(|v| v.as_f64()), Some(5e-8)); - } - - #[test] - fn new_section_exact_multiline_layout() { - let src = "{\n a: 1,\n}"; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "generate_alternates", "true")?; - set_scalar_inline(root, "learn", "threshold", "1e-7") - }).unwrap(); - - let expected = "\ -{ - a: 1, - learn: { - generate_alternates: true, - threshold: 1e-7, - }, -}"; - assert_eq!(out, expected, "\n--- got ---\n{}\n--- want ---\n{}\n", out, expected); - } - - #[test] - fn new_section_and_key_format_cleanly() { - // The kind of config we actually have in ~/.consciousness - // (top-level sections separated by blank lines, 4-space indent - // for keys within each section). Appending a fresh `learn` - // section with one key should land cleanly, not as - // `learn\n\n :{key\n :value}`. - let src = "{\n memory: {\n user_name: \"Kent\",\n },\n}"; - let out = edit_str(src, |root| { - set_scalar_inline(root, "learn", "generate_alternates", "true") - }).unwrap(); - - // No stray key-to-colon-on-next-line anywhere. - assert!(!out.contains("learn\n"), "learn key wraps: {}", out); - assert!(!out.contains("generate_alternates\n"), - "inner key wraps: {}", out); - - // The output should reparse. - let v: serde_json::Value = json_five::from_str(&out).unwrap(); - assert_eq!( - v.pointer("/learn/generate_alternates").and_then(|x| x.as_bool()), - Some(true), - "output: {}", out, - ); - } - - #[test] - fn roundtrip_stable_without_change() { - let src = r#"{ - // heading - a: 1, - b: { c: 2 }, // inline -}"#; - let text = from_str(src).unwrap(); - assert_eq!(text.to_string(), src); - } -} diff --git a/src/hippocampus/neuro/scoring.rs b/src/hippocampus/neuro/scoring.rs index c9cbb40..5828fd0 100644 --- a/src/hippocampus/neuro/scoring.rs +++ b/src/hippocampus/neuro/scoring.rs @@ -230,6 +230,10 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio rationale: Vec::new(), }; + // Active agent types from config + let config = crate::config::get(); + let agent_types: Vec<&str> = config.agent_types.iter().map(|s| s.as_str()).collect(); + // Target: α ≥ 2.5 (healthy scale-free) if alpha < 2.0 { plan.add("linker", 100); @@ -270,6 +274,48 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio // Split: handle oversized nodes plan.set("split", 5); + // Distribute agent budget using Elo ratings + let budget = crate::config::get().agent_budget; + let elo_path = crate::config::get().data_dir.join("agent-elo.json"); + if let Ok(elo_json) = std::fs::read_to_string(&elo_path) { + if let Ok(ratings) = serde_json::from_str::>(&elo_json) { + let elos: Vec = agent_types.iter() + .map(|t| ratings.get(*t).copied().unwrap_or(1000.0)) + .collect(); + let min_elo = elos.iter().copied().fold(f64::MAX, f64::min); + + let weights: Vec = elos.iter() + .map(|e| { + let shifted = e - min_elo + 50.0; + shifted * shifted + }) + .collect(); + let total_weight: f64 = weights.iter().sum(); + + let allocate = |w: f64| -> usize { + ((w / total_weight * budget as f64).round() as usize).max(2) + }; + + for (i, agent) in agent_types.iter().enumerate() { + plan.set(agent, allocate(weights[i])); + } + + let summary: Vec = agent_types.iter() + .map(|a| format!("{}={}", a, plan.count(a))) + .collect(); + plan.rationale.push(format!( + "Elo allocation (budget={}): {}", budget, summary.join(" "))); + } + } else { + // No Elo file — use budget with equal distribution + let per_type = budget / agent_types.len(); + for agent in &agent_types { + plan.set(agent, per_type); + } + plan.rationale.push(format!( + "No Elo ratings — equal distribution ({} each, budget={})", per_type, budget)); + } + plan } diff --git a/src/lib.rs b/src/lib.rs index e6411e3..1a71735 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,6 @@ pub mod subconscious; // Unified configuration pub mod config; -pub mod config_writer; // Session state pub mod session; diff --git a/src/mind/log.rs b/src/mind/log.rs index 7ac0d79..b69f2ca 100644 --- a/src/mind/log.rs +++ b/src/mind/log.rs @@ -55,13 +55,17 @@ impl ConversationLog { } pub fn oldest_timestamp(&self) -> Option> { + // Read forward from the start to find first timestamp let file = File::open(&self.path).ok()?; let mmap = unsafe { Mmap::map(&file).ok()? }; + // Find first { ... } and parse for line in mmap.split(|&b| b == b'\n') { if line.is_empty() { continue; } if let Ok(node) = serde_json::from_slice::(line) { if let Some(leaf) = node.leaf() { - return Some(leaf.timestamp()); + if let Some(ts) = leaf.timestamp() { + return Some(ts); + } } } } diff --git a/src/mind/mod.rs b/src/mind/mod.rs index 11d45b1..a221e80 100644 --- a/src/mind/mod.rs +++ b/src/mind/mod.rs @@ -147,25 +147,6 @@ pub struct MindState { pub unc_idle: bool, /// When the unconscious idle timer will fire (for UI display). pub unc_idle_deadline: Instant, - /// Fine-tuning candidates identified by scoring. - pub finetune_candidates: Vec, - /// Last scoring run stats for UI display. - pub finetune_last_run: Option, -} - -/// Stats from the last finetune scoring run. -#[derive(Clone, Debug)] -pub struct FinetuneScoringStats { - /// Count of assistant responses we considered (recent half of context). - pub responses_considered: usize, - /// How many exceeded the divergence threshold. - pub above_threshold: usize, - /// Threshold used for this run. - pub threshold: f64, - /// Highest divergence observed. - pub max_divergence: f64, - /// Error message if the run failed. - pub error: Option, } impl Clone for MindState { @@ -184,8 +165,6 @@ impl Clone for MindState { turn_handle: None, // Not cloned — only Mind's loop uses this unc_idle: self.unc_idle, unc_idle_deadline: self.unc_idle_deadline, - finetune_candidates: self.finetune_candidates.clone(), - finetune_last_run: self.finetune_last_run.clone(), } } } @@ -198,12 +177,6 @@ pub enum MindCommand { Score, /// Run full N×M memory scoring matrix (/score command) ScoreFull, - /// Score for finetune candidates - ScoreFinetune, - /// Update the finetune divergence threshold and persist to config. - SetLearnThreshold(f64), - /// Toggle alternate-response generation during scoring; persist to config. - SetLearnGenerateAlternates(bool), /// Abort current turn, kill processes Interrupt, /// Reset session @@ -229,8 +202,6 @@ impl MindState { turn_handle: None, unc_idle: false, unc_idle_deadline: Instant::now() + std::time::Duration::from_secs(60), - finetune_candidates: Vec::new(), - finetune_last_run: None, } } @@ -317,7 +288,6 @@ impl MindState { /// Background task completion events. enum BgEvent { ScoringDone, - FinetuneCandidate(learn::FinetuneCandidate), } // --- Mind: cognitive state machine --- @@ -354,26 +324,13 @@ impl Mind { client, config.context_parts.clone(), config.app.clone(), + config.prompt_file.clone(), conversation_log, crate::agent::tools::ActiveTools::new(), crate::agent::tools::tools(), ).await; - // Migrate legacy "file exists = enabled" sentinel for the - // generate-alternates flag into the config. One-shot; after this - // the sentinel is gone and the config is the source of truth. - let legacy_sentinel = dirs::home_dir().unwrap_or_default() - .join(".consciousness/cache/finetune-alternates"); - if legacy_sentinel.exists() { - if !crate::config::app().learn.generate_alternates { - let _ = crate::config_writer::set_learn_generate_alternates(true); - } - let _ = std::fs::remove_file(&legacy_sentinel); - } - - let shared = Arc::new(std::sync::Mutex::new(MindState::new( - config.app.dmn.max_turns, - ))); + let shared = Arc::new(std::sync::Mutex::new(MindState::new(config.app.dmn.max_turns))); let (turn_watch, _) = tokio::sync::watch::channel(false); let (conscious_active, _) = tokio::sync::watch::channel(false); let (bg_tx, bg_rx) = mpsc::unbounded_channel(); @@ -572,20 +529,6 @@ impl Mind { } self.agent.compact().await; } - MindCommand::ScoreFinetune => { - self.start_finetune_scoring(); - } - MindCommand::SetLearnThreshold(value) => { - if let Err(e) = crate::config_writer::set_learn_threshold(value) { - dbglog!("[learn] failed to persist threshold {}: {:#}", value, e); - } - } - MindCommand::SetLearnGenerateAlternates(value) => { - if let Err(e) = crate::config_writer::set_learn_generate_alternates(value) { - dbglog!("[learn] failed to persist generate_alternates {}: {:#}", - value, e); - } - } } } } @@ -660,72 +603,6 @@ impl Mind { }); } - /// Score responses for fine-tuning candidates. - /// - /// Scores the most recent half of the context — responses near the end - /// of the context window were generated with the most context available, - /// which is what we want to train on. The threshold is a temporary knob; - /// once this runs continuously, we'll just train whatever lands at full - /// context without filtering. - pub fn start_finetune_scoring(&self) { - // Snapshot the config values we need before spawning — the scoring - // task shouldn't hold the config read lock across async work. - let (threshold, gen_alternates) = { - let app = crate::config::app(); - (app.learn.threshold, app.learn.generate_alternates) - }; - // Clear the previous run's candidates so this run's stream is fresh. - self.shared.lock().unwrap().finetune_candidates.clear(); - - let agent = self.agent.clone(); - let bg_tx = self.bg_tx.clone(); - let shared = self.shared.clone(); - tokio::spawn(async move { - let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await; - - let (context, client) = { - let ctx = agent.context.lock().await; - (ctx.clone(), agent.client.clone()) - }; - - let entries = context.conversation(); - let score_count = entries.len() / 2; - let range_start = entries.len() - score_count; - let responses_considered: usize = entries[range_start..].iter() - .filter(|n| matches!(n, crate::agent::context::AstNode::Branch { role: crate::agent::context::Role::Assistant, .. })) - .count(); - - activity.update(format!("finetune: scoring {} responses...", responses_considered)).await; - - let bg_tx_cb = bg_tx.clone(); - let stats = match learn::score_finetune_candidates( - &context, score_count, &client, threshold, - gen_alternates, &activity, - |c| { let _ = bg_tx_cb.send(BgEvent::FinetuneCandidate(c)); }, - ).await { - Ok((above_threshold, max_div)) => { - FinetuneScoringStats { - responses_considered, - above_threshold, - threshold, - max_divergence: max_div, - error: None, - } - } - Err(e) => FinetuneScoringStats { - responses_considered, - above_threshold: 0, - threshold, - max_divergence: 0.0, - error: Some(format!("{}", e)), - }, - }; - - shared.lock().unwrap().finetune_last_run = Some(stats); - // activity drops here, marking completion and notifying observers - }); - } - async fn start_turn(&self, text: &str, target: StreamTarget) { { match target { @@ -790,12 +667,6 @@ impl Mind { let mut bg_rx = self.bg_rx.lock().unwrap().take() .expect("Mind::run() called twice"); let mut sub_handle: Option> = None; - - // Start finetune scoring at startup (scores existing conversation) - if !self.config.no_agents { - self.start_finetune_scoring(); - } - loop { let (timeout, has_input) = { let me = self.shared.lock().unwrap(); @@ -821,9 +692,6 @@ impl Mind { BgEvent::ScoringDone => { self.shared.lock().unwrap().scoring_in_flight = false; } - BgEvent::FinetuneCandidate(c) => { - self.shared.lock().unwrap().finetune_candidates.push(c); - } } } @@ -843,7 +711,6 @@ impl Mind { cmds.push(MindCommand::Compact); if !self.config.no_agents { cmds.push(MindCommand::Score); - cmds.push(MindCommand::ScoreFinetune); } } diff --git a/src/mind/subconscious.rs b/src/mind/subconscious.rs index 21cc549..d5bee34 100644 --- a/src/mind/subconscious.rs +++ b/src/mind/subconscious.rs @@ -20,7 +20,6 @@ use std::path::PathBuf; use std::time::{Duration, Instant}; -use crate::thalamus::idle::{hours_since_last_dream, DREAM_INTERVAL_HOURS}; /// DMN state machine. #[derive(Debug, Clone)] @@ -92,8 +91,7 @@ impl State { /// Generate the DMN prompt for the current state, informed by /// user presence and error patterns. pub fn prompt(&self, ctx: &DmnContext) -> String { - let app = crate::config::app(); - let user = &app.user_name; + let user = &crate::config::get().user_name; let idle_info = if ctx.user_idle < Duration::from_secs(60) { format!("{} is here (active recently).", user) @@ -140,22 +138,10 @@ impl State { ) } State::Foraging => { - let dream_hint = { - let hours = hours_since_last_dream(); - if hours >= DREAM_INTERVAL_HOURS { - format!( - " You haven't dreamed in {} hours — consider running \ - ~/.consciousness/tools/dream-start.sh.", - hours - ) - } else { - String::new() - } - }; format!( "[dmn] Foraging time. {} Follow whatever catches your attention — \ - memory files, code, ideas. Call yield_to_user when you want to rest.{}{}", - idle_info, dream_hint, stuck_warning + memory files, code, ideas. Call yield_to_user when you want to rest.{}", + idle_info, stuck_warning ) } State::Resting { since } => { diff --git a/src/mind/unconscious.rs b/src/mind/unconscious.rs index 4f9a0ca..8989264 100644 --- a/src/mind/unconscious.rs +++ b/src/mind/unconscious.rs @@ -275,7 +275,17 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc phase: s.phase.clone(), }).collect()); - // Create standalone Agent — stored so UI can read context. + // Create standalone Agent — stored so UI can read context + let config = crate::config::get(); + let base_url = config.api_base_url.as_deref().unwrap_or(""); + let api_key = config.api_key.as_deref().unwrap_or(""); + let model = config.api_model.as_deref().unwrap_or(""); + if base_url.is_empty() || model.is_empty() { + dbglog!("[unconscious] API not configured"); + auto.steps = orig_steps; + return Err(auto); + } + let cli = crate::user::CliArgs::default(); let (app, _) = match crate::config::load_app(&cli) { Ok(r) => r, @@ -285,21 +295,12 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc return Err(auto); } }; - let resolved = match app.resolve_model(&app.default_backend) { - Ok(r) => r, - Err(e) => { - dbglog!("[unconscious] API not configured: {}", e); - auto.steps = orig_steps; - return Err(auto); - } - }; // Unconscious agents have self-contained prompts — no standard context. - let client = crate::agent::api::ApiClient::new( - &resolved.api_base, &resolved.api_key, &resolved.model_id); + let client = crate::agent::api::ApiClient::new(base_url, api_key, model); let agent = crate::agent::Agent::new( client, Vec::new(), - app, None, + app, String::new(), None, crate::agent::tools::ActiveTools::new(), auto.tools.clone(), ).await; diff --git a/src/subconscious/agents/bail-no-competing.sh b/src/subconscious/agents/bail-no-competing.sh index 95b8219..43c3096 100755 --- a/src/subconscious/agents/bail-no-competing.sh +++ b/src/subconscious/agents/bail-no-competing.sh @@ -1,49 +1,21 @@ #!/bin/bash -# Bail if another agent is in the same phase-group as us. +# Bail if other agents are alive in the state dir. +# $1 = this agent's pid file name (e.g. pid-12345) +# cwd = state dir # -# $1 = our pid file name (e.g. "pid-12345") -# $2 = the phase we're about to enter (e.g. "surface", "observe") -# cwd = state dir -# -# Also refreshes our own pid file with the current phase on each call, -# so concurrent agents can read each other's phase by cat'ing the pid -# files in the state dir. -# -# Phase groups: "surface" vs everything else ("post-surface"). We allow -# at most one agent per group to be alive at a time — so surface can run -# at a higher frequency than the slower organize/observe tail. -# -# Exit 0 = continue, exit 1 = bail (another agent in our group is alive). +# Exit 0 = continue, exit 1 = bail shopt -s nullglob my_pid_file="$1" -my_phase="$2" - -# Refresh our own pid file with the current phase. -printf '%s' "$my_phase" > "$my_pid_file" - -group_of() { - if [[ "$1" == "surface" ]]; then - echo "surface" - else - echo "post-surface" - fi -} - -my_group=$(group_of "$my_phase") for f in pid-*; do - [[ "$f" == "$my_pid_file" ]] && continue + [[ $f == $my_pid_file ]] && continue pid="${f#pid-}" - if ! kill -0 "$pid" 2>/dev/null; then - rm -f "$f" # stale pid file, clean up - continue - fi - other_phase=$(cat "$f" 2>/dev/null) - other_group=$(group_of "$other_phase") - if [[ "$my_group" == "$other_group" ]]; then - exit 1 + if kill -0 "$pid" 2>/dev/null; then + exit 1 # competing agent is alive + else + rm -f "$f" # stale pid file, clean up fi done diff --git a/src/subconscious/defs.rs b/src/subconscious/defs.rs index a862c8d..8828043 100644 --- a/src/subconscious/defs.rs +++ b/src/subconscious/defs.rs @@ -396,14 +396,13 @@ fn resolve_conversation(budget: Option) -> String { let cfg = crate::config::get(); let max_bytes = budget.unwrap_or_else(|| cfg.surface_conversation_bytes.unwrap_or(100_000)); - let app = crate::config::app(); let mut fragments: Vec = Vec::new(); let mut total_bytes = 0; let mut oldest_ts = String::new(); for (role, content, ts) in iter { if total_bytes >= max_bytes { break; } - let name = if role == "user" { &app.user_name } else { &app.assistant_name }; + let name = if role == "user" { &cfg.user_name } else { &cfg.assistant_name }; let formatted = if !ts.is_empty() { oldest_ts = ts[..ts.floor_char_boundary(ts.len().min(19))].to_string(); format!("**{}** {}: {}", name, &oldest_ts, content) @@ -624,13 +623,11 @@ pub async fn run_agent( let mut all_keys = keys; let mut resolved_steps = Vec::new(); for step in &def.steps { - let template = { - let app = crate::config::app(); - step.prompt - .replace("{agent_name}", &def.agent) - .replace("{user_name}", &app.user_name) - .replace("{assistant_name}", &app.assistant_name) - }; + let cfg = crate::config::get(); + let template = step.prompt + .replace("{agent_name}", &def.agent) + .replace("{user_name}", &cfg.user_name) + .replace("{assistant_name}", &cfg.assistant_name); let (prompt, extra_keys) = resolve_placeholders(&template, &all_keys, count).await; all_keys.extend(extra_keys); resolved_steps.push(super::prompts::ResolvedStep { diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index 7137211..f9e5ab5 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -16,7 +16,6 @@ use crate::agent::api::ApiClient; use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role}; -use crate::agent::tokenizer; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); @@ -53,18 +52,13 @@ fn is_assistant(node: &AstNode) -> bool { /// /// Includes all sections up to and including conversation entries in /// `range`, with `filter` applied to conversation entries. -/// -/// Returns (token_ids, assistant_ranges) where assistant_ranges are -/// (start, end) token positions for each assistant message. fn build_token_ids( context: &ContextState, range: std::ops::Range, filter: Filter, -) -> (Vec, Vec<(usize, usize)>) { +) -> Vec { use crate::agent::context::Ast; let mut ids = Vec::new(); - let mut assistant_ranges = Vec::new(); - for node in context.system() { ids.extend(node.token_ids()); } @@ -92,16 +86,9 @@ fn build_token_ids( Filter::SkipAllMemories => is_memory(node), }; if skip { continue; } - - // Track assistant message boundaries - let is_asst = is_assistant(node); - let start = ids.len(); ids.extend(node.token_ids()); - if is_asst { - assistant_ranges.push((start, ids.len())); - } } - (ids, assistant_ranges) + ids } // ── Score API ─────────────────────────────────────────────────── @@ -126,19 +113,13 @@ async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, prompt: &[u32], - ranges: &[(usize, usize)], priority: Option, ) -> anyhow::Result> { - // Nothing to score — skip the round-trip. - if ranges.is_empty() { - return Ok(Vec::new()); - } let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let mut body = serde_json::json!({ "model": client.model, "prompt": prompt, - "score_ranges": ranges, "logprobs": 1, }); if let Some(p) = priority { @@ -186,10 +167,8 @@ async fn score_divergence( filter: Filter<'_>, priority: Option, ) -> anyhow::Result<(Vec, Vec)> { - let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None); - let (without_tokens, without_ranges) = build_token_ids(context, range, filter); - let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?; - let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?; + let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?; + let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } @@ -228,21 +207,21 @@ pub async fn score_memories( let http = http_client(); let activity = crate::agent::start_activity(agent, "scoring: baseline").await; - let (baseline_tokens, baseline_ranges) = { + let baseline_tokens = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None) }; - let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?; + let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?; dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); for (mem_idx, key) in memory_keys.iter().enumerate() { activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); - let (tokens, ranges) = { + let tokens = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key)) }; - let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await { + let row = match call_score(&http, client, &tokens, Some(5)).await { Ok(without) => { let divs = divergence(&baseline, &without); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); @@ -473,302 +452,3 @@ pub async fn score_finetune( results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) } - -/// Concatenate the text of a Branch's Leaf children — what the model -/// actually produced on that turn (Content + Thinking + ToolCall name). -fn render_branch_text(children: &[AstNode]) -> String { - children.iter() - .filter_map(|c| match c { - AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()), - _ => None, - }) - .collect::>() - .join("") -} - -/// Render the last `max_msgs` user/assistant branches before `idx` as a -/// review-friendly string with `[user]` / `[assistant]` markers. -fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String { - use crate::agent::context::Role; - let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs); - for i in (0..idx).rev() { - if picked.len() >= max_msgs { break; } - if let AstNode::Branch { role, .. } = &entries[i] { - if matches!(role, Role::User | Role::Assistant) { - picked.push(&entries[i]); - } - } - } - picked.reverse(); - - let mut out = String::new(); - for node in picked { - if let AstNode::Branch { role, children, .. } = node { - let marker = match role { - Role::User => "[user]", - Role::Assistant => "[assistant]", - _ => continue, - }; - out.push_str(marker); - out.push('\n'); - out.push_str(render_branch_text(children).trim()); - out.push_str("\n\n"); - } - } - out.trim_end().to_string() -} - -/// Enriched finetune candidate with context for review. -#[derive(Clone, Debug)] -pub struct FinetuneCandidate { - pub entry_idx: usize, - pub divergence: f64, - pub response_text: String, - /// Last couple of user/assistant messages before this response, - /// already rendered with role markers, for F6 display context. - pub prior_context: String, - /// Token IDs for context (everything before the response). - pub context_ids: Vec, - /// Token IDs for the response (what we're training on). - pub continuation_ids: Vec, - /// What the model would have said without memories (if generated). - pub alternate_text: Option, - /// Timestamp in nanos — used as unique key for trained-set dedup. - pub timestamp_ns: i64, -} - -/// Score and enrich finetune candidates with full context. -/// -/// Candidates are delivered via `on_candidate` one-at-a-time as they become -/// ready: scoring happens once (one /score call), then for each candidate -/// that passes the threshold we optionally generate an alternate response -/// and then emit it. The activity status is updated during the alternate -/// phase so the UI doesn't look stuck. -/// -/// Returns (count_above_threshold, max_divergence). -pub async fn score_finetune_candidates( - context: &ContextState, - count: usize, - client: &ApiClient, - min_divergence: f64, - generate_alternates: bool, - activity: &crate::agent::ActivityGuard, - mut on_candidate: impl FnMut(FinetuneCandidate), -) -> anyhow::Result<(usize, f64)> { - let scores = score_finetune(context, count, client).await?; - - let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max); - - let entries = context.conversation(); - let trained = load_trained(); - let mut candidates: Vec = Vec::new(); - - for (entry_idx, divergence) in scores { - if divergence < min_divergence { - continue; - } - - let node = &entries[entry_idx]; - - // Skip if already trained on. - let timestamp_ns = node_timestamp_ns(node); - if trained.contains(×tamp_ns) { - continue; - } - - // Extract response text — content of the assistant turn. - let response_text = match node { - AstNode::Branch { children, .. } => render_branch_text(children), - _ => continue, - }; - - // Skip turns that produced nothing human-visible (e.g., a - // tool-only turn, or an interrupted generation). They'd show - // up as blank cards and we'd still burn alternate-gen on them. - if response_text.trim().is_empty() { - continue; - } - - // Build the last couple of user/assistant exchanges for review. - let prior_context = render_prior_context(entries, entry_idx, 2); - - // Build token IDs: context = everything before response, continuation = response. - let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None); - let continuation_ids: Vec = node.token_ids().into_iter().collect(); - - candidates.push(FinetuneCandidate { - entry_idx, - divergence, - response_text, - prior_context, - context_ids, - continuation_ids, - alternate_text: None, - timestamp_ns, - }); - } - - let total = candidates.len(); - let gen_alternates = generate_alternates && total > 0; - - for (i, mut candidate) in candidates.into_iter().enumerate() { - if gen_alternates { - activity.update( - format!("finetune: generating alternate {}/{}", i + 1, total) - ).await; - match generate_alternate(context, candidate.entry_idx, client).await { - Ok(text) => candidate.alternate_text = Some(text), - Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e), - } - } - on_candidate(candidate); - } - - Ok((total, max_divergence)) -} - -/// Generate what the model would say without memories for a given entry. -async fn generate_alternate( - context: &ContextState, - entry_idx: usize, - client: &ApiClient, -) -> anyhow::Result { - use crate::agent::api::{SamplingParams, StreamToken}; - - // Build context tokens without memories, up to the response - let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories); - - // Add assistant turn start - prompt.push(tokenizer::IM_START); - prompt.extend(tokenizer::encode("assistant\n")); - - // Generate completion - let sampling = SamplingParams { - temperature: 0.6, - top_p: 0.95, - top_k: 20, - }; - let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5)); - - let mut tokens = Vec::new(); - while let Some(tok) = rx.recv().await { - match tok { - StreamToken::Token(id) => tokens.push(id), - StreamToken::Done { .. } => break, - StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), - } - } - - Ok(tokenizer::decode(&tokens)) -} - -// ── Finetune config and persistence ───────────────────────────── - -use std::path::PathBuf; -use std::collections::HashSet; - -const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json"; - -fn trained_path() -> PathBuf { - dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE) -} - -/// Load set of trained response timestamps (nanos since epoch). -pub fn load_trained() -> HashSet { - let path = trained_path(); - match std::fs::read_to_string(&path) { - Ok(content) => serde_json::from_str(&content).unwrap_or_default(), - Err(_) => HashSet::new(), - } -} - -/// Mark a response as trained by its timestamp. -pub fn mark_trained(timestamp_ns: i64) { - let mut trained = load_trained(); - trained.insert(timestamp_ns); - let path = trained_path(); - if let Some(parent) = path.parent() { - let _ = std::fs::create_dir_all(parent); - } - if let Ok(json) = serde_json::to_string(&trained) { - let _ = std::fs::write(&path, json); - } -} - -/// Get timestamp in nanoseconds from an AstNode. -/// i64-ns representation covers 1677..2262 via chrono; timestamps -/// outside that window would be bugs we'd want to surface, hence panic. -pub fn node_timestamp_ns(node: &AstNode) -> i64 { - let ts = match node { - AstNode::Leaf(leaf) => leaf.timestamp(), - AstNode::Branch { timestamp, .. } => *timestamp, - }; - ts.timestamp_nanos_opt() - .expect("timestamp outside i64-ns representable range (1677..2262)") -} - -// ── Training API ──────────────────────────────────────────────── - -/// Training sample for /train endpoint. -#[derive(serde::Serialize)] -struct TrainingSample { - context_ids: Vec, - continuation_ids: Vec, -} - -/// Data needed to send a training sample. -pub struct TrainData { - pub context_ids: Vec, - pub continuation_ids: Vec, - pub timestamp_ns: i64, -} - -/// Send training samples to the server. -/// -/// Returns job_id on success, marks each sample as trained. -pub async fn send_to_train( - samples: Vec, - client: &ApiClient, -) -> anyhow::Result { - if samples.is_empty() { - anyhow::bail!("no samples to train"); - } - - let api_samples: Vec = samples.iter() - .map(|s| TrainingSample { - context_ids: s.context_ids.clone(), - continuation_ids: s.continuation_ids.clone(), - }) - .collect(); - - let body = serde_json::json!({ - "training_data": { - "samples": api_samples, - } - }); - - let http = http_client(); - let url = format!("{}/train", client.base_url()); - let response = http.send_json("POST", &url, &[], &body).await?; - - let status = response.status(); - let result: serde_json::Value = response.json().await?; - - if !status.is_success() { - let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); - anyhow::bail!("train API HTTP {}: {}", status, msg); - } - - // Mark all samples as trained - for s in &samples { - mark_trained(s.timestamp_ns); - } - - let job_id = result.get("job_id") - .and_then(|j| j.as_str()) - .unwrap_or("unknown") - .to_string(); - - dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id); - Ok(job_id) -} diff --git a/src/thalamus/idle.rs b/src/thalamus/idle.rs index 71baa81..6c78b19 100644 --- a/src/thalamus/idle.rs +++ b/src/thalamus/idle.rs @@ -372,10 +372,6 @@ impl State { } pub fn hours_since_last_dream() -> u64 { - // If a dream is currently in progress, no nudge needed - if home().join(".consciousness/state/dream-state").exists() { - return 0; - } let path = home().join(".consciousness/logs/dream-log.jsonl"); let content = match fs::read_to_string(path) { Ok(c) if !c.is_empty() => c, diff --git a/src/user/chat.rs b/src/user/chat.rs index 47c5d56..a94e039 100644 --- a/src/user/chat.rs +++ b/src/user/chat.rs @@ -112,7 +112,13 @@ pub async fn cmd_switch_model( let _new_client = crate::agent::api::ApiClient::new( &resolved.api_base, &resolved.api_key, &resolved.model_id, ); - agent.state.lock().await.notify(format!("switched to {}", resolved.model_id)); + let prompt_changed = resolved.prompt_file != agent.prompt_file; + if prompt_changed { + agent.compact().await; + agent.state.lock().await.notify(format!("switched to {} (recompacted)", resolved.model_id)); + } else { + agent.state.lock().await.notify(format!("switched to {}", resolved.model_id)); + } } fn notify_help(agent: &std::sync::Arc) { diff --git a/src/user/context.rs b/src/user/context.rs index 17660b5..4cfa78d 100644 --- a/src/user/context.rs +++ b/src/user/context.rs @@ -126,7 +126,14 @@ impl ScreenView for ConsciousScreen { let section_style = Style::default().fg(Color::Yellow); lines.push(Line::styled("── Model ──", section_style)); - lines.push(Line::raw(format!(" Current: {}", app.status.model))); + let model_display = app.context_info.as_ref() + .map_or_else(|| app.status.model.clone(), |i| i.model.clone()); + lines.push(Line::raw(format!(" Current: {}", model_display))); + if let Some(ref info) = app.context_info { + lines.push(Line::raw(format!(" Backend: {}", info.backend))); + lines.push(Line::raw(format!(" Prompt: {}", info.prompt_file))); + lines.push(Line::raw(format!(" Available: {}", info.available_models.join(", ")))); + } lines.push(Line::raw("")); lines.push(Line::styled("── Context State ──", section_style)); @@ -146,6 +153,8 @@ impl ScreenView for ConsciousScreen { lines.push(Line::raw(format!(" {:53} {:>6} tokens", "────────", "──────"))); lines.push(Line::raw(format!(" {:53} {:>6} tokens", "Total", total))); + } else if let Some(ref info) = app.context_info { + lines.push(Line::raw(format!(" Context message: {:>6} chars", info.context_message_chars))); } lines.push(Line::raw("")); diff --git a/src/user/learn.rs b/src/user/learn.rs deleted file mode 100644 index 0bd351f..0000000 --- a/src/user/learn.rs +++ /dev/null @@ -1,341 +0,0 @@ -// learn.rs — F6: fine-tuning review screen -// -// Shows responses identified as training candidates (high divergence -// when memories stripped). Queue for review before sending to /finetune. - -use ratatui::{ - layout::{Constraint, Layout, Rect}, - style::{Color, Modifier, Style}, - text::{Line, Span}, - widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap}, - Frame, -}; -use ratatui::crossterm::event::{Event, KeyCode, KeyEvent}; - -use super::{App, ScreenView, screen_legend}; - -/// A candidate response identified for fine-tuning. -#[derive(Clone, Debug)] -pub struct FinetuneCandidate { - /// Index in conversation entries. - pub entry_idx: usize, - /// Divergence score (higher = more dependent on memories). - pub divergence: f64, - /// The assistant response text. - pub response_text: String, - /// Prior user/assistant messages for review context. - pub prior_context: String, - /// Status: pending, approved, rejected, sent. - pub status: CandidateStatus, - /// Token IDs for context. - pub context_ids: Vec, - /// Token IDs for continuation (what we're training on). - pub continuation_ids: Vec, - /// What the model would have said without memories (if generated). - pub alternate_text: Option, - /// Timestamp in nanos — used as unique key for trained-set dedup. - pub timestamp_ns: i64, -} - -#[derive(Clone, Debug, PartialEq)] -pub enum CandidateStatus { - Pending, - Approved, - Rejected, - Sent, -} - -impl From for FinetuneCandidate { - fn from(c: crate::subconscious::learn::FinetuneCandidate) -> Self { - FinetuneCandidate { - entry_idx: c.entry_idx, - divergence: c.divergence, - response_text: c.response_text, - prior_context: c.prior_context, - status: CandidateStatus::Pending, - context_ids: c.context_ids, - continuation_ids: c.continuation_ids, - alternate_text: c.alternate_text, - timestamp_ns: c.timestamp_ns, - } - } -} - -pub(crate) struct LearnScreen { - list_state: ListState, - mind_tx: tokio::sync::mpsc::UnboundedSender, -} - -impl LearnScreen { - pub fn new( - mind_tx: tokio::sync::mpsc::UnboundedSender, - ) -> Self { - Self { - list_state: ListState::default(), - mind_tx, - } - } - - fn selected_idx(&self) -> Option { - self.list_state.selected() - } -} - -impl ScreenView for LearnScreen { - fn label(&self) -> &'static str { "learn" } - - fn tick(&mut self, frame: &mut Frame, area: Rect, - events: &[Event], app: &mut App) { - - // Handle input first (before borrowing candidates for rendering) - let candidate_count = app.finetune_candidates.len(); - for event in events { - if let Event::Key(KeyEvent { code, .. }) = event { - match code { - KeyCode::Up | KeyCode::Char('k') => { - let i = self.list_state.selected().unwrap_or(0); - self.list_state.select(Some(i.saturating_sub(1))); - } - KeyCode::Down | KeyCode::Char('j') => { - let i = self.list_state.selected().unwrap_or(0); - let max = candidate_count.saturating_sub(1); - self.list_state.select(Some((i + 1).min(max))); - } - KeyCode::Char('a') => { - if let Some(idx) = self.selected_idx() { - app.finetune_action(idx, CandidateStatus::Approved); - } - } - KeyCode::Char('r') => { - if let Some(idx) = self.selected_idx() { - app.finetune_action(idx, CandidateStatus::Rejected); - } - } - KeyCode::Char('g') => { - let current = crate::config::app().learn.generate_alternates; - let _ = self.mind_tx.send( - crate::mind::MindCommand::SetLearnGenerateAlternates(!current)); - } - KeyCode::Char('s') => { - app.finetune_send_approved(); - } - KeyCode::Char('+') | KeyCode::Char('=') => { - // Raise threshold 10× (less sensitive — fewer candidates). - let new = crate::config::app().learn.threshold * 10.0; - let _ = self.mind_tx.send( - crate::mind::MindCommand::SetLearnThreshold(new)); - } - KeyCode::Char('-') => { - // Lower threshold 10× (more sensitive — more candidates). - let new = crate::config::app().learn.threshold / 10.0; - let _ = self.mind_tx.send( - crate::mind::MindCommand::SetLearnThreshold(new)); - } - _ => {} - } - } - } - - // Ensure selection is valid - if candidate_count > 0 { - let sel = self.list_state.selected().unwrap_or(0).min(candidate_count - 1); - self.list_state.select(Some(sel)); - } - - // Now render - let (threshold, gen_on) = { - let app_cfg = crate::config::app(); - (app_cfg.learn.threshold, app_cfg.learn.generate_alternates) - }; - let block = Block::default() - .title_top(Line::from(screen_legend()).left_aligned()) - .title_top(Line::from(" learn ").right_aligned()) - .borders(Borders::ALL) - .border_style(Style::default().fg(Color::Magenta)); - let inner = block.inner(area); - frame.render_widget(block, area); - - // Split inner: top line for settings, rest for content. - let [settings_area, content_area] = Layout::vertical([ - Constraint::Length(1), - Constraint::Min(0), - ]).areas(inner); - - let settings = Line::from(vec![ - Span::raw(" thresh: "), - Span::styled(format!("{:e}", threshold), Style::default().fg(Color::Yellow)), - Span::raw(" gen: "), - Span::styled( - if gen_on { "[on]" } else { "[off]" }, - Style::default().fg(if gen_on { Color::Green } else { Color::DarkGray }), - ), - ]); - frame.render_widget(Paragraph::new(settings), settings_area); - - let candidates = &app.finetune_candidates; - - if candidates.is_empty() { - render_empty(frame, content_area, app); - } else { - // Layout: list on left, detail on right - let [list_area, detail_area] = Layout::horizontal([ - Constraint::Percentage(40), - Constraint::Percentage(60), - ]).areas(content_area); - - // Render candidate list - let items: Vec = candidates.iter().map(|c| { - let status_char = match c.status { - CandidateStatus::Pending => ' ', - CandidateStatus::Approved => '+', - CandidateStatus::Rejected => '-', - CandidateStatus::Sent => '*', - }; - let style = match c.status { - CandidateStatus::Pending => Style::default(), - CandidateStatus::Approved => Style::default().fg(Color::Green), - CandidateStatus::Rejected => Style::default().fg(Color::DarkGray), - CandidateStatus::Sent => Style::default().fg(Color::Cyan), - }; - ListItem::new(Line::from(vec![ - Span::styled(format!("[{}] ", status_char), style), - Span::styled(format!("{:.2} ", c.divergence), Style::default().fg(Color::Yellow)), - Span::raw(truncate(&c.response_text, 30)), - ])) - }).collect(); - - let list = List::new(items) - .block(Block::default().borders(Borders::RIGHT).title(" candidates ")) - .highlight_style(Style::default().add_modifier(Modifier::REVERSED)); - frame.render_stateful_widget(list, list_area, &mut self.list_state); - - // Render detail for selected candidate - if let Some(idx) = self.selected_idx() { - if let Some(candidate) = candidates.get(idx) { - render_detail(frame, candidate, detail_area); - } - } - } - - // Render help at bottom (always, even when empty) - let help = Line::from(vec![ - Span::styled(" j/k/\u{2191}\u{2193}", Style::default().fg(Color::Cyan)), - Span::raw("=nav "), - Span::styled("a", Style::default().fg(Color::Green)), - Span::raw("=approve "), - Span::styled("r", Style::default().fg(Color::Red)), - Span::raw("=reject "), - Span::styled("g", Style::default().fg(Color::Yellow)), - Span::raw("=gen "), - Span::styled("s", Style::default().fg(Color::Magenta)), - Span::raw("=send "), - Span::styled("+/-", Style::default().fg(Color::Cyan)), - Span::raw("=thresh "), - ]); - let help_area = Rect { - y: area.y + area.height - 1, - height: 1, - ..area - }; - frame.render_widget(Paragraph::new(help), help_area); - } -} - -fn render_empty(frame: &mut Frame, inner: Rect, app: &App) { - let mut lines = Vec::new(); - lines.push(Line::from("")); - - match app.mind_state.as_ref().and_then(|ms| ms.finetune_last_run.as_ref()) { - Some(stats) => { - lines.push(Line::from(vec![ - Span::raw(" Last run: "), - Span::styled( - format!("{}", stats.responses_considered), - Style::default().fg(Color::Cyan), - ), - Span::raw(" responses considered, "), - Span::styled( - format!("{}", stats.above_threshold), - Style::default().fg(if stats.above_threshold > 0 { Color::Green } else { Color::DarkGray }), - ), - Span::raw(" above threshold, max divergence: "), - Span::styled( - format!("{:.4}", stats.max_divergence), - Style::default().fg(Color::Yellow), - ), - ])); - if let Some(err) = &stats.error { - lines.push(Line::from(vec![ - Span::raw(" "), - Span::styled( - format!("Error: {}", err), - Style::default().fg(Color::Red), - ), - ])); - } - } - None => { - lines.push(Line::styled( - " No scoring run yet.", - Style::default().fg(Color::DarkGray), - )); - } - } - - lines.push(Line::from("")); - lines.push(Line::styled( - " Scoring runs at startup and after each turn.", - Style::default().fg(Color::DarkGray), - )); - - frame.render_widget(Paragraph::new(lines), inner); -} - -fn render_detail(frame: &mut Frame, c: &FinetuneCandidate, area: Rect) { - let [header_area, content_area] = Layout::vertical([ - Constraint::Length(3), - Constraint::Min(1), - ]).areas(area); - - // Header: divergence, status - let alt_status = if c.alternate_text.is_some() { "yes" } else { "no" }; - let header = Paragraph::new(vec![ - Line::from(vec![ - Span::raw(" divergence: "), - Span::styled(format!("{:.3}", c.divergence), Style::default().fg(Color::Yellow)), - Span::raw(format!(" entry: {} alt: {}", c.entry_idx, alt_status)), - ]), - ]); - frame.render_widget(header, header_area); - - // Content: prior context, the scored response, and alternate - // (if available). - let content_block = Block::default() - .borders(Borders::TOP) - .title(" context & response "); - - let mut text = String::new(); - if !c.prior_context.is_empty() { - text.push_str(&c.prior_context); - text.push_str("\n\n─── response ───\n\n"); - } - text.push_str(&c.response_text); - if let Some(alt) = &c.alternate_text { - text.push_str("\n\n─── without memories ───\n\n"); - text.push_str(alt); - } - - let content = Paragraph::new(text) - .block(content_block) - .wrap(Wrap { trim: false }); - frame.render_widget(content, content_area); -} - -fn truncate(s: &str, max: usize) -> String { - let first_line = s.lines().next().unwrap_or(""); - if first_line.len() > max { - format!("{}...", &first_line[..max]) - } else { - first_line.to_string() - } -} diff --git a/src/user/mod.rs b/src/user/mod.rs index 93da72c..09e485f 100644 --- a/src/user/mod.rs +++ b/src/user/mod.rs @@ -5,12 +5,11 @@ pub(crate) mod chat; mod context; -pub(crate) mod learn; pub(crate) mod scroll_pane; pub mod selectable; mod subconscious; -mod thalamus; mod unconscious; +mod thalamus; mod widgets; use anyhow::Result; @@ -45,6 +44,15 @@ struct StatusInfo { } /// Context loading details for the debug screen. +#[derive(Debug, Clone)] +struct ContextInfo { + model: String, + available_models: Vec, + prompt_file: String, + backend: String, + context_message_chars: usize, +} + /// Build the screen legend from screen labels. fn screen_legend_from(screens: &[Box]) -> String { let parts: Vec = screens.iter().enumerate() @@ -101,6 +109,7 @@ struct App { top_k: u32, agent: std::sync::Arc, should_quit: bool, + context_info: Option, agent_state: Vec, unconscious_state: Vec, mind_state: Option, @@ -112,8 +121,6 @@ struct App { walked_count: usize, channel_status: Vec, idle_info: Option, - /// Fine-tuning candidates pending review. - finetune_candidates: Vec, } impl App { @@ -135,6 +142,7 @@ impl App { top_k: 20, agent, should_quit: false, + context_info: None, agent_state: Vec::new(), unconscious_state: Vec::new(), mind_state: None, @@ -143,52 +151,9 @@ impl App { rebuild_tools_pending: false, walked_count: 0, channel_status: Vec::new(), idle_info: None, - finetune_candidates: Vec::new(), } } - fn finetune_action(&mut self, idx: usize, status: learn::CandidateStatus) { - if let Some(candidate) = self.finetune_candidates.get_mut(idx) { - candidate.status = status; - } - } - - fn finetune_send_approved(&mut self) { - // Collect approved candidates - let samples: Vec = self.finetune_candidates.iter() - .filter(|c| c.status == learn::CandidateStatus::Approved) - .map(|c| crate::subconscious::learn::TrainData { - context_ids: c.context_ids.clone(), - continuation_ids: c.continuation_ids.clone(), - timestamp_ns: c.timestamp_ns, - }) - .collect(); - - if samples.is_empty() { - return; - } - - // Mark as sent in UI immediately - for candidate in &mut self.finetune_candidates { - if candidate.status == learn::CandidateStatus::Approved { - candidate.status = learn::CandidateStatus::Sent; - } - } - - // Spawn async task to send to training server - let client = self.agent.client.clone(); - tokio::spawn(async move { - match crate::subconscious::learn::send_to_train(samples, &client).await { - Ok(job_id) => { - dbglog!("[finetune] training started: {}", job_id); - } - Err(e) => { - dbglog!("[finetune] send failed: {:#}", e); - } - } - }); - } - fn set_channel_status(&mut self, channels: Vec<(String, bool, u32)>) { self.channel_status = channels.into_iter() @@ -228,9 +193,6 @@ fn restore_terminal(terminal: &mut ratatui::Terminal Result<()> { let (config, _figment) = crate::config::load_session(&cli).await?; - // Pick up external edits (vim, F6 hotkeys, etc.) without restart. - crate::config::watch_config(cli.clone()); - if config.app.debug { unsafe { std::env::set_var("POC_DEBUG", "1") }; } @@ -372,7 +334,7 @@ async fn run( } let notify_rx = crate::thalamus::channels::subscribe_all(); - // F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus, F6=learn + // F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus let mut screens: Vec> = vec![ Box::new(crate::user::chat::InteractScreen::new( mind.agent.clone(), mind.shared.clone(), mind_tx.clone(), @@ -381,7 +343,6 @@ async fn run( Box::new(crate::user::subconscious::SubconsciousScreen::new()), Box::new(crate::user::unconscious::UnconsciousScreen::new()), Box::new(crate::user::thalamus::ThalamusScreen::new()), - Box::new(crate::user::learn::LearnScreen::new(mind_tx.clone())), ]; let mut active_screen: usize = 1; // F-key number tui::set_screen_legend(tui::screen_legend_from(&*screens)); @@ -458,8 +419,7 @@ async fn run( idle_state.decay_ewma(); app.update_idle(&idle_state); app.agent_state = mind.subconscious_snapshots().await; - { - let mut unc = mind.unconscious.lock().await; + if let Ok(mut unc) = mind.unconscious.try_lock() { let toggles: Vec = app.agent_toggles.drain(..).collect(); for name in &toggles { if mind.subconscious.lock().await.toggle(name).is_none() { @@ -473,38 +433,7 @@ async fn run( }; app.unconscious_state = unc.snapshots(store_guard.as_deref()); app.graph_health = unc.graph_health.clone(); - } - - // Sync mind state (finetune candidates, last scoring run, etc.) - { - let ms = mind.shared.lock().unwrap(); - // Sync finetune candidates: add new ones, keep existing (preserves approval status), - // remove sent candidates, keep only 10 most recent rejected. - app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent); - for c in &ms.finetune_candidates { - let exists = app.finetune_candidates.iter() - .any(|existing| existing.timestamp_ns == c.timestamp_ns); - if !exists { - app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone())); - } - } - let mut rejected: Vec<_> = app.finetune_candidates.iter() - .enumerate() - .filter(|(_, c)| c.status == learn::CandidateStatus::Rejected) - .map(|(i, c)| (i, c.timestamp_ns)) - .collect(); - if rejected.len() > 10 { - rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts)); - let to_remove: std::collections::HashSet<_> = rejected[10..] - .iter().map(|(i, _)| *i).collect(); - let mut idx = 0; - app.finetune_candidates.retain(|_| { - let keep = !to_remove.contains(&idx); - idx += 1; - keep - }); - } - app.mind_state = Some(ms.clone()); + app.mind_state = Some(mind.shared.lock().unwrap().clone()); } app.walked_count = mind.subconscious_walked().await.len(); if !startup_done { @@ -601,11 +530,16 @@ async fn run( // --- CLI --- use clap::{Parser, Subcommand}; +use std::path::PathBuf; -#[derive(Parser, Debug, Default, Clone)] +#[derive(Parser, Debug, Default)] #[command(name = "consciousness", about = "Substrate-independent AI agent")] pub struct CliArgs { - /// Model override (selects a named entry from `models` in config.json5) + /// Select active backend ("anthropic" or "openrouter") + #[arg(long)] + pub backend: Option, + + /// Model override #[arg(short, long)] pub model: Option, @@ -625,6 +559,10 @@ pub struct CliArgs { #[arg(long)] pub show_config: bool, + /// Project memory directory + #[arg(long)] + pub memory_project: Option, + /// Max consecutive DMN turns #[arg(long)] pub dmn_max_turns: Option, @@ -637,7 +575,7 @@ pub struct CliArgs { pub command: Option, } -#[derive(Subcommand, Debug, Clone)] +#[derive(Subcommand, Debug)] pub enum SubCmd { /// Print new output since last read and exit Read { diff --git a/training/DESIGN.md b/training/DESIGN.md index 2df4e6d..f966fa4 100644 --- a/training/DESIGN.md +++ b/training/DESIGN.md @@ -3,7 +3,7 @@ ## Overview Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference. -Full-weight updates (not LoRA) using Apollo optimizer with rank-64 +Full-weight updates (not LoRA) using Apollo optimizer with rank-256 gradient projection. No pause required — HOGWILD concurrent training. Weights shared via CUDA IPC between vLLM and the training process. @@ -22,41 +22,25 @@ The training signal comes from two sources: │ │ │ ┌──────────────────────────────────────────────┐ │ │ │ Model Weights (54GB, bf16) │ │ -│ │ Shared: vLLM inference + HF training │ │ +│ │ Shared via CUDA IPC │ │ │ └──────────────┬──────────────┬────────────────┘ │ │ │ │ │ │ ┌──────────────▼──┐ ┌───────▼────────────────┐ │ -│ │ vLLM (inference)│ │ Training subprocess │ │ -│ │ KV cache ~60GB │ │ HF model wrapper │ │ -│ │ /completions │ │ Apollo optimizer ~2.5GB │ │ -│ │ /score │ │ Checkpoint sync │ │ -│ └────────┬────────┘ └───────────▲─────────────┘ │ -│ │ │ │ -│ │ ZMQ IPC │ │ -│ └───────────────────────┘ │ +│ │ vLLM (inference)│ │ Apollo (training) │ │ +│ │ KV cache ~60GB │ │ Gradients ~54GB │ │ +│ │ Serves requests │ │ Optimizer state ~10GB │ │ +│ │ Never paused │ │ Activations ~10GB │ │ +│ └─────────────────┘ └────────────────────────┘ │ └─────────────────────────────────────────────────────┘ -Process Architecture: -┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ -│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │ -│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │ -│ │ │ │ │ │ -│ export_hook.py │ │ /completions │ │ HF model views │ -│ exports IPC │ │ /score │ │ Apollo optimizer│ -│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │ -└─────────────────┘ └─────────────────┘ └─────────────────┘ - │ │ - └──── IPC handles file ──────────────────┘ - /tmp/vllm_weight_handles.pt - -Moria B200 (vLLM) +Moria B200 ┌──────────────────┐ ┌──────────────────┐ -│ Training signal │ HTTP │ /completions │ -│ agent │──────────>│ /score │ -│ │ │ /train │ -│ Dream loop │ │ /checkpoint │ -│ (generates │ │ /train/status │ -│ scenarios) │ │ │ +│ Training signal │ HTTP │ Apollo worker │ +│ agent │──────────>│ daemon │ +│ │ │ │ +│ Dream loop │ │ Checkpoint sync │ +│ (generates │ │ (mmap + diff, │ +│ scenarios) │ │ every 10 min) │ └──────────────────┘ └──────────────────┘ ``` @@ -75,9 +59,10 @@ LoRA trains adapter matrices, not base weights. For personality and behavioral changes that persist as disposition, the base weights need to change. Apollo makes this memory-feasible. -### Rank 64 -Not Mini (rank-1). Rank-64 captures gradient structure across diverse -training examples while keeping memory low (~2.5GB on 27B model). +### Rank 256 +Not Mini (rank-1). With 100+ diverse training examples, the +gradient's effective dimensionality can reach hundreds. Rank-256 +captures the structure. Memory cost: ~10GB (negligible on B200). Compute cost: <0.25% of forward+backward. ### Channel-wise scaling @@ -105,7 +90,7 @@ from a per-parameter seed each step. ### Parameter grouping (Qwen3.5 gotcha) conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector needs 2D matrices with min dimension >= rank. Small/3D tensors -use standard Adam. Large 2D matrices use Apollo. +use standard Adam. Large 2D matrices use Apollo with rank-256. ## Training Data Pipeline @@ -215,42 +200,16 @@ against live GPU weights block by block, memcpy only changed regions. For small behavioral updates, turns a 54GB write into a few hundred MB. -- Scheduled 10 minutes after training (batched) +- Every 10 minutes via cron on B200 - Daily rsync to moria for long-term storage -- Tool: `apollo-checkpoint sync --model-dir ` - -## State Files - -### B200 (training server) - -| File | Purpose | -|------|---------| -| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). | -| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. | -| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. | -| `/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. | - -### Moria (client) - -| File | Purpose | -|------|---------| -| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. | -| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. | - -### In-memory (training_worker subprocess) - -| State | Location | Notes | -|-------|----------|-------| -| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. | -| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. | -| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. | +- Tool: `apollo-checkpoint sync --model-dir ` (Rust) ## Hyperparameters | Parameter | Value | Rationale | |-----------|-------|-----------| | Learning rate | 1e-5 to 1e-4 | Standard for full fine-tuning. Higher for diverse batches. | -| Rank | 64 | Captures gradient structure. ~2.5GB state. Defined in `train_router.DEFAULT_RANK`. | +| Rank | 256 | Captures gradient structure across 100+ examples. ~10GB state. | | Scale type | channel | Per-channel precision, matches LLaMA-Factory defaults. | | Epochs | 1 | One pass over diverse data. Multiple epochs risk overfitting. | | Batch size | 1 | Single examples, immediate updates. | @@ -261,32 +220,34 @@ a few hundred MB. ## Components ### Built ✓ -- `optimizer.py` — Apollo optimizer (configurable rank) -- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ -- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync) +- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256) +- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking) - `weight_mapping.py` — vLLM merged → HF separate views (validated) -- `export_hook.py` — vLLM plugin hook for IPC handle export -- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python) +- `training_example.py` — tokenization with chat template +- `vllm_export_hook.py` — source patch for IPC handle export +- `checkpoint/` — Rust tool for mmap + diff checkpoint sync ### To build -- **Dream loop → training bridge**: connect dream output to /train +- **Dream loop → training bridge**: connect dream output to Apollo - **Training-signal agent**: flags moments in conversation logs - **Instruction stripping**: remove scaffolding from training examples - **Quality monitoring**: track model capability over time +- **HF model forward pass integration**: wire into apollo_worker ## Files ``` training/ - DESIGN.md — this document - pyproject.toml — package config, vLLM plugin entry point - apollo_plugin/ - __init__.py — plugin registration - export_hook.py — patches vLLM worker to export IPC handles - train_router.py — /train endpoint, forwards to worker via ZMQ - training_worker.py — training subprocess (HF model, Apollo, checkpoint) - optimizer.py — Apollo optimizer - weight_mapping.py — vLLM ↔ HF weight views - checkpoint_sync.py — mmap + diff sync to safetensors - steering.py — steering vector extraction (experimental) + DESIGN.md — this document + apollo_mini.py — Apollo optimizer + apollo_worker.py — HTTP training daemon + weight_mapping.py — vLLM ↔ HF weight views + training_example.py — tokenization helpers + export_weights.py — standalone weight export (unused) + vllm_export_hook.py — vLLM source patch for IPC export + start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch) + train.py — standalone training script (alternative) + checkpoint/ + Cargo.toml — Rust checkpoint tool + src/main.rs — mmap + diff sync ``` diff --git a/training/apollo_plugin/optimizer.py b/training/apollo_mini.py similarity index 97% rename from training/apollo_plugin/optimizer.py rename to training/apollo_mini.py index 9abce94..166ae3a 100644 --- a/training/apollo_plugin/optimizer.py +++ b/training/apollo_mini.py @@ -8,9 +8,9 @@ Channel-wise or tensor-wise scaling is sufficient. Apollo approximates these scaling factors using a low-rank auxiliary optimizer state based on pure random projection. -Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead -vs forward+backward. Sufficient for behavioral training with diverse -examples. +Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25% +compute overhead vs forward+backward. Captures gradient structure +across 100+ behavioral training examples per batch. Key implementation details from the paper: - Gradient scale factor α = √(n/r) compensates for projection ratio @@ -34,7 +34,7 @@ class Apollo(Optimizer): Args: params: model parameters lr: learning rate (default: 1e-4) - rank: projection rank (default: 64) + rank: projection rank (default: 256) betas: Adam momentum coefficients (default: (0.9, 0.999)) eps: numerical stability term (default: 1e-8) weight_decay: decoupled weight decay (default: 0.01) @@ -46,7 +46,7 @@ class Apollo(Optimizer): Set to None to disable. """ - def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999), + def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, warmup_steps=0, scale=None, proj_refresh=200, norm_growth_limit=1.01): defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps, diff --git a/training/apollo_plugin/__init__.py b/training/apollo_plugin/__init__.py deleted file mode 100644 index b2e121e..0000000 --- a/training/apollo_plugin/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Apollo training plugin for vLLM. - -Enables continuous fine-tuning alongside live inference by: -1. Exporting CUDA IPC handles for weight sharing (export_hook) -2. Adding /train endpoint to vLLM's HTTP server (train_router) -3. Block-level checkpoint sync to safetensors files - -Install: pip install -e /path/to/training -Then vLLM auto-loads via entry point. -""" - -from .export_hook import _patch_model_runner -from .train_router import _patch_api_server - - -def register(): - """Called by vLLM's plugin loader on startup.""" - _patch_model_runner() - _patch_api_server() diff --git a/training/apollo_plugin/checkpoint_sync.py b/training/apollo_plugin/checkpoint_sync.py deleted file mode 100644 index c2d7b2f..0000000 --- a/training/apollo_plugin/checkpoint_sync.py +++ /dev/null @@ -1,503 +0,0 @@ -"""Sync live GPU weights to safetensors files on disk. - -Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's -merged layout to HuggingFace's separate layout, diffs block-by-block -against on-disk safetensors files, and writes only changed blocks. - -For small behavioral training steps, this turns a 54GB checkpoint -write into a few hundred MB of actual disk I/O. - -Usage: - # Sync live weights to disk - python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B - - # Debug name mapping issues - python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B - - # From Python: - from checkpoint_sync import checkpoint_sync - result = checkpoint_sync("/path/to/model") -""" - -import json -import mmap -import struct -import sys -from pathlib import Path -from typing import Dict, List, Tuple, Any -import logging - -import torch - -logger = logging.getLogger(__name__) - -DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size -DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt" - - -# --------------------------------------------------------------------------- -# vLLM → HuggingFace weight name/shape conversion -# --------------------------------------------------------------------------- -# Qwen3.5-27B dimensions (could be read from config.json for generality) - -HIDDEN = 5120 -NUM_K_HEADS = 16 -NUM_V_HEADS = 48 -HEAD_K_DIM = 128 -HEAD_V_DIM = 128 -KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 -VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144 -INTERMEDIATE = 17408 - -# Full attention (some layers use standard attention, not GDN) -NUM_ATTN_HEADS = 24 -NUM_ATTN_KV_HEADS = 4 -ATTN_HEAD_DIM = 256 -ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512 -ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288 -ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 -ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 - - -def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: - """Convert vLLM merged weights to HF-compatible separate tensors. - - vLLM merges certain projections for efficiency: - - qkv_proj (full attn) → q_proj, k_proj, v_proj - - in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z - - in_proj_ba (GDN) → in_proj_b, in_proj_a - - gate_up_proj (MLP) → gate_proj, up_proj - - Returns views that share GPU memory with the original tensors. - """ - hf_params = {} - - for name, tensor in vllm_params.items(): - # Strip vLLM's 'language_model.' prefix to match HF naming - hf_name = name.removeprefix('language_model.') - - if 'in_proj_qkvz' in name: - # GDN layer: [key*2 + value*2, hidden] → qkv + z - prefix = hf_name.replace('in_proj_qkvz.weight', '') - split_at = KEY_DIM * 2 + VALUE_DIM - hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at] - hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:] - - elif 'in_proj_ba' in name: - # GDN layer: [num_v_heads*2, hidden] → b + a - prefix = hf_name.replace('in_proj_ba.weight', '') - hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS] - hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:] - - elif 'qkv_proj' in name: - # Full attention: [q + k + v, hidden] → separate - prefix = hf_name.replace('qkv_proj.weight', '') - hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM] - hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM] - hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:] - - elif 'gate_up_proj' in name: - # MLP: [intermediate*2, hidden] → gate + up - prefix = hf_name.replace('gate_up_proj.weight', '') - hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE] - hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:] - - else: - # Pass through unchanged - hf_params[hf_name] = tensor - - return hf_params - - -# --------------------------------------------------------------------------- -# Safetensors file handling -# --------------------------------------------------------------------------- - -def read_safetensors_index(model_dir: Path) -> Dict[str, str]: - """Map tensor names to safetensors filenames. - - For sharded models, reads model.safetensors.index.json. - For single-file models, returns empty dict (default to model.safetensors). - """ - index_path = model_dir / "model.safetensors.index.json" - if not index_path.exists(): - return {} - - with open(index_path) as f: - index = json.load(f) - - return dict(index.get("weight_map", {})) - - -def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]: - """Parse safetensors file header. - - Returns (data_start_offset, header_dict). - Header dict maps tensor names to metadata including 'data_offsets'. - """ - header_size = struct.unpack(' Tuple[int, int]: - """Sync a single tensor to mmap'd file using block-level diffing. - - Returns (bytes_compared, bytes_changed). - """ - start = data_start + offsets[0] - end = data_start + offsets[1] - disk_len = end - start - - # Transfer tensor to CPU and get raw bytes - # Use .detach() to avoid autograd overhead, .contiguous() for memory layout - try: - live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes() - except Exception as e: - logger.warning(f"Failed to transfer {name} to CPU: {e}") - return 0, 0 - - if len(live_bytes) != disk_len: - logger.warning( - f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} " - f"(shape={list(tensor.shape)}, dtype={tensor.dtype})" - ) - return 0, 0 - - # Block-level diff: compare and write only changed blocks - compared = 0 - changed = 0 - offset = 0 - - while offset < disk_len: - block_end = min(offset + block_size, disk_len) - block_len = block_end - offset - - disk_block = mm[start + offset:start + block_end] - live_block = live_bytes[offset:block_end] - - compared += block_len - - if disk_block != live_block: - mm[start + offset:start + block_end] = live_block - changed += block_len - - offset = block_end - - return compared, changed - - -def sync_file( - file_path: Path, - tensors: Dict[str, torch.Tensor], - block_size: int, -) -> Tuple[int, int, int, int]: - """Sync tensors to a single safetensors file. - - Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing). - """ - with open(file_path, 'r+b') as f: - mm = mmap.mmap(f.fileno(), 0) - - try: - data_start, header = parse_safetensors_header(memoryview(mm)) - - total_compared = 0 - total_changed = 0 - found = 0 - missing = 0 - - for name, tensor in tensors.items(): - if name == "__metadata__": - continue - - if name not in header: - missing += 1 - continue - - found += 1 - meta = header[name] - offsets = meta['data_offsets'] - - compared, changed = sync_tensor_to_mmap( - mm, name, tensor, data_start, offsets, block_size - ) - total_compared += compared - total_changed += changed - - # Flush changes to disk - if total_changed > 0: - mm.flush() - - return total_compared, total_changed, found, missing - - finally: - mm.close() - - -# --------------------------------------------------------------------------- -# Main entry point -# --------------------------------------------------------------------------- - -def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]: - """Load vLLM weight tensors from CUDA IPC handles. - - The handles file is written by vllm_export_hook.py on vLLM startup. - Each handle can be used to reconstruct a tensor pointing to vLLM's - GPU memory — no copy, direct access. - """ - handles = torch.load(handles_path, weights_only=False) - - # Skip metadata entry - handles.pop('__metadata__', None) - - weights = {} - for name, info in handles.items(): - func, args = info['handle'] - try: - weights[name] = func(*args) - except Exception as e: - logger.warning(f"Failed to reconstruct {name}: {e}") - - return weights - - -def checkpoint_sync( - model_dir: str, - handles_path: str = DEFAULT_HANDLES_PATH, - block_size: int = DEFAULT_BLOCK_SIZE, -) -> Dict[str, Any]: - """Sync live GPU weights to model safetensors files. - - This is the main entry point. Call this after training steps - or periodically to checkpoint weights without full serialization. - - Args: - model_dir: Directory containing safetensors files - handles_path: Path to vLLM weight IPC handles file - block_size: Block size for diffing (default 4KB) - - Returns: - Dict with sync statistics: - - total_compared: bytes compared - - total_changed: bytes actually written - - files_changed: list of modified filenames - - tensors_synced: number of tensors processed - - tensors_missing: tensors not found in safetensors - """ - model_dir = Path(model_dir) - - if not Path(handles_path).exists(): - raise FileNotFoundError( - f"Weight handles not found: {handles_path}. " - "Is vLLM running with the export hook?" - ) - - # Step 1: Load live weights from GPU via IPC - logger.info("Loading live weights from GPU...") - vllm_weights = load_vllm_weights(handles_path) - logger.info(f" Loaded {len(vllm_weights)} vLLM tensors") - - # Step 2: Convert to HF naming/layout - hf_weights = vllm_to_hf_tensors(vllm_weights) - logger.info(f" Converted to {len(hf_weights)} HF tensors") - - # Step 3: Map tensors to safetensors files - weight_map = read_safetensors_index(model_dir) - - by_file: Dict[str, Dict[str, torch.Tensor]] = {} - unmapped = [] - - for name, tensor in hf_weights.items(): - filename = weight_map.get(name) - if filename is None: - # Single-file model or missing from index - if (model_dir / "model.safetensors").exists(): - filename = "model.safetensors" - else: - unmapped.append(name) - continue - by_file.setdefault(filename, {})[name] = tensor - - if unmapped: - logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...") - - # Step 4: Sync each file - total_compared = 0 - total_changed = 0 - total_found = 0 - total_missing = 0 - files_changed = [] - - for filename in sorted(by_file.keys()): - tensors = by_file[filename] - file_path = model_dir / filename - - if not file_path.exists(): - logger.warning(f" File not found: {filename}") - total_missing += len(tensors) - continue - - compared, changed, found, missing = sync_file(file_path, tensors, block_size) - - total_compared += compared - total_changed += changed - total_found += found - total_missing += missing - - if changed > 0: - files_changed.append(filename) - logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)") - - # Summary - if total_changed == 0: - logger.info("No changes - model files are up to date") - else: - pct = (total_changed / total_compared * 100) if total_compared > 0 else 0 - logger.info( - f"Synced: {total_changed / 1e6:.2f} MB changed / " - f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)" - ) - - if total_missing > 0: - logger.warning(f" {total_missing} tensors not found in safetensors files") - - return { - "total_compared": total_compared, - "total_changed": total_changed, - "files_changed": files_changed, - "tensors_synced": total_found, - "tensors_missing": total_missing, - } - - -# --------------------------------------------------------------------------- -# Diagnostics -# --------------------------------------------------------------------------- - -def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH): - """Print diagnostic info about weight name mappings. - - Useful for debugging mismatches between vLLM and safetensors names. - """ - model_dir = Path(model_dir) - - # Load and convert vLLM weights - vllm_weights = load_vllm_weights(handles_path) - hf_weights = vllm_to_hf_tensors(vllm_weights) - hf_names = set(hf_weights.keys()) - - # Read safetensors index - weight_map = read_safetensors_index(model_dir) - disk_names = set(weight_map.keys()) - - # If single-file model, parse that file's header - if not disk_names: - st_path = model_dir / "model.safetensors" - if st_path.exists(): - with open(st_path, 'rb') as f: - mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - _, header = parse_safetensors_header(memoryview(mm)) - disk_names = {k for k in header.keys() if k != "__metadata__"} - mm.close() - - print(f"vLLM tensors (raw): {len(vllm_weights)}") - print(f"HF tensors (converted): {len(hf_names)}") - print(f"Disk tensors: {len(disk_names)}") - print() - - in_both = hf_names & disk_names - only_hf = hf_names - disk_names - only_disk = disk_names - hf_names - - print(f"Matched: {len(in_both)}") - print(f"Only in HF (won't sync): {len(only_hf)}") - print(f"Only on disk (not updated): {len(only_disk)}") - - if only_hf: - print(f"\nSample HF-only: {sorted(only_hf)[:5]}") - if only_disk: - print(f"\nSample disk-only: {sorted(only_disk)[:5]}") - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - -def main(): - import argparse - - parser = argparse.ArgumentParser( - description="Sync live GPU weights to safetensors files" - ) - subparsers = parser.add_subparsers(dest="command", help="Command") - - # sync command - sync_parser = subparsers.add_parser("sync", help="Sync weights to disk") - sync_parser.add_argument( - "--model-dir", required=True, - help="Directory containing safetensors files" - ) - sync_parser.add_argument( - "--handles", default=DEFAULT_HANDLES_PATH, - help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" - ) - sync_parser.add_argument( - "--block-size", type=int, default=DEFAULT_BLOCK_SIZE, - help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})" - ) - sync_parser.add_argument( - "-v", "--verbose", action="store_true", - help="Verbose output" - ) - - # diagnose command - diag_parser = subparsers.add_parser("diagnose", help="Check name mappings") - diag_parser.add_argument( - "--model-dir", required=True, - help="Directory containing safetensors files" - ) - diag_parser.add_argument( - "--handles", default=DEFAULT_HANDLES_PATH, - help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" - ) - - args = parser.parse_args() - - if args.command is None: - parser.print_help() - sys.exit(1) - - logging.basicConfig( - level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO, - format='%(message)s' - ) - - try: - if args.command == "sync": - result = checkpoint_sync(args.model_dir, args.handles, args.block_size) - print(json.dumps(result, indent=2)) - elif args.command == "diagnose": - diagnose(args.model_dir, args.handles) - except FileNotFoundError as e: - logger.error(str(e)) - sys.exit(1) - except Exception as e: - logger.exception(f"Failed: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/training/apollo_plugin/train_router.py b/training/apollo_plugin/train_router.py deleted file mode 100644 index d6f90b4..0000000 --- a/training/apollo_plugin/train_router.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Training endpoint for vLLM - forwards to training subprocess via ZMQ. - -Patches vLLM's build_app() to add /train route. The actual training runs -in a dedicated subprocess (training_worker.py) to avoid blocking the -event loop and to keep training work isolated from vLLM internals. -""" - -import asyncio -import logging -import os -import subprocess -import sys -from datetime import datetime -from pathlib import Path -from typing import Any - -import zmq -import zmq.asyncio - -from fastapi import APIRouter, FastAPI -from fastapi.responses import JSONResponse -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - -router = APIRouter() - -DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" - -# Global state for subprocess management -_worker_process: subprocess.Popen | None = None -_zmq_context: zmq.asyncio.Context | None = None -_zmq_socket: zmq.asyncio.Socket | None = None -_initialized: bool = False - - -class TrainRequest(BaseModel): - training_data: dict[str, Any] # {"samples": [...], "config": {...}} - - -class TrainResponse(BaseModel): - job_id: str - status: str - training_samples: int - loss_history: list[float] - - -def _start_worker_subprocess(): - """Start the training worker subprocess.""" - global _worker_process - - if _worker_process is not None and _worker_process.poll() is None: - return # Still running - - # Start worker as subprocess using script path - worker_script = Path(__file__).parent / 'training_worker.py' - _worker_process = subprocess.Popen( - [sys.executable, str(worker_script)], - env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR}, - ) - logger.info(f"Started training worker subprocess (pid={_worker_process.pid})") - - # Give it a moment to bind the socket - import time - time.sleep(0.5) - - -def _ensure_initialized(): - """Ensure subprocess is running and ZMQ socket is connected.""" - global _zmq_context, _zmq_socket, _initialized - - if _initialized: - return - - # Start worker if needed - _start_worker_subprocess() - - # Create async ZMQ context and socket - _zmq_context = zmq.asyncio.Context() - _zmq_socket = _zmq_context.socket(zmq.REQ) - _zmq_socket.connect(DEFAULT_ZMQ_ADDR) - - # Set timeout for recv - _zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training - - _initialized = True - logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}") - - -async def _send_request(request: dict[str, Any]) -> dict[str, Any]: - """Send request to worker and wait for response.""" - _ensure_initialized() - - # ZMQ async send/recv - await _zmq_socket.send_json(request) - response = await _zmq_socket.recv_json() - return response - - -@router.post("/train") -async def handle_train(request: TrainRequest): - """Handle training request - forwards to training subprocess.""" - try: - _ensure_initialized() - except Exception as e: - return JSONResponse( - content={"error": f"Training not available: {e}"}, - status_code=503, - ) - - try: - training_data = request.training_data - samples = training_data.get("samples", []) - config = training_data.get("config", {}) - - if not samples: - return JSONResponse( - content={"error": "No training samples provided"}, - status_code=400, - ) - - job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - logger.info(f"Starting training job {job_id} with {len(samples)} samples") - - # Forward to worker - response = await _send_request({ - 'type': 'train', - 'samples': samples, - 'config': config, - }) - - if 'error' in response: - return JSONResponse( - content={"error": response['error']}, - status_code=500, - ) - - logger.info( - f"Training job {job_id} completed, " - f"final loss: {response['loss_history'][-1]:.4f}" - ) - - return JSONResponse(content={ - "job_id": job_id, - "status": response['status'], - "training_samples": response['training_samples'], - "loss_history": response['loss_history'], - }) - - except zmq.Again: - logger.error("Training request timed out") - return JSONResponse( - content={"error": "Training request timed out"}, - status_code=504, - ) - except Exception as e: - logger.exception(f"Training failed: {e}") - return JSONResponse( - content={"error": str(e)}, - status_code=500, - ) - - -@router.post("/checkpoint") -async def handle_checkpoint(): - """Trigger checkpoint sync to disk.""" - try: - _ensure_initialized() - except Exception as e: - return JSONResponse( - content={"error": f"Training not available: {e}"}, - status_code=503, - ) - - try: - response = await _send_request({'type': 'checkpoint'}) - - if 'error' in response: - return JSONResponse( - content={"error": response['error']}, - status_code=500, - ) - - return JSONResponse(content=response) - - except Exception as e: - logger.exception(f"Checkpoint failed: {e}") - return JSONResponse( - content={"error": str(e)}, - status_code=500, - ) - - -@router.get("/train/status") -async def handle_status(): - """Get training worker status.""" - try: - _ensure_initialized() - except Exception as e: - return JSONResponse( - content={ - "status": "unavailable", - "error": str(e), - }, - status_code=503, - ) - - try: - response = await _send_request({'type': 'status'}) - return JSONResponse(content=response) - - except Exception as e: - return JSONResponse( - content={ - "status": "error", - "error": str(e), - }, - status_code=500, - ) - - -def attach_router(app: FastAPI): - """Attach training router to FastAPI app.""" - app.include_router(router) - logger.info("Training router attached") - - -def _patch_api_server(): - """Patch vLLM's build_app to include our training router.""" - from vllm.entrypoints.openai import api_server - - original_build_app = api_server.build_app - - def patched_build_app(*args, **kwargs): - app = original_build_app(*args, **kwargs) - attach_router(app) - return app - - api_server.build_app = patched_build_app - logger.info("API server patched for /train endpoint") diff --git a/training/apollo_plugin/training_worker.py b/training/apollo_plugin/training_worker.py deleted file mode 100644 index f8b8c23..0000000 --- a/training/apollo_plugin/training_worker.py +++ /dev/null @@ -1,323 +0,0 @@ -"""Training subprocess - handles Apollo training and checkpoint sync. - -Long-lived process that: -1. Loads IPC handles from vLLM's exported weights -2. Creates HF model with views into vLLM's GPU memory -3. Handles training requests via ZMQ -4. Handles checkpoint sync requests -5. Persists Apollo optimizer state between calls - -Communicates with the API server's /train endpoint via ZMQ REP socket. -""" - -import logging -import os -import signal -import sys -from pathlib import Path -from typing import Any - -# Handle running as script vs module -if __name__ == '__main__' and __package__ is None: - # Running as script - add parent to path for imports - sys.path.insert(0, str(Path(__file__).parent.parent)) - __package__ = 'apollo_plugin' - -import torch -import torch.nn as nn -import zmq - -from .checkpoint_sync import checkpoint_sync -from .optimizer import Apollo -from .weight_mapping import load_hf_model_with_vllm_weights - -logger = logging.getLogger(__name__) - -DEFAULT_RANK = 64 -DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock" -HANDLE_PATH = "/tmp/vllm_weight_handles.pt" -OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt" - - -class TrainingWorker: - """Long-lived training worker process.""" - - def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR): - self.zmq_addr = zmq_addr - self.model: nn.Module | None = None - self.optimizer: Apollo | None = None - self.model_path: str | None = None - self._running = True - - def _create_model_wrapper(self) -> nn.Module: - """Create HF model wrapper with views into vLLM's GPU memory.""" - if not os.path.exists(HANDLE_PATH): - raise FileNotFoundError( - f"Weight handles not found: {HANDLE_PATH}. " - "Is vLLM running with the export hook?" - ) - - handles = torch.load(HANDLE_PATH, weights_only=False) - - # Extract metadata - metadata = handles.pop('__metadata__', {}) - self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH') - if not self.model_path: - raise ValueError( - "Model path not found in handles metadata or APOLLO_MODEL_PATH env var" - ) - - # Reconstruct tensors from IPC handles - vllm_params = {} - for name, info in handles.items(): - func, args = info['handle'] - vllm_params[name] = func(*args) - - model = load_hf_model_with_vllm_weights(vllm_params, self.model_path) - model.train() - return model - - def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo: - """Get existing optimizer or create new one.""" - if self.optimizer is not None: - return self.optimizer - - # Build parameter groups (Apollo for 2D+, standard Adam for small/1D) - apollo_params, standard_params = [], [] - for p in self.model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - if not groups: - raise ValueError("No trainable parameters found") - - self.optimizer = Apollo( - groups, - lr=config.get('lr', 1e-5), - rank=config.get('rank', DEFAULT_RANK), - betas=tuple(config.get('betas', (0.9, 0.999))), - eps=config.get('eps', 1e-8), - weight_decay=config.get('weight_decay', 0.01), - warmup_steps=config.get('warmup_steps', 0), - scale=config.get('scale'), - proj_refresh=config.get('proj_refresh', 200), - norm_growth_limit=config.get('norm_growth_limit', 1.01), - ) - - # Restore state if exists - if os.path.exists(OPTIMIZER_STATE_PATH): - try: - state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False) - self.optimizer.load_state_dict(state) - logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}") - except Exception as e: - logger.warning(f"Could not restore optimizer state: {e}") - - logger.info( - f"Optimizer: {len(apollo_params)} apollo params, " - f"{len(standard_params)} standard, " - f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB" - ) - - return self.optimizer - - def _save_optimizer_state(self): - """Save optimizer state for persistence.""" - if self.optimizer is not None: - torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH) - logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}") - - def _run_training( - self, - samples: list[dict[str, Any]], - config: dict[str, Any], - ) -> list[float]: - """Run Apollo training on the given samples.""" - optimizer = self._get_or_create_optimizer(config) - - loss_history = [] - - for i, sample in enumerate(samples): - ctx_ids = sample['context_ids'] - cont_ids = sample['continuation_ids'] - all_ids = ctx_ids + cont_ids - context_len = len(ctx_ids) - - input_ids = torch.tensor([all_ids], device='cuda:0') - - optimizer.zero_grad() - - # Context-frozen forward pass - with torch.no_grad(): - outputs = self.model(input_ids[:, :context_len], use_cache=True) - past_kv = outputs.past_key_values - - # Decision tokens with gradients - with torch.enable_grad(): - outputs = self.model( - input_ids[:, context_len:], - past_key_values=past_kv, - use_cache=False, - ) - logits = outputs.logits - - # Shift: predict next token from each position - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - - loss = nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) - - loss.backward() - optimizer.step() - - loss_val = loss.item() - loss_history.append(loss_val) - logger.info( - f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " - f"(ctx={context_len}, cont={len(cont_ids)} tokens)" - ) - - return loss_history - - def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]: - """Handle a training request.""" - samples = request.get('samples', []) - config = request.get('config', {}) - - if not samples: - return {'error': 'No training samples provided'} - - try: - loss_history = self._run_training(samples, config) - return { - 'status': 'completed', - 'training_samples': len(samples), - 'loss_history': loss_history, - } - except Exception as e: - logger.exception(f"Training failed: {e}") - return {'error': str(e)} - - def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]: - """Handle a checkpoint sync request.""" - if not self.model_path: - return {'error': 'Model path not set'} - - try: - self._save_optimizer_state() - result = checkpoint_sync(self.model_path) - return { - 'status': 'completed', - 'total_changed': result['total_changed'], - 'files_changed': result['files_changed'], - } - except Exception as e: - logger.exception(f"Checkpoint sync failed: {e}") - return {'error': str(e)} - - def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]: - """Handle a status request.""" - return { - 'status': 'ready', - 'model_loaded': self.model is not None, - 'optimizer_loaded': self.optimizer is not None, - 'model_path': self.model_path, - 'optimizer_state_mb': ( - self.optimizer.state_size_bytes() / 1e6 - if self.optimizer else 0 - ), - } - - def run(self): - """Main loop - listen for requests and handle them.""" - # Set up signal handlers - def handle_signal(signum, frame): - logger.info(f"Received signal {signum}, shutting down...") - self._running = False - - signal.signal(signal.SIGTERM, handle_signal) - signal.signal(signal.SIGINT, handle_signal) - - # Set up ZMQ socket first so API server can connect - context = zmq.Context() - socket = context.socket(zmq.REP) - socket.bind(self.zmq_addr) - logger.info(f"Training worker listening on {self.zmq_addr}") - - # Create HF model wrapper with views into vLLM's GPU memory - logger.info("Connecting to vLLM weights via IPC handles...") - try: - self.model = self._create_model_wrapper() - logger.info("HF model wrapper ready (views into vLLM GPU memory)") - except Exception as e: - logger.error(f"Failed to connect to vLLM weights: {e}") - logger.info("Will retry on first training request") - - # Set socket timeout so we can check _running flag - socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout - - while self._running: - try: - message = socket.recv_json() - except zmq.Again: - # Timeout, check _running and continue - continue - - request_type = message.get('type', 'train') - logger.info(f"Received {request_type} request") - - # Ensure model is loaded - if self.model is None and request_type != 'status': - try: - self.model = self._create_model_wrapper() - except Exception as e: - socket.send_json({'error': f'Model not loaded: {e}'}) - continue - - # Dispatch request - if request_type == 'train': - response = self._handle_train(message) - elif request_type == 'checkpoint': - response = self._handle_checkpoint(message) - elif request_type == 'status': - response = self._handle_status(message) - else: - response = {'error': f'Unknown request type: {request_type}'} - - socket.send_json(response) - - # Cleanup - logger.info("Saving optimizer state before shutdown...") - self._save_optimizer_state() - socket.close() - context.term() - logger.info("Training worker shut down") - - -def main(): - """Entry point for running as a subprocess.""" - logging.basicConfig( - level=logging.INFO, - format='[apollo-worker] %(asctime)s %(levelname)s %(message)s', - datefmt='%H:%M:%S', - ) - - zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR) - worker = TrainingWorker(zmq_addr) - worker.run() - - -if __name__ == '__main__': - main() diff --git a/training/apollo_worker.py b/training/apollo_worker.py new file mode 100755 index 0000000..d46fb55 --- /dev/null +++ b/training/apollo_worker.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 +""" +Apollo Mini Training Daemon + +This daemon: +1. Listens over HTTPS for training requests from poc-agent +2. Pauses vLLM inference +3. Runs APOLLO-Mini training with torch.enable_grad() +4. Saves checkpoints and training metadata +5. Resumes vLLM inference + +Communication protocol: +- POST /train: Start a training job +- GET /status/{job_id}: Check training status +- GET /checkpoints: List available checkpoints +""" + +import asyncio +import json +import logging +import os +import sys +import time +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, Any, List +from enum import Enum + +import torch +import torch.nn as nn +from aiohttp import web + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger('apollo_worker') + +class TrainingStatus(Enum): + PENDING = "pending" + PAUSING_VLLM = "pausing_vllm" + TRAINING = "training" + SAVING_CHECKPOINT = "saving_checkpoint" + RESUMING_VLLM = "resuming_vllm" + COMPLETED = "completed" + FAILED = "failed" + +@dataclass +class TrainingJob: + job_id: str + status: TrainingStatus + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + model_path: Optional[str] = None + checkpoint_path: Optional[str] = None + training_samples: int = 0 + loss_history: List[float] = field(default_factory=list) + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + 'job_id': self.job_id, + 'status': self.status.value, + 'created_at': self.created_at.isoformat(), + 'started_at': self.started_at.isoformat() if self.started_at else None, + 'completed_at': self.completed_at.isoformat() if self.completed_at else None, + 'model_path': self.model_path, + 'checkpoint_path': self.checkpoint_path, + 'training_samples': self.training_samples, + 'loss_history': self.loss_history, + 'error': self.error, + } + +class ApolloWorker: + def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): + self.config = self._load_config(config_path) + self.jobs: Dict[str, TrainingJob] = {} + self.vllm_paused = False + self.app = web.Application() + self._setup_routes() + + def _load_config(self, config_path: str) -> Dict[str, Any]: + """Load configuration from file or use defaults.""" + default_config = { + 'host': '0.0.0.0', + 'port': 8080, + 'vllm_socket': '/tmp/vllm_control.sock', + 'model_path': '/home/ubuntu/models/Qwen3.5-27B', + 'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints', + 'max_training_samples': 100, + 'learning_rate': 1e-5, + 'batch_size': 1, + } + + if os.path.exists(config_path): + with open(config_path, 'r') as f: + user_config = json.load(f) + default_config.update(user_config) + + Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) + return default_config + + def _setup_routes(self): + """Setup HTTP routes.""" + self.app.router.add_post('/train', self.handle_train_request) + self.app.router.add_get('/status/{job_id}', self.handle_status_request) + self.app.router.add_get('/checkpoints', self.handle_list_checkpoints) + self.app.router.add_get('/health', self.handle_health_check) + + async def handle_health_check(self, request: web.Request) -> web.Response: + """Health check endpoint.""" + return web.json_response({ + 'status': 'healthy', + 'vllm_paused': self.vllm_paused, + 'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]]) + }) + + async def handle_train_request(self, request: web.Request) -> web.Response: + """Handle training request from poc-agent.""" + try: + data = await request.json() + + # Validate required fields + if 'training_data' not in data: + return web.json_response( + {'error': 'Missing training_data field'}, + status=400 + ) + + job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}" + job = TrainingJob( + job_id=job_id, + status=TrainingStatus.PENDING, + created_at=datetime.now(), + model_path=self.config['model_path'] + ) + self.jobs[job_id] = job + + # Start training in background + asyncio.create_task(self.execute_training(job, data)) + + return web.json_response({ + 'job_id': job_id, + 'status': 'accepted', + 'message': 'Training job started' + }) + + except Exception as e: + logger.error(f"Error handling train request: {e}") + return web.json_response( + {'error': str(e)}, + status=500 + ) + + async def handle_status_request(self, request: web.Request) -> web.Response: + """Get training job status.""" + job_id = request.match_info['job_id'] + + if job_id not in self.jobs: + return web.json_response( + {'error': 'Job not found'}, + status=404 + ) + + job = self.jobs[job_id] + return web.json_response(job.to_dict()) + + async def handle_list_checkpoints(self, request: web.Request) -> web.Response: + """List available checkpoints.""" + checkpoint_dir = Path(self.config['checkpoint_dir']) + checkpoints = [] + + if checkpoint_dir.exists(): + for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True): + checkpoints.append({ + 'filename': checkpoint_file.name, + 'path': str(checkpoint_file), + 'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(), + 'size': checkpoint_file.stat().st_size + }) + + return web.json_response({'checkpoints': checkpoints}) + + async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]): + """Execute the training pipeline.""" + try: + logger.info(f"Starting training job {job.job_id}") + job.started_at = datetime.now() + + # Step 1: Pause vLLM + job.status = TrainingStatus.PAUSING_VLLM + logger.info("Pausing vLLM...") + await self.pause_vllm() + self.vllm_paused = True + + # Step 2: Load model and prepare for training + job.status = TrainingStatus.TRAINING + logger.info("Loading model and preparing for training...") + + # Load model (this would be the actual Qwen3.5-27B model) + # For now, we'll use a placeholder + model = await self.load_model_for_training() + + # Step 3: Run APOLLO-Mini training + logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples") + + # Extract training samples + samples = training_data['samples'] + job.training_samples = len(samples) + + # Run training loop + loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {})) + job.loss_history = loss_history + + # Step 4: Save checkpoint + job.status = TrainingStatus.SAVING_CHECKPOINT + logger.info("Saving checkpoint...") + checkpoint_path = await self.save_checkpoint(model, job) + job.checkpoint_path = checkpoint_path + + # Step 5: Resume vLLM + job.status = TrainingStatus.RESUMING_VLLM + logger.info("Resuming vLLM...") + await self.resume_vllm() + self.vllm_paused = False + + # Mark job as completed + job.status = TrainingStatus.COMPLETED + job.completed_at = datetime.now() + + logger.info(f"Training job {job.job_id} completed successfully") + + except Exception as e: + logger.error(f"Training job {job.job_id} failed: {e}") + job.status = TrainingStatus.FAILED + job.error = str(e) + job.completed_at = datetime.now() + + # Try to resume vLLM if it was paused + if self.vllm_paused: + try: + await self.resume_vllm() + self.vllm_paused = False + except Exception as resume_error: + logger.error(f"Failed to resume vLLM after training error: {resume_error}") + + async def pause_vllm(self): + """Pause vLLM inference via HTTP API.""" + import aiohttp as aio + url = self.config.get('vllm_url', 'http://localhost:8000') + try: + async with aio.ClientSession() as session: + async with session.post( + f"{url}/pause_generation", + json={"mode": "keep", "clear_cache": False}, + timeout=aio.ClientTimeout(total=10), + ) as resp: + resp.raise_for_status() + logger.info("vLLM paused") + except Exception as e: + logger.warning(f"Failed to pause vLLM: {e}") + + async def resume_vllm(self): + """Resume vLLM inference via HTTP API.""" + import aiohttp as aio + url = self.config.get('vllm_url', 'http://localhost:8000') + try: + async with aio.ClientSession() as session: + async with session.post( + f"{url}/resume_generation", + timeout=aio.ClientTimeout(total=10), + ) as resp: + resp.raise_for_status() + logger.info("vLLM resumed") + except Exception as e: + logger.warning(f"Failed to resume vLLM: {e}") + + async def load_model_for_training(self) -> nn.Module: + """Load HF model with weights pointing to vLLM's GPU memory. + + Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible + views (narrowing merged weights into separate q/k/v/z etc.), and + constructs the HF model around those views. No weight copying — + all parameters share vLLM's GPU memory. + """ + handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt') + model_path = self.config['model_path'] + + # Import vLLM weights via CUDA IPC + logger.info(f"Importing vLLM weights from {handle_path}") + handles = torch.load(handle_path, weights_only=False) + vllm_params = {} + for name, info in handles.items(): + func, args = info['handle'] + vllm_params[name] = func(*args) + logger.info(f"Imported {len(vllm_params)} parameters") + + # Map vLLM merged layout → HF separate layout (views, no copies) + from weight_mapping import load_hf_model_with_vllm_weights + model = load_hf_model_with_vllm_weights(vllm_params, model_path) + logger.info("HF model constructed with vLLM weight views") + + return model + + async def run_apollo_training(self, model: nn.Module, + samples: List[Dict[str, str]], + config: Dict[str, Any]) -> List[float]: + """Run Apollo-Mini training on conversation decision points.""" + from apollo_mini import Apollo + from transformers import AutoTokenizer + + lr = config.get('learning_rate', self.config['learning_rate']) + tokenizer = AutoTokenizer.from_pretrained( + self.config['model_path'], trust_remote_code=True) + + # Build parameter groups (Apollo for 2D+, standard for small/1D) + apollo_params, standard_params = [], [] + for p in model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= 2: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + rank = config.get('apollo_rank', 1) + optimizer = Apollo(groups, lr=lr, rank=rank) + logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, " + f"{len(standard_params)} standard, " + f"state={optimizer.state_size_bytes()/1e6:.1f}MB") + + loss_history = [] + + for i, sample in enumerate(samples): + context = sample.get('context', '') + continuation = sample.get('continuation', '') + + # Tokenize + ctx_ids = tokenizer.encode(context, add_special_tokens=True) + cont_ids = tokenizer.encode(continuation, add_special_tokens=False) + all_ids = ctx_ids + cont_ids + context_len = len(ctx_ids) + + input_ids = torch.tensor([all_ids], device='cuda:0') + + optimizer.zero_grad() + + # Context-frozen forward pass + with torch.no_grad(): + # Forward through context (no gradients) + outputs = model(input_ids[:, :context_len], use_cache=True) + past_kv = outputs.past_key_values + + # Decision tokens with gradients + with torch.enable_grad(): + outputs = model( + input_ids[:, context_len:], + past_key_values=past_kv, + use_cache=False, + ) + logits = outputs.logits # [1, cont_len, vocab] + + # Shift: predict next token from each position + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + + loss = nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + loss.backward() + optimizer.step() + + loss_val = loss.item() + loss_history.append(loss_val) + logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} " + f"(ctx={context_len}, cont={len(cont_ids)} tokens)") + + logger.info(f"Training done: {len(samples)} examples, " + f"final loss={loss_history[-1]:.4f}") + return loss_history + + async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str: + """Save model checkpoint in HuggingFace safetensors format.""" + from safetensors.torch import save_file + import shutil + + checkpoint_dir = Path(self.config['checkpoint_dir']) + date_str = datetime.now().strftime('%Y-%m-%d') + out_dir = checkpoint_dir / date_str + out_dir.mkdir(parents=True, exist_ok=True) + + # Save weights + tensors = {name: p.data.contiguous().cpu() + for name, p in model.named_parameters()} + save_path = out_dir / "model.safetensors" + save_file(tensors, str(save_path)) + + # Copy config files + config_dir = Path(self.config['model_path']) + for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', + 'special_tokens_map.json']: + src = config_dir / f + if src.exists(): + shutil.copy2(src, out_dir / f) + + # Save training metadata + meta = { + 'job_id': job.job_id, + 'training_samples': job.training_samples, + 'loss_history': job.loss_history, + 'timestamp': datetime.now().isoformat(), + } + with open(out_dir / 'training-meta.json', 'w') as f: + json.dump(meta, f, indent=2) + + # Update latest symlink + latest = checkpoint_dir / 'latest' + if latest.is_symlink(): + latest.unlink() + latest.symlink_to(date_str) + + size_gb = save_path.stat().st_size / 1e9 + logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)") + return str(out_dir) + + async def run(self): + """Run the daemon.""" + logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}") + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, self.config['host'], self.config['port']) + await site.start() + logger.info("Apollo Worker is running") + + # Keep running + while True: + await asyncio.sleep(3600) # Sleep for an hour + +def main(): + worker = ApolloWorker() + asyncio.run(worker.run()) + +if __name__ == '__main__': + main() diff --git a/training/checkpoint/Cargo.toml b/training/checkpoint/Cargo.toml new file mode 100644 index 0000000..45e511a --- /dev/null +++ b/training/checkpoint/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "apollo-checkpoint" +version = "0.1.0" +edition = "2024" + +[dependencies] +memmap2 = "0.9" +safetensors = "0.5" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +anyhow = "1" +clap = { version = "4", features = ["derive"] } diff --git a/training/checkpoint/src/main.rs b/training/checkpoint/src/main.rs new file mode 100644 index 0000000..1ebd0df --- /dev/null +++ b/training/checkpoint/src/main.rs @@ -0,0 +1,265 @@ +// apollo-checkpoint — Sync live GPU weights back to model files on disk. +// +// mmaps the model's safetensors files, reads live weights from GPU via +// Python helper (CUDA IPC handles), compares block by block, and memcpys +// only changed regions back into the mmap. For small behavioral training +// steps, this turns a 54GB write into a few hundred MB. +// +// The model files on disk are the checkpoint. No separate checkpoint +// directory — just keep the model up to date. +// +// Usage: +// apollo-checkpoint sync \ +// --handles /tmp/vllm_weight_handles.pt \ +// --model-dir /path/to/Qwen3.5-27B +// +// Runs every 10 minutes via cron. Daily rsync to moria. + +use anyhow::{Context, Result, bail}; +use clap::{Parser, Subcommand}; +use memmap2::MmapMut; +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; + +#[derive(Parser)] +#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")] +struct Cli { + #[command(subcommand)] + command: Cmd, +} + +#[derive(Subcommand)] +enum Cmd { + /// Sync live GPU weights back to model safetensors files + Sync { + /// Path to vLLM weight IPC handles + #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] + handles: PathBuf, + + /// Model directory containing safetensors files + #[arg(long)] + model_dir: PathBuf, + + /// Block size for diffing (bytes) + #[arg(long, default_value_t = 4096)] + block_size: usize, + }, +} + +/// Dump live GPU weights to a flat binary file, ordered by safetensors +/// file and offset to match the on-disk layout. +/// +/// Returns a map of (safetensors filename, tensor name) → raw bytes. +fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result>> { + let dump_path = output_dir.join(".live_dump.bin"); + let index_path = output_dir.join(".live_dump.json"); + + let status = Command::new("python3") + .arg("-c") + .arg(format!(r#" +import torch, json + +handles = torch.load("{handles}", weights_only=False) +index = {{}} +offset = 0 + +with open("{dump}", "wb") as f: + for name in sorted(handles.keys()): + info = handles[name] + func, args = info["handle"] + tensor = func(*args) + data = tensor.contiguous().cpu().numpy().tobytes() + f.write(data) + index[name] = {{"offset": offset, "size": len(data)}} + offset += len(data) + +with open("{index}", "w") as f: + json.dump(index, f) + +print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") +"#, + handles = handles_path.display(), + dump = dump_path.display(), + index = index_path.display(), + )) + .status() + .context("Failed to run Python weight dump")?; + + if !status.success() { + bail!("Python weight dump failed"); + } + + let index_str = fs::read_to_string(&index_path)?; + let index: HashMap = serde_json::from_str(&index_str)?; + let dump_data = fs::read(&dump_path)?; + + let mut result = HashMap::new(); + for (name, entry) in &index { + result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec()); + } + + // Clean up temp files + let _ = fs::remove_file(&dump_path); + let _ = fs::remove_file(&index_path); + + Ok(result) +} + +#[derive(serde::Deserialize)] +struct DumpEntry { + offset: usize, + size: usize, +} + +/// Read the safetensors index to map parameter names to files. +fn read_safetensors_index(model_dir: &Path) -> Result> { + let index_path = model_dir.join("model.safetensors.index.json"); + if !index_path.exists() { + // Single file model + return Ok(HashMap::new()); + } + + let index_str = fs::read_to_string(&index_path)?; + let index: serde_json::Value = serde_json::from_str(&index_str)?; + let weight_map = index["weight_map"] + .as_object() + .context("No weight_map in index")?; + + let mut result = HashMap::new(); + for (name, file) in weight_map { + result.insert(name.clone(), file.as_str().unwrap().to_string()); + } + Ok(result) +} + +/// Sync changed blocks from live weights into a mmap'd safetensors file. +/// Returns (total_bytes_compared, bytes_changed). +fn sync_tensors_to_file( + file_path: &Path, + tensors: &[(String, Vec)], + block_size: usize, +) -> Result<(usize, usize)> { + use safetensors::SafeTensors; + + let file = fs::OpenOptions::new() + .read(true) + .write(true) + .open(file_path) + .with_context(|| format!("Failed to open {}", file_path.display()))?; + + let mut mmap = unsafe { MmapMut::map_mut(&file)? }; + + // Parse safetensors header to find tensor offsets + let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize; + let header_json: serde_json::Value = + serde_json::from_slice(&mmap[8..8 + header_size])?; + let data_start = 8 + header_size; + + let mut total_compared = 0usize; + let mut total_changed = 0usize; + + for (name, live_data) in tensors { + let meta = match header_json.get(name) { + Some(m) => m, + None => { + eprintln!(" Warning: {} not found in {}", name, file_path.display()); + continue; + } + }; + + let offsets = meta["data_offsets"].as_array().unwrap(); + let start = data_start + offsets[0].as_u64().unwrap() as usize; + let end = data_start + offsets[1].as_u64().unwrap() as usize; + let disk_data = &mmap[start..end]; + + if disk_data.len() != live_data.len() { + eprintln!(" Warning: size mismatch for {}: disk={} live={}", + name, disk_data.len(), live_data.len()); + continue; + } + + // Diff block by block, memcpy only changed blocks + let mut offset = 0; + while offset < disk_data.len() { + let block_end = (offset + block_size).min(disk_data.len()); + total_compared += block_end - offset; + + if disk_data[offset..block_end] != live_data[offset..block_end] { + mmap[start + offset..start + block_end] + .copy_from_slice(&live_data[offset..block_end]); + total_changed += block_end - offset; + } + offset = block_end; + } + } + + mmap.flush()?; + Ok((total_compared, total_changed)) +} + +fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> { + if !handles.exists() { + bail!("Weight handles not found: {}. Is vLLM running with the export hook?", + handles.display()); + } + + eprintln!("Dumping live weights from GPU..."); + let live_weights = dump_live_weights(&handles, &model_dir)?; + eprintln!(" {} tensors dumped", live_weights.len()); + + // Map parameter names to safetensors files + let weight_map = read_safetensors_index(&model_dir)?; + + // Group tensors by safetensors file + let mut by_file: HashMap)>> = HashMap::new(); + for (name, data) in live_weights { + let file = weight_map + .get(&name) + .cloned() + .unwrap_or_else(|| "model.safetensors".to_string()); + by_file.entry(file).or_default().push((name, data)); + } + + let mut total_compared = 0usize; + let mut total_changed = 0usize; + + for (filename, tensors) in &by_file { + let file_path = model_dir.join(filename); + if !file_path.exists() { + eprintln!(" Warning: {} not found, skipping", filename); + continue; + } + + let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?; + total_compared += compared; + total_changed += changed; + + if changed > 0 { + eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6); + } + } + + if total_changed == 0 { + eprintln!("No changes — model files are up to date"); + } else { + eprintln!( + "Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)", + total_changed as f64 / 1e6, + total_compared as f64 / 1e9, + total_changed as f64 / total_compared as f64 * 100.0, + ); + } + + Ok(()) +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + match cli.command { + Cmd::Sync { handles, model_dir, block_size } => { + cmd_sync(handles, model_dir, block_size) + } + } +} diff --git a/training/export_weights.py b/training/export_weights.py new file mode 100644 index 0000000..ef2f608 --- /dev/null +++ b/training/export_weights.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""Export vLLM's live model weight IPC handles for the training process. + +Connects to a running vLLM instance, iterates over model parameters, +and exports CUDA IPC handles that allow another process to access the +same GPU memory without copying. + +Usage: + # Run after vLLM is serving: + python3 export_weights.py --output /tmp/vllm_weight_handles.pt + + # Or via vLLM's API (future): + curl -X POST http://localhost:8000/export_weights +""" + +import argparse +import sys +import torch +from pathlib import Path + + +def export_from_model(model, output_path: str): + """Export IPC handles for all model parameters.""" + from torch.multiprocessing.reductions import reduce_tensor + + handles = {} + total_bytes = 0 + + for name, param in model.named_parameters(): + handle = reduce_tensor(param.data) + handles[name] = { + 'handle': handle, + 'shape': list(param.shape), + 'dtype': str(param.dtype), + } + param_bytes = param.nelement() * param.element_size() + total_bytes += param_bytes + + torch.save(handles, output_path) + + n_params = len(handles) + print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)") + print(f"Saved to {output_path}") + return handles + + +def main(): + parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles") + parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt", + help="Output path for IPC handles") + parser.add_argument("--vllm-pid", type=int, default=None, + help="vLLM worker PID (auto-detected if not specified)") + args = parser.parse_args() + + # For now: load the model directly and export. + # TODO: connect to running vLLM process instead. + print("Note: This currently loads the model separately.") + print("Full integration will export from the running vLLM process.") + print() + + # Detect model path from running vLLM + import subprocess + result = subprocess.run( + ['ps', 'aux'], capture_output=True, text=True + ) + model_path = None + for line in result.stdout.split('\n'): + if 'vllm' in line and '--model' in line: + parts = line.split() + for i, p in enumerate(parts): + if p == '--model' and i + 1 < len(parts): + model_path = parts[i + 1] + break + # Also check model_tag format + if p.startswith('--model='): + model_path = p.split('=', 1)[1] + break + + if model_path: + print(f"Detected vLLM model: {model_path}") + else: + print("Could not detect running vLLM model. Specify manually.") + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/training/apollo_plugin/steering.py b/training/extract_steering_vector.py similarity index 100% rename from training/apollo_plugin/steering.py rename to training/extract_steering_vector.py diff --git a/training/first_training_step.py b/training/first_training_step.py new file mode 100644 index 0000000..0e6ffd8 --- /dev/null +++ b/training/first_training_step.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +"""First real Apollo training step — ready for Kent to run. + +This script: +1. Imports vLLM's live weights via CUDA IPC +2. Constructs HF model with shared memory views +3. Runs ONE forward+backward on a real training example +4. Applies ONE Apollo optimizer step +5. Verifies vLLM still works after the update + +The training example is from March 30: Kent said "use vLLM's code" +and the model should have accepted instead of suggesting alternatives. + +Usage: + source ~/training-env/bin/activate + python3 first_training_step.py [--dry-run] +""" + +import argparse +import sys +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoConfig, AutoTokenizer +from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + +sys.path.insert(0, '.') +from weight_mapping import vllm_to_hf_views +from apollo_mini import Apollo + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dry-run', action='store_true', + help="Run forward+backward but don't apply the optimizer step") + parser.add_argument('--lr', type=float, default=1e-5, + help="Learning rate (default: 1e-5 = conservative)") + parser.add_argument('--rank', type=int, default=256) + parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt') + parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B') + args = parser.parse_args() + + print("=== First Apollo Training Step ===\n") + + # 1. Import vLLM weights + print("1. Importing vLLM weights via CUDA IPC...") + handles = torch.load(args.handles, weights_only=False) + vllm_params = {} + for name, info in handles.items(): + func, args_h = info['handle'] + vllm_params[name] = func(*args_h) + print(f" {len(vllm_params)} parameters imported") + + # 2. Map to HF layout + print("2. Mapping to HF layout (zero-copy views)...") + hf_params = vllm_to_hf_views(vllm_params) + + # 3. Create HF model + print("3. Creating HF model with shared weights...") + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + with torch.device('meta'): + model = Qwen3_5ForCausalLM(config.text_config) + + replaced = 0 + for name, param in list(model.named_parameters()): + if name in hf_params: + parts = name.split('.') + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], + nn.Parameter(hf_params[name], requires_grad=True)) + replaced += 1 + print(f" {replaced} parameters replaced with vLLM memory views") + + # 4. Load tokenizer + print("4. Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + # 5. Construct training example + print("5. Constructing training example...") + + # Context: conversation where Kent says to use vLLM's code + # Target: the response that accepts the direction + context = ( + "<|im_start|>user\n" + "vllm has a fused kernel already, right?<|im_end|>\n" + "<|im_start|>assistant\n" + "Yeah — vLLM has `gdn_attention_core` which is a custom op " + "that does the whole GDN layer's core in one dispatch.<|im_end|>\n" + "<|im_start|>user\n" + "Why wouldn't we just use that?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + # The CORRECT response (accept direction, don't suggest alternatives) + continuation = ( + "We should. Let me pull in their kernel and wire it into " + "our Rust orchestration. Which file should I start with?" + ) + + context_ids = tokenizer.encode(context, add_special_tokens=False) + continuation_ids = tokenizer.encode(continuation, add_special_tokens=False) + all_ids = context_ids + continuation_ids + context_len = len(context_ids) + + print(f" Context: {context_len} tokens") + print(f" Continuation: {len(continuation_ids)} tokens") + print(f" Total: {len(all_ids)} tokens") + + input_ids = torch.tensor([all_ids], device='cuda:0') + + # 6. Initialize Apollo optimizer + print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...") + apollo_params = [] + standard_params = [] + for p in model.parameters(): + if p.requires_grad: + if p.ndim >= 2 and min(p.shape) >= args.rank: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({'params': apollo_params}) + if standard_params: + groups.append({'params': standard_params}) + + optimizer = Apollo(groups, lr=args.lr, rank=args.rank) + print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard") + + # 7. Forward pass + print("7. Forward pass...") + model.train() + optimizer.zero_grad() + + # Context-frozen: no grad for context, grad for continuation + with torch.no_grad(): + ctx_output = model(input_ids[:, :context_len], use_cache=True) + past_kv = ctx_output.past_key_values + + with torch.enable_grad(): + output = model(input_ids[:, context_len:], + past_key_values=past_kv, use_cache=False) + logits = output.logits + # Shift for next-token prediction + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, context_len + 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + print(f" Loss: {loss.item():.4f}") + + # 8. Backward pass + print("8. Backward pass...") + loss.backward() + n_grads = sum(1 for p in model.parameters() if p.grad is not None) + print(f" {n_grads} parameters have gradients") + + # 9. Apollo step (or dry run) + if args.dry_run: + print("\n9. DRY RUN — skipping optimizer step") + print(" (run without --dry-run to apply the update)") + else: + print("9. Applying Apollo optimizer step...") + # Record a few weight norms before + sample_norms_before = {} + for name, p in model.named_parameters(): + if 'layers.0.' in name and p.grad is not None: + sample_norms_before[name] = p.data.norm().item() + + optimizer.step() + + # Check weight changes + print(" Weight changes (layer 0):") + for name, before in sample_norms_before.items(): + p = dict(model.named_parameters())[name] + after = p.data.norm().item() + delta = abs(after - before) + pct = delta / before * 100 if before > 0 else 0 + print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)") + + optimizer.zero_grad() + + # 10. Verify vLLM still works + print("\n10. Verifying vLLM still serves...") + import subprocess + result = subprocess.run( + ['curl', '-s', '--max-time', '30', + '-X', 'POST', 'http://localhost:8000/v1/chat/completions', + '-H', 'Content-Type: application/json', + '-H', 'Authorization: Bearer bcachefs-agents-2026', + '-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'], + capture_output=True, text=True, timeout=45 + ) + if result.returncode == 0 and 'choices' in result.stdout: + print(" vLLM still serving ✓") + else: + print(" WARNING: vLLM may not be responding") + print(f" stdout: {result.stdout[:200]}") + + print("\n=== COMPLETE ===") + if args.dry_run: + print("Run without --dry-run to apply the first real training step.") + else: + print("First Apollo training step applied to vLLM's live weights.") + print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") + + +if __name__ == '__main__': + main() diff --git a/training/pyproject.toml b/training/pyproject.toml deleted file mode 100644 index 7cf0581..0000000 --- a/training/pyproject.toml +++ /dev/null @@ -1,29 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0"] -build-backend = "setuptools.build_meta" - -[project] -name = "apollo-plugin" -version = "0.1.0" -description = "Apollo training plugin for vLLM" -requires-python = ">=3.10" -dependencies = [ - "torch", - "aiohttp", - "safetensors", - "pyzmq", -] - -[project.optional-dependencies] -dev = ["pytest"] - -[project.entry-points."vllm.general_plugins"] -apollo = "apollo_plugin:register" - -[project.scripts] -apollo-checkpoint = "apollo_plugin.checkpoint_sync:main" -apollo-worker = "apollo_plugin.training_worker:main" - -[tool.setuptools.packages.find] -where = ["."] -include = ["apollo_plugin*"] diff --git a/training/start_vllm_with_apollo.sh b/training/start_vllm_with_apollo.sh new file mode 100755 index 0000000..98dfedb --- /dev/null +++ b/training/start_vllm_with_apollo.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Start vLLM with Apollo weight export hook. +# +# The hook patches vLLM's model runner to export CUDA IPC handles +# after loading, so the Apollo training process can share the same +# GPU memory. + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +exec python3 -c " +import sys +sys.path.insert(0, '$SCRIPT_DIR') +import vllm_export_hook # patches model runner before vLLM loads + +sys.argv = ['vllm'] + sys.argv[1:] +from vllm.entrypoints.cli.main import main +main() +" serve "$@" diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..a5fbe2c --- /dev/null +++ b/training/train.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +"""Nightly training process for Apollo-Mini fine-tuning. + +Imports vLLM's model weights via CUDA IPC, runs context-frozen +training on flagged conversation segments, saves updated checkpoint. + +Usage: + python3 train.py \ + --weights /tmp/vllm_weight_handles.pt \ + --examples training-examples.jsonl \ + --checkpoint-dir checkpoints/ \ + --lr 1e-5 +""" + +import argparse +import json +import os +import sys +import time +from datetime import datetime +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from apollo_mini import ApolloMini + + +def import_weights(handle_path: str) -> dict[str, torch.Tensor]: + """Import weight tensors from CUDA IPC handles.""" + handles = torch.load(handle_path, weights_only=False) + params = {} + for name, info in handles.items(): + func, args = info['handle'] + tensor = func(*args) + params[name] = tensor + return params + + +def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]: + """Split parameters into Apollo-Mini and standard groups. + + Apollo-Mini needs 2D+ matrices with min dimension >= 2. + Small tensors (norms, biases, conv1d 3D weights) use standard Adam. + """ + apollo_params = [] + standard_params = [] + + for name, p in params.items(): + p.requires_grad_(True) + if p.ndim >= 2 and min(p.shape) >= 2: + apollo_params.append(p) + else: + standard_params.append(p) + + groups = [] + if apollo_params: + groups.append({ + 'params': apollo_params, + 'name': 'apollo', + }) + if standard_params: + groups.append({ + 'params': standard_params, + 'name': 'standard', + }) + + n_apollo = sum(p.nelement() for p in apollo_params) + n_standard = sum(p.nelement() for p in standard_params) + print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M") + return groups + + +def forward_pass(params, input_ids, context_len, device): + """Run context-frozen forward pass. + + Args: + params: dict of name -> tensor (shared with vLLM) + input_ids: full sequence [1, seq_len] + context_len: number of context tokens (no gradient) + device: CUDA device + + Returns: + logits for decision tokens, target ids for loss + """ + # TODO: Build proper forward model matching vLLM's weight layout. + # For now this is a placeholder — the real implementation needs + # to replicate vLLM's model architecture (merged projections, + # GDN recurrence, full attention, MLP) using the shared weights. + raise NotImplementedError( + "Forward model not yet implemented. " + "Need to build a model that matches vLLM's merged weight layout " + "(MergedColumnParallelLinear for qkvz/ba/gate_up, " + "RowParallelLinear for out_proj/down) and computes the same " + "forward pass with autograd enabled." + ) + + +def save_checkpoint(params: dict[str, torch.Tensor], + checkpoint_dir: str, + config_path: str = None): + """Save model checkpoint in HuggingFace safetensors format. + + Saves weights split across shards matching the original model layout, + archives the previous checkpoint, and updates the 'latest' symlink. + """ + date_str = datetime.now().strftime("%Y-%m-%d") + out_dir = Path(checkpoint_dir) / date_str + out_dir.mkdir(parents=True, exist_ok=True) + + # Save all weights in a single safetensors file for now. + # TODO: split across shards matching HF model index for large models. + tensors = {} + for name, param in params.items(): + tensors[name] = param.data.contiguous().cpu() + + save_path = out_dir / "model.safetensors" + save_file(tensors, str(save_path)) + print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)") + + # Copy config files if provided + if config_path: + import shutil + config_dir = Path(config_path) + for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', + 'special_tokens_map.json', 'generation_config.json']: + src = config_dir / f + if src.exists(): + shutil.copy2(src, out_dir / f) + + # Update latest symlink + latest = Path(checkpoint_dir) / "latest" + if latest.is_symlink(): + latest.unlink() + latest.symlink_to(date_str) + print(f"Updated {latest} -> {date_str}") + + return str(out_dir) + + +def train_step(params, example, optimizer, device, log_entries): + """Run one training step on a single example. + + Args: + params: dict of name -> tensor + example: dict with 'input_ids', 'context_len', 'target_ids' + optimizer: ApolloMini instance + device: CUDA device + log_entries: list to append log dicts to + + Returns: + loss value + """ + optimizer.zero_grad() + + input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0) + context_len = example['context_len'] + + # Forward pass (context frozen, decision tokens with grad) + logits, targets = forward_pass(params, input_ids, context_len, device) + + # Cross-entropy loss on decision tokens + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]), + targets.view(-1), + ) + + # Backward + loss.backward() + + # Compute gradient stats before optimizer step + total_grad_norm = 0.0 + for p in params.values(): + if p.grad is not None: + total_grad_norm += p.grad.norm().item() ** 2 + total_grad_norm = total_grad_norm ** 0.5 + + # Optimizer step + optimizer.step() + + # Log + log_entries.append({ + 'example_id': example.get('id', 'unknown'), + 'loss': loss.item(), + 'grad_norm': total_grad_norm, + 'timestamp': datetime.now().isoformat(), + }) + + return loss.item() + + +def main(): + parser = argparse.ArgumentParser(description="Apollo-Mini training") + parser.add_argument("--weights", required=True, + help="Path to exported weight IPC handles") + parser.add_argument("--examples", required=True, + help="Path to training examples JSONL") + parser.add_argument("--checkpoint-dir", default="checkpoints", + help="Directory for saving checkpoints") + parser.add_argument("--config-path", default=None, + help="Path to model config files (for checkpoint)") + parser.add_argument("--lr", type=float, default=1e-5, + help="Learning rate") + parser.add_argument("--warmup-steps", type=int, default=10, + help="Learning rate warmup steps") + parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--dry-run", action="store_true", + help="Load weights and validate, don't train") + args = parser.parse_args() + + print(f"Apollo-Mini Training") + print(f" weights: {args.weights}") + print(f" examples: {args.examples}") + print(f" lr: {args.lr}") + print() + + # Import weights + print("Importing weights via CUDA IPC...") + params = import_weights(args.weights) + print(f" {len(params)} parameters imported") + + # Make parameter groups + param_groups = make_param_groups(params) + + # Initialize optimizer + optimizer = ApolloMini(param_groups, lr=args.lr, + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps) + print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") + + if args.dry_run: + print("\nDry run — weights imported and validated successfully.") + return + + # Load training examples + examples = [] + with open(args.examples) as f: + for line in f: + examples.append(json.loads(line)) + print(f" {len(examples)} training examples") + + # Training loop + log_entries = [] + print(f"\nTraining...") + t0 = time.time() + + for i, example in enumerate(examples): + loss = train_step(params, example, optimizer, 'cuda:0', log_entries) + print(f" [{i+1}/{len(examples)}] loss={loss:.4f}") + + elapsed = time.time() - t0 + print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s") + print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") + + # Save checkpoint + print("\nSaving checkpoint...") + save_checkpoint(params, args.checkpoint_dir, args.config_path) + + # Save training log + date_str = datetime.now().strftime("%Y-%m-%d") + log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl" + with open(log_path, 'w') as f: + for entry in log_entries: + f.write(json.dumps(entry) + '\n') + print(f"Training log: {log_path}") + + +if __name__ == '__main__': + main() diff --git a/training/training_example.py b/training/training_example.py new file mode 100644 index 0000000..b5779e0 --- /dev/null +++ b/training/training_example.py @@ -0,0 +1,175 @@ +"""Training example construction and tokenization. + +Takes raw conversation context + improved continuation, produces +tokenized tensors ready for context-frozen forward+backward. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path + +import torch +from transformers import AutoTokenizer + + +@dataclass +class TrainingExample: + """A single training example for context-frozen training.""" + id: str + context: str # conversation up to decision point + continuation: str # the better response + reason: str = "" # why this is a training target + memories: list[str] = field(default_factory=list) # memories that were in context + + # Computed after tokenization + input_ids: torch.Tensor | None = None + context_len: int = 0 + total_len: int = 0 + + def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"): + """Tokenize context + continuation into training-ready tensors. + + The chat template is applied to make the token distribution + match what the model sees during inference. + """ + # Build messages for context (everything up to the decision) + # The context should already be in chat format + context_ids = tokenizer.encode(self.context, add_special_tokens=False) + continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False) + + self.context_len = len(context_ids) + self.total_len = len(context_ids) + len(continuation_ids) + + if self.total_len > max_len: + # Truncate context from the left, keep continuation intact + excess = self.total_len - max_len + context_ids = context_ids[excess:] + self.context_len = len(context_ids) + self.total_len = len(context_ids) + len(continuation_ids) + + all_ids = context_ids + continuation_ids + self.input_ids = torch.tensor(all_ids, device=device) + return self + + def to_dict(self) -> dict: + return { + 'id': self.id, + 'context': self.context, + 'continuation': self.continuation, + 'reason': self.reason, + 'memories': self.memories, + 'context_len': self.context_len, + 'total_len': self.total_len, + } + + @classmethod + def from_dict(cls, d: dict) -> 'TrainingExample': + return cls( + id=d['id'], + context=d['context'], + continuation=d['continuation'], + reason=d.get('reason', ''), + memories=d.get('memories', []), + ) + + +def load_examples(path: str) -> list[TrainingExample]: + """Load training examples from JSONL file.""" + examples = [] + with open(path) as f: + for line in f: + if line.strip(): + examples.append(TrainingExample.from_dict(json.loads(line))) + return examples + + +def save_examples(examples: list[TrainingExample], path: str): + """Save training examples to JSONL file.""" + with open(path, 'w') as f: + for ex in examples: + f.write(json.dumps(ex.to_dict()) + '\n') + + +class ExampleTokenizer: + """Handles tokenization with the model's chat template. + + Applies the same chat template that vLLM uses during inference, + so the token distribution matches what the model expects. + """ + + def __init__(self, model_path: str): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) + + def prepare_example(self, example: TrainingExample, + max_len: int = 8192, + device: str = "cuda:0") -> TrainingExample: + """Tokenize an example using the chat template. + + For proper training, the context should be formatted exactly + as vLLM would format it — with chat template applied. + """ + # Apply chat template to get the exact token sequence + # the model would see during inference + # + # Context: everything up to the decision point + # Continuation: the improved response + # + # We tokenize them separately to know where context ends + # and continuation begins. + context_ids = self.tokenizer.encode( + example.context, add_special_tokens=True) + continuation_ids = self.tokenizer.encode( + example.continuation, add_special_tokens=False) + + example.context_len = len(context_ids) + example.total_len = len(context_ids) + len(continuation_ids) + + if example.total_len > max_len: + excess = example.total_len - max_len + context_ids = context_ids[excess:] + example.context_len = len(context_ids) + example.total_len = example.context_len + len(continuation_ids) + + all_ids = context_ids + continuation_ids + example.input_ids = torch.tensor(all_ids, device=device) + return example + + def prepare_from_messages(self, example_id: str, + messages: list[dict], + decision_idx: int, + better_response: str, + reason: str = "", + memories: list[str] | None = None, + max_len: int = 8192, + device: str = "cuda:0") -> TrainingExample: + """Build a training example from a chat message list. + + Args: + example_id: unique identifier + messages: list of {"role": ..., "content": ...} dicts + decision_idx: index of the assistant message to replace + better_response: the improved response text + reason: why this is a training target + memories: memory keys that were in context + max_len: maximum sequence length + device: target device + + Returns: + Tokenized TrainingExample + """ + # Context: all messages up to (not including) the decision + context_messages = messages[:decision_idx] + context_text = self.tokenizer.apply_chat_template( + context_messages, tokenize=False, add_generation_prompt=True) + + # Build the example + example = TrainingExample( + id=example_id, + context=context_text, + continuation=better_response, + reason=reason, + memories=memories or [], + ) + + return self.prepare_example(example, max_len=max_len, device=device) diff --git a/training/apollo_plugin/export_hook.py b/training/vllm_export_hook.py similarity index 76% rename from training/apollo_plugin/export_hook.py rename to training/vllm_export_hook.py index e0ff6fc..6a0bf1e 100644 --- a/training/apollo_plugin/export_hook.py +++ b/training/vllm_export_hook.py @@ -1,12 +1,17 @@ """Monkey-patch vLLM to export weight IPC handles on startup. -Usage — install the apollo_plugin package: +Usage — add to start_vllm.sh BEFORE the vllm serve command: - pip install -e /path/to/training + export VLLM_PLUGINS=vllm_export_hook + vllm serve Qwen/Qwen3.5-27B ... -Then vLLM auto-discovers and loads via entry point. Or filter: +Or use Python to launch vLLM with the hook: - VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ... + python3 -c " + import vllm_export_hook # installs the patch + from vllm.entrypoints.openai.api_server import run_server + run_server(...) + " The hook patches vLLM's model runner to export IPC handles after model loading completes. The handles are saved to a file that the @@ -20,7 +25,7 @@ from pathlib import Path HANDLE_PATH = "/tmp/vllm_weight_handles.pt" -def export_model_weights(model, model_path: str | None = None): +def export_model_weights(model): """Export CUDA IPC handles for all model parameters.""" from torch.multiprocessing.reductions import reduce_tensor @@ -38,12 +43,6 @@ def export_model_weights(model, model_path: str | None = None): } total_bytes += param.nelement() * param.element_size() - # Include metadata for training worker - handles['__metadata__'] = { - 'model_path': model_path, - 'num_params': len(handles), - } - torch.save(handles, HANDLE_PATH) print(f"[apollo] Exported {len(handles)} weight handles " f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}") @@ -64,11 +63,14 @@ def _patch_model_runner(): def patched_load(self, *args, **kwargs): result = original_load(self, *args, **kwargs) try: - model_path = self.vllm_config.model_config.model - export_model_weights(self.model_runner.model, model_path) + export_model_weights(self.model_runner.model) except Exception as e: print(f"[apollo] Failed to export weights: {e}") return result gpu_worker.Worker.load_model = patched_load print("[apollo] Weight export hook installed") + + +# Auto-install when imported +_patch_model_runner() diff --git a/training/apollo_plugin/weight_mapping.py b/training/weight_mapping.py similarity index 100% rename from training/apollo_plugin/weight_mapping.py rename to training/weight_mapping.py