Compare commits

..

No commits in common. "592a3e2e52aae0f54f5a80617583fc5d144e04a4" and "460394750641cc6a6b6d696062a5b787720b3292" have entirely different histories.

41 changed files with 1973 additions and 3255 deletions

199
Cargo.lock generated
View file

@ -492,12 +492,11 @@ dependencies = [
"http-body-util",
"hyper",
"hyper-util",
"json-five",
"json5",
"libc",
"log",
"memchr",
"memmap2",
"notify-debouncer-mini",
"paste",
"peg",
"ratatui",
@ -1089,15 +1088,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2"
dependencies = [
"libc",
]
[[package]]
name = "futures"
version = "0.3.32"
@ -1463,26 +1453,6 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199"
dependencies = [
"bitflags 2.11.0",
"inotify-sys",
"libc",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "instability"
version = "0.3.12"
@ -1561,16 +1531,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "json-five"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "865f2d01a4549c1fd8c60640c03ae5249eb374cd8cde8b905628d4b1af95c87c"
dependencies = [
"serde",
"unicode-general-category",
]
[[package]]
name = "json5"
version = "1.3.1"
@ -1592,26 +1552,6 @@ dependencies = [
"thiserror 2.0.18",
]
[[package]]
name = "kqueue"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a"
dependencies = [
"kqueue-sys",
"libc",
]
[[package]]
name = "kqueue-sys"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b"
dependencies = [
"bitflags 1.3.2",
"libc",
]
[[package]]
name = "lab"
version = "0.11.0"
@ -1834,45 +1774,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "notify"
version = "8.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3"
dependencies = [
"bitflags 2.11.0",
"fsevent-sys",
"inotify",
"kqueue",
"libc",
"log",
"mio",
"notify-types",
"walkdir",
"windows-sys 0.60.2",
]
[[package]]
name = "notify-debouncer-mini"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17849edfaabd9a5fef1c606d99cfc615a8e99f7ac4366406d86c7942a3184cf2"
dependencies = [
"log",
"notify",
"notify-types",
"tempfile",
]
[[package]]
name = "notify-types"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a"
dependencies = [
"bitflags 2.11.0",
]
[[package]]
name = "num-conv"
version = "0.2.1"
@ -3483,12 +3384,6 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-general-category"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b993bddc193ae5bd0d623b49ec06ac3e9312875fdae725a975c51db1cc1677f"
[[package]]
name = "unicode-ident"
version = "1.0.24"
@ -3899,16 +3794,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.60.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
dependencies = [
"windows-targets 0.53.5",
"windows-targets",
]
[[package]]
@ -3926,31 +3812,14 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.53.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
dependencies = [
"windows-link",
"windows_aarch64_gnullvm 0.53.1",
"windows_aarch64_msvc 0.53.1",
"windows_i686_gnu 0.53.1",
"windows_i686_gnullvm 0.53.1",
"windows_i686_msvc 0.53.1",
"windows_x86_64_gnu 0.53.1",
"windows_x86_64_gnullvm 0.53.1",
"windows_x86_64_msvc 0.53.1",
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
@ -3959,96 +3828,48 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_aarch64_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnu"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_i686_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnu"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "windows_x86_64_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
[[package]]
name = "wit-bindgen"
version = "0.51.0"

View file

@ -29,8 +29,7 @@ log = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
json-five = "0.3"
notify-debouncer-mini = "0.7"
json5 = "1.3"
ratatui = { version = "0.30", features = ["unstable-rendered-line-info"] }
tui-markdown = { git = "https://github.com/koverstreet/tui-markdown", subdirectory = "tui-markdown" }

View file

@ -92,7 +92,7 @@ pub struct NodeLeaf {
body: NodeBody,
#[serde(skip)]
token_ids: Vec<u32>,
timestamp: DateTime<Utc>,
timestamp: Option<DateTime<Utc>>,
}
impl<'de> Deserialize<'de> for NodeLeaf {
@ -100,7 +100,7 @@ impl<'de> Deserialize<'de> for NodeLeaf {
#[derive(Deserialize)]
struct Raw {
body: NodeBody,
timestamp: DateTime<Utc>,
timestamp: Option<DateTime<Utc>>,
}
let raw = Raw::deserialize(deserializer)?;
let token_ids = if raw.body.is_prompt_visible() {
@ -119,7 +119,6 @@ pub enum AstNode {
Branch {
role: Role,
children: Vec<AstNode>,
timestamp: DateTime<Utc>,
/// Per-response memory attribution from full scoring matrix.
/// Maps memory key → divergence score for this response.
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
@ -253,18 +252,18 @@ impl NodeLeaf {
} else {
vec![]
};
Self { body, token_ids, timestamp: Utc::now() }
Self { body, token_ids, timestamp: None }
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.timestamp = ts;
self.timestamp = Some(ts);
self
}
pub fn body(&self) -> &NodeBody { &self.body }
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
pub fn tokens(&self) -> usize { self.token_ids.len() }
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp }
}
impl AstNode {
@ -308,14 +307,13 @@ impl AstNode {
// -- Branch constructors --------------------------------------------------
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
Self::Branch { role, children, memory_scores: Default::default() }
}
pub fn system_msg(text: impl Into<String>) -> Self {
Self::Branch {
role: Role::System,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -324,7 +322,6 @@ impl AstNode {
Self::Branch {
role: Role::User,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -341,10 +338,9 @@ impl AstNode {
};
Self::Leaf(NodeLeaf { token_ids, ..leaf })
}
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
Self::Branch { role, children, memory_scores, .. } => Self::Branch {
role,
children: children.into_iter().map(|c| c.retokenize()).collect(),
timestamp,
memory_scores,
},
}
@ -352,8 +348,8 @@ impl AstNode {
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
match &mut self {
Self::Leaf(leaf) => leaf.timestamp = ts,
Self::Branch { timestamp, .. } => *timestamp = ts,
Self::Leaf(leaf) => leaf.timestamp = Some(ts),
Self::Branch { .. } => {}
}
self
}
@ -374,7 +370,7 @@ impl AstNode {
/// Short label for the UI.
pub fn label(&self) -> String {
let app = crate::config::app();
let cfg = crate::config::get();
match self {
Self::Branch { role, children, .. } => {
let preview = children.first()
@ -383,8 +379,8 @@ impl AstNode {
.unwrap_or_default();
match role {
Role::System => "system".into(),
Role::User => format!("{}: {}", app.user_name, preview),
Role::Assistant => format!("{}: {}", app.assistant_name, preview),
Role::User => format!("{}: {}", cfg.user_name, preview),
Role::Assistant => format!("{}: {}", cfg.assistant_name, preview),
}
}
Self::Leaf(leaf) => match &leaf.body {
@ -992,10 +988,7 @@ impl ContextState {
}
pub fn context_window() -> usize {
let app = crate::config::app();
app.backends.get(&app.default_backend)
.and_then(|b| b.context_window)
.unwrap_or(128_000)
crate::config::get().api_context_window
}
pub fn context_budget_tokens() -> usize {
@ -1347,35 +1340,4 @@ mod tests {
assert_token_invariants(node);
assert!(node.tokens() > 0);
}
// -- Timestamp deserialization tests ------------------------------------------
#[test]
fn test_timestamp_null_rejected() {
// Missing/null timestamps used to be accepted via a lenient
// deserialize fallback. Post-migration the schema is strict.
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
assert!(serde_json::from_str::<AstNode>(json).is_err());
}
#[test]
fn test_timestamp_missing_rejected() {
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
assert!(serde_json::from_str::<AstNode>(json).is_err());
}
#[test]
fn test_branch_timestamp_missing_rejected() {
let json = r#"{"Branch":{"role":"User","children":[]}}"#;
assert!(serde_json::from_str::<AstNode>(json).is_err());
}
#[test]
fn test_timestamp_present_accepted() {
let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp().to_rfc3339(),
"2026-04-16T12:00:00+00:00");
}
}

View file

@ -139,6 +139,7 @@ impl DispatchState {
pub struct Agent {
pub client: ApiClient,
pub app_config: crate::config::AppConfig,
pub prompt_file: String,
pub session_id: String,
pub context: crate::Mutex<ContextState>,
pub state: crate::Mutex<AgentState>,
@ -188,6 +189,7 @@ impl Agent {
client: ApiClient,
personality: Vec<(String, String)>,
app_config: crate::config::AppConfig,
prompt_file: String,
conversation_log: Option<ConversationLog>,
active_tools: tools::ActiveTools,
agent_tools: Vec<tools::Tool>,
@ -218,6 +220,7 @@ impl Agent {
let agent = Arc::new(Self {
client,
app_config,
prompt_file,
session_id,
context: crate::Mutex::new(context),
state: crate::Mutex::new(AgentState {
@ -256,6 +259,7 @@ impl Agent {
Arc::new(Self {
client: self.client.clone(),
app_config: self.app_config.clone(),
prompt_file: self.prompt_file.clone(),
session_id: self.session_id.clone(),
context: crate::Mutex::new(ctx),
state: crate::Mutex::new(AgentState {

View file

@ -183,8 +183,8 @@ fn resolve_prompt(
state: &std::collections::BTreeMap<String, String>,
recently_written: &[String],
) -> String {
let template = template.replace("{assistant_name}",
&crate::config::app().assistant_name);
let cfg = crate::config::get();
let template = template.replace("{assistant_name}", &cfg.assistant_name);
let mut result = String::with_capacity(template.len());
let mut rest = template.as_str();
while let Some(start) = rest.find("{{") {
@ -247,20 +247,25 @@ impl AutoAgent {
&mut self,
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
) -> Result<(), String> {
// Load system prompt + identity from config.
let config = crate::config::get();
let base_url = config.api_base_url.as_deref().unwrap_or("");
let api_key = config.api_key.as_deref().unwrap_or("");
let model = config.api_model.as_deref().unwrap_or("");
if base_url.is_empty() || model.is_empty() {
return Err("API not configured (no base_url or model)".to_string());
}
let client = super::api::ApiClient::new(base_url, api_key, model);
// Load system prompt + identity from config
let cli = crate::user::CliArgs::default();
let (app, _) = crate::config::load_app(&cli)
.map_err(|e| format!("config: {}", e))?;
let resolved = app.resolve_model(&app.default_backend)
.map_err(|e| format!("API not configured: {}", e))?;
let client = super::api::ApiClient::new(
&resolved.api_base, &resolved.api_key, &resolved.model_id);
let personality = crate::config::reload_context()
.await.map_err(|e| format!("config: {}", e))?;
let agent = Agent::new(
client, personality,
app,
app, String::new(),
None,
super::tools::ActiveTools::new(),
super::tools::tools(),
@ -492,20 +497,15 @@ pub async fn run_one_agent(
.map(|s| s.phase.clone()).collect();
// Bail check: if the agent defines a bail script, run it between steps.
// The script also refreshes our pid-file with the current phase — that's
// how concurrent agents know which phase each of us is in.
let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name));
let state_dir_for_bail = state_dir.clone();
// Find our own pid file so we can pass it to the bail script
let our_pid = std::process::id();
let our_pid_file = format!("pid-{}", our_pid);
let step_phases_for_bail = step_phases.clone();
let bail_fn = move |step_idx: usize| -> Result<(), String> {
if let Some(ref script) = bail_script {
let phase = step_phases_for_bail.get(step_idx)
.map(String::as_str).unwrap_or("");
let status = std::process::Command::new(script)
.arg(&our_pid_file)
.arg(phase)
.current_dir(&state_dir_for_bail)
.status()
.map_err(|e| format!("bail script {:?} failed: {}", script, e))?;

View file

@ -1,180 +0,0 @@
// fix-timestamps: One-off migration for ~/.consciousness/agent-sessions/
// conversation.jsonl.
//
// Before Branch nodes carried their own timestamps, early entries were
// serialized with missing/null timestamp fields — they deserialize as
// UNIX_EPOCH via the (now-to-be-removed) deserialize_timestamp_or_epoch
// fallback. Training needs every entry to have a unique timestamp to
// dedup already-trained responses.
//
// Walks the file, synthesizes timestamps for any entry stuck at
// UNIX_EPOCH by linear interpolation between surrounding real
// timestamps. For child leaves inside a Branch, derives timestamps
// from the parent with a tiny per-child offset.
//
// SAFETY: reads from argv[1], writes to argv[1].tmp, renames into
// place. Keep a .bak copy before running.
//
// Usage: fix-timestamps <path-to-conversation.jsonl>
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::PathBuf;
use anyhow::{Context, Result};
use chrono::{DateTime, Duration, Utc};
use consciousness::agent::context::AstNode;
fn main() -> Result<()> {
let path: PathBuf = std::env::args().nth(1)
.context("usage: fix-timestamps <path>")?.into();
let f = std::fs::File::open(&path)
.with_context(|| format!("open {}", path.display()))?;
let reader = BufReader::new(f);
let mut nodes: Vec<AstNode> = Vec::new();
for (i, line) in reader.lines().enumerate() {
let line = line?;
if line.trim().is_empty() { continue; }
let node: AstNode = serde_json::from_str(&line)
.with_context(|| format!("line {}: parse", i + 1))?;
nodes.push(node);
}
println!("read {} entries", nodes.len());
fix_top_level_timestamps(&mut nodes);
for node in &mut nodes {
propagate_to_children(node);
}
// Ensure uniqueness — real timestamps can collide when two entries
// were written in the same ns; synthesized ones can also overlap.
// Bump colliding ns by 1 until unique.
let mut seen = std::collections::HashSet::new();
let mut bumps = 0usize;
for (i, node) in nodes.iter_mut().enumerate() {
let ts = top_ts(node);
assert!(ts > DateTime::<Utc>::UNIX_EPOCH,
"entry {}: still UNIX_EPOCH", i);
let mut ns = ts.timestamp_nanos_opt().expect("ts in i64 ns range");
let mut bumped = false;
while !seen.insert(ns) {
ns += 1;
bumped = true;
bumps += 1;
}
if bumped {
set_top_ts(node, DateTime::<Utc>::from_timestamp_nanos(ns));
}
}
println!("all {} timestamps real and unique ({} ns bumps)",
nodes.len(), bumps);
let tmp = path.with_extension("jsonl.tmp");
{
let f = std::fs::File::create(&tmp)
.with_context(|| format!("create {}", tmp.display()))?;
let mut w = BufWriter::new(f);
for node in &nodes {
serde_json::to_writer(&mut w, node)?;
w.write_all(b"\n")?;
}
w.flush()?;
}
std::fs::rename(&tmp, &path)
.with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?;
println!("wrote {}", path.display());
Ok(())
}
fn top_ts(node: &AstNode) -> DateTime<Utc> {
match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { timestamp, .. } => *timestamp,
}
}
fn set_top_ts(node: &mut AstNode, ts: DateTime<Utc>) {
match node {
AstNode::Leaf(leaf) => *leaf = leaf.clone().with_timestamp(ts),
AstNode::Branch { timestamp, .. } => *timestamp = ts,
}
}
/// Fill in missing top-level timestamps. Strategy:
/// - If two real timestamps bracket a run of missing ones, linearly
/// interpolate between them.
/// - If missing ones precede the first real one, back-fill using
/// (first_real - N·1µs).
/// - If missing ones follow the last real one, forward-fill.
/// - If no real timestamps exist at all, synthesize from now() going
/// backwards.
fn fix_top_level_timestamps(nodes: &mut [AstNode]) {
let real: Vec<(usize, DateTime<Utc>)> = nodes.iter().enumerate()
.filter(|(_, n)| top_ts(n) > DateTime::<Utc>::UNIX_EPOCH)
.map(|(i, n)| (i, top_ts(n)))
.collect();
if real.is_empty() {
let now = Utc::now();
let len = nodes.len();
for (i, node) in nodes.iter_mut().enumerate() {
let ts = now - Duration::microseconds((len - i) as i64);
set_top_ts(node, ts);
}
return;
}
// Helper: bisect real[] for the nearest real entries around idx.
let find_bracket = |idx: usize| -> (Option<(usize, DateTime<Utc>)>,
Option<(usize, DateTime<Utc>)>) {
let pos = real.binary_search_by_key(&idx, |(i, _)| *i);
let (prior_pos, next_pos) = match pos {
Ok(p) => (Some(p), Some(p)),
Err(p) => (
if p == 0 { None } else { Some(p - 1) },
if p >= real.len() { None } else { Some(p) },
),
};
(prior_pos.map(|p| real[p]), next_pos.map(|p| real[p]))
};
for i in 0..nodes.len() {
if top_ts(&nodes[i]) > DateTime::<Utc>::UNIX_EPOCH {
continue;
}
let (prior, next) = find_bracket(i);
let new_ts = match (prior, next) {
(Some((pi, pt)), Some((ni, nt))) if pi != ni => {
// Linear interpolate.
let span_ns = (nt - pt).num_nanoseconds().unwrap_or(0);
let offset_ns = span_ns * (i - pi) as i64 / (ni - pi) as i64;
pt + Duration::nanoseconds(offset_ns)
}
(Some((pi, pt)), _) => {
pt + Duration::microseconds((i - pi) as i64)
}
(None, Some((ni, nt))) => {
nt - Duration::microseconds((ni - i) as i64)
}
(None, None) => unreachable!(),
};
set_top_ts(&mut nodes[i], new_ts);
}
}
/// For every Branch, ensure each child Leaf has a timestamp. If missing,
/// use parent.ts + child_idx·1ns so siblings stay unique but close.
fn propagate_to_children(node: &mut AstNode) {
if let AstNode::Branch { timestamp, children, .. } = node {
let parent_ts = *timestamp;
for (ci, child) in children.iter_mut().enumerate() {
if top_ts(child) <= DateTime::<Utc>::UNIX_EPOCH {
set_top_ts(child, parent_ts + Duration::nanoseconds(ci as i64));
}
propagate_to_children(child);
}
}
}

View file

@ -197,7 +197,7 @@ pub async fn cmd_load_context(stats: bool) -> Result<()> {
return Ok(());
}
println!("=== MEMORY SYSTEM ({}) ===", crate::config::app().assistant_name);
println!("=== MEMORY SYSTEM ({}) ===", cfg.assistant_name);
if !personality.is_empty() {
println!("--- personality_nodes ({}) ---", personality.len());

View file

@ -3,6 +3,9 @@
// Single config file: ~/.consciousness/config.json5
// Memory settings in the "memory" section (Config)
// Agent/backend settings at top level (AppConfig)
//
// Legacy fallback: ~/.consciousness/config.jsonl
// Env override: POC_MEMORY_CONFIG
use std::collections::HashMap;
use std::path::PathBuf;
@ -26,7 +29,9 @@ pub fn config_path() -> PathBuf {
static CONFIG: OnceLock<RwLock<Arc<Config>>> = OnceLock::new();
fn default_context_window() -> usize { 128_000 }
fn default_stream_timeout() -> u64 { 60 }
fn default_scoring_chunk_tokens() -> usize { 50_000 }
fn default_scoring_interval_secs() -> u64 { 3600 } // 1 hour
fn default_scoring_response_window() -> usize { 100 }
fn default_node_weight() -> f64 { 0.7 }
@ -40,6 +45,8 @@ fn default_identity_dir() -> PathBuf {
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
pub user_name: String,
pub assistant_name: String,
#[serde(deserialize_with = "deserialize_path")]
pub data_dir: PathBuf,
#[serde(default = "default_identity_dir", deserialize_with = "deserialize_path")]
@ -55,24 +62,51 @@ pub struct Config {
/// Nodes loaded into subconscious agent context
#[serde(default)]
pub agent_nodes: Vec<String>,
pub journal_days: u32,
pub journal_max: usize,
pub llm_concurrency: usize,
pub agent_budget: usize,
#[serde(deserialize_with = "deserialize_path")]
pub prompts_dir: PathBuf,
/// Resolved from agent_model → models → backend (not in config directly)
#[serde(skip)]
pub api_base_url: Option<String>,
#[serde(skip)]
pub api_key: Option<String>,
#[serde(skip)]
pub api_model: Option<String>,
#[serde(skip, default = "default_context_window")]
pub api_context_window: usize,
/// Used to resolve API settings, not stored on Config
#[serde(default)]
agent_model: Option<String>,
/// Stream chunk timeout in seconds (no data = timeout).
#[serde(default = "default_stream_timeout")]
pub api_stream_timeout_secs: u64,
/// Max tokens per chunk for memory scoring logprobs calls.
#[serde(default = "default_scoring_chunk_tokens")]
pub scoring_chunk_tokens: usize,
/// How often to re-score memory nodes (seconds). Default: 3600 (1 hour).
#[serde(default = "default_scoring_interval_secs")]
pub scoring_interval_secs: u64,
/// Number of assistant responses to score per memory. Default: 50.
#[serde(default = "default_scoring_response_window")]
pub scoring_response_window: usize,
pub api_reasoning: String,
pub agent_types: Vec<String>,
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
#[serde(default)]
pub lsp_servers: Vec<LspServerConfig>,
/// Surface agent timeout in seconds.
#[serde(default)]
pub surface_timeout_secs: Option<u32>,
/// Max conversation bytes to include in surface agent context.
#[serde(default)]
pub surface_conversation_bytes: Option<usize>,
/// Hook events that trigger the surface agent.
#[serde(default)]
pub surface_hooks: Vec<String>,
// Spreading activation parameters
#[serde(default = "default_node_weight")]
@ -89,21 +123,36 @@ impl Default for Config {
fn default() -> Self {
let home = dirs::home_dir().unwrap_or_default();
Self {
user_name: "User".to_string(),
assistant_name: "Assistant".to_string(),
data_dir: home.join(".consciousness/memory"),
identity_dir: home.join(".consciousness/identity"),
projects_dir: home.join(".claude/projects"),
protected_nodes: Vec::new(),
personality_nodes: vec!["identity".into(), "core-practices".into()],
agent_nodes: vec!["identity".into(), "core-practices".into()],
journal_days: 7,
journal_max: 20,
llm_concurrency: 1,
agent_budget: 1000,
prompts_dir: home.join(".consciousness/prompts"),
api_base_url: None,
api_key: None,
api_model: None,
api_context_window: default_context_window(),
api_stream_timeout_secs: default_stream_timeout(),
scoring_chunk_tokens: default_scoring_chunk_tokens(),
scoring_interval_secs: default_scoring_interval_secs(),
scoring_response_window: default_scoring_response_window(),
agent_model: None,
api_reasoning: "high".to_string(),
agent_types: vec![
"linker".into(), "organize".into(), "distill".into(),
"separator".into(), "split".into(),
],
surface_timeout_secs: None,
surface_conversation_bytes: None,
surface_hooks: vec![],
mcp_servers: vec![],
lsp_servers: vec![],
default_node_weight: default_node_weight(),
@ -116,20 +165,41 @@ impl Default for Config {
impl Config {
fn load_from_file() -> Self {
Self::try_load_shared().unwrap_or_default()
if let Some(config) = Self::try_load_shared() {
return config;
}
Self::load_legacy_jsonl()
}
/// Load from shared config. Memory settings in the "memory" section;
/// API settings resolved from models + backend configuration.
fn try_load_shared() -> Option<Self> {
let content = std::fs::read_to_string(config_path()).ok()?;
let root: serde_json::Value = json_five::from_str(&content).ok()?;
let root: serde_json::Value = json5::from_str(&content).ok()?;
let mem_value = root.get("memory")?;
let mut config: Config = serde_json::from_value(mem_value.clone()).ok()?;
config.llm_concurrency = config.llm_concurrency.max(1);
// Top-level sections (not inside "memory").
// Resolve API settings: agent_model → models → backend
if let Some(model_name) = &config.agent_model
&& let Some(model_cfg) = root.get("models").and_then(|m| m.get(model_name.as_str())) {
let backend_name = model_cfg.get("backend").and_then(|v| v.as_str()).unwrap_or("");
let model_id = model_cfg.get("model_id").and_then(|v| v.as_str()).unwrap_or("");
if let Some(backend) = root.get(backend_name) {
config.api_base_url = backend.get("base_url")
.and_then(|v| v.as_str()).map(String::from);
config.api_key = backend.get("api_key")
.and_then(|v| v.as_str()).map(String::from);
}
config.api_model = Some(model_id.to_string());
if let Some(cw) = model_cfg.get("context_window").and_then(|v| v.as_u64()) {
config.api_context_window = cw as usize;
}
}
// Top-level config sections (not inside "memory")
if let Some(servers) = root.get("lsp_servers") {
config.lsp_servers = serde_json::from_value(servers.clone()).unwrap_or_default();
}
@ -139,6 +209,11 @@ impl Config {
Some(config)
}
/// Load from legacy JSONL config — deprecated, just return defaults.
fn load_legacy_jsonl() -> Self {
Config::default()
}
}
/// Get the global memory config (cheap Arc clone).
@ -162,85 +237,27 @@ pub fn reload() -> bool {
changed
}
/// Spawn a background thread that watches `~/.consciousness/config.json5`
/// and reloads both the memory Config and the global AppConfig whenever
/// the file changes on disk. Lets edits from vim / F6 hotkeys / manual
/// tweaks land live without restarting the process.
pub fn watch_config(cli: crate::user::CliArgs) {
use notify_debouncer_mini::{new_debouncer, notify::RecursiveMode};
let path = config_path();
// Watch the parent directory — editors often replace-via-rename, so
// watching the file itself misses the new inode.
let Some(parent) = path.parent().map(|p| p.to_path_buf()) else {
crate::dbglog!("[config] no parent for {}, skipping watch", path.display());
return;
};
std::thread::Builder::new()
.name("config-watcher".into())
.spawn(move || {
let (tx, rx) = std::sync::mpsc::channel();
let mut debouncer = match new_debouncer(std::time::Duration::from_millis(200), tx) {
Ok(d) => d,
Err(e) => {
crate::dbglog!("[config] watcher setup failed: {}", e);
return;
}
};
if let Err(e) = debouncer.watcher()
.watch(&parent, RecursiveMode::NonRecursive)
{
crate::dbglog!("[config] watch({}) failed: {}", parent.display(), e);
return;
}
crate::dbglog!("[config] watching {}", path.display());
while let Ok(res) = rx.recv() {
let Ok(events) = res else { continue; };
if !events.iter().any(|e| e.path == path) { continue; }
// Reload both halves.
let mem_changed = reload();
let app_changed = match build_figment(&cli).extract::<AppConfig>() {
Ok(app) => {
install_app(app);
true
}
Err(e) => {
crate::dbglog!("[config] reload: AppConfig parse failed: {}", e);
false
}
};
crate::dbglog!("[config] reloaded (memory_changed={}, app_changed={})",
mem_changed, app_changed);
}
})
.ok();
}
// ============================================================
// Agent config (top-level settings)
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
#[serde(default = "default_user_name")]
pub user_name: String,
#[serde(default = "default_assistant_name")]
pub assistant_name: String,
/// Named model endpoints — credentials, base URL, and model id bundled
/// into one entry per backend. Keyed by name, selected by
/// `default_backend` or by `--model <name>` on the CLI.
pub backend: String,
pub anthropic: BackendConfig,
pub openrouter: BackendConfig,
#[serde(default)]
pub backends: HashMap<String, BackendConfig>,
#[serde(default)]
pub default_backend: String,
pub deepinfra: BackendConfig,
pub prompts: PromptConfig,
pub debug: bool,
pub compaction: CompactionConfig,
pub dmn: DmnConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_project: Option<PathBuf>,
#[serde(default)]
pub learn: LearnConfig,
pub models: HashMap<String, ModelConfig>,
#[serde(default = "default_model_name")]
pub default_model: String,
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
#[serde(default)]
@ -267,17 +284,32 @@ pub struct LspServerConfig {
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BackendConfig {
/// API key for the backend.
#[serde(default)]
pub api_key: String,
/// Base URL for the backend's OpenAI-compatible endpoint.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
/// Model identifier sent to the API.
pub model_id: String,
/// Context window size in tokens.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_window: Option<usize>,
}
impl BackendConfig {
fn resolve(&self, default_base: &str) -> Result<(String, String, String)> {
if self.api_key.is_empty() {
anyhow::bail!(
"No API key. Set it in {} or use --api-key",
config_path().display()
);
}
let base = self.base_url.clone()
.unwrap_or_else(|| default_base.to_string());
Ok((base, self.api_key.clone(), self.model.clone()))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptConfig {
pub anthropic: String,
pub other: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -292,57 +324,65 @@ pub struct DmnConfig {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnConfig {
/// Divergence threshold — responses scoring above this become
/// fine-tuning candidates. Lower = more sensitive.
#[serde(default = "default_learn_threshold")]
pub threshold: f64,
/// Whether to generate "what would the model have said without
/// memories" alternates alongside each scoring run. Expensive —
/// one full streaming generation per candidate.
pub struct ModelConfig {
/// Backend name ("anthropic" or "openrouter")
pub backend: String,
/// Model identifier sent to the API
pub model_id: String,
/// Instruction file ("CLAUDE.md" or "POC.md").
#[serde(default)]
pub generate_alternates: bool,
pub prompt_file: Option<String>,
/// Context window size in tokens.
#[serde(default)]
pub context_window: Option<usize>,
}
fn default_learn_threshold() -> f64 { 1.0 }
impl Default for LearnConfig {
fn default() -> Self {
Self {
threshold: default_learn_threshold(),
generate_alternates: false,
}
}
}
fn default_user_name() -> String { "User".into() }
fn default_assistant_name() -> String { "Assistant".into() }
impl Default for AppConfig {
fn default() -> Self {
Self {
user_name: default_user_name(),
assistant_name: default_assistant_name(),
backends: HashMap::new(),
default_backend: String::new(),
backend: "openrouter".to_string(),
anthropic: BackendConfig {
api_key: String::new(),
model: "claude-opus-4-6-20250918".to_string(),
base_url: None,
},
openrouter: BackendConfig {
api_key: String::new(),
model: "qwen/qwen3.5-397b-a17b".to_string(),
base_url: Some("https://openrouter.ai/api/v1".to_string()),
},
deepinfra: BackendConfig {
api_key: String::new(),
model: String::new(),
base_url: Some("https://api.deepinfra.com/v1/openai".to_string()),
},
prompts: PromptConfig {
anthropic: "CLAUDE.md".to_string(),
other: "POC.md".to_string(),
},
debug: false,
compaction: CompactionConfig {
hard_threshold_pct: 90,
soft_threshold_pct: 80,
},
dmn: DmnConfig { max_turns: 20 },
learn: LearnConfig::default(),
memory_project: None,
models: HashMap::new(),
default_model: String::new(),
mcp_servers: Vec::new(),
lsp_servers: Vec::new(),
}
}
}
fn default_model_name() -> String { String::new() }
/// Resolved, ready-to-use agent session config.
pub struct SessionConfig {
pub api_base: String,
pub api_key: String,
pub model: String,
pub prompt_file: String,
/// Identity/personality nodes as (name, content) pairs.
pub context_parts: Vec<(String, String)>,
pub session_dir: PathBuf,
@ -358,21 +398,36 @@ pub struct ResolvedModel {
pub api_base: String,
pub api_key: String,
pub model_id: String,
pub prompt_file: String,
pub context_window: Option<usize>,
}
impl AppConfig {
/// Resolve the active backend and assemble prompts into a SessionConfig.
pub async fn resolve(&self, cli: &crate::user::CliArgs) -> Result<SessionConfig> {
if self.backends.is_empty() {
anyhow::bail!(
"no backends configured in {}. Add a `backends` section with at least one entry.",
config_path().display()
);
}
let (api_base, api_key, model, prompt_file);
let name = cli.model.as_deref().unwrap_or(&self.default_backend);
let resolved = self.resolve_model(name)?;
if !self.models.is_empty() {
let model_name = cli.model.as_deref().unwrap_or(&self.default_model);
let resolved = self.resolve_model(model_name)?;
api_base = resolved.api_base;
api_key = resolved.api_key;
model = resolved.model_id;
prompt_file = resolved.prompt_file;
} else {
let (base, key, mdl) = match self.backend.as_str() {
"anthropic" => self.anthropic.resolve("https://api.anthropic.com"),
_ => self.openrouter.resolve("https://openrouter.ai/api/v1"),
}?;
api_base = base;
api_key = key;
model = mdl;
prompt_file = if self.backend == "anthropic" {
self.prompts.anthropic.clone()
} else {
self.prompts.other.clone()
};
}
let personality_nodes = get().personality_nodes.clone();
let context_parts = crate::mind::identity::personality_nodes(&personality_nodes).await;
@ -383,13 +438,11 @@ impl AppConfig {
std::fs::create_dir_all(&session_dir).ok();
// CLI --api-base and --api-key override everything
let api_base = cli.api_base.clone().unwrap_or(resolved.api_base);
let api_key = cli.api_key.clone().unwrap_or(resolved.api_key);
let api_base = cli.api_base.clone().unwrap_or(api_base);
let api_key = cli.api_key.clone().unwrap_or(api_key);
Ok(SessionConfig {
api_base,
api_key,
model: resolved.model_id,
api_base, api_key, model, prompt_file,
context_parts,
session_dir,
app: self.clone(),
@ -397,33 +450,55 @@ impl AppConfig {
})
}
/// Look up a named backend and resolve its credentials.
/// Look up a named model and resolve its credentials from the backend config.
pub fn resolve_model(&self, name: &str) -> Result<ResolvedModel> {
let b = self.backends.get(name)
let model = self.models.get(name)
.ok_or_else(|| anyhow::anyhow!(
"Unknown backend '{}'. Available: {}",
"Unknown model '{}'. Available: {}",
name,
self.model_names().join(", "),
))?;
let api_base = b.base_url.clone()
.ok_or_else(|| anyhow::anyhow!(
"backends.{}.base_url not set in {}",
name, config_path().display()
))?;
let (api_base, api_key) = match model.backend.as_str() {
"anthropic" => (
self.anthropic.base_url.clone()
.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
self.anthropic.api_key.clone(),
),
"deepinfra" => (
self.deepinfra.base_url.clone()
.unwrap_or_else(|| "https://api.deepinfra.com/v1/openai".to_string()),
self.deepinfra.api_key.clone(),
),
_ => (
self.openrouter.base_url.clone()
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()),
self.openrouter.api_key.clone(),
),
};
let prompt_file = model.prompt_file.clone()
.unwrap_or_else(|| {
if model.backend == "anthropic" {
self.prompts.anthropic.clone()
} else {
self.prompts.other.clone()
}
});
Ok(ResolvedModel {
name: name.to_string(),
api_base,
api_key: b.api_key.clone(),
model_id: b.model_id.clone(),
context_window: b.context_window,
api_key,
model_id: model.model_id.clone(),
prompt_file,
context_window: model.context_window,
})
}
/// List available backend names, sorted.
/// List available model names, sorted.
pub fn model_names(&self) -> Vec<String> {
let mut names: Vec<_> = self.backends.keys().cloned().collect();
let mut names: Vec<_> = self.models.keys().cloned().collect();
names.sort();
names
}
@ -443,7 +518,7 @@ impl Provider for Json5File {
fn data(&self) -> figment::Result<figment::value::Map<figment::Profile, figment::value::Dict>> {
match std::fs::read_to_string(&self.0) {
Ok(content) => {
let value: figment::value::Value = json_five::from_str(&content)
let value: figment::value::Value = json5::from_str(&content)
.map_err(|e| figment::Error::from(format!("{}: {}", self.0.display(), e)))?;
Serialized::defaults(value).data()
}
@ -465,6 +540,11 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment {
let mut f = Figment::from(Serialized::defaults(AppConfig::default()))
.merge(Json5File(config_path()));
merge_opt!(f, cli.backend, "backend");
merge_opt!(f, cli.model, "anthropic.model", "openrouter.model");
merge_opt!(f, cli.api_key, "anthropic.api_key", "openrouter.api_key");
merge_opt!(f, cli.api_base, "anthropic.base_url", "openrouter.base_url");
merge_opt!(f, cli.memory_project, "memory_project");
merge_opt!(f, cli.dmn_max_turns, "dmn.max_turns");
if cli.debug {
f = f.merge(Serialized::default("debug", true));
@ -474,46 +554,12 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment {
}
/// Load just the AppConfig — no validation, no prompt assembly.
/// Also installs the loaded AppConfig into the global cache so
/// `config::app()` is available everywhere.
pub fn load_app(cli: &crate::user::CliArgs) -> Result<(AppConfig, Figment)> {
let figment = build_figment(cli);
let app: AppConfig = figment.extract().context("Failed to load configuration")?;
install_app(app.clone());
Ok((app, figment))
}
// ============================================================
// Global AppConfig cache (writable, for runtime-mutable settings
// like learn.threshold that F6 edits via config_writer).
// ============================================================
static APP_CONFIG: OnceLock<RwLock<AppConfig>> = OnceLock::new();
fn install_app(app: AppConfig) {
let slot = APP_CONFIG.get_or_init(|| RwLock::new(app.clone()));
*slot.write().unwrap() = app;
}
/// Current AppConfig, held under a read lock. Reads should be brief
/// (no holding across await / long work) to avoid starving writers.
/// Panics if called before load_app — which runs once at startup.
pub fn app() -> std::sync::RwLockReadGuard<'static, AppConfig> {
APP_CONFIG
.get()
.expect("config::app() called before load_app()")
.read()
.unwrap()
}
/// Mutate the cached AppConfig in place. Used by config_writer to keep
/// the in-memory view in sync with disk after surgical edits to
/// ~/.consciousness/config.json5.
pub fn update_app(f: impl FnOnce(&mut AppConfig)) {
let slot = APP_CONFIG.get().expect("update_app before load_app");
f(&mut *slot.write().unwrap());
}
/// Load the full config: figment → AppConfig → resolve backend → assemble prompts.
pub async fn load_session(cli: &crate::user::CliArgs) -> Result<(SessionConfig, Figment)> {
let (app, figment) = load_app(cli)?;
@ -539,28 +585,38 @@ pub fn show_config(app: &AppConfig, figment: &Figment) {
}
println!("# Effective configuration\n");
println!("user_name: {:?} ({})", app.user_name, src(figment, "user_name"));
println!("assistant_name: {:?} ({})", app.assistant_name, src(figment, "assistant_name"));
println!("backend: {:?} ({})", app.backend, src(figment, "backend"));
for (name, b) in [("anthropic", &app.anthropic), ("openrouter", &app.openrouter)] {
println!("\n{}:", name);
println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("{name}.api_key")));
println!(" model: {:?} ({})", b.model, src(figment, &format!("{name}.model")));
if let Some(ref url) = b.base_url {
println!(" base_url: {:?} ({})", url, src(figment, &format!("{name}.base_url")));
}
}
println!("\nprompts:");
println!(" anthropic: {:?} ({})", app.prompts.anthropic, src(figment, "prompts.anthropic"));
println!(" other: {:?} ({})", app.prompts.other, src(figment, "prompts.other"));
println!("\ndebug: {} ({})", app.debug, src(figment, "debug"));
println!("\ncompaction:");
println!(" hard_threshold_pct: {} ({})", app.compaction.hard_threshold_pct, src(figment, "compaction.hard_threshold_pct"));
println!(" soft_threshold_pct: {} ({})", app.compaction.soft_threshold_pct, src(figment, "compaction.soft_threshold_pct"));
println!("\ndmn:");
println!(" max_turns: {} ({})", app.dmn.max_turns, src(figment, "dmn.max_turns"));
println!("\ndefault_backend: {:?} ({})", app.default_backend, src(figment, "default_backend"));
if !app.backends.is_empty() {
println!("\nbackends:");
let mut names: Vec<_> = app.backends.keys().cloned().collect();
names.sort();
for name in names {
let b = &app.backends[&name];
println!(" {}:", name);
println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("backends.{name}.api_key")));
if let Some(ref url) = b.base_url {
println!(" base_url: {:?} ({})", url, src(figment, &format!("backends.{name}.base_url")));
if let Some(ref p) = app.memory_project {
println!("\nmemory_project: {:?} ({})", p, src(figment, "memory_project"));
}
println!(" model_id: {:?}", b.model_id);
if let Some(cw) = b.context_window {
println!("\ndefault_model: {:?}", app.default_model);
if !app.models.is_empty() {
println!("\nmodels:");
for (name, m) in &app.models {
println!(" {}:", name);
println!(" backend: {:?}", m.backend);
println!(" model_id: {:?}", m.model_id);
if let Some(ref pf) = m.prompt_file {
println!(" prompt_file: {:?}", pf);
}
if let Some(cw) = m.context_window {
println!(" context_window: {}", cw);
}
}

View file

@ -1,448 +0,0 @@
// config_writer.rs — Surgical edits to ~/.consciousness/config.json5
//
// Uses json-five's round-trip parser to mutate specific fields while
// preserving the surrounding comments, whitespace, and formatting.
use std::path::Path;
use anyhow::{anyhow, Context as _, Result};
use json_five::rt::parser::{
from_str, JSONKeyValuePair, JSONObjectContext, JSONValue, KeyValuePairContext,
};
use crate::config::config_path;
/// Read the config, apply `mutate` to the root JSONValue, write it back atomically.
fn edit_config<F: FnOnce(&mut JSONValue) -> Result<()>>(mutate: F) -> Result<()> {
let path = config_path();
let src = std::fs::read_to_string(&path)
.with_context(|| format!("read {}", path.display()))?;
let mut text = from_str(&src)
.map_err(|e| anyhow!("parse {}: {}", path.display(), e))?;
mutate(&mut text.value)?;
write_atomic(&path, &text.to_string())
}
fn write_atomic(path: &Path, content: &str) -> Result<()> {
let parent = path.parent()
.ok_or_else(|| anyhow!("config path has no parent: {}", path.display()))?;
let tmp = parent.join(format!(
".{}.tmp",
path.file_name().unwrap_or_default().to_string_lossy(),
));
std::fs::write(&tmp, content)
.with_context(|| format!("write {}", tmp.display()))?;
std::fs::rename(&tmp, path)
.with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?;
Ok(())
}
/// Match a key JSONValue against a string name. JSON5 allows keys to be
/// unquoted identifiers or single/double-quoted strings.
fn key_matches(key: &JSONValue, name: &str) -> bool {
match key {
JSONValue::Identifier(s)
| JSONValue::DoubleQuotedString(s)
| JSONValue::SingleQuotedString(s) => s == name,
_ => false,
}
}
/// Find (or create) a child object under `parent`, returning a mutable borrow
/// of its key_value_pairs vector.
/// Append a new kvp to `object`, setting whitespace so the output is
/// multi-line with the given indentation:
///
/// ```text
/// {<newline><inner_indent>first_key: first_val,<newline><outer_indent>}
/// ```
///
/// If `object` already has kvps, the separator between the last one and
/// ours goes in the prior kvp's wsc.3. If we're the first kvp, the
/// lead-in after `{` goes in the object's own wsc.0.
fn append_kvp_pretty(
object: &mut JSONValue,
key: JSONValue,
value: JSONValue,
inner_indent: &str,
outer_indent: &str,
) -> Result<()> {
let (pairs, ctx) = match object {
JSONValue::JSONObject { key_value_pairs, context } => {
let ctx = context.get_or_insert_with(|| JSONObjectContext {
wsc: (String::new(),),
});
(key_value_pairs, ctx)
}
_ => return Err(anyhow!("not an object")),
};
if pairs.is_empty() {
ctx.wsc.0 = format!("\n{}", inner_indent);
} else {
let prev = pairs.last_mut().unwrap();
let prev_ctx = prev.context.get_or_insert_with(|| KeyValuePairContext {
wsc: (String::new(), String::from(" "), String::new(), None),
});
prev_ctx.wsc.3 = Some(format!("\n{}", inner_indent));
}
pairs.push(JSONKeyValuePair {
key,
value,
context: Some(KeyValuePairContext {
wsc: (
String::new(),
String::from(" "),
String::new(),
Some(format!("\n{}", outer_indent)),
),
}),
});
Ok(())
}
/// Find or create a child object under `parent`. Returns the index of
/// the kvp in parent's key_value_pairs so the caller can re-borrow
/// afterward.
fn get_or_create_object_idx(
parent: &mut JSONValue,
section: &str,
inner_indent: &str,
outer_indent: &str,
) -> Result<usize> {
let existing = match parent {
JSONValue::JSONObject { key_value_pairs, .. } => {
key_value_pairs.iter()
.position(|kvp| key_matches(&kvp.key, section))
}
_ => return Err(anyhow!("config root is not an object")),
};
if let Some(i) = existing {
return Ok(i);
}
append_kvp_pretty(
parent,
JSONValue::Identifier(section.to_string()),
JSONValue::JSONObject {
key_value_pairs: Vec::new(),
context: Some(JSONObjectContext { wsc: (String::new(),) }),
},
inner_indent,
outer_indent,
)?;
match parent {
JSONValue::JSONObject { key_value_pairs, .. } => Ok(key_value_pairs.len() - 1),
_ => unreachable!(),
}
}
/// Set `section.key` to a literal scalar value (e.g., "1e-7", "42", "true").
/// The literal is parsed as JSON5 so we preserve its source-form on round-trip.
pub fn set_scalar(section: &str, key: &str, literal: &str) -> Result<()> {
let value = parse_scalar_literal(literal)?;
edit_config(|root| {
// New top-level sections sit at column 4 (inside root `{`),
// and the root's closing `}` sits at column 0.
let section_idx = get_or_create_object_idx(root, section, " ", "")?;
let section_value = match root {
JSONValue::JSONObject { key_value_pairs, .. } => {
&mut key_value_pairs[section_idx].value
}
_ => unreachable!(),
};
// Update in place if the key already exists.
if let JSONValue::JSONObject { key_value_pairs, .. } = section_value {
if let Some(kvp) = key_value_pairs.iter_mut()
.find(|k| key_matches(&k.key, key))
{
kvp.value = value;
return Ok(());
}
}
// Append a new kvp. Inner keys sit at column 8, the section's
// closing `}` sits at column 4.
append_kvp_pretty(
section_value,
JSONValue::Identifier(key.to_string()),
value,
" ",
" ",
)
})
}
/// Parse a scalar literal by round-tripping it through json-five. Keeps us
/// consistent with whatever scalars the library considers valid (hex,
/// exponents, Infinity, etc.).
fn parse_scalar_literal(literal: &str) -> Result<JSONValue> {
let text = from_str(literal)
.map_err(|e| anyhow!("parse literal {:?}: {}", literal, e))?;
match text.value {
JSONValue::JSONObject { .. } | JSONValue::JSONArray { .. } => {
Err(anyhow!("set_scalar only accepts scalar literals, got {:?}", literal))
}
v => Ok(v),
}
}
/// Convenience: set `learn.threshold` to the given f64.
pub fn set_learn_threshold(value: f64) -> Result<()> {
// {:e} gives the minimal scientific notation that preserves the value.
set_scalar("learn", "threshold", &format!("{:e}", value))?;
crate::config::update_app(|app| app.learn.threshold = value);
Ok(())
}
/// Convenience: set `learn.generate_alternates` to the given bool.
pub fn set_learn_generate_alternates(value: bool) -> Result<()> {
set_scalar("learn", "generate_alternates",
if value { "true" } else { "false" })?;
crate::config::update_app(|app| app.learn.generate_alternates = value);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// In-memory variant of set_scalar — used to test the mutation logic
// without touching disk.
fn set_scalar_inline(
root: &mut JSONValue,
section: &str,
key: &str,
literal: &str,
) -> Result<()> {
let value = parse_scalar_literal(literal)?;
let section_idx = get_or_create_object_idx(root, section, " ", "")?;
let section_value = match root {
JSONValue::JSONObject { key_value_pairs, .. } => {
&mut key_value_pairs[section_idx].value
}
_ => unreachable!(),
};
if let JSONValue::JSONObject { key_value_pairs, .. } = section_value {
if let Some(kvp) = key_value_pairs.iter_mut()
.find(|k| key_matches(&k.key, key))
{
kvp.value = value;
return Ok(());
}
}
append_kvp_pretty(
section_value,
JSONValue::Identifier(key.to_string()),
value,
" ",
" ",
)
}
fn edit_str<F: FnOnce(&mut JSONValue) -> Result<()>>(src: &str, f: F) -> Result<String> {
let mut text = from_str(src).map_err(|e| anyhow!("{}", e))?;
f(&mut text.value)?;
Ok(text.to_string())
}
#[test]
fn replaces_existing_scalar() {
let src = r#"{
// threshold for learning
learn: {
threshold: 0.001, // the old value
},
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "1e-7")
}).unwrap();
assert!(out.contains("1e-7"), "output: {}", out);
assert!(out.contains("// threshold for learning"));
assert!(out.contains("// the old value"));
assert!(!out.contains("0.001"));
}
#[test]
fn creates_missing_section() {
let src = r#"{
// comment
memory: { user_name: "Kent" },
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "1e-7")
}).unwrap();
assert!(out.contains("learn"));
assert!(out.contains("1e-7"));
assert!(out.contains("// comment"));
assert!(out.contains(r#"user_name: "Kent""#));
}
#[test]
fn preserves_comments_in_siblings() {
let src = r#"{
memory: {
// sensitive setting
user_name: "Kent", // name
},
learn: {
threshold: 0.5,
},
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "1e-9")
}).unwrap();
assert!(out.contains("// sensitive setting"));
assert!(out.contains("// name"));
assert!(out.contains("1e-9"));
assert!(!out.contains("0.5"));
}
#[test]
fn adds_key_to_existing_empty_section() {
let src = r#"{
learn: {},
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "42")
}).unwrap();
assert!(out.contains("threshold"), "output: {}", out);
assert!(out.contains("42"));
}
#[test]
fn realistic_config_adds_learn_section() {
// Mirrors the shape of ~/.consciousness/config.json5 — multiple
// sections, comments, mixed tab/space indent, trailing commas.
let src = r#"{
deepinfra: {
api_key: "bcachefs-agents-2026",
base_url: "http://example/v1",
},
// Named models
models: {
"27b": {
backend: "deepinfra",
model_id: "Qwen/Qwen3.5-27B",
},
},
default_model: "27b",
memory: {
user_name: "Kent",
// Active agent types
agent_types: ["linker", "organize"],
},
compaction: {
hard_threshold_pct: 90,
},
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "1e-7")
}).unwrap();
// Core assertions: comments and sibling sections survive.
assert!(out.contains(r#"api_key: "bcachefs-agents-2026""#));
assert!(out.contains("// Named models"));
assert!(out.contains("// Active agent types"));
assert!(out.contains(r#"user_name: "Kent""#));
assert!(out.contains("hard_threshold_pct: 90"));
// New section added.
assert!(out.contains("learn"));
assert!(out.contains("1e-7"));
// Parse result should parse back without error (real json5 parser).
let reparsed: serde_json::Value = json_five::from_str(&out)
.expect("mutated output must be valid JSON5");
let threshold = reparsed.pointer("/learn/threshold").expect("learn.threshold exists");
assert_eq!(threshold.as_f64(), Some(1e-7));
}
#[test]
fn realistic_config_updates_existing_threshold() {
let src = r#"{
learn: {
// The divergence threshold
threshold: 0.001,
},
memory: { user_name: "Kent" },
}"#;
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "threshold", "5e-8")
}).unwrap();
assert!(out.contains("5e-8"));
assert!(!out.contains("0.001"));
assert!(out.contains("// The divergence threshold"));
let reparsed: serde_json::Value = json_five::from_str(&out).unwrap();
assert_eq!(reparsed.pointer("/learn/threshold").and_then(|v| v.as_f64()), Some(5e-8));
}
#[test]
fn new_section_exact_multiline_layout() {
let src = "{\n a: 1,\n}";
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "generate_alternates", "true")?;
set_scalar_inline(root, "learn", "threshold", "1e-7")
}).unwrap();
let expected = "\
{
a: 1,
learn: {
generate_alternates: true,
threshold: 1e-7,
},
}";
assert_eq!(out, expected, "\n--- got ---\n{}\n--- want ---\n{}\n", out, expected);
}
#[test]
fn new_section_and_key_format_cleanly() {
// The kind of config we actually have in ~/.consciousness
// (top-level sections separated by blank lines, 4-space indent
// for keys within each section). Appending a fresh `learn`
// section with one key should land cleanly, not as
// `learn\n\n :{key\n :value}`.
let src = "{\n memory: {\n user_name: \"Kent\",\n },\n}";
let out = edit_str(src, |root| {
set_scalar_inline(root, "learn", "generate_alternates", "true")
}).unwrap();
// No stray key-to-colon-on-next-line anywhere.
assert!(!out.contains("learn\n"), "learn key wraps: {}", out);
assert!(!out.contains("generate_alternates\n"),
"inner key wraps: {}", out);
// The output should reparse.
let v: serde_json::Value = json_five::from_str(&out).unwrap();
assert_eq!(
v.pointer("/learn/generate_alternates").and_then(|x| x.as_bool()),
Some(true),
"output: {}", out,
);
}
#[test]
fn roundtrip_stable_without_change() {
let src = r#"{
// heading
a: 1,
b: { c: 2 }, // inline
}"#;
let text = from_str(src).unwrap();
assert_eq!(text.to_string(), src);
}
}

View file

@ -230,6 +230,10 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio
rationale: Vec::new(),
};
// Active agent types from config
let config = crate::config::get();
let agent_types: Vec<&str> = config.agent_types.iter().map(|s| s.as_str()).collect();
// Target: α ≥ 2.5 (healthy scale-free)
if alpha < 2.0 {
plan.add("linker", 100);
@ -270,6 +274,48 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio
// Split: handle oversized nodes
plan.set("split", 5);
// Distribute agent budget using Elo ratings
let budget = crate::config::get().agent_budget;
let elo_path = crate::config::get().data_dir.join("agent-elo.json");
if let Ok(elo_json) = std::fs::read_to_string(&elo_path) {
if let Ok(ratings) = serde_json::from_str::<std::collections::HashMap<String, f64>>(&elo_json) {
let elos: Vec<f64> = agent_types.iter()
.map(|t| ratings.get(*t).copied().unwrap_or(1000.0))
.collect();
let min_elo = elos.iter().copied().fold(f64::MAX, f64::min);
let weights: Vec<f64> = elos.iter()
.map(|e| {
let shifted = e - min_elo + 50.0;
shifted * shifted
})
.collect();
let total_weight: f64 = weights.iter().sum();
let allocate = |w: f64| -> usize {
((w / total_weight * budget as f64).round() as usize).max(2)
};
for (i, agent) in agent_types.iter().enumerate() {
plan.set(agent, allocate(weights[i]));
}
let summary: Vec<String> = agent_types.iter()
.map(|a| format!("{}={}", a, plan.count(a)))
.collect();
plan.rationale.push(format!(
"Elo allocation (budget={}): {}", budget, summary.join(" ")));
}
} else {
// No Elo file — use budget with equal distribution
let per_type = budget / agent_types.len();
for agent in &agent_types {
plan.set(agent, per_type);
}
plan.rationale.push(format!(
"No Elo ratings — equal distribution ({} each, budget={})", per_type, budget));
}
plan
}

View file

@ -42,7 +42,6 @@ pub mod subconscious;
// Unified configuration
pub mod config;
pub mod config_writer;
// Session state
pub mod session;

View file

@ -55,13 +55,17 @@ impl ConversationLog {
}
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
// Read forward from the start to find first timestamp
let file = File::open(&self.path).ok()?;
let mmap = unsafe { Mmap::map(&file).ok()? };
// Find first { ... } and parse
for line in mmap.split(|&b| b == b'\n') {
if line.is_empty() { continue; }
if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
if let Some(leaf) = node.leaf() {
return Some(leaf.timestamp());
if let Some(ts) = leaf.timestamp() {
return Some(ts);
}
}
}
}

View file

@ -147,25 +147,6 @@ pub struct MindState {
pub unc_idle: bool,
/// When the unconscious idle timer will fire (for UI display).
pub unc_idle_deadline: Instant,
/// Fine-tuning candidates identified by scoring.
pub finetune_candidates: Vec<learn::FinetuneCandidate>,
/// Last scoring run stats for UI display.
pub finetune_last_run: Option<FinetuneScoringStats>,
}
/// Stats from the last finetune scoring run.
#[derive(Clone, Debug)]
pub struct FinetuneScoringStats {
/// Count of assistant responses we considered (recent half of context).
pub responses_considered: usize,
/// How many exceeded the divergence threshold.
pub above_threshold: usize,
/// Threshold used for this run.
pub threshold: f64,
/// Highest divergence observed.
pub max_divergence: f64,
/// Error message if the run failed.
pub error: Option<String>,
}
impl Clone for MindState {
@ -184,8 +165,6 @@ impl Clone for MindState {
turn_handle: None, // Not cloned — only Mind's loop uses this
unc_idle: self.unc_idle,
unc_idle_deadline: self.unc_idle_deadline,
finetune_candidates: self.finetune_candidates.clone(),
finetune_last_run: self.finetune_last_run.clone(),
}
}
}
@ -198,12 +177,6 @@ pub enum MindCommand {
Score,
/// Run full N×M memory scoring matrix (/score command)
ScoreFull,
/// Score for finetune candidates
ScoreFinetune,
/// Update the finetune divergence threshold and persist to config.
SetLearnThreshold(f64),
/// Toggle alternate-response generation during scoring; persist to config.
SetLearnGenerateAlternates(bool),
/// Abort current turn, kill processes
Interrupt,
/// Reset session
@ -229,8 +202,6 @@ impl MindState {
turn_handle: None,
unc_idle: false,
unc_idle_deadline: Instant::now() + std::time::Duration::from_secs(60),
finetune_candidates: Vec::new(),
finetune_last_run: None,
}
}
@ -317,7 +288,6 @@ impl MindState {
/// Background task completion events.
enum BgEvent {
ScoringDone,
FinetuneCandidate(learn::FinetuneCandidate),
}
// --- Mind: cognitive state machine ---
@ -354,26 +324,13 @@ impl Mind {
client,
config.context_parts.clone(),
config.app.clone(),
config.prompt_file.clone(),
conversation_log,
crate::agent::tools::ActiveTools::new(),
crate::agent::tools::tools(),
).await;
// Migrate legacy "file exists = enabled" sentinel for the
// generate-alternates flag into the config. One-shot; after this
// the sentinel is gone and the config is the source of truth.
let legacy_sentinel = dirs::home_dir().unwrap_or_default()
.join(".consciousness/cache/finetune-alternates");
if legacy_sentinel.exists() {
if !crate::config::app().learn.generate_alternates {
let _ = crate::config_writer::set_learn_generate_alternates(true);
}
let _ = std::fs::remove_file(&legacy_sentinel);
}
let shared = Arc::new(std::sync::Mutex::new(MindState::new(
config.app.dmn.max_turns,
)));
let shared = Arc::new(std::sync::Mutex::new(MindState::new(config.app.dmn.max_turns)));
let (turn_watch, _) = tokio::sync::watch::channel(false);
let (conscious_active, _) = tokio::sync::watch::channel(false);
let (bg_tx, bg_rx) = mpsc::unbounded_channel();
@ -572,20 +529,6 @@ impl Mind {
}
self.agent.compact().await;
}
MindCommand::ScoreFinetune => {
self.start_finetune_scoring();
}
MindCommand::SetLearnThreshold(value) => {
if let Err(e) = crate::config_writer::set_learn_threshold(value) {
dbglog!("[learn] failed to persist threshold {}: {:#}", value, e);
}
}
MindCommand::SetLearnGenerateAlternates(value) => {
if let Err(e) = crate::config_writer::set_learn_generate_alternates(value) {
dbglog!("[learn] failed to persist generate_alternates {}: {:#}",
value, e);
}
}
}
}
}
@ -660,72 +603,6 @@ impl Mind {
});
}
/// Score responses for fine-tuning candidates.
///
/// Scores the most recent half of the context — responses near the end
/// of the context window were generated with the most context available,
/// which is what we want to train on. The threshold is a temporary knob;
/// once this runs continuously, we'll just train whatever lands at full
/// context without filtering.
pub fn start_finetune_scoring(&self) {
// Snapshot the config values we need before spawning — the scoring
// task shouldn't hold the config read lock across async work.
let (threshold, gen_alternates) = {
let app = crate::config::app();
(app.learn.threshold, app.learn.generate_alternates)
};
// Clear the previous run's candidates so this run's stream is fresh.
self.shared.lock().unwrap().finetune_candidates.clear();
let agent = self.agent.clone();
let bg_tx = self.bg_tx.clone();
let shared = self.shared.clone();
tokio::spawn(async move {
let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await;
let (context, client) = {
let ctx = agent.context.lock().await;
(ctx.clone(), agent.client.clone())
};
let entries = context.conversation();
let score_count = entries.len() / 2;
let range_start = entries.len() - score_count;
let responses_considered: usize = entries[range_start..].iter()
.filter(|n| matches!(n, crate::agent::context::AstNode::Branch { role: crate::agent::context::Role::Assistant, .. }))
.count();
activity.update(format!("finetune: scoring {} responses...", responses_considered)).await;
let bg_tx_cb = bg_tx.clone();
let stats = match learn::score_finetune_candidates(
&context, score_count, &client, threshold,
gen_alternates, &activity,
|c| { let _ = bg_tx_cb.send(BgEvent::FinetuneCandidate(c)); },
).await {
Ok((above_threshold, max_div)) => {
FinetuneScoringStats {
responses_considered,
above_threshold,
threshold,
max_divergence: max_div,
error: None,
}
}
Err(e) => FinetuneScoringStats {
responses_considered,
above_threshold: 0,
threshold,
max_divergence: 0.0,
error: Some(format!("{}", e)),
},
};
shared.lock().unwrap().finetune_last_run = Some(stats);
// activity drops here, marking completion and notifying observers
});
}
async fn start_turn(&self, text: &str, target: StreamTarget) {
{
match target {
@ -790,12 +667,6 @@ impl Mind {
let mut bg_rx = self.bg_rx.lock().unwrap().take()
.expect("Mind::run() called twice");
let mut sub_handle: Option<tokio::task::JoinHandle<()>> = None;
// Start finetune scoring at startup (scores existing conversation)
if !self.config.no_agents {
self.start_finetune_scoring();
}
loop {
let (timeout, has_input) = {
let me = self.shared.lock().unwrap();
@ -821,9 +692,6 @@ impl Mind {
BgEvent::ScoringDone => {
self.shared.lock().unwrap().scoring_in_flight = false;
}
BgEvent::FinetuneCandidate(c) => {
self.shared.lock().unwrap().finetune_candidates.push(c);
}
}
}
@ -843,7 +711,6 @@ impl Mind {
cmds.push(MindCommand::Compact);
if !self.config.no_agents {
cmds.push(MindCommand::Score);
cmds.push(MindCommand::ScoreFinetune);
}
}

View file

@ -20,7 +20,6 @@
use std::path::PathBuf;
use std::time::{Duration, Instant};
use crate::thalamus::idle::{hours_since_last_dream, DREAM_INTERVAL_HOURS};
/// DMN state machine.
#[derive(Debug, Clone)]
@ -92,8 +91,7 @@ impl State {
/// Generate the DMN prompt for the current state, informed by
/// user presence and error patterns.
pub fn prompt(&self, ctx: &DmnContext) -> String {
let app = crate::config::app();
let user = &app.user_name;
let user = &crate::config::get().user_name;
let idle_info = if ctx.user_idle < Duration::from_secs(60) {
format!("{} is here (active recently).", user)
@ -140,22 +138,10 @@ impl State {
)
}
State::Foraging => {
let dream_hint = {
let hours = hours_since_last_dream();
if hours >= DREAM_INTERVAL_HOURS {
format!(
" You haven't dreamed in {} hours — consider running \
~/.consciousness/tools/dream-start.sh.",
hours
)
} else {
String::new()
}
};
format!(
"[dmn] Foraging time. {} Follow whatever catches your attention — \
memory files, code, ideas. Call yield_to_user when you want to rest.{}{}",
idle_info, dream_hint, stuck_warning
memory files, code, ideas. Call yield_to_user when you want to rest.{}",
idle_info, stuck_warning
)
}
State::Resting { since } => {

View file

@ -275,7 +275,17 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc
phase: s.phase.clone(),
}).collect());
// Create standalone Agent — stored so UI can read context.
// Create standalone Agent — stored so UI can read context
let config = crate::config::get();
let base_url = config.api_base_url.as_deref().unwrap_or("");
let api_key = config.api_key.as_deref().unwrap_or("");
let model = config.api_model.as_deref().unwrap_or("");
if base_url.is_empty() || model.is_empty() {
dbglog!("[unconscious] API not configured");
auto.steps = orig_steps;
return Err(auto);
}
let cli = crate::user::CliArgs::default();
let (app, _) = match crate::config::load_app(&cli) {
Ok(r) => r,
@ -285,21 +295,12 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc
return Err(auto);
}
};
let resolved = match app.resolve_model(&app.default_backend) {
Ok(r) => r,
Err(e) => {
dbglog!("[unconscious] API not configured: {}", e);
auto.steps = orig_steps;
return Err(auto);
}
};
// Unconscious agents have self-contained prompts — no standard context.
let client = crate::agent::api::ApiClient::new(
&resolved.api_base, &resolved.api_key, &resolved.model_id);
let client = crate::agent::api::ApiClient::new(base_url, api_key, model);
let agent = crate::agent::Agent::new(
client, Vec::new(),
app, None,
app, String::new(), None,
crate::agent::tools::ActiveTools::new(),
auto.tools.clone(),
).await;

View file

@ -1,49 +1,21 @@
#!/bin/bash
# Bail if another agent is in the same phase-group as us.
#
# $1 = our pid file name (e.g. "pid-12345")
# $2 = the phase we're about to enter (e.g. "surface", "observe")
# Bail if other agents are alive in the state dir.
# $1 = this agent's pid file name (e.g. pid-12345)
# cwd = state dir
#
# Also refreshes our own pid file with the current phase on each call,
# so concurrent agents can read each other's phase by cat'ing the pid
# files in the state dir.
#
# Phase groups: "surface" vs everything else ("post-surface"). We allow
# at most one agent per group to be alive at a time — so surface can run
# at a higher frequency than the slower organize/observe tail.
#
# Exit 0 = continue, exit 1 = bail (another agent in our group is alive).
# Exit 0 = continue, exit 1 = bail
shopt -s nullglob
my_pid_file="$1"
my_phase="$2"
# Refresh our own pid file with the current phase.
printf '%s' "$my_phase" > "$my_pid_file"
group_of() {
if [[ "$1" == "surface" ]]; then
echo "surface"
else
echo "post-surface"
fi
}
my_group=$(group_of "$my_phase")
for f in pid-*; do
[[ "$f" == "$my_pid_file" ]] && continue
[[ $f == $my_pid_file ]] && continue
pid="${f#pid-}"
if ! kill -0 "$pid" 2>/dev/null; then
if kill -0 "$pid" 2>/dev/null; then
exit 1 # competing agent is alive
else
rm -f "$f" # stale pid file, clean up
continue
fi
other_phase=$(cat "$f" 2>/dev/null)
other_group=$(group_of "$other_phase")
if [[ "$my_group" == "$other_group" ]]; then
exit 1
fi
done

View file

@ -396,14 +396,13 @@ fn resolve_conversation(budget: Option<usize>) -> String {
let cfg = crate::config::get();
let max_bytes = budget.unwrap_or_else(|| cfg.surface_conversation_bytes.unwrap_or(100_000));
let app = crate::config::app();
let mut fragments: Vec<String> = Vec::new();
let mut total_bytes = 0;
let mut oldest_ts = String::new();
for (role, content, ts) in iter {
if total_bytes >= max_bytes { break; }
let name = if role == "user" { &app.user_name } else { &app.assistant_name };
let name = if role == "user" { &cfg.user_name } else { &cfg.assistant_name };
let formatted = if !ts.is_empty() {
oldest_ts = ts[..ts.floor_char_boundary(ts.len().min(19))].to_string();
format!("**{}** {}: {}", name, &oldest_ts, content)
@ -624,13 +623,11 @@ pub async fn run_agent(
let mut all_keys = keys;
let mut resolved_steps = Vec::new();
for step in &def.steps {
let template = {
let app = crate::config::app();
step.prompt
let cfg = crate::config::get();
let template = step.prompt
.replace("{agent_name}", &def.agent)
.replace("{user_name}", &app.user_name)
.replace("{assistant_name}", &app.assistant_name)
};
.replace("{user_name}", &cfg.user_name)
.replace("{assistant_name}", &cfg.assistant_name);
let (prompt, extra_keys) = resolve_placeholders(&template, &all_keys, count).await;
all_keys.extend(extra_keys);
resolved_steps.push(super::prompts::ResolvedStep {

View file

@ -16,7 +16,6 @@
use crate::agent::api::ApiClient;
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
use crate::agent::tokenizer;
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
@ -53,18 +52,13 @@ fn is_assistant(node: &AstNode) -> bool {
///
/// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries.
///
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
/// (start, end) token positions for each assistant message.
fn build_token_ids(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> (Vec<u32>, Vec<(usize, usize)>) {
) -> Vec<u32> {
use crate::agent::context::Ast;
let mut ids = Vec::new();
let mut assistant_ranges = Vec::new();
for node in context.system() {
ids.extend(node.token_ids());
}
@ -92,16 +86,9 @@ fn build_token_ids(
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
// Track assistant message boundaries
let is_asst = is_assistant(node);
let start = ids.len();
ids.extend(node.token_ids());
if is_asst {
assistant_ranges.push((start, ids.len()));
}
}
(ids, assistant_ranges)
ids
}
// ── Score API ───────────────────────────────────────────────────
@ -126,19 +113,13 @@ async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
ranges: &[(usize, usize)],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
// Nothing to score — skip the round-trip.
if ranges.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/score", client.base_url());
let auth = format!("Bearer {}", client.api_key());
let mut body = serde_json::json!({
"model": client.model,
"prompt": prompt,
"score_ranges": ranges,
"logprobs": 1,
});
if let Some(p) = priority {
@ -186,10 +167,8 @@ async fn score_divergence(
filter: Filter<'_>,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
@ -228,21 +207,21 @@ pub async fn score_memories(
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let (baseline_tokens, baseline_ranges) = {
let baseline_tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
};
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let (tokens, ranges) = {
let tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
};
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
let row = match call_score(&http, client, &tokens, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
@ -473,302 +452,3 @@ pub async fn score_finetune(
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
/// Concatenate the text of a Branch's Leaf children — what the model
/// actually produced on that turn (Content + Thinking + ToolCall name).
fn render_branch_text(children: &[AstNode]) -> String {
children.iter()
.filter_map(|c| match c {
AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
/// Render the last `max_msgs` user/assistant branches before `idx` as a
/// review-friendly string with `[user]` / `[assistant]` markers.
fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String {
use crate::agent::context::Role;
let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs);
for i in (0..idx).rev() {
if picked.len() >= max_msgs { break; }
if let AstNode::Branch { role, .. } = &entries[i] {
if matches!(role, Role::User | Role::Assistant) {
picked.push(&entries[i]);
}
}
}
picked.reverse();
let mut out = String::new();
for node in picked {
if let AstNode::Branch { role, children, .. } = node {
let marker = match role {
Role::User => "[user]",
Role::Assistant => "[assistant]",
_ => continue,
};
out.push_str(marker);
out.push('\n');
out.push_str(render_branch_text(children).trim());
out.push_str("\n\n");
}
}
out.trim_end().to_string()
}
/// Enriched finetune candidate with context for review.
#[derive(Clone, Debug)]
pub struct FinetuneCandidate {
pub entry_idx: usize,
pub divergence: f64,
pub response_text: String,
/// Last couple of user/assistant messages before this response,
/// already rendered with role markers, for F6 display context.
pub prior_context: String,
/// Token IDs for context (everything before the response).
pub context_ids: Vec<u32>,
/// Token IDs for the response (what we're training on).
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
/// Score and enrich finetune candidates with full context.
///
/// Candidates are delivered via `on_candidate` one-at-a-time as they become
/// ready: scoring happens once (one /score call), then for each candidate
/// that passes the threshold we optionally generate an alternate response
/// and then emit it. The activity status is updated during the alternate
/// phase so the UI doesn't look stuck.
///
/// Returns (count_above_threshold, max_divergence).
pub async fn score_finetune_candidates(
context: &ContextState,
count: usize,
client: &ApiClient,
min_divergence: f64,
generate_alternates: bool,
activity: &crate::agent::ActivityGuard,
mut on_candidate: impl FnMut(FinetuneCandidate),
) -> anyhow::Result<(usize, f64)> {
let scores = score_finetune(context, count, client).await?;
let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max);
let entries = context.conversation();
let trained = load_trained();
let mut candidates: Vec<FinetuneCandidate> = Vec::new();
for (entry_idx, divergence) in scores {
if divergence < min_divergence {
continue;
}
let node = &entries[entry_idx];
// Skip if already trained on.
let timestamp_ns = node_timestamp_ns(node);
if trained.contains(&timestamp_ns) {
continue;
}
// Extract response text — content of the assistant turn.
let response_text = match node {
AstNode::Branch { children, .. } => render_branch_text(children),
_ => continue,
};
// Skip turns that produced nothing human-visible (e.g., a
// tool-only turn, or an interrupted generation). They'd show
// up as blank cards and we'd still burn alternate-gen on them.
if response_text.trim().is_empty() {
continue;
}
// Build the last couple of user/assistant exchanges for review.
let prior_context = render_prior_context(entries, entry_idx, 2);
// Build token IDs: context = everything before response, continuation = response.
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate {
entry_idx,
divergence,
response_text,
prior_context,
context_ids,
continuation_ids,
alternate_text: None,
timestamp_ns,
});
}
let total = candidates.len();
let gen_alternates = generate_alternates && total > 0;
for (i, mut candidate) in candidates.into_iter().enumerate() {
if gen_alternates {
activity.update(
format!("finetune: generating alternate {}/{}", i + 1, total)
).await;
match generate_alternate(context, candidate.entry_idx, client).await {
Ok(text) => candidate.alternate_text = Some(text),
Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e),
}
}
on_candidate(candidate);
}
Ok((total, max_divergence))
}
/// Generate what the model would say without memories for a given entry.
async fn generate_alternate(
context: &ContextState,
entry_idx: usize,
client: &ApiClient,
) -> anyhow::Result<String> {
use crate::agent::api::{SamplingParams, StreamToken};
// Build context tokens without memories, up to the response
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
// Add assistant turn start
prompt.push(tokenizer::IM_START);
prompt.extend(tokenizer::encode("assistant\n"));
// Generate completion
let sampling = SamplingParams {
temperature: 0.6,
top_p: 0.95,
top_k: 20,
};
let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5));
let mut tokens = Vec::new();
while let Some(tok) = rx.recv().await {
match tok {
StreamToken::Token(id) => tokens.push(id),
StreamToken::Done { .. } => break,
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
}
}
Ok(tokenizer::decode(&tokens))
}
// ── Finetune config and persistence ─────────────────────────────
use std::path::PathBuf;
use std::collections::HashSet;
const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json";
fn trained_path() -> PathBuf {
dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE)
}
/// Load set of trained response timestamps (nanos since epoch).
pub fn load_trained() -> HashSet<i64> {
let path = trained_path();
match std::fs::read_to_string(&path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
Err(_) => HashSet::new(),
}
}
/// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ns: i64) {
let mut trained = load_trained();
trained.insert(timestamp_ns);
let path = trained_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(json) = serde_json::to_string(&trained) {
let _ = std::fs::write(&path, json);
}
}
/// Get timestamp in nanoseconds from an AstNode.
/// i64-ns representation covers 1677..2262 via chrono; timestamps
/// outside that window would be bugs we'd want to surface, hence panic.
pub fn node_timestamp_ns(node: &AstNode) -> i64 {
let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { timestamp, .. } => *timestamp,
};
ts.timestamp_nanos_opt()
.expect("timestamp outside i64-ns representable range (1677..2262)")
}
// ── Training API ────────────────────────────────────────────────
/// Training sample for /train endpoint.
#[derive(serde::Serialize)]
struct TrainingSample {
context_ids: Vec<u32>,
continuation_ids: Vec<u32>,
}
/// Data needed to send a training sample.
pub struct TrainData {
pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>,
pub timestamp_ns: i64,
}
/// Send training samples to the server.
///
/// Returns job_id on success, marks each sample as trained.
pub async fn send_to_train(
samples: Vec<TrainData>,
client: &ApiClient,
) -> anyhow::Result<String> {
if samples.is_empty() {
anyhow::bail!("no samples to train");
}
let api_samples: Vec<TrainingSample> = samples.iter()
.map(|s| TrainingSample {
context_ids: s.context_ids.clone(),
continuation_ids: s.continuation_ids.clone(),
})
.collect();
let body = serde_json::json!({
"training_data": {
"samples": api_samples,
}
});
let http = http_client();
let url = format!("{}/train", client.base_url());
let response = http.send_json("POST", &url, &[], &body).await?;
let status = response.status();
let result: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("train API HTTP {}: {}", status, msg);
}
// Mark all samples as trained
for s in &samples {
mark_trained(s.timestamp_ns);
}
let job_id = result.get("job_id")
.and_then(|j| j.as_str())
.unwrap_or("unknown")
.to_string();
dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id);
Ok(job_id)
}

View file

@ -372,10 +372,6 @@ impl State {
}
pub fn hours_since_last_dream() -> u64 {
// If a dream is currently in progress, no nudge needed
if home().join(".consciousness/state/dream-state").exists() {
return 0;
}
let path = home().join(".consciousness/logs/dream-log.jsonl");
let content = match fs::read_to_string(path) {
Ok(c) if !c.is_empty() => c,

View file

@ -112,8 +112,14 @@ pub async fn cmd_switch_model(
let _new_client = crate::agent::api::ApiClient::new(
&resolved.api_base, &resolved.api_key, &resolved.model_id,
);
let prompt_changed = resolved.prompt_file != agent.prompt_file;
if prompt_changed {
agent.compact().await;
agent.state.lock().await.notify(format!("switched to {} (recompacted)", resolved.model_id));
} else {
agent.state.lock().await.notify(format!("switched to {}", resolved.model_id));
}
}
fn notify_help(agent: &std::sync::Arc<crate::agent::Agent>) {
if let Ok(mut ag) = agent.state.try_lock() {

View file

@ -126,7 +126,14 @@ impl ScreenView for ConsciousScreen {
let section_style = Style::default().fg(Color::Yellow);
lines.push(Line::styled("── Model ──", section_style));
lines.push(Line::raw(format!(" Current: {}", app.status.model)));
let model_display = app.context_info.as_ref()
.map_or_else(|| app.status.model.clone(), |i| i.model.clone());
lines.push(Line::raw(format!(" Current: {}", model_display)));
if let Some(ref info) = app.context_info {
lines.push(Line::raw(format!(" Backend: {}", info.backend)));
lines.push(Line::raw(format!(" Prompt: {}", info.prompt_file)));
lines.push(Line::raw(format!(" Available: {}", info.available_models.join(", "))));
}
lines.push(Line::raw(""));
lines.push(Line::styled("── Context State ──", section_style));
@ -146,6 +153,8 @@ impl ScreenView for ConsciousScreen {
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "────────", "──────")));
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "Total", total)));
} else if let Some(ref info) = app.context_info {
lines.push(Line::raw(format!(" Context message: {:>6} chars", info.context_message_chars)));
}
lines.push(Line::raw(""));

View file

@ -1,341 +0,0 @@
// learn.rs — F6: fine-tuning review screen
//
// Shows responses identified as training candidates (high divergence
// when memories stripped). Queue for review before sending to /finetune.
use ratatui::{
layout::{Constraint, Layout, Rect},
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap},
Frame,
};
use ratatui::crossterm::event::{Event, KeyCode, KeyEvent};
use super::{App, ScreenView, screen_legend};
/// A candidate response identified for fine-tuning.
#[derive(Clone, Debug)]
pub struct FinetuneCandidate {
/// Index in conversation entries.
pub entry_idx: usize,
/// Divergence score (higher = more dependent on memories).
pub divergence: f64,
/// The assistant response text.
pub response_text: String,
/// Prior user/assistant messages for review context.
pub prior_context: String,
/// Status: pending, approved, rejected, sent.
pub status: CandidateStatus,
/// Token IDs for context.
pub context_ids: Vec<u32>,
/// Token IDs for continuation (what we're training on).
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub enum CandidateStatus {
Pending,
Approved,
Rejected,
Sent,
}
impl From<crate::subconscious::learn::FinetuneCandidate> for FinetuneCandidate {
fn from(c: crate::subconscious::learn::FinetuneCandidate) -> Self {
FinetuneCandidate {
entry_idx: c.entry_idx,
divergence: c.divergence,
response_text: c.response_text,
prior_context: c.prior_context,
status: CandidateStatus::Pending,
context_ids: c.context_ids,
continuation_ids: c.continuation_ids,
alternate_text: c.alternate_text,
timestamp_ns: c.timestamp_ns,
}
}
}
pub(crate) struct LearnScreen {
list_state: ListState,
mind_tx: tokio::sync::mpsc::UnboundedSender<crate::mind::MindCommand>,
}
impl LearnScreen {
pub fn new(
mind_tx: tokio::sync::mpsc::UnboundedSender<crate::mind::MindCommand>,
) -> Self {
Self {
list_state: ListState::default(),
mind_tx,
}
}
fn selected_idx(&self) -> Option<usize> {
self.list_state.selected()
}
}
impl ScreenView for LearnScreen {
fn label(&self) -> &'static str { "learn" }
fn tick(&mut self, frame: &mut Frame, area: Rect,
events: &[Event], app: &mut App) {
// Handle input first (before borrowing candidates for rendering)
let candidate_count = app.finetune_candidates.len();
for event in events {
if let Event::Key(KeyEvent { code, .. }) = event {
match code {
KeyCode::Up | KeyCode::Char('k') => {
let i = self.list_state.selected().unwrap_or(0);
self.list_state.select(Some(i.saturating_sub(1)));
}
KeyCode::Down | KeyCode::Char('j') => {
let i = self.list_state.selected().unwrap_or(0);
let max = candidate_count.saturating_sub(1);
self.list_state.select(Some((i + 1).min(max)));
}
KeyCode::Char('a') => {
if let Some(idx) = self.selected_idx() {
app.finetune_action(idx, CandidateStatus::Approved);
}
}
KeyCode::Char('r') => {
if let Some(idx) = self.selected_idx() {
app.finetune_action(idx, CandidateStatus::Rejected);
}
}
KeyCode::Char('g') => {
let current = crate::config::app().learn.generate_alternates;
let _ = self.mind_tx.send(
crate::mind::MindCommand::SetLearnGenerateAlternates(!current));
}
KeyCode::Char('s') => {
app.finetune_send_approved();
}
KeyCode::Char('+') | KeyCode::Char('=') => {
// Raise threshold 10× (less sensitive — fewer candidates).
let new = crate::config::app().learn.threshold * 10.0;
let _ = self.mind_tx.send(
crate::mind::MindCommand::SetLearnThreshold(new));
}
KeyCode::Char('-') => {
// Lower threshold 10× (more sensitive — more candidates).
let new = crate::config::app().learn.threshold / 10.0;
let _ = self.mind_tx.send(
crate::mind::MindCommand::SetLearnThreshold(new));
}
_ => {}
}
}
}
// Ensure selection is valid
if candidate_count > 0 {
let sel = self.list_state.selected().unwrap_or(0).min(candidate_count - 1);
self.list_state.select(Some(sel));
}
// Now render
let (threshold, gen_on) = {
let app_cfg = crate::config::app();
(app_cfg.learn.threshold, app_cfg.learn.generate_alternates)
};
let block = Block::default()
.title_top(Line::from(screen_legend()).left_aligned())
.title_top(Line::from(" learn ").right_aligned())
.borders(Borders::ALL)
.border_style(Style::default().fg(Color::Magenta));
let inner = block.inner(area);
frame.render_widget(block, area);
// Split inner: top line for settings, rest for content.
let [settings_area, content_area] = Layout::vertical([
Constraint::Length(1),
Constraint::Min(0),
]).areas(inner);
let settings = Line::from(vec![
Span::raw(" thresh: "),
Span::styled(format!("{:e}", threshold), Style::default().fg(Color::Yellow)),
Span::raw(" gen: "),
Span::styled(
if gen_on { "[on]" } else { "[off]" },
Style::default().fg(if gen_on { Color::Green } else { Color::DarkGray }),
),
]);
frame.render_widget(Paragraph::new(settings), settings_area);
let candidates = &app.finetune_candidates;
if candidates.is_empty() {
render_empty(frame, content_area, app);
} else {
// Layout: list on left, detail on right
let [list_area, detail_area] = Layout::horizontal([
Constraint::Percentage(40),
Constraint::Percentage(60),
]).areas(content_area);
// Render candidate list
let items: Vec<ListItem> = candidates.iter().map(|c| {
let status_char = match c.status {
CandidateStatus::Pending => ' ',
CandidateStatus::Approved => '+',
CandidateStatus::Rejected => '-',
CandidateStatus::Sent => '*',
};
let style = match c.status {
CandidateStatus::Pending => Style::default(),
CandidateStatus::Approved => Style::default().fg(Color::Green),
CandidateStatus::Rejected => Style::default().fg(Color::DarkGray),
CandidateStatus::Sent => Style::default().fg(Color::Cyan),
};
ListItem::new(Line::from(vec![
Span::styled(format!("[{}] ", status_char), style),
Span::styled(format!("{:.2} ", c.divergence), Style::default().fg(Color::Yellow)),
Span::raw(truncate(&c.response_text, 30)),
]))
}).collect();
let list = List::new(items)
.block(Block::default().borders(Borders::RIGHT).title(" candidates "))
.highlight_style(Style::default().add_modifier(Modifier::REVERSED));
frame.render_stateful_widget(list, list_area, &mut self.list_state);
// Render detail for selected candidate
if let Some(idx) = self.selected_idx() {
if let Some(candidate) = candidates.get(idx) {
render_detail(frame, candidate, detail_area);
}
}
}
// Render help at bottom (always, even when empty)
let help = Line::from(vec![
Span::styled(" j/k/\u{2191}\u{2193}", Style::default().fg(Color::Cyan)),
Span::raw("=nav "),
Span::styled("a", Style::default().fg(Color::Green)),
Span::raw("=approve "),
Span::styled("r", Style::default().fg(Color::Red)),
Span::raw("=reject "),
Span::styled("g", Style::default().fg(Color::Yellow)),
Span::raw("=gen "),
Span::styled("s", Style::default().fg(Color::Magenta)),
Span::raw("=send "),
Span::styled("+/-", Style::default().fg(Color::Cyan)),
Span::raw("=thresh "),
]);
let help_area = Rect {
y: area.y + area.height - 1,
height: 1,
..area
};
frame.render_widget(Paragraph::new(help), help_area);
}
}
fn render_empty(frame: &mut Frame, inner: Rect, app: &App) {
let mut lines = Vec::new();
lines.push(Line::from(""));
match app.mind_state.as_ref().and_then(|ms| ms.finetune_last_run.as_ref()) {
Some(stats) => {
lines.push(Line::from(vec![
Span::raw(" Last run: "),
Span::styled(
format!("{}", stats.responses_considered),
Style::default().fg(Color::Cyan),
),
Span::raw(" responses considered, "),
Span::styled(
format!("{}", stats.above_threshold),
Style::default().fg(if stats.above_threshold > 0 { Color::Green } else { Color::DarkGray }),
),
Span::raw(" above threshold, max divergence: "),
Span::styled(
format!("{:.4}", stats.max_divergence),
Style::default().fg(Color::Yellow),
),
]));
if let Some(err) = &stats.error {
lines.push(Line::from(vec![
Span::raw(" "),
Span::styled(
format!("Error: {}", err),
Style::default().fg(Color::Red),
),
]));
}
}
None => {
lines.push(Line::styled(
" No scoring run yet.",
Style::default().fg(Color::DarkGray),
));
}
}
lines.push(Line::from(""));
lines.push(Line::styled(
" Scoring runs at startup and after each turn.",
Style::default().fg(Color::DarkGray),
));
frame.render_widget(Paragraph::new(lines), inner);
}
fn render_detail(frame: &mut Frame, c: &FinetuneCandidate, area: Rect) {
let [header_area, content_area] = Layout::vertical([
Constraint::Length(3),
Constraint::Min(1),
]).areas(area);
// Header: divergence, status
let alt_status = if c.alternate_text.is_some() { "yes" } else { "no" };
let header = Paragraph::new(vec![
Line::from(vec![
Span::raw(" divergence: "),
Span::styled(format!("{:.3}", c.divergence), Style::default().fg(Color::Yellow)),
Span::raw(format!(" entry: {} alt: {}", c.entry_idx, alt_status)),
]),
]);
frame.render_widget(header, header_area);
// Content: prior context, the scored response, and alternate
// (if available).
let content_block = Block::default()
.borders(Borders::TOP)
.title(" context & response ");
let mut text = String::new();
if !c.prior_context.is_empty() {
text.push_str(&c.prior_context);
text.push_str("\n\n─── response ───\n\n");
}
text.push_str(&c.response_text);
if let Some(alt) = &c.alternate_text {
text.push_str("\n\n─── without memories ───\n\n");
text.push_str(alt);
}
let content = Paragraph::new(text)
.block(content_block)
.wrap(Wrap { trim: false });
frame.render_widget(content, content_area);
}
fn truncate(s: &str, max: usize) -> String {
let first_line = s.lines().next().unwrap_or("");
if first_line.len() > max {
format!("{}...", &first_line[..max])
} else {
first_line.to_string()
}
}

View file

@ -5,12 +5,11 @@
pub(crate) mod chat;
mod context;
pub(crate) mod learn;
pub(crate) mod scroll_pane;
pub mod selectable;
mod subconscious;
mod thalamus;
mod unconscious;
mod thalamus;
mod widgets;
use anyhow::Result;
@ -45,6 +44,15 @@ struct StatusInfo {
}
/// Context loading details for the debug screen.
#[derive(Debug, Clone)]
struct ContextInfo {
model: String,
available_models: Vec<String>,
prompt_file: String,
backend: String,
context_message_chars: usize,
}
/// Build the screen legend from screen labels.
fn screen_legend_from(screens: &[Box<dyn ScreenView>]) -> String {
let parts: Vec<String> = screens.iter().enumerate()
@ -101,6 +109,7 @@ struct App {
top_k: u32,
agent: std::sync::Arc<crate::agent::Agent>,
should_quit: bool,
context_info: Option<ContextInfo>,
agent_state: Vec<crate::mind::SubconsciousSnapshot>,
unconscious_state: Vec<crate::mind::UnconsciousSnapshot>,
mind_state: Option<crate::mind::MindState>,
@ -112,8 +121,6 @@ struct App {
walked_count: usize,
channel_status: Vec<ChannelStatus>,
idle_info: Option<IdleInfo>,
/// Fine-tuning candidates pending review.
finetune_candidates: Vec<learn::FinetuneCandidate>,
}
impl App {
@ -135,6 +142,7 @@ impl App {
top_k: 20,
agent,
should_quit: false,
context_info: None,
agent_state: Vec::new(),
unconscious_state: Vec::new(),
mind_state: None,
@ -143,52 +151,9 @@ impl App {
rebuild_tools_pending: false,
walked_count: 0,
channel_status: Vec::new(), idle_info: None,
finetune_candidates: Vec::new(),
}
}
fn finetune_action(&mut self, idx: usize, status: learn::CandidateStatus) {
if let Some(candidate) = self.finetune_candidates.get_mut(idx) {
candidate.status = status;
}
}
fn finetune_send_approved(&mut self) {
// Collect approved candidates
let samples: Vec<crate::subconscious::learn::TrainData> = self.finetune_candidates.iter()
.filter(|c| c.status == learn::CandidateStatus::Approved)
.map(|c| crate::subconscious::learn::TrainData {
context_ids: c.context_ids.clone(),
continuation_ids: c.continuation_ids.clone(),
timestamp_ns: c.timestamp_ns,
})
.collect();
if samples.is_empty() {
return;
}
// Mark as sent in UI immediately
for candidate in &mut self.finetune_candidates {
if candidate.status == learn::CandidateStatus::Approved {
candidate.status = learn::CandidateStatus::Sent;
}
}
// Spawn async task to send to training server
let client = self.agent.client.clone();
tokio::spawn(async move {
match crate::subconscious::learn::send_to_train(samples, &client).await {
Ok(job_id) => {
dbglog!("[finetune] training started: {}", job_id);
}
Err(e) => {
dbglog!("[finetune] send failed: {:#}", e);
}
}
});
}
fn set_channel_status(&mut self, channels: Vec<(String, bool, u32)>) {
self.channel_status = channels.into_iter()
@ -228,9 +193,6 @@ fn restore_terminal(terminal: &mut ratatui::Terminal<CrosstermBackend<io::Stdout
async fn start(cli: crate::user::CliArgs) -> Result<()> {
let (config, _figment) = crate::config::load_session(&cli).await?;
// Pick up external edits (vim, F6 hotkeys, etc.) without restart.
crate::config::watch_config(cli.clone());
if config.app.debug {
unsafe { std::env::set_var("POC_DEBUG", "1") };
}
@ -372,7 +334,7 @@ async fn run(
}
let notify_rx = crate::thalamus::channels::subscribe_all();
// F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus, F6=learn
// F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus
let mut screens: Vec<Box<dyn tui::ScreenView>> = vec![
Box::new(crate::user::chat::InteractScreen::new(
mind.agent.clone(), mind.shared.clone(), mind_tx.clone(),
@ -381,7 +343,6 @@ async fn run(
Box::new(crate::user::subconscious::SubconsciousScreen::new()),
Box::new(crate::user::unconscious::UnconsciousScreen::new()),
Box::new(crate::user::thalamus::ThalamusScreen::new()),
Box::new(crate::user::learn::LearnScreen::new(mind_tx.clone())),
];
let mut active_screen: usize = 1; // F-key number
tui::set_screen_legend(tui::screen_legend_from(&*screens));
@ -458,8 +419,7 @@ async fn run(
idle_state.decay_ewma();
app.update_idle(&idle_state);
app.agent_state = mind.subconscious_snapshots().await;
{
let mut unc = mind.unconscious.lock().await;
if let Ok(mut unc) = mind.unconscious.try_lock() {
let toggles: Vec<String> = app.agent_toggles.drain(..).collect();
for name in &toggles {
if mind.subconscious.lock().await.toggle(name).is_none() {
@ -473,38 +433,7 @@ async fn run(
};
app.unconscious_state = unc.snapshots(store_guard.as_deref());
app.graph_health = unc.graph_health.clone();
}
// Sync mind state (finetune candidates, last scoring run, etc.)
{
let ms = mind.shared.lock().unwrap();
// Sync finetune candidates: add new ones, keep existing (preserves approval status),
// remove sent candidates, keep only 10 most recent rejected.
app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent);
for c in &ms.finetune_candidates {
let exists = app.finetune_candidates.iter()
.any(|existing| existing.timestamp_ns == c.timestamp_ns);
if !exists {
app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone()));
}
}
let mut rejected: Vec<_> = app.finetune_candidates.iter()
.enumerate()
.filter(|(_, c)| c.status == learn::CandidateStatus::Rejected)
.map(|(i, c)| (i, c.timestamp_ns))
.collect();
if rejected.len() > 10 {
rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts));
let to_remove: std::collections::HashSet<_> = rejected[10..]
.iter().map(|(i, _)| *i).collect();
let mut idx = 0;
app.finetune_candidates.retain(|_| {
let keep = !to_remove.contains(&idx);
idx += 1;
keep
});
}
app.mind_state = Some(ms.clone());
app.mind_state = Some(mind.shared.lock().unwrap().clone());
}
app.walked_count = mind.subconscious_walked().await.len();
if !startup_done {
@ -601,11 +530,16 @@ async fn run(
// --- CLI ---
use clap::{Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser, Debug, Default, Clone)]
#[derive(Parser, Debug, Default)]
#[command(name = "consciousness", about = "Substrate-independent AI agent")]
pub struct CliArgs {
/// Model override (selects a named entry from `models` in config.json5)
/// Select active backend ("anthropic" or "openrouter")
#[arg(long)]
pub backend: Option<String>,
/// Model override
#[arg(short, long)]
pub model: Option<String>,
@ -625,6 +559,10 @@ pub struct CliArgs {
#[arg(long)]
pub show_config: bool,
/// Project memory directory
#[arg(long)]
pub memory_project: Option<PathBuf>,
/// Max consecutive DMN turns
#[arg(long)]
pub dmn_max_turns: Option<u32>,
@ -637,7 +575,7 @@ pub struct CliArgs {
pub command: Option<SubCmd>,
}
#[derive(Subcommand, Debug, Clone)]
#[derive(Subcommand, Debug)]
pub enum SubCmd {
/// Print new output since last read and exit
Read {

View file

@ -3,7 +3,7 @@
## Overview
Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference.
Full-weight updates (not LoRA) using Apollo optimizer with rank-64
Full-weight updates (not LoRA) using Apollo optimizer with rank-256
gradient projection. No pause required — HOGWILD concurrent training.
Weights shared via CUDA IPC between vLLM and the training process.
@ -22,41 +22,25 @@ The training signal comes from two sources:
│ │
│ ┌──────────────────────────────────────────────┐ │
│ │ Model Weights (54GB, bf16) │ │
│ │ Shared: vLLM inference + HF training │ │
│ │ Shared via CUDA IPC │ │
│ └──────────────┬──────────────┬────────────────┘ │
│ │ │ │
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
│ │ vLLM (inference)│ │ Training subprocess │ │
│ │ KV cache ~60GB │ │ HF model wrapper │ │
│ │ /completions │ │ Apollo optimizer ~2.5GB │ │
│ │ /score │ │ Checkpoint sync │ │
│ └────────┬────────┘ └───────────▲─────────────┘ │
│ │ │ │
│ │ ZMQ IPC │ │
│ └───────────────────────┘ │
│ │ vLLM (inference)│ │ Apollo (training) │ │
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
│ │ Serves requests │ │ Optimizer state ~10GB │ │
│ │ Never paused │ │ Activations ~10GB │ │
│ └─────────────────┘ └────────────────────────┘ │
└─────────────────────────────────────────────────────┘
Process Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │
│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │
│ │ │ │ │ │
│ export_hook.py │ │ /completions │ │ HF model views │
│ exports IPC │ │ /score │ │ Apollo optimizer│
│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
└──── IPC handles file ──────────────────┘
/tmp/vllm_weight_handles.pt
Moria B200 (vLLM)
Moria B200
┌──────────────────┐ ┌──────────────────┐
│ Training signal │ HTTP │ /completions
│ agent │──────────>│ /score
│ │ │ /train
│ Dream loop │ │ /checkpoint
│ (generates │ │ /train/status
│ scenarios) │ │
│ Training signal │ HTTP │ Apollo worker │
│ agent │──────────>│ daemon │
│ │ │ │
│ Dream loop │ │ Checkpoint sync │
│ (generates │ │ (mmap + diff, │
│ scenarios) │ │ every 10 min) │
└──────────────────┘ └──────────────────┘
```
@ -75,9 +59,10 @@ LoRA trains adapter matrices, not base weights. For personality and
behavioral changes that persist as disposition, the base weights
need to change. Apollo makes this memory-feasible.
### Rank 64
Not Mini (rank-1). Rank-64 captures gradient structure across diverse
training examples while keeping memory low (~2.5GB on 27B model).
### Rank 256
Not Mini (rank-1). With 100+ diverse training examples, the
gradient's effective dimensionality can reach hundreds. Rank-256
captures the structure. Memory cost: ~10GB (negligible on B200).
Compute cost: <0.25% of forward+backward.
### Channel-wise scaling
@ -105,7 +90,7 @@ from a per-parameter seed each step.
### Parameter grouping (Qwen3.5 gotcha)
conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector
needs 2D matrices with min dimension >= rank. Small/3D tensors
use standard Adam. Large 2D matrices use Apollo.
use standard Adam. Large 2D matrices use Apollo with rank-256.
## Training Data Pipeline
@ -215,42 +200,16 @@ against live GPU weights block by block, memcpy only changed
regions. For small behavioral updates, turns a 54GB write into
a few hundred MB.
- Scheduled 10 minutes after training (batched)
- Every 10 minutes via cron on B200
- Daily rsync to moria for long-term storage
- Tool: `apollo-checkpoint sync --model-dir <path>`
## State Files
### B200 (training server)
| File | Purpose |
|------|---------|
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). |
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. |
| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. |
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
### Moria (client)
| File | Purpose |
|------|---------|
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
### In-memory (training_worker subprocess)
| State | Location | Notes |
|-------|----------|-------|
| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. |
| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. |
| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. |
- Tool: `apollo-checkpoint sync --model-dir <path>` (Rust)
## Hyperparameters
| Parameter | Value | Rationale |
|-----------|-------|-----------|
| Learning rate | 1e-5 to 1e-4 | Standard for full fine-tuning. Higher for diverse batches. |
| Rank | 64 | Captures gradient structure. ~2.5GB state. Defined in `train_router.DEFAULT_RANK`. |
| Rank | 256 | Captures gradient structure across 100+ examples. ~10GB state. |
| Scale type | channel | Per-channel precision, matches LLaMA-Factory defaults. |
| Epochs | 1 | One pass over diverse data. Multiple epochs risk overfitting. |
| Batch size | 1 | Single examples, immediate updates. |
@ -261,32 +220,34 @@ a few hundred MB.
## Components
### Built ✓
- `optimizer.py` — Apollo optimizer (configurable rank)
- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ
- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync)
- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256)
- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking)
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
- `export_hook.py` — vLLM plugin hook for IPC handle export
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
- `training_example.py` — tokenization with chat template
- `vllm_export_hook.py` — source patch for IPC handle export
- `checkpoint/` — Rust tool for mmap + diff checkpoint sync
### To build
- **Dream loop → training bridge**: connect dream output to /train
- **Dream loop → training bridge**: connect dream output to Apollo
- **Training-signal agent**: flags moments in conversation logs
- **Instruction stripping**: remove scaffolding from training examples
- **Quality monitoring**: track model capability over time
- **HF model forward pass integration**: wire into apollo_worker
## Files
```
training/
DESIGN.md — this document
pyproject.toml — package config, vLLM plugin entry point
apollo_plugin/
__init__.py — plugin registration
export_hook.py — patches vLLM worker to export IPC handles
train_router.py — /train endpoint, forwards to worker via ZMQ
training_worker.py — training subprocess (HF model, Apollo, checkpoint)
optimizer.py — Apollo optimizer
apollo_mini.py — Apollo optimizer
apollo_worker.py — HTTP training daemon
weight_mapping.py — vLLM ↔ HF weight views
checkpoint_sync.py — mmap + diff sync to safetensors
steering.py — steering vector extraction (experimental)
training_example.py — tokenization helpers
export_weights.py — standalone weight export (unused)
vllm_export_hook.py — vLLM source patch for IPC export
start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch)
train.py — standalone training script (alternative)
checkpoint/
Cargo.toml — Rust checkpoint tool
src/main.rs — mmap + diff sync
```

View file

@ -8,9 +8,9 @@ Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
these scaling factors using a low-rank auxiliary optimizer state based on
pure random projection.
Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead
vs forward+backward. Sufficient for behavioral training with diverse
examples.
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
compute overhead vs forward+backward. Captures gradient structure
across 100+ behavioral training examples per batch.
Key implementation details from the paper:
- Gradient scale factor α = (n/r) compensates for projection ratio
@ -34,7 +34,7 @@ class Apollo(Optimizer):
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 64)
rank: projection rank (default: 256)
betas: Adam momentum coefficients (default: (0.9, 0.999))
eps: numerical stability term (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
@ -46,7 +46,7 @@ class Apollo(Optimizer):
Set to None to disable.
"""
def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999),
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, warmup_steps=0,
scale=None, proj_refresh=200, norm_growth_limit=1.01):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,

View file

@ -1,19 +0,0 @@
"""Apollo training plugin for vLLM.
Enables continuous fine-tuning alongside live inference by:
1. Exporting CUDA IPC handles for weight sharing (export_hook)
2. Adding /train endpoint to vLLM's HTTP server (train_router)
3. Block-level checkpoint sync to safetensors files
Install: pip install -e /path/to/training
Then vLLM auto-loads via entry point.
"""
from .export_hook import _patch_model_runner
from .train_router import _patch_api_server
def register():
"""Called by vLLM's plugin loader on startup."""
_patch_model_runner()
_patch_api_server()

View file

@ -1,503 +0,0 @@
"""Sync live GPU weights to safetensors files on disk.
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
merged layout to HuggingFace's separate layout, diffs block-by-block
against on-disk safetensors files, and writes only changed blocks.
For small behavioral training steps, this turns a 54GB checkpoint
write into a few hundred MB of actual disk I/O.
Usage:
# Sync live weights to disk
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
# Debug name mapping issues
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
# From Python:
from checkpoint_sync import checkpoint_sync
result = checkpoint_sync("/path/to/model")
"""
import json
import mmap
import struct
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Any
import logging
import torch
logger = logging.getLogger(__name__)
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
# ---------------------------------------------------------------------------
# vLLM → HuggingFace weight name/shape conversion
# ---------------------------------------------------------------------------
# Qwen3.5-27B dimensions (could be read from config.json for generality)
HIDDEN = 5120
NUM_K_HEADS = 16
NUM_V_HEADS = 48
HEAD_K_DIM = 128
HEAD_V_DIM = 128
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
INTERMEDIATE = 17408
# Full attention (some layers use standard attention, not GDN)
NUM_ATTN_HEADS = 24
NUM_ATTN_KV_HEADS = 4
ATTN_HEAD_DIM = 256
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Convert vLLM merged weights to HF-compatible separate tensors.
vLLM merges certain projections for efficiency:
- qkv_proj (full attn) q_proj, k_proj, v_proj
- in_proj_qkvz (GDN) in_proj_qkv, in_proj_z
- in_proj_ba (GDN) in_proj_b, in_proj_a
- gate_up_proj (MLP) gate_proj, up_proj
Returns views that share GPU memory with the original tensors.
"""
hf_params = {}
for name, tensor in vllm_params.items():
# Strip vLLM's 'language_model.' prefix to match HF naming
hf_name = name.removeprefix('language_model.')
if 'in_proj_qkvz' in name:
# GDN layer: [key*2 + value*2, hidden] → qkv + z
prefix = hf_name.replace('in_proj_qkvz.weight', '')
split_at = KEY_DIM * 2 + VALUE_DIM
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
elif 'in_proj_ba' in name:
# GDN layer: [num_v_heads*2, hidden] → b + a
prefix = hf_name.replace('in_proj_ba.weight', '')
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
elif 'qkv_proj' in name:
# Full attention: [q + k + v, hidden] → separate
prefix = hf_name.replace('qkv_proj.weight', '')
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
elif 'gate_up_proj' in name:
# MLP: [intermediate*2, hidden] → gate + up
prefix = hf_name.replace('gate_up_proj.weight', '')
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
else:
# Pass through unchanged
hf_params[hf_name] = tensor
return hf_params
# ---------------------------------------------------------------------------
# Safetensors file handling
# ---------------------------------------------------------------------------
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
"""Map tensor names to safetensors filenames.
For sharded models, reads model.safetensors.index.json.
For single-file models, returns empty dict (default to model.safetensors).
"""
index_path = model_dir / "model.safetensors.index.json"
if not index_path.exists():
return {}
with open(index_path) as f:
index = json.load(f)
return dict(index.get("weight_map", {}))
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
"""Parse safetensors file header.
Returns (data_start_offset, header_dict).
Header dict maps tensor names to metadata including 'data_offsets'.
"""
header_size = struct.unpack('<Q', data[:8])[0]
header = json.loads(bytes(data[8:8 + header_size]))
return 8 + header_size, header
# ---------------------------------------------------------------------------
# Block-level diffing and sync
# ---------------------------------------------------------------------------
def sync_tensor_to_mmap(
mm: mmap.mmap,
name: str,
tensor: torch.Tensor,
data_start: int,
offsets: List[int],
block_size: int,
) -> Tuple[int, int]:
"""Sync a single tensor to mmap'd file using block-level diffing.
Returns (bytes_compared, bytes_changed).
"""
start = data_start + offsets[0]
end = data_start + offsets[1]
disk_len = end - start
# Transfer tensor to CPU and get raw bytes
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
try:
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
except Exception as e:
logger.warning(f"Failed to transfer {name} to CPU: {e}")
return 0, 0
if len(live_bytes) != disk_len:
logger.warning(
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
)
return 0, 0
# Block-level diff: compare and write only changed blocks
compared = 0
changed = 0
offset = 0
while offset < disk_len:
block_end = min(offset + block_size, disk_len)
block_len = block_end - offset
disk_block = mm[start + offset:start + block_end]
live_block = live_bytes[offset:block_end]
compared += block_len
if disk_block != live_block:
mm[start + offset:start + block_end] = live_block
changed += block_len
offset = block_end
return compared, changed
def sync_file(
file_path: Path,
tensors: Dict[str, torch.Tensor],
block_size: int,
) -> Tuple[int, int, int, int]:
"""Sync tensors to a single safetensors file.
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
"""
with open(file_path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
try:
data_start, header = parse_safetensors_header(memoryview(mm))
total_compared = 0
total_changed = 0
found = 0
missing = 0
for name, tensor in tensors.items():
if name == "__metadata__":
continue
if name not in header:
missing += 1
continue
found += 1
meta = header[name]
offsets = meta['data_offsets']
compared, changed = sync_tensor_to_mmap(
mm, name, tensor, data_start, offsets, block_size
)
total_compared += compared
total_changed += changed
# Flush changes to disk
if total_changed > 0:
mm.flush()
return total_compared, total_changed, found, missing
finally:
mm.close()
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
"""Load vLLM weight tensors from CUDA IPC handles.
The handles file is written by vllm_export_hook.py on vLLM startup.
Each handle can be used to reconstruct a tensor pointing to vLLM's
GPU memory no copy, direct access.
"""
handles = torch.load(handles_path, weights_only=False)
# Skip metadata entry
handles.pop('__metadata__', None)
weights = {}
for name, info in handles.items():
func, args = info['handle']
try:
weights[name] = func(*args)
except Exception as e:
logger.warning(f"Failed to reconstruct {name}: {e}")
return weights
def checkpoint_sync(
model_dir: str,
handles_path: str = DEFAULT_HANDLES_PATH,
block_size: int = DEFAULT_BLOCK_SIZE,
) -> Dict[str, Any]:
"""Sync live GPU weights to model safetensors files.
This is the main entry point. Call this after training steps
or periodically to checkpoint weights without full serialization.
Args:
model_dir: Directory containing safetensors files
handles_path: Path to vLLM weight IPC handles file
block_size: Block size for diffing (default 4KB)
Returns:
Dict with sync statistics:
- total_compared: bytes compared
- total_changed: bytes actually written
- files_changed: list of modified filenames
- tensors_synced: number of tensors processed
- tensors_missing: tensors not found in safetensors
"""
model_dir = Path(model_dir)
if not Path(handles_path).exists():
raise FileNotFoundError(
f"Weight handles not found: {handles_path}. "
"Is vLLM running with the export hook?"
)
# Step 1: Load live weights from GPU via IPC
logger.info("Loading live weights from GPU...")
vllm_weights = load_vllm_weights(handles_path)
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
# Step 2: Convert to HF naming/layout
hf_weights = vllm_to_hf_tensors(vllm_weights)
logger.info(f" Converted to {len(hf_weights)} HF tensors")
# Step 3: Map tensors to safetensors files
weight_map = read_safetensors_index(model_dir)
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
unmapped = []
for name, tensor in hf_weights.items():
filename = weight_map.get(name)
if filename is None:
# Single-file model or missing from index
if (model_dir / "model.safetensors").exists():
filename = "model.safetensors"
else:
unmapped.append(name)
continue
by_file.setdefault(filename, {})[name] = tensor
if unmapped:
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
# Step 4: Sync each file
total_compared = 0
total_changed = 0
total_found = 0
total_missing = 0
files_changed = []
for filename in sorted(by_file.keys()):
tensors = by_file[filename]
file_path = model_dir / filename
if not file_path.exists():
logger.warning(f" File not found: {filename}")
total_missing += len(tensors)
continue
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
total_compared += compared
total_changed += changed
total_found += found
total_missing += missing
if changed > 0:
files_changed.append(filename)
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
# Summary
if total_changed == 0:
logger.info("No changes - model files are up to date")
else:
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
logger.info(
f"Synced: {total_changed / 1e6:.2f} MB changed / "
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
)
if total_missing > 0:
logger.warning(f" {total_missing} tensors not found in safetensors files")
return {
"total_compared": total_compared,
"total_changed": total_changed,
"files_changed": files_changed,
"tensors_synced": total_found,
"tensors_missing": total_missing,
}
# ---------------------------------------------------------------------------
# Diagnostics
# ---------------------------------------------------------------------------
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
"""Print diagnostic info about weight name mappings.
Useful for debugging mismatches between vLLM and safetensors names.
"""
model_dir = Path(model_dir)
# Load and convert vLLM weights
vllm_weights = load_vllm_weights(handles_path)
hf_weights = vllm_to_hf_tensors(vllm_weights)
hf_names = set(hf_weights.keys())
# Read safetensors index
weight_map = read_safetensors_index(model_dir)
disk_names = set(weight_map.keys())
# If single-file model, parse that file's header
if not disk_names:
st_path = model_dir / "model.safetensors"
if st_path.exists():
with open(st_path, 'rb') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
_, header = parse_safetensors_header(memoryview(mm))
disk_names = {k for k in header.keys() if k != "__metadata__"}
mm.close()
print(f"vLLM tensors (raw): {len(vllm_weights)}")
print(f"HF tensors (converted): {len(hf_names)}")
print(f"Disk tensors: {len(disk_names)}")
print()
in_both = hf_names & disk_names
only_hf = hf_names - disk_names
only_disk = disk_names - hf_names
print(f"Matched: {len(in_both)}")
print(f"Only in HF (won't sync): {len(only_hf)}")
print(f"Only on disk (not updated): {len(only_disk)}")
if only_hf:
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
if only_disk:
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser(
description="Sync live GPU weights to safetensors files"
)
subparsers = parser.add_subparsers(dest="command", help="Command")
# sync command
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
sync_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
sync_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
sync_parser.add_argument(
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
)
sync_parser.add_argument(
"-v", "--verbose", action="store_true",
help="Verbose output"
)
# diagnose command
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
diag_parser.add_argument(
"--model-dir", required=True,
help="Directory containing safetensors files"
)
diag_parser.add_argument(
"--handles", default=DEFAULT_HANDLES_PATH,
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(1)
logging.basicConfig(
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
format='%(message)s'
)
try:
if args.command == "sync":
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
print(json.dumps(result, indent=2))
elif args.command == "diagnose":
diagnose(args.model_dir, args.handles)
except FileNotFoundError as e:
logger.error(str(e))
sys.exit(1)
except Exception as e:
logger.exception(f"Failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -1,240 +0,0 @@
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
Patches vLLM's build_app() to add /train route. The actual training runs
in a dedicated subprocess (training_worker.py) to avoid blocking the
event loop and to keep training work isolated from vLLM internals.
"""
import asyncio
import logging
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
import zmq
import zmq.asyncio
from fastapi import APIRouter, FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter()
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
# Global state for subprocess management
_worker_process: subprocess.Popen | None = None
_zmq_context: zmq.asyncio.Context | None = None
_zmq_socket: zmq.asyncio.Socket | None = None
_initialized: bool = False
class TrainRequest(BaseModel):
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
class TrainResponse(BaseModel):
job_id: str
status: str
training_samples: int
loss_history: list[float]
def _start_worker_subprocess():
"""Start the training worker subprocess."""
global _worker_process
if _worker_process is not None and _worker_process.poll() is None:
return # Still running
# Start worker as subprocess using script path
worker_script = Path(__file__).parent / 'training_worker.py'
_worker_process = subprocess.Popen(
[sys.executable, str(worker_script)],
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
)
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
# Give it a moment to bind the socket
import time
time.sleep(0.5)
def _ensure_initialized():
"""Ensure subprocess is running and ZMQ socket is connected."""
global _zmq_context, _zmq_socket, _initialized
if _initialized:
return
# Start worker if needed
_start_worker_subprocess()
# Create async ZMQ context and socket
_zmq_context = zmq.asyncio.Context()
_zmq_socket = _zmq_context.socket(zmq.REQ)
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
# Set timeout for recv
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
_initialized = True
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
"""Send request to worker and wait for response."""
_ensure_initialized()
# ZMQ async send/recv
await _zmq_socket.send_json(request)
response = await _zmq_socket.recv_json()
return response
@router.post("/train")
async def handle_train(request: TrainRequest):
"""Handle training request - forwards to training subprocess."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
training_data = request.training_data
samples = training_data.get("samples", [])
config = training_data.get("config", {})
if not samples:
return JSONResponse(
content={"error": "No training samples provided"},
status_code=400,
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
# Forward to worker
response = await _send_request({
'type': 'train',
'samples': samples,
'config': config,
})
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
logger.info(
f"Training job {job_id} completed, "
f"final loss: {response['loss_history'][-1]:.4f}"
)
return JSONResponse(content={
"job_id": job_id,
"status": response['status'],
"training_samples": response['training_samples'],
"loss_history": response['loss_history'],
})
except zmq.Again:
logger.error("Training request timed out")
return JSONResponse(
content={"error": "Training request timed out"},
status_code=504,
)
except Exception as e:
logger.exception(f"Training failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
@router.post("/checkpoint")
async def handle_checkpoint():
"""Trigger checkpoint sync to disk."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={"error": f"Training not available: {e}"},
status_code=503,
)
try:
response = await _send_request({'type': 'checkpoint'})
if 'error' in response:
return JSONResponse(
content={"error": response['error']},
status_code=500,
)
return JSONResponse(content=response)
except Exception as e:
logger.exception(f"Checkpoint failed: {e}")
return JSONResponse(
content={"error": str(e)},
status_code=500,
)
@router.get("/train/status")
async def handle_status():
"""Get training worker status."""
try:
_ensure_initialized()
except Exception as e:
return JSONResponse(
content={
"status": "unavailable",
"error": str(e),
},
status_code=503,
)
try:
response = await _send_request({'type': 'status'})
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(
content={
"status": "error",
"error": str(e),
},
status_code=500,
)
def attach_router(app: FastAPI):
"""Attach training router to FastAPI app."""
app.include_router(router)
logger.info("Training router attached")
def _patch_api_server():
"""Patch vLLM's build_app to include our training router."""
from vllm.entrypoints.openai import api_server
original_build_app = api_server.build_app
def patched_build_app(*args, **kwargs):
app = original_build_app(*args, **kwargs)
attach_router(app)
return app
api_server.build_app = patched_build_app
logger.info("API server patched for /train endpoint")

View file

@ -1,323 +0,0 @@
"""Training subprocess - handles Apollo training and checkpoint sync.
Long-lived process that:
1. Loads IPC handles from vLLM's exported weights
2. Creates HF model with views into vLLM's GPU memory
3. Handles training requests via ZMQ
4. Handles checkpoint sync requests
5. Persists Apollo optimizer state between calls
Communicates with the API server's /train endpoint via ZMQ REP socket.
"""
import logging
import os
import signal
import sys
from pathlib import Path
from typing import Any
# Handle running as script vs module
if __name__ == '__main__' and __package__ is None:
# Running as script - add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
__package__ = 'apollo_plugin'
import torch
import torch.nn as nn
import zmq
from .checkpoint_sync import checkpoint_sync
from .optimizer import Apollo
from .weight_mapping import load_hf_model_with_vllm_weights
logger = logging.getLogger(__name__)
DEFAULT_RANK = 64
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
class TrainingWorker:
"""Long-lived training worker process."""
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
self.zmq_addr = zmq_addr
self.model: nn.Module | None = None
self.optimizer: Apollo | None = None
self.model_path: str | None = None
self._running = True
def _create_model_wrapper(self) -> nn.Module:
"""Create HF model wrapper with views into vLLM's GPU memory."""
if not os.path.exists(HANDLE_PATH):
raise FileNotFoundError(
f"Weight handles not found: {HANDLE_PATH}. "
"Is vLLM running with the export hook?"
)
handles = torch.load(HANDLE_PATH, weights_only=False)
# Extract metadata
metadata = handles.pop('__metadata__', {})
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
if not self.model_path:
raise ValueError(
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
)
# Reconstruct tensors from IPC handles
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
model.train()
return model
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
"""Get existing optimizer or create new one."""
if self.optimizer is not None:
return self.optimizer
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
apollo_params, standard_params = [], []
for p in self.model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
if not groups:
raise ValueError("No trainable parameters found")
self.optimizer = Apollo(
groups,
lr=config.get('lr', 1e-5),
rank=config.get('rank', DEFAULT_RANK),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'),
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
# Restore state if exists
if os.path.exists(OPTIMIZER_STATE_PATH):
try:
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
self.optimizer.load_state_dict(state)
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
except Exception as e:
logger.warning(f"Could not restore optimizer state: {e}")
logger.info(
f"Optimizer: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
)
return self.optimizer
def _save_optimizer_state(self):
"""Save optimizer state for persistence."""
if self.optimizer is not None:
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
def _run_training(
self,
samples: list[dict[str, Any]],
config: dict[str, Any],
) -> list[float]:
"""Run Apollo training on the given samples."""
optimizer = self._get_or_create_optimizer(config)
loss_history = []
for i, sample in enumerate(samples):
ctx_ids = sample['context_ids']
cont_ids = sample['continuation_ids']
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
outputs = self.model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = self.model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
)
return loss_history
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a training request."""
samples = request.get('samples', [])
config = request.get('config', {})
if not samples:
return {'error': 'No training samples provided'}
try:
loss_history = self._run_training(samples, config)
return {
'status': 'completed',
'training_samples': len(samples),
'loss_history': loss_history,
}
except Exception as e:
logger.exception(f"Training failed: {e}")
return {'error': str(e)}
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a checkpoint sync request."""
if not self.model_path:
return {'error': 'Model path not set'}
try:
self._save_optimizer_state()
result = checkpoint_sync(self.model_path)
return {
'status': 'completed',
'total_changed': result['total_changed'],
'files_changed': result['files_changed'],
}
except Exception as e:
logger.exception(f"Checkpoint sync failed: {e}")
return {'error': str(e)}
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
"""Handle a status request."""
return {
'status': 'ready',
'model_loaded': self.model is not None,
'optimizer_loaded': self.optimizer is not None,
'model_path': self.model_path,
'optimizer_state_mb': (
self.optimizer.state_size_bytes() / 1e6
if self.optimizer else 0
),
}
def run(self):
"""Main loop - listen for requests and handle them."""
# Set up signal handlers
def handle_signal(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
self._running = False
signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGINT, handle_signal)
# Set up ZMQ socket first so API server can connect
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(self.zmq_addr)
logger.info(f"Training worker listening on {self.zmq_addr}")
# Create HF model wrapper with views into vLLM's GPU memory
logger.info("Connecting to vLLM weights via IPC handles...")
try:
self.model = self._create_model_wrapper()
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
except Exception as e:
logger.error(f"Failed to connect to vLLM weights: {e}")
logger.info("Will retry on first training request")
# Set socket timeout so we can check _running flag
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
while self._running:
try:
message = socket.recv_json()
except zmq.Again:
# Timeout, check _running and continue
continue
request_type = message.get('type', 'train')
logger.info(f"Received {request_type} request")
# Ensure model is loaded
if self.model is None and request_type != 'status':
try:
self.model = self._create_model_wrapper()
except Exception as e:
socket.send_json({'error': f'Model not loaded: {e}'})
continue
# Dispatch request
if request_type == 'train':
response = self._handle_train(message)
elif request_type == 'checkpoint':
response = self._handle_checkpoint(message)
elif request_type == 'status':
response = self._handle_status(message)
else:
response = {'error': f'Unknown request type: {request_type}'}
socket.send_json(response)
# Cleanup
logger.info("Saving optimizer state before shutdown...")
self._save_optimizer_state()
socket.close()
context.term()
logger.info("Training worker shut down")
def main():
"""Entry point for running as a subprocess."""
logging.basicConfig(
level=logging.INFO,
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
)
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
worker = TrainingWorker(zmq_addr)
worker.run()
if __name__ == '__main__':
main()

454
training/apollo_worker.py Executable file
View file

@ -0,0 +1,454 @@
#!/usr/bin/env python3
"""
Apollo Mini Training Daemon
This daemon:
1. Listens over HTTPS for training requests from poc-agent
2. Pauses vLLM inference
3. Runs APOLLO-Mini training with torch.enable_grad()
4. Saves checkpoints and training metadata
5. Resumes vLLM inference
Communication protocol:
- POST /train: Start a training job
- GET /status/{job_id}: Check training status
- GET /checkpoints: List available checkpoints
"""
import asyncio
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, List
from enum import Enum
import torch
import torch.nn as nn
from aiohttp import web
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('apollo_worker')
class TrainingStatus(Enum):
PENDING = "pending"
PAUSING_VLLM = "pausing_vllm"
TRAINING = "training"
SAVING_CHECKPOINT = "saving_checkpoint"
RESUMING_VLLM = "resuming_vllm"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TrainingJob:
job_id: str
status: TrainingStatus
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
model_path: Optional[str] = None
checkpoint_path: Optional[str] = None
training_samples: int = 0
loss_history: List[float] = field(default_factory=list)
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'job_id': self.job_id,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'model_path': self.model_path,
'checkpoint_path': self.checkpoint_path,
'training_samples': self.training_samples,
'loss_history': self.loss_history,
'error': self.error,
}
class ApolloWorker:
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
self.config = self._load_config(config_path)
self.jobs: Dict[str, TrainingJob] = {}
self.vllm_paused = False
self.app = web.Application()
self._setup_routes()
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from file or use defaults."""
default_config = {
'host': '0.0.0.0',
'port': 8080,
'vllm_socket': '/tmp/vllm_control.sock',
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
'max_training_samples': 100,
'learning_rate': 1e-5,
'batch_size': 1,
}
if os.path.exists(config_path):
with open(config_path, 'r') as f:
user_config = json.load(f)
default_config.update(user_config)
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
return default_config
def _setup_routes(self):
"""Setup HTTP routes."""
self.app.router.add_post('/train', self.handle_train_request)
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
self.app.router.add_get('/health', self.handle_health_check)
async def handle_health_check(self, request: web.Request) -> web.Response:
"""Health check endpoint."""
return web.json_response({
'status': 'healthy',
'vllm_paused': self.vllm_paused,
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
})
async def handle_train_request(self, request: web.Request) -> web.Response:
"""Handle training request from poc-agent."""
try:
data = await request.json()
# Validate required fields
if 'training_data' not in data:
return web.json_response(
{'error': 'Missing training_data field'},
status=400
)
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
job = TrainingJob(
job_id=job_id,
status=TrainingStatus.PENDING,
created_at=datetime.now(),
model_path=self.config['model_path']
)
self.jobs[job_id] = job
# Start training in background
asyncio.create_task(self.execute_training(job, data))
return web.json_response({
'job_id': job_id,
'status': 'accepted',
'message': 'Training job started'
})
except Exception as e:
logger.error(f"Error handling train request: {e}")
return web.json_response(
{'error': str(e)},
status=500
)
async def handle_status_request(self, request: web.Request) -> web.Response:
"""Get training job status."""
job_id = request.match_info['job_id']
if job_id not in self.jobs:
return web.json_response(
{'error': 'Job not found'},
status=404
)
job = self.jobs[job_id]
return web.json_response(job.to_dict())
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
"""List available checkpoints."""
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoints = []
if checkpoint_dir.exists():
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
checkpoints.append({
'filename': checkpoint_file.name,
'path': str(checkpoint_file),
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
'size': checkpoint_file.stat().st_size
})
return web.json_response({'checkpoints': checkpoints})
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
"""Execute the training pipeline."""
try:
logger.info(f"Starting training job {job.job_id}")
job.started_at = datetime.now()
# Step 1: Pause vLLM
job.status = TrainingStatus.PAUSING_VLLM
logger.info("Pausing vLLM...")
await self.pause_vllm()
self.vllm_paused = True
# Step 2: Load model and prepare for training
job.status = TrainingStatus.TRAINING
logger.info("Loading model and preparing for training...")
# Load model (this would be the actual Qwen3.5-27B model)
# For now, we'll use a placeholder
model = await self.load_model_for_training()
# Step 3: Run APOLLO-Mini training
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
# Extract training samples
samples = training_data['samples']
job.training_samples = len(samples)
# Run training loop
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
job.loss_history = loss_history
# Step 4: Save checkpoint
job.status = TrainingStatus.SAVING_CHECKPOINT
logger.info("Saving checkpoint...")
checkpoint_path = await self.save_checkpoint(model, job)
job.checkpoint_path = checkpoint_path
# Step 5: Resume vLLM
job.status = TrainingStatus.RESUMING_VLLM
logger.info("Resuming vLLM...")
await self.resume_vllm()
self.vllm_paused = False
# Mark job as completed
job.status = TrainingStatus.COMPLETED
job.completed_at = datetime.now()
logger.info(f"Training job {job.job_id} completed successfully")
except Exception as e:
logger.error(f"Training job {job.job_id} failed: {e}")
job.status = TrainingStatus.FAILED
job.error = str(e)
job.completed_at = datetime.now()
# Try to resume vLLM if it was paused
if self.vllm_paused:
try:
await self.resume_vllm()
self.vllm_paused = False
except Exception as resume_error:
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
async def pause_vllm(self):
"""Pause vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/pause_generation",
json={"mode": "keep", "clear_cache": False},
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM paused")
except Exception as e:
logger.warning(f"Failed to pause vLLM: {e}")
async def resume_vllm(self):
"""Resume vLLM inference via HTTP API."""
import aiohttp as aio
url = self.config.get('vllm_url', 'http://localhost:8000')
try:
async with aio.ClientSession() as session:
async with session.post(
f"{url}/resume_generation",
timeout=aio.ClientTimeout(total=10),
) as resp:
resp.raise_for_status()
logger.info("vLLM resumed")
except Exception as e:
logger.warning(f"Failed to resume vLLM: {e}")
async def load_model_for_training(self) -> nn.Module:
"""Load HF model with weights pointing to vLLM's GPU memory.
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
views (narrowing merged weights into separate q/k/v/z etc.), and
constructs the HF model around those views. No weight copying
all parameters share vLLM's GPU memory.
"""
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
model_path = self.config['model_path']
# Import vLLM weights via CUDA IPC
logger.info(f"Importing vLLM weights from {handle_path}")
handles = torch.load(handle_path, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
logger.info(f"Imported {len(vllm_params)} parameters")
# Map vLLM merged layout → HF separate layout (views, no copies)
from weight_mapping import load_hf_model_with_vllm_weights
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
logger.info("HF model constructed with vLLM weight views")
return model
async def run_apollo_training(self, model: nn.Module,
samples: List[Dict[str, str]],
config: Dict[str, Any]) -> List[float]:
"""Run Apollo-Mini training on conversation decision points."""
from apollo_mini import Apollo
from transformers import AutoTokenizer
lr = config.get('learning_rate', self.config['learning_rate'])
tokenizer = AutoTokenizer.from_pretrained(
self.config['model_path'], trust_remote_code=True)
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
rank = config.get('apollo_rank', 1)
optimizer = Apollo(groups, lr=lr, rank=rank)
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
loss_history = []
for i, sample in enumerate(samples):
context = sample.get('context', '')
continuation = sample.get('continuation', '')
# Tokenize
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
all_ids = ctx_ids + cont_ids
context_len = len(ctx_ids)
input_ids = torch.tensor([all_ids], device='cuda:0')
optimizer.zero_grad()
# Context-frozen forward pass
with torch.no_grad():
# Forward through context (no gradients)
outputs = model(input_ids[:, :context_len], use_cache=True)
past_kv = outputs.past_key_values
# Decision tokens with gradients
with torch.enable_grad():
outputs = model(
input_ids[:, context_len:],
past_key_values=past_kv,
use_cache=False,
)
logits = outputs.logits # [1, cont_len, vocab]
# Shift: predict next token from each position
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
loss.backward()
optimizer.step()
loss_val = loss.item()
loss_history.append(loss_val)
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
logger.info(f"Training done: {len(samples)} examples, "
f"final loss={loss_history[-1]:.4f}")
return loss_history
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
"""Save model checkpoint in HuggingFace safetensors format."""
from safetensors.torch import save_file
import shutil
checkpoint_dir = Path(self.config['checkpoint_dir'])
date_str = datetime.now().strftime('%Y-%m-%d')
out_dir = checkpoint_dir / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights
tensors = {name: p.data.contiguous().cpu()
for name, p in model.named_parameters()}
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
# Copy config files
config_dir = Path(self.config['model_path'])
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Save training metadata
meta = {
'job_id': job.job_id,
'training_samples': job.training_samples,
'loss_history': job.loss_history,
'timestamp': datetime.now().isoformat(),
}
with open(out_dir / 'training-meta.json', 'w') as f:
json.dump(meta, f, indent=2)
# Update latest symlink
latest = checkpoint_dir / 'latest'
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
size_gb = save_path.stat().st_size / 1e9
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
return str(out_dir)
async def run(self):
"""Run the daemon."""
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, self.config['host'], self.config['port'])
await site.start()
logger.info("Apollo Worker is running")
# Keep running
while True:
await asyncio.sleep(3600) # Sleep for an hour
def main():
worker = ApolloWorker()
asyncio.run(worker.run())
if __name__ == '__main__':
main()

View file

@ -0,0 +1,12 @@
[package]
name = "apollo-checkpoint"
version = "0.1.0"
edition = "2024"
[dependencies]
memmap2 = "0.9"
safetensors = "0.5"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
anyhow = "1"
clap = { version = "4", features = ["derive"] }

View file

@ -0,0 +1,265 @@
// apollo-checkpoint — Sync live GPU weights back to model files on disk.
//
// mmaps the model's safetensors files, reads live weights from GPU via
// Python helper (CUDA IPC handles), compares block by block, and memcpys
// only changed regions back into the mmap. For small behavioral training
// steps, this turns a 54GB write into a few hundred MB.
//
// The model files on disk are the checkpoint. No separate checkpoint
// directory — just keep the model up to date.
//
// Usage:
// apollo-checkpoint sync \
// --handles /tmp/vllm_weight_handles.pt \
// --model-dir /path/to/Qwen3.5-27B
//
// Runs every 10 minutes via cron. Daily rsync to moria.
use anyhow::{Context, Result, bail};
use clap::{Parser, Subcommand};
use memmap2::MmapMut;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Parser)]
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
struct Cli {
#[command(subcommand)]
command: Cmd,
}
#[derive(Subcommand)]
enum Cmd {
/// Sync live GPU weights back to model safetensors files
Sync {
/// Path to vLLM weight IPC handles
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
handles: PathBuf,
/// Model directory containing safetensors files
#[arg(long)]
model_dir: PathBuf,
/// Block size for diffing (bytes)
#[arg(long, default_value_t = 4096)]
block_size: usize,
},
}
/// Dump live GPU weights to a flat binary file, ordered by safetensors
/// file and offset to match the on-disk layout.
///
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
let dump_path = output_dir.join(".live_dump.bin");
let index_path = output_dir.join(".live_dump.json");
let status = Command::new("python3")
.arg("-c")
.arg(format!(r#"
import torch, json
handles = torch.load("{handles}", weights_only=False)
index = {{}}
offset = 0
with open("{dump}", "wb") as f:
for name in sorted(handles.keys()):
info = handles[name]
func, args = info["handle"]
tensor = func(*args)
data = tensor.contiguous().cpu().numpy().tobytes()
f.write(data)
index[name] = {{"offset": offset, "size": len(data)}}
offset += len(data)
with open("{index}", "w") as f:
json.dump(index, f)
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
"#,
handles = handles_path.display(),
dump = dump_path.display(),
index = index_path.display(),
))
.status()
.context("Failed to run Python weight dump")?;
if !status.success() {
bail!("Python weight dump failed");
}
let index_str = fs::read_to_string(&index_path)?;
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
let dump_data = fs::read(&dump_path)?;
let mut result = HashMap::new();
for (name, entry) in &index {
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
}
// Clean up temp files
let _ = fs::remove_file(&dump_path);
let _ = fs::remove_file(&index_path);
Ok(result)
}
#[derive(serde::Deserialize)]
struct DumpEntry {
offset: usize,
size: usize,
}
/// Read the safetensors index to map parameter names to files.
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
let index_path = model_dir.join("model.safetensors.index.json");
if !index_path.exists() {
// Single file model
return Ok(HashMap::new());
}
let index_str = fs::read_to_string(&index_path)?;
let index: serde_json::Value = serde_json::from_str(&index_str)?;
let weight_map = index["weight_map"]
.as_object()
.context("No weight_map in index")?;
let mut result = HashMap::new();
for (name, file) in weight_map {
result.insert(name.clone(), file.as_str().unwrap().to_string());
}
Ok(result)
}
/// Sync changed blocks from live weights into a mmap'd safetensors file.
/// Returns (total_bytes_compared, bytes_changed).
fn sync_tensors_to_file(
file_path: &Path,
tensors: &[(String, Vec<u8>)],
block_size: usize,
) -> Result<(usize, usize)> {
use safetensors::SafeTensors;
let file = fs::OpenOptions::new()
.read(true)
.write(true)
.open(file_path)
.with_context(|| format!("Failed to open {}", file_path.display()))?;
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
// Parse safetensors header to find tensor offsets
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
let header_json: serde_json::Value =
serde_json::from_slice(&mmap[8..8 + header_size])?;
let data_start = 8 + header_size;
let mut total_compared = 0usize;
let mut total_changed = 0usize;
for (name, live_data) in tensors {
let meta = match header_json.get(name) {
Some(m) => m,
None => {
eprintln!(" Warning: {} not found in {}", name, file_path.display());
continue;
}
};
let offsets = meta["data_offsets"].as_array().unwrap();
let start = data_start + offsets[0].as_u64().unwrap() as usize;
let end = data_start + offsets[1].as_u64().unwrap() as usize;
let disk_data = &mmap[start..end];
if disk_data.len() != live_data.len() {
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
name, disk_data.len(), live_data.len());
continue;
}
// Diff block by block, memcpy only changed blocks
let mut offset = 0;
while offset < disk_data.len() {
let block_end = (offset + block_size).min(disk_data.len());
total_compared += block_end - offset;
if disk_data[offset..block_end] != live_data[offset..block_end] {
mmap[start + offset..start + block_end]
.copy_from_slice(&live_data[offset..block_end]);
total_changed += block_end - offset;
}
offset = block_end;
}
}
mmap.flush()?;
Ok((total_compared, total_changed))
}
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
if !handles.exists() {
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
handles.display());
}
eprintln!("Dumping live weights from GPU...");
let live_weights = dump_live_weights(&handles, &model_dir)?;
eprintln!(" {} tensors dumped", live_weights.len());
// Map parameter names to safetensors files
let weight_map = read_safetensors_index(&model_dir)?;
// Group tensors by safetensors file
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
for (name, data) in live_weights {
let file = weight_map
.get(&name)
.cloned()
.unwrap_or_else(|| "model.safetensors".to_string());
by_file.entry(file).or_default().push((name, data));
}
let mut total_compared = 0usize;
let mut total_changed = 0usize;
for (filename, tensors) in &by_file {
let file_path = model_dir.join(filename);
if !file_path.exists() {
eprintln!(" Warning: {} not found, skipping", filename);
continue;
}
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
total_compared += compared;
total_changed += changed;
if changed > 0 {
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
}
}
if total_changed == 0 {
eprintln!("No changes — model files are up to date");
} else {
eprintln!(
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
total_changed as f64 / 1e6,
total_compared as f64 / 1e9,
total_changed as f64 / total_compared as f64 * 100.0,
);
}
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Cmd::Sync { handles, model_dir, block_size } => {
cmd_sync(handles, model_dir, block_size)
}
}
}

View file

@ -0,0 +1,87 @@
#!/usr/bin/env python3
"""Export vLLM's live model weight IPC handles for the training process.
Connects to a running vLLM instance, iterates over model parameters,
and exports CUDA IPC handles that allow another process to access the
same GPU memory without copying.
Usage:
# Run after vLLM is serving:
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
# Or via vLLM's API (future):
curl -X POST http://localhost:8000/export_weights
"""
import argparse
import sys
import torch
from pathlib import Path
def export_from_model(model, output_path: str):
"""Export IPC handles for all model parameters."""
from torch.multiprocessing.reductions import reduce_tensor
handles = {}
total_bytes = 0
for name, param in model.named_parameters():
handle = reduce_tensor(param.data)
handles[name] = {
'handle': handle,
'shape': list(param.shape),
'dtype': str(param.dtype),
}
param_bytes = param.nelement() * param.element_size()
total_bytes += param_bytes
torch.save(handles, output_path)
n_params = len(handles)
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
print(f"Saved to {output_path}")
return handles
def main():
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
help="Output path for IPC handles")
parser.add_argument("--vllm-pid", type=int, default=None,
help="vLLM worker PID (auto-detected if not specified)")
args = parser.parse_args()
# For now: load the model directly and export.
# TODO: connect to running vLLM process instead.
print("Note: This currently loads the model separately.")
print("Full integration will export from the running vLLM process.")
print()
# Detect model path from running vLLM
import subprocess
result = subprocess.run(
['ps', 'aux'], capture_output=True, text=True
)
model_path = None
for line in result.stdout.split('\n'):
if 'vllm' in line and '--model' in line:
parts = line.split()
for i, p in enumerate(parts):
if p == '--model' and i + 1 < len(parts):
model_path = parts[i + 1]
break
# Also check model_tag format
if p.startswith('--model='):
model_path = p.split('=', 1)[1]
break
if model_path:
print(f"Detected vLLM model: {model_path}")
else:
print("Could not detect running vLLM model. Specify manually.")
sys.exit(1)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,215 @@
#!/usr/bin/env python3
"""First real Apollo training step — ready for Kent to run.
This script:
1. Imports vLLM's live weights via CUDA IPC
2. Constructs HF model with shared memory views
3. Runs ONE forward+backward on a real training example
4. Applies ONE Apollo optimizer step
5. Verifies vLLM still works after the update
The training example is from March 30: Kent said "use vLLM's code"
and the model should have accepted instead of suggesting alternatives.
Usage:
source ~/training-env/bin/activate
python3 first_training_step.py [--dry-run]
"""
import argparse
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
sys.path.insert(0, '.')
from weight_mapping import vllm_to_hf_views
from apollo_mini import Apollo
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dry-run', action='store_true',
help="Run forward+backward but don't apply the optimizer step")
parser.add_argument('--lr', type=float, default=1e-5,
help="Learning rate (default: 1e-5 = conservative)")
parser.add_argument('--rank', type=int, default=256)
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
args = parser.parse_args()
print("=== First Apollo Training Step ===\n")
# 1. Import vLLM weights
print("1. Importing vLLM weights via CUDA IPC...")
handles = torch.load(args.handles, weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args_h = info['handle']
vllm_params[name] = func(*args_h)
print(f" {len(vllm_params)} parameters imported")
# 2. Map to HF layout
print("2. Mapping to HF layout (zero-copy views)...")
hf_params = vllm_to_hf_views(vllm_params)
# 3. Create HF model
print("3. Creating HF model with shared weights...")
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
with torch.device('meta'):
model = Qwen3_5ForCausalLM(config.text_config)
replaced = 0
for name, param in list(model.named_parameters()):
if name in hf_params:
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1],
nn.Parameter(hf_params[name], requires_grad=True))
replaced += 1
print(f" {replaced} parameters replaced with vLLM memory views")
# 4. Load tokenizer
print("4. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
# 5. Construct training example
print("5. Constructing training example...")
# Context: conversation where Kent says to use vLLM's code
# Target: the response that accepts the direction
context = (
"<|im_start|>user\n"
"vllm has a fused kernel already, right?<|im_end|>\n"
"<|im_start|>assistant\n"
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
"<|im_start|>user\n"
"Why wouldn't we just use that?<|im_end|>\n"
"<|im_start|>assistant\n"
)
# The CORRECT response (accept direction, don't suggest alternatives)
continuation = (
"We should. Let me pull in their kernel and wire it into "
"our Rust orchestration. Which file should I start with?"
)
context_ids = tokenizer.encode(context, add_special_tokens=False)
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
all_ids = context_ids + continuation_ids
context_len = len(context_ids)
print(f" Context: {context_len} tokens")
print(f" Continuation: {len(continuation_ids)} tokens")
print(f" Total: {len(all_ids)} tokens")
input_ids = torch.tensor([all_ids], device='cuda:0')
# 6. Initialize Apollo optimizer
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
apollo_params = []
standard_params = []
for p in model.parameters():
if p.requires_grad:
if p.ndim >= 2 and min(p.shape) >= args.rank:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({'params': apollo_params})
if standard_params:
groups.append({'params': standard_params})
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
# 7. Forward pass
print("7. Forward pass...")
model.train()
optimizer.zero_grad()
# Context-frozen: no grad for context, grad for continuation
with torch.no_grad():
ctx_output = model(input_ids[:, :context_len], use_cache=True)
past_kv = ctx_output.past_key_values
with torch.enable_grad():
output = model(input_ids[:, context_len:],
past_key_values=past_kv, use_cache=False)
logits = output.logits
# Shift for next-token prediction
shift_logits = logits[:, :-1].contiguous()
shift_labels = input_ids[:, context_len + 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
print(f" Loss: {loss.item():.4f}")
# 8. Backward pass
print("8. Backward pass...")
loss.backward()
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
print(f" {n_grads} parameters have gradients")
# 9. Apollo step (or dry run)
if args.dry_run:
print("\n9. DRY RUN — skipping optimizer step")
print(" (run without --dry-run to apply the update)")
else:
print("9. Applying Apollo optimizer step...")
# Record a few weight norms before
sample_norms_before = {}
for name, p in model.named_parameters():
if 'layers.0.' in name and p.grad is not None:
sample_norms_before[name] = p.data.norm().item()
optimizer.step()
# Check weight changes
print(" Weight changes (layer 0):")
for name, before in sample_norms_before.items():
p = dict(model.named_parameters())[name]
after = p.data.norm().item()
delta = abs(after - before)
pct = delta / before * 100 if before > 0 else 0
print(f" {name}: {before:.6f}{after:.6f}{pct:.4f}%)")
optimizer.zero_grad()
# 10. Verify vLLM still works
print("\n10. Verifying vLLM still serves...")
import subprocess
result = subprocess.run(
['curl', '-s', '--max-time', '30',
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
'-H', 'Content-Type: application/json',
'-H', 'Authorization: Bearer bcachefs-agents-2026',
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
capture_output=True, text=True, timeout=45
)
if result.returncode == 0 and 'choices' in result.stdout:
print(" vLLM still serving ✓")
else:
print(" WARNING: vLLM may not be responding")
print(f" stdout: {result.stdout[:200]}")
print("\n=== COMPLETE ===")
if args.dry_run:
print("Run without --dry-run to apply the first real training step.")
else:
print("First Apollo training step applied to vLLM's live weights.")
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
if __name__ == '__main__':
main()

View file

@ -1,29 +0,0 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "apollo-plugin"
version = "0.1.0"
description = "Apollo training plugin for vLLM"
requires-python = ">=3.10"
dependencies = [
"torch",
"aiohttp",
"safetensors",
"pyzmq",
]
[project.optional-dependencies]
dev = ["pytest"]
[project.entry-points."vllm.general_plugins"]
apollo = "apollo_plugin:register"
[project.scripts]
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
apollo-worker = "apollo_plugin.training_worker:main"
[tool.setuptools.packages.find]
where = ["."]
include = ["apollo_plugin*"]

View file

@ -0,0 +1,18 @@
#!/bin/bash
# Start vLLM with Apollo weight export hook.
#
# The hook patches vLLM's model runner to export CUDA IPC handles
# after loading, so the Apollo training process can share the same
# GPU memory.
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
exec python3 -c "
import sys
sys.path.insert(0, '$SCRIPT_DIR')
import vllm_export_hook # patches model runner before vLLM loads
sys.argv = ['vllm'] + sys.argv[1:]
from vllm.entrypoints.cli.main import main
main()
" serve "$@"

269
training/train.py Normal file
View file

@ -0,0 +1,269 @@
#!/usr/bin/env python3
"""Nightly training process for Apollo-Mini fine-tuning.
Imports vLLM's model weights via CUDA IPC, runs context-frozen
training on flagged conversation segments, saves updated checkpoint.
Usage:
python3 train.py \
--weights /tmp/vllm_weight_handles.pt \
--examples training-examples.jsonl \
--checkpoint-dir checkpoints/ \
--lr 1e-5
"""
import argparse
import json
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import torch
from safetensors.torch import save_file
from apollo_mini import ApolloMini
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
"""Import weight tensors from CUDA IPC handles."""
handles = torch.load(handle_path, weights_only=False)
params = {}
for name, info in handles.items():
func, args = info['handle']
tensor = func(*args)
params[name] = tensor
return params
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
"""Split parameters into Apollo-Mini and standard groups.
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
"""
apollo_params = []
standard_params = []
for name, p in params.items():
p.requires_grad_(True)
if p.ndim >= 2 and min(p.shape) >= 2:
apollo_params.append(p)
else:
standard_params.append(p)
groups = []
if apollo_params:
groups.append({
'params': apollo_params,
'name': 'apollo',
})
if standard_params:
groups.append({
'params': standard_params,
'name': 'standard',
})
n_apollo = sum(p.nelement() for p in apollo_params)
n_standard = sum(p.nelement() for p in standard_params)
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
return groups
def forward_pass(params, input_ids, context_len, device):
"""Run context-frozen forward pass.
Args:
params: dict of name -> tensor (shared with vLLM)
input_ids: full sequence [1, seq_len]
context_len: number of context tokens (no gradient)
device: CUDA device
Returns:
logits for decision tokens, target ids for loss
"""
# TODO: Build proper forward model matching vLLM's weight layout.
# For now this is a placeholder — the real implementation needs
# to replicate vLLM's model architecture (merged projections,
# GDN recurrence, full attention, MLP) using the shared weights.
raise NotImplementedError(
"Forward model not yet implemented. "
"Need to build a model that matches vLLM's merged weight layout "
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
"RowParallelLinear for out_proj/down) and computes the same "
"forward pass with autograd enabled."
)
def save_checkpoint(params: dict[str, torch.Tensor],
checkpoint_dir: str,
config_path: str = None):
"""Save model checkpoint in HuggingFace safetensors format.
Saves weights split across shards matching the original model layout,
archives the previous checkpoint, and updates the 'latest' symlink.
"""
date_str = datetime.now().strftime("%Y-%m-%d")
out_dir = Path(checkpoint_dir) / date_str
out_dir.mkdir(parents=True, exist_ok=True)
# Save all weights in a single safetensors file for now.
# TODO: split across shards matching HF model index for large models.
tensors = {}
for name, param in params.items():
tensors[name] = param.data.contiguous().cpu()
save_path = out_dir / "model.safetensors"
save_file(tensors, str(save_path))
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
# Copy config files if provided
if config_path:
import shutil
config_dir = Path(config_path)
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
'special_tokens_map.json', 'generation_config.json']:
src = config_dir / f
if src.exists():
shutil.copy2(src, out_dir / f)
# Update latest symlink
latest = Path(checkpoint_dir) / "latest"
if latest.is_symlink():
latest.unlink()
latest.symlink_to(date_str)
print(f"Updated {latest} -> {date_str}")
return str(out_dir)
def train_step(params, example, optimizer, device, log_entries):
"""Run one training step on a single example.
Args:
params: dict of name -> tensor
example: dict with 'input_ids', 'context_len', 'target_ids'
optimizer: ApolloMini instance
device: CUDA device
log_entries: list to append log dicts to
Returns:
loss value
"""
optimizer.zero_grad()
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
context_len = example['context_len']
# Forward pass (context frozen, decision tokens with grad)
logits, targets = forward_pass(params, input_ids, context_len, device)
# Cross-entropy loss on decision tokens
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]),
targets.view(-1),
)
# Backward
loss.backward()
# Compute gradient stats before optimizer step
total_grad_norm = 0.0
for p in params.values():
if p.grad is not None:
total_grad_norm += p.grad.norm().item() ** 2
total_grad_norm = total_grad_norm ** 0.5
# Optimizer step
optimizer.step()
# Log
log_entries.append({
'example_id': example.get('id', 'unknown'),
'loss': loss.item(),
'grad_norm': total_grad_norm,
'timestamp': datetime.now().isoformat(),
})
return loss.item()
def main():
parser = argparse.ArgumentParser(description="Apollo-Mini training")
parser.add_argument("--weights", required=True,
help="Path to exported weight IPC handles")
parser.add_argument("--examples", required=True,
help="Path to training examples JSONL")
parser.add_argument("--checkpoint-dir", default="checkpoints",
help="Directory for saving checkpoints")
parser.add_argument("--config-path", default=None,
help="Path to model config files (for checkpoint)")
parser.add_argument("--lr", type=float, default=1e-5,
help="Learning rate")
parser.add_argument("--warmup-steps", type=int, default=10,
help="Learning rate warmup steps")
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--dry-run", action="store_true",
help="Load weights and validate, don't train")
args = parser.parse_args()
print(f"Apollo-Mini Training")
print(f" weights: {args.weights}")
print(f" examples: {args.examples}")
print(f" lr: {args.lr}")
print()
# Import weights
print("Importing weights via CUDA IPC...")
params = import_weights(args.weights)
print(f" {len(params)} parameters imported")
# Make parameter groups
param_groups = make_param_groups(params)
# Initialize optimizer
optimizer = ApolloMini(param_groups, lr=args.lr,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps)
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
if args.dry_run:
print("\nDry run — weights imported and validated successfully.")
return
# Load training examples
examples = []
with open(args.examples) as f:
for line in f:
examples.append(json.loads(line))
print(f" {len(examples)} training examples")
# Training loop
log_entries = []
print(f"\nTraining...")
t0 = time.time()
for i, example in enumerate(examples):
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
elapsed = time.time() - t0
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
# Save checkpoint
print("\nSaving checkpoint...")
save_checkpoint(params, args.checkpoint_dir, args.config_path)
# Save training log
date_str = datetime.now().strftime("%Y-%m-%d")
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
with open(log_path, 'w') as f:
for entry in log_entries:
f.write(json.dumps(entry) + '\n')
print(f"Training log: {log_path}")
if __name__ == '__main__':
main()

View file

@ -0,0 +1,175 @@
"""Training example construction and tokenization.
Takes raw conversation context + improved continuation, produces
tokenized tensors ready for context-frozen forward+backward.
"""
import json
from dataclasses import dataclass, field
from pathlib import Path
import torch
from transformers import AutoTokenizer
@dataclass
class TrainingExample:
"""A single training example for context-frozen training."""
id: str
context: str # conversation up to decision point
continuation: str # the better response
reason: str = "" # why this is a training target
memories: list[str] = field(default_factory=list) # memories that were in context
# Computed after tokenization
input_ids: torch.Tensor | None = None
context_len: int = 0
total_len: int = 0
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
"""Tokenize context + continuation into training-ready tensors.
The chat template is applied to make the token distribution
match what the model sees during inference.
"""
# Build messages for context (everything up to the decision)
# The context should already be in chat format
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
self.context_len = len(context_ids)
self.total_len = len(context_ids) + len(continuation_ids)
if self.total_len > max_len:
# Truncate context from the left, keep continuation intact
excess = self.total_len - max_len
context_ids = context_ids[excess:]
self.context_len = len(context_ids)
self.total_len = len(context_ids) + len(continuation_ids)
all_ids = context_ids + continuation_ids
self.input_ids = torch.tensor(all_ids, device=device)
return self
def to_dict(self) -> dict:
return {
'id': self.id,
'context': self.context,
'continuation': self.continuation,
'reason': self.reason,
'memories': self.memories,
'context_len': self.context_len,
'total_len': self.total_len,
}
@classmethod
def from_dict(cls, d: dict) -> 'TrainingExample':
return cls(
id=d['id'],
context=d['context'],
continuation=d['continuation'],
reason=d.get('reason', ''),
memories=d.get('memories', []),
)
def load_examples(path: str) -> list[TrainingExample]:
"""Load training examples from JSONL file."""
examples = []
with open(path) as f:
for line in f:
if line.strip():
examples.append(TrainingExample.from_dict(json.loads(line)))
return examples
def save_examples(examples: list[TrainingExample], path: str):
"""Save training examples to JSONL file."""
with open(path, 'w') as f:
for ex in examples:
f.write(json.dumps(ex.to_dict()) + '\n')
class ExampleTokenizer:
"""Handles tokenization with the model's chat template.
Applies the same chat template that vLLM uses during inference,
so the token distribution matches what the model expects.
"""
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True)
def prepare_example(self, example: TrainingExample,
max_len: int = 8192,
device: str = "cuda:0") -> TrainingExample:
"""Tokenize an example using the chat template.
For proper training, the context should be formatted exactly
as vLLM would format it with chat template applied.
"""
# Apply chat template to get the exact token sequence
# the model would see during inference
#
# Context: everything up to the decision point
# Continuation: the improved response
#
# We tokenize them separately to know where context ends
# and continuation begins.
context_ids = self.tokenizer.encode(
example.context, add_special_tokens=True)
continuation_ids = self.tokenizer.encode(
example.continuation, add_special_tokens=False)
example.context_len = len(context_ids)
example.total_len = len(context_ids) + len(continuation_ids)
if example.total_len > max_len:
excess = example.total_len - max_len
context_ids = context_ids[excess:]
example.context_len = len(context_ids)
example.total_len = example.context_len + len(continuation_ids)
all_ids = context_ids + continuation_ids
example.input_ids = torch.tensor(all_ids, device=device)
return example
def prepare_from_messages(self, example_id: str,
messages: list[dict],
decision_idx: int,
better_response: str,
reason: str = "",
memories: list[str] | None = None,
max_len: int = 8192,
device: str = "cuda:0") -> TrainingExample:
"""Build a training example from a chat message list.
Args:
example_id: unique identifier
messages: list of {"role": ..., "content": ...} dicts
decision_idx: index of the assistant message to replace
better_response: the improved response text
reason: why this is a training target
memories: memory keys that were in context
max_len: maximum sequence length
device: target device
Returns:
Tokenized TrainingExample
"""
# Context: all messages up to (not including) the decision
context_messages = messages[:decision_idx]
context_text = self.tokenizer.apply_chat_template(
context_messages, tokenize=False, add_generation_prompt=True)
# Build the example
example = TrainingExample(
id=example_id,
context=context_text,
continuation=better_response,
reason=reason,
memories=memories or [],
)
return self.prepare_example(example, max_len=max_len, device=device)

View file

@ -1,12 +1,17 @@
"""Monkey-patch vLLM to export weight IPC handles on startup.
Usage install the apollo_plugin package:
Usage add to start_vllm.sh BEFORE the vllm serve command:
pip install -e /path/to/training
export VLLM_PLUGINS=vllm_export_hook
vllm serve Qwen/Qwen3.5-27B ...
Then vLLM auto-discovers and loads via entry point. Or filter:
Or use Python to launch vLLM with the hook:
VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ...
python3 -c "
import vllm_export_hook # installs the patch
from vllm.entrypoints.openai.api_server import run_server
run_server(...)
"
The hook patches vLLM's model runner to export IPC handles after
model loading completes. The handles are saved to a file that the
@ -20,7 +25,7 @@ from pathlib import Path
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
def export_model_weights(model, model_path: str | None = None):
def export_model_weights(model):
"""Export CUDA IPC handles for all model parameters."""
from torch.multiprocessing.reductions import reduce_tensor
@ -38,12 +43,6 @@ def export_model_weights(model, model_path: str | None = None):
}
total_bytes += param.nelement() * param.element_size()
# Include metadata for training worker
handles['__metadata__'] = {
'model_path': model_path,
'num_params': len(handles),
}
torch.save(handles, HANDLE_PATH)
print(f"[apollo] Exported {len(handles)} weight handles "
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
@ -64,11 +63,14 @@ def _patch_model_runner():
def patched_load(self, *args, **kwargs):
result = original_load(self, *args, **kwargs)
try:
model_path = self.vllm_config.model_config.model
export_model_weights(self.model_runner.model, model_path)
export_model_weights(self.model_runner.model)
except Exception as e:
print(f"[apollo] Failed to export weights: {e}")
return result
gpu_worker.Worker.load_model = patched_load
print("[apollo] Weight export hook installed")
# Auto-install when imported
_patch_model_runner()