// training.rs — Memory importance scoring via /v1/score // // Three scoring modes, all built on the same call_score() primitive: // // score_memories() — Full N×M matrix (memories × responses) for the // debug screen. Expensive: N+1 API calls. // // memory_score() — Single memory importance. Scores the 50 messages // after it was surfaced, with/without that memory. // 2 API calls. // // finetune_score() — Identifies training candidates. Scores recent // messages with all memories stripped. Responses // with high divergence depend on memories the model // hasn't internalized. 2 API calls. use crate::agent::api::ApiClient; use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role}; use crate::agent::tokenizer; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); // ── Message building ──────────────────────────────────────────── /// What to filter when building the message array for scoring. #[allow(dead_code)] enum Filter<'a> { None, SkipIndex(usize), SkipKey(&'a str), SkipAllMemories, } fn is_memory(node: &AstNode) -> bool { matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. })) } fn memory_key(node: &AstNode) -> Option<&str> { match node { AstNode::Leaf(leaf) => match leaf.body() { NodeBody::Memory { key, .. } => Some(key), _ => None, }, _ => None, } } fn is_assistant(node: &AstNode) -> bool { matches!(node, AstNode::Branch { role: Role::Assistant, .. }) } /// Build a token ID array for a scoring call. /// /// Includes all sections up to and including conversation entries in /// `range`, with `filter` applied to conversation entries. fn build_token_ids( context: &ContextState, range: std::ops::Range, filter: Filter, ) -> Vec { use crate::agent::context::Ast; let mut ids = Vec::new(); for node in context.system() { ids.extend(node.token_ids()); } // Identity nodes can be filtered by key for scoring for node in context.identity() { let skip = match &filter { Filter::SkipKey(key) => memory_key(node) == Some(*key), Filter::SkipAllMemories => is_memory(node), _ => false, }; if !skip { ids.extend(node.token_ids()); } } for node in context.journal() { ids.extend(node.token_ids()); } let entries = context.conversation(); for i in range { let node = &entries[i]; let skip = match &filter { Filter::None => false, Filter::SkipIndex(idx) => i == *idx, Filter::SkipKey(key) => memory_key(node) == Some(*key), Filter::SkipAllMemories => is_memory(node), }; if skip { continue; } ids.extend(node.token_ids()); } ids } // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] struct ScoreResult { total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreResponse { scores: Vec, } fn http_client() -> crate::agent::api::http::HttpClient { crate::agent::api::http::HttpClient::builder() .timeout(SCORE_TIMEOUT) .build() } async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, prompt: &[u32], priority: Option, ) -> anyhow::Result> { let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let mut body = serde_json::json!({ "model": client.model, "prompt": prompt, "logprobs": 1, }); if let Some(p) = priority { body["priority"] = serde_json::json!(p); } let response = http .send_json("POST", &url, &[ ("authorization", &auth), ], &body) .await?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) } /// Compute per-position logprob divergence: how much worse the model /// scores each response without something vs with it. fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { baseline.iter().enumerate() .map(|(i, base)| { let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); (base.total_logprob - without_lp).max(0.0) }) .collect() } /// Score two message sets and return total divergence. async fn score_divergence( http: &crate::agent::api::http::HttpClient, client: &ApiClient, context: &ContextState, range: std::ops::Range, filter: Filter<'_>, priority: Option, ) -> anyhow::Result<(Vec, Vec)> { let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?; let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } // ── Full matrix scoring (debug screen) ────────────────────────── /// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( client: &ApiClient, agent: &std::sync::Arc, ) -> anyhow::Result<()> { // Collect memory keys and response indices under a brief lock let (memory_keys, response_indices) = { let ctx = agent.context.lock().await; // Include identity nodes and conversation memories let mut keys: Vec = ctx.identity().iter() .chain(ctx.conversation().iter()) .filter_map(|node| memory_key(node).map(String::from)) .collect(); keys.dedup(); let resp: Vec = ctx.conversation().iter().enumerate() .filter(|(_, node)| is_assistant(node)) .map(|(i, _)| i) .collect(); (keys, resp) }; if memory_keys.is_empty() || response_indices.is_empty() { return Ok(()); } let total = memory_keys.len(); dbglog!("[scoring-full] starting: {} memories × {} responses", total, response_indices.len()); let http = http_client(); let activity = crate::agent::start_activity(agent, "scoring: baseline").await; let baseline_tokens = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None) }; let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?; dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); for (mem_idx, key) in memory_keys.iter().enumerate() { activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); let tokens = { let ctx = agent.context.lock().await; build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key)) }; let row = match call_score(&http, client, &tokens, Some(5)).await { Ok(without) => { let divs = divergence(&baseline, &without); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!("[scoring-full] {}/{}: {} max_div={:.3}", mem_idx + 1, total, key, max_div); divs } Err(e) => { dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}", mem_idx + 1, total, key, e); vec![0.0; baseline.len()] } }; // Write this memory's scores to the live AST nodes { let mut ctx = agent.context.lock().await; let mut set_count = 0; for (resp_idx, &idx) in response_indices.iter().enumerate() { if idx >= ctx.conversation().len() { continue; } let node = &mut ctx.conversation_mut()[idx]; if let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node { if let Some(&score) = row.get(resp_idx) { if score > 0.01 { memory_scores.insert(key.clone(), score); set_count += 1; } else { memory_scores.remove(key.as_str()); } } } } dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count); } agent.state.lock().await.changed.notify_one(); } Ok(()) } /// Find the entry index after `start` that contains the Nth assistant response. /// Returns (end_index, true) if N responses were found, (entries.len(), false) if not. fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) { let mut count = 0; for i in start..entries.len() { if is_assistant(&entries[i]) { count += 1; if count >= n { return (i + 1, true); } } } (entries.len(), false) } // ── Single memory scoring ─────────────────────────────────────── /// Score how important a single memory is to the conversation. /// /// Scores the 50 messages after the memory was surfaced — the window /// where it could have influenced responses. Returns the sum of /// divergence, or 0.0 if the memory isn't in the conversation. pub async fn score_memory( context: &ContextState, key: &str, client: &ApiClient, ) -> anyhow::Result { const RESPONSE_WINDOW: usize = 50; let entries = context.conversation(); let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) { Some(p) => p, None => return Ok(0.0), }; let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW); let range = first_pos..end; if !entries[range.clone()].iter().any(|node| is_assistant(node)) { return Ok(0.0); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?; Ok(divs.iter().sum()) } // ── Background memory scoring ─────────────────────────────────── /// Score memories in the conversation that are due for re-scoring. /// /// Checks the graph for each memory's last_scored timestamp. Scores /// nodes that haven't been scored within `max_age_secs`, oldest first. /// Updates the graph weight (EWMA) and last_scored after each. /// /// Returns the number of nodes scored and their (key, score) pairs. pub async fn score_memories_incremental( context: &ContextState, max_age_secs: i64, response_window: usize, client: &ApiClient, agent: &std::sync::Arc, mut on_score: F, ) -> anyhow::Result where F: FnMut(String, f64) -> Fut, Fut: std::future::Future, { let now = chrono::Utc::now().timestamp(); // Collect unique memory keys with their first position let mut seen = std::collections::HashSet::new(); let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored) let store_arc = crate::hippocampus::access_local()?; { let store = &*store_arc; // Identity nodes always score at position 0; conversation nodes at their index let identity_nodes = context.identity().iter().map(|n| (0, n)); let conv_nodes = context.conversation().iter().enumerate(); for (pos, node) in identity_nodes.chain(conv_nodes) { if let Some(key) = memory_key(node) { if !seen.insert(key.to_owned()) { continue; } let last_scored = store.get_node(key) .ok() .flatten() .map(|n| n.last_scored) .unwrap_or(0); if now - last_scored >= max_age_secs { candidates.push((pos, key.to_owned(), last_scored)); } } } } // Score oldest-first candidates.sort_by_key(|&(_, _, last)| last); let http = http_client(); let mut scored = 0; let entries = context.conversation(); let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum(); let token_cutoff = total_tokens * 60 / 100; // Precompute cumulative token position for each entry let mut cumulative: Vec = Vec::with_capacity(entries.len()); let mut running = 0; for e in entries { running += e.tokens(); cumulative.push(running); } let total = candidates.len(); dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total); let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await; for (pos, key, _) in &candidates { // Only score memories in the first 60% of the conversation by tokens — // recent memories don't have enough responses to evaluate yet. let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens); if cum > token_cutoff { dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff); continue; } let (end, _) = nth_response_end(context.conversation(), *pos, response_window); let range = *pos..end; if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) { dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end); continue; } activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await; match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!( "[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses, ); on_score(key.clone(), max_div).await; scored += 1; } Err(e) => { dbglog!( "[scoring] {} FAILED: {:#}", key, e, ); } } } Ok(scored) } // ── Fine-tuning scoring ───────────────────────────────────────── /// Score which recent responses are candidates for fine-tuning. /// /// Removes all memories and scores the most recent `count` messages. /// Responses with high divergence depend on memories the model hasn't /// internalized — these are fine-tuning candidates. /// /// Returns (entry_index, divergence) pairs, sorted by divergence descending. pub async fn score_finetune( context: &ContextState, count: usize, client: &ApiClient, ) -> anyhow::Result> { let entries = context.conversation(); let range = entries.len().saturating_sub(count)..entries.len(); let response_positions: Vec = range.clone() .filter(|&i| is_assistant(&entries[i])) .collect(); if response_positions.is_empty() { return Ok(Vec::new()); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) .collect(); results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) } /// Enriched finetune candidate with context for review. #[derive(Clone, Debug)] pub struct FinetuneCandidate { pub entry_idx: usize, pub divergence: f64, pub response_text: String, /// Token IDs for context (everything before the response). pub context_ids: Vec, /// Token IDs for the response (what we're training on). pub continuation_ids: Vec, /// What the model would have said without memories (if generated). pub alternate_text: Option, /// Timestamp in millis for tracking trained status. pub timestamp_ms: i64, } /// Score and enrich finetune candidates with full context. /// /// Returns candidates ready for review, with context/continuation token IDs /// already computed for sending to /finetune. pub async fn score_finetune_candidates( context: &ContextState, count: usize, client: &ApiClient, min_divergence: f64, ) -> anyhow::Result> { let scores = score_finetune(context, count, client).await?; let entries = context.conversation(); let mut candidates = Vec::new(); let trained = load_trained(); for (entry_idx, divergence) in scores { if divergence < min_divergence { continue; } let node = &entries[entry_idx]; // Get timestamp and skip if already trained let timestamp_ms = match node_timestamp_ms(node) { Some(ts) => { if trained.contains(&ts) { continue; // Already trained, skip } ts } None => continue, // No timestamp, skip }; // Extract response text let response_text = match node { AstNode::Branch { children, .. } => { children.iter() .filter_map(|c| match c { AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()), _ => None, }) .collect::>() .join("") } _ => continue, }; // Build token IDs: context = everything before response, continuation = response let context_ids = build_token_ids(context, 0..entry_idx, Filter::None); let continuation_ids: Vec = node.token_ids().into_iter().collect(); candidates.push(FinetuneCandidate { entry_idx, divergence, response_text, context_ids, continuation_ids, alternate_text: None, timestamp_ms, }); } // Generate alternates if enabled if alternates_enabled() && !candidates.is_empty() { for candidate in &mut candidates { match generate_alternate(context, candidate.entry_idx, client).await { Ok(text) => candidate.alternate_text = Some(text), Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e), } } } Ok(candidates) } /// Generate what the model would say without memories for a given entry. async fn generate_alternate( context: &ContextState, entry_idx: usize, client: &ApiClient, ) -> anyhow::Result { use crate::agent::api::{SamplingParams, StreamToken}; // Build context tokens without memories, up to the response let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories); // Add assistant turn start prompt.push(tokenizer::IM_START); prompt.extend(tokenizer::encode("assistant\n")); // Generate completion let sampling = SamplingParams { temperature: 0.6, top_p: 0.95, top_k: 20, }; let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5)); let mut tokens = Vec::new(); while let Some(tok) = rx.recv().await { match tok { StreamToken::Token(id) => tokens.push(id), StreamToken::Done { .. } => break, StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), } } Ok(tokenizer::decode(&tokens)) } // ── Finetune config and persistence ───────────────────────────── use std::path::PathBuf; use std::collections::HashSet; const FINETUNE_ALTERNATES_FILE: &str = ".consciousness/cache/finetune-alternates"; const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json"; fn alternates_path() -> PathBuf { dirs::home_dir().unwrap_or_default().join(FINETUNE_ALTERNATES_FILE) } fn trained_path() -> PathBuf { dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE) } /// Check if alternate response generation is enabled. pub fn alternates_enabled() -> bool { alternates_path().exists() } /// Toggle alternate response generation and persist the setting. pub fn set_alternates(enabled: bool) { let path = alternates_path(); if enabled { if let Some(parent) = path.parent() { let _ = std::fs::create_dir_all(parent); } let _ = std::fs::write(&path, ""); } else { let _ = std::fs::remove_file(&path); } } /// Load set of trained response timestamps (millis since epoch). pub fn load_trained() -> HashSet { let path = trained_path(); match std::fs::read_to_string(&path) { Ok(content) => serde_json::from_str(&content).unwrap_or_default(), Err(_) => HashSet::new(), } } /// Mark a response as trained by its timestamp. pub fn mark_trained(timestamp_ms: i64) { let mut trained = load_trained(); trained.insert(timestamp_ms); let path = trained_path(); if let Some(parent) = path.parent() { let _ = std::fs::create_dir_all(parent); } if let Ok(json) = serde_json::to_string(&trained) { let _ = std::fs::write(&path, json); } } /// Get timestamp in millis from an AstNode (for Branch, uses first child). pub fn node_timestamp_ms(node: &AstNode) -> Option { let ts = match node { AstNode::Leaf(leaf) => leaf.timestamp(), AstNode::Branch { children, .. } => { children.first()?.leaf()?.timestamp() } }?; Some(ts.timestamp_millis()) } // ── Training API ──────────────────────────────────────────────── /// Training sample for /train endpoint. #[derive(serde::Serialize)] struct TrainingSample { context_ids: Vec, continuation_ids: Vec, } /// Data needed to send a training sample. pub struct TrainData { pub context_ids: Vec, pub continuation_ids: Vec, pub timestamp_ms: i64, } /// Send training samples to the server. /// /// Returns job_id on success, marks each sample as trained. pub async fn send_to_train( samples: Vec, client: &ApiClient, ) -> anyhow::Result { if samples.is_empty() { anyhow::bail!("no samples to train"); } let api_samples: Vec = samples.iter() .map(|s| TrainingSample { context_ids: s.context_ids.clone(), continuation_ids: s.continuation_ids.clone(), }) .collect(); let body = serde_json::json!({ "training_data": { "samples": api_samples, } }); let http = http_client(); let url = format!("{}/train", client.base_url()); let response = http.send_json("POST", &url, &[], &body).await?; let status = response.status(); let result: serde_json::Value = response.json().await?; if !status.is_success() { let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("train API HTTP {}: {}", status, msg); } // Mark all samples as trained for s in &samples { mark_trained(s.timestamp_ms); } let job_id = result.get("job_id") .and_then(|j| j.as_str()) .unwrap_or("unknown") .to_string(); dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id); Ok(job_id) }