// training.rs — Memory importance scoring via /v1/score // // Three scoring modes, all built on the same call_score() primitive: // // score_memories() — Full N×M matrix (memories × responses) for the // debug screen. Expensive: N+1 API calls. // // memory_score() — Single memory importance. Scores the 50 messages // after it was surfaced, with/without that memory. // 2 API calls. // // finetune_score() — Identifies training candidates. Scores recent // messages with all memories stripped. Responses // with high divergence depend on memories the model // hasn't internalized. 2 API calls. use crate::agent::api::ApiClient; use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role}; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); // ── Message building ──────────────────────────────────────────── /// What to filter when building the message array for scoring. #[allow(dead_code)] enum Filter<'a> { None, SkipIndex(usize), SkipKey(&'a str), SkipAllMemories, } fn is_memory(node: &AstNode) -> bool { matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. })) } fn memory_key(node: &AstNode) -> Option<&str> { match node { AstNode::Leaf(leaf) => match leaf.body() { NodeBody::Memory { key, .. } => Some(key), _ => None, }, _ => None, } } fn is_assistant(node: &AstNode) -> bool { matches!(node, AstNode::Branch { role: Role::Assistant, .. }) } /// Push an AstNode as one or more JSON messages for the scoring API. fn push_api_message(node: &AstNode, msgs: &mut Vec) { match node { AstNode::Branch { role, children } => { let content: String = children.iter().map(|c| c.render()).collect(); msgs.push(serde_json::json!({ "role": role.as_str(), "content": content, })); } AstNode::Leaf(leaf) => { let role = match leaf.body() { NodeBody::ToolResult(_) => "tool", _ => "user", }; msgs.push(serde_json::json!({ "role": role, "content": leaf.body().text(), })); } } } /// Build the messages array for a scoring call. /// /// Always includes system prompt as prefix, then entries from `range` /// filtered by `filter`. fn build_messages( context: &ContextState, range: std::ops::Range, filter: Filter, ) -> Vec { let mut msgs = Vec::new(); for node in context.system() { push_api_message(node, &mut msgs); } let entries = context.conversation(); for i in range { let node = &entries[i]; let skip = match &filter { Filter::None => false, Filter::SkipIndex(idx) => i == *idx, Filter::SkipKey(key) => memory_key(node) == Some(*key), Filter::SkipAllMemories => is_memory(node), }; if skip { continue; } push_api_message(node, &mut msgs); } msgs } // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] struct ScoreResult { total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreResponse { scores: Vec, } fn http_client() -> crate::agent::api::http::HttpClient { crate::agent::api::http::HttpClient::builder() .timeout(SCORE_TIMEOUT) .build() } async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, messages: &[serde_json::Value], ) -> anyhow::Result> { let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let body = serde_json::json!({ "model": client.model, "messages": messages, "logprobs": 1, }); let response = http .send_json("POST", &url, &[ ("authorization", &auth), ], &body) .await?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) } /// Compute per-position logprob divergence: how much worse the model /// scores each response without something vs with it. fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { baseline.iter().enumerate() .map(|(i, base)| { let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); (base.total_logprob - without_lp).max(0.0) }) .collect() } /// Score two message sets and return total divergence. async fn score_divergence( http: &crate::agent::api::http::HttpClient, client: &ApiClient, context: &ContextState, range: std::ops::Range, filter: Filter<'_>, ) -> anyhow::Result<(Vec, Vec)> { let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?; let without = call_score(http, client, &build_messages(context, range, filter)).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } // ── Full matrix scoring (debug screen) ────────────────────────── /// Result of scoring one conversation's memory usage. pub struct MemoryScore { pub memory_weights: Vec<(String, f64)>, pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, pub memory_keys: Vec, pub response_entry_indices: Vec, } impl MemoryScore { pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; let mut result: Vec<(&str, f64)> = self.memory_keys.iter() .zip(self.matrix.iter()) .filter_map(|(key, row)| { let score = row.get(resp_idx).copied().unwrap_or(0.0); if score > 0.01 { Some((key.as_str(), score)) } else { None } }) .collect(); result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); result } } /// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( context: &ContextState, client: &ApiClient, ) -> anyhow::Result { let mut memory_keys: Vec = context.conversation().iter() .filter_map(|node| memory_key(node).map(String::from)) .collect(); memory_keys.dedup(); let response_indices: Vec = context.conversation().iter().enumerate() .filter(|(_, node)| is_assistant(node)) .map(|(i, _)| i) .collect(); if memory_keys.is_empty() || response_indices.is_empty() { return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let http = http_client(); let range = 0..context.conversation().len(); let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?; let total = memory_keys.len(); let mut matrix: Vec> = Vec::new(); for (mem_idx, key) in memory_keys.iter().enumerate() { dbglog!( "scoring {}/{}: {}...", mem_idx + 1, total, key, ); let msgs = build_messages(context, range.clone(), Filter::SkipKey(key)); match call_score(&http, client, &msgs).await { Ok(without) => matrix.push(divergence(&baseline, &without)), Err(e) => { dbglog!( "[training] {} FAILED: {:#}", key, e, ); matrix.push(vec![0.0; baseline.len()]); } } } let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); let mut response_scores = vec![0.0; response_indices.len()]; for row in &matrix { for (j, &v) in row.iter().enumerate() { if j < response_scores.len() { response_scores[j] += v; } } } Ok(MemoryScore { memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } /// Find the entry index after `start` that contains the Nth assistant response. /// Returns (end_index, true) if N responses were found, (entries.len(), false) if not. fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) { let mut count = 0; for i in start..entries.len() { if is_assistant(&entries[i]) { count += 1; if count >= n { return (i + 1, true); } } } (entries.len(), false) } // ── Single memory scoring ─────────────────────────────────────── /// Score how important a single memory is to the conversation. /// /// Scores the 50 messages after the memory was surfaced — the window /// where it could have influenced responses. Returns the sum of /// divergence, or 0.0 if the memory isn't in the conversation. pub async fn score_memory( context: &ContextState, key: &str, client: &ApiClient, ) -> anyhow::Result { const RESPONSE_WINDOW: usize = 50; let entries = context.conversation(); let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) { Some(p) => p, None => return Ok(0.0), }; let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW); let range = first_pos..end; if !entries[range.clone()].iter().any(|node| is_assistant(node)) { return Ok(0.0); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?; Ok(divs.iter().sum()) } // ── Background memory scoring ─────────────────────────────────── /// Score memories in the conversation that are due for re-scoring. /// /// Checks the graph for each memory's last_scored timestamp. Scores /// nodes that haven't been scored within `max_age_secs`, oldest first. /// Updates the graph weight (EWMA) and last_scored after each. /// /// Returns the number of nodes scored and their (key, score) pairs. pub async fn score_memories_incremental( context: &ContextState, max_age_secs: i64, response_window: usize, client: &ApiClient, agent: &std::sync::Arc, mut on_score: F, ) -> anyhow::Result where F: FnMut(String, f64) -> Fut, Fut: std::future::Future, { let now = chrono::Utc::now().timestamp(); // Collect unique memory keys with their first position let mut seen = std::collections::HashSet::new(); let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored) let store = crate::hippocampus::store::Store::load().unwrap_or_default(); for (i, node) in context.conversation().iter().enumerate() { if let Some(key) = memory_key(node) { if !seen.insert(key.to_owned()) { continue; } let last_scored = store.nodes.get(key) .map(|n| n.last_scored) .unwrap_or(0); if now - last_scored >= max_age_secs { candidates.push((i, key.to_owned(), last_scored)); } } } // Score oldest-first candidates.sort_by_key(|&(_, _, last)| last); let http = http_client(); let mut scored = 0; let entries = context.conversation(); let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum(); let token_cutoff = total_tokens * 60 / 100; // Precompute cumulative token position for each entry let mut cumulative: Vec = Vec::with_capacity(entries.len()); let mut running = 0; for e in entries { running += e.tokens(); cumulative.push(running); } for (pos, key, _) in &candidates { // Only score memories in the first 70% of the conversation by tokens — // recent memories don't have enough responses to evaluate yet. if cumulative.get(*pos).copied().unwrap_or(total_tokens) > token_cutoff { continue; } let (end, _) = nth_response_end(context.conversation(), *pos, response_window); let range = *pos..end; if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) { continue; } let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await; match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!( "[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses, ); on_score(key.clone(), max_div).await; scored += 1; } Err(e) => { dbglog!( "[scoring] {} FAILED: {:#}", key, e, ); } } } Ok(scored) } // ── Fine-tuning scoring ───────────────────────────────────────── /// Score which recent responses are candidates for fine-tuning. /// /// Removes all memories and scores the most recent `count` messages. /// Responses with high divergence depend on memories the model hasn't /// internalized — these are fine-tuning candidates. /// /// Returns (entry_index, divergence) pairs, sorted by divergence descending. pub async fn score_finetune( context: &ContextState, count: usize, client: &ApiClient, ) -> anyhow::Result> { let entries = context.conversation(); let range = entries.len().saturating_sub(count)..entries.len(); let response_positions: Vec = range.clone() .filter(|&i| is_assistant(&entries[i])) .collect(); if response_positions.is_empty() { return Ok(Vec::new()); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) .collect(); results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) }