// training.rs — Memory importance scoring via /v1/score // // Three scoring modes, all built on the same call_score() primitive: // // score_memories() — Full N×M matrix (memories × responses) for the // debug screen. Expensive: N+1 API calls. // // memory_score() — Single memory importance. Scores the 50 messages // after it was surfaced, with/without that memory. // 2 API calls. // // finetune_score() — Identifies training candidates. Scores recent // messages with all memories stripped. Responses // with high divergence depend on memories the model // hasn't internalized. 2 API calls. use crate::agent::api::ApiClient; use crate::agent::api::*; use crate::agent::context::{ConversationEntry, ContextEntry, ContextState}; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); // ── Message building ──────────────────────────────────────────── /// What to filter when building the message array for scoring. enum Filter<'a> { None, SkipIndex(usize), SkipKey(&'a str), SkipAllMemories, } /// Build the messages array for a scoring call. /// /// Always includes system prompt + context message as prefix, then /// entries from `range` filtered by `filter`. fn build_messages( context: &ContextState, range: std::ops::Range, filter: Filter, ) -> Vec { let mut msgs = Vec::new(); for e in context.system.entries() { msgs.push(serde_json::json!({"role": "system", "content": e.entry.message().content_text()})); } let ctx = context.render_context_message(); if !ctx.is_empty() { msgs.push(serde_json::json!({"role": "user", "content": ctx})); } let entries = context.conversation.entries(); for i in range { let ce = &entries[i]; let entry = &ce.entry; let skip = match &filter { Filter::None => false, Filter::SkipIndex(idx) => i == *idx, Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == *key), Filter::SkipAllMemories => entry.is_memory(), }; if skip { continue; } let m = entry.api_message(); msgs.push(serde_json::json!({ "role": m.role_str(), "content": m.content_text(), })); } msgs } // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] struct ScoreResult { total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreResponse { scores: Vec, } fn http_client() -> crate::agent::api::http::HttpClient { crate::agent::api::http::HttpClient::builder() .timeout(SCORE_TIMEOUT) .build() } async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, messages: &[serde_json::Value], ) -> anyhow::Result> { let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let body = serde_json::json!({ "model": client.model, "messages": messages, "logprobs": 1, }); let response = http .send_json("POST", &url, &[ ("authorization", &auth), ], &body) .await?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) } /// Compute per-position logprob divergence: how much worse the model /// scores each response without something vs with it. fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { baseline.iter().enumerate() .map(|(i, base)| { let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); (base.total_logprob - without_lp).max(0.0) }) .collect() } /// Score two message sets and return total divergence. async fn score_divergence( http: &crate::agent::api::http::HttpClient, client: &ApiClient, context: &ContextState, range: std::ops::Range, filter: Filter<'_>, ) -> anyhow::Result<(Vec, Vec)> { let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?; let without = call_score(http, client, &build_messages(context, range, filter)).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } // ── Full matrix scoring (debug screen) ────────────────────────── /// Result of scoring one conversation's memory usage. pub struct MemoryScore { pub memory_weights: Vec<(String, f64)>, pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, pub memory_keys: Vec, pub response_entry_indices: Vec, } impl MemoryScore { pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; let mut result: Vec<(&str, f64)> = self.memory_keys.iter() .zip(self.matrix.iter()) .filter_map(|(key, row)| { let score = row.get(resp_idx).copied().unwrap_or(0.0); if score > 0.01 { Some((key.as_str(), score)) } else { None } }) .collect(); result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); result } } /// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( context: &ContextState, client: &ApiClient, ) -> anyhow::Result { let mut memory_keys: Vec = context.conversation.entries().iter() .filter_map(|ce| match &ce.entry { ConversationEntry::Memory { key, .. } => Some(key.clone()), _ => None, }) .collect(); memory_keys.dedup(); let response_indices: Vec = context.conversation.entries().iter().enumerate() .filter(|(_, ce)| ce.entry.message().role == Role::Assistant) .map(|(i, _)| i) .collect(); if memory_keys.is_empty() || response_indices.is_empty() { return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let http = http_client(); let range = 0..context.conversation.entries().len(); let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?; let total = memory_keys.len(); let mut matrix: Vec> = Vec::new(); for (mem_idx, key) in memory_keys.iter().enumerate() { dbglog!( "scoring {}/{}: {}...", mem_idx + 1, total, key, ); let msgs = build_messages(context, range.clone(), Filter::SkipKey(key)); match call_score(&http, client, &msgs).await { Ok(without) => matrix.push(divergence(&baseline, &without)), Err(e) => { dbglog!( "[training] {} FAILED: {:#}", key, e, ); matrix.push(vec![0.0; baseline.len()]); } } } let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); let mut response_scores = vec![0.0; response_indices.len()]; for row in &matrix { for (j, &v) in row.iter().enumerate() { if j < response_scores.len() { response_scores[j] += v; } } } Ok(MemoryScore { memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } /// Find the entry index after `start` that contains the Nth assistant response. /// Returns (end_index, true) if N responses were found, (entries.len(), false) if not. fn nth_response_end(entries: &[ContextEntry], start: usize, n: usize) -> (usize, bool) { let mut count = 0; for i in start..entries.len() { if entries[i].entry.message().role == Role::Assistant { count += 1; if count >= n { return (i + 1, true); } } } (entries.len(), false) } // ── Single memory scoring ─────────────────────────────────────── /// Score how important a single memory is to the conversation. /// /// Scores the 50 messages after the memory was surfaced — the window /// where it could have influenced responses. Returns the sum of /// divergence, or 0.0 if the memory isn't in the conversation. pub async fn score_memory( context: &ContextState, key: &str, client: &ApiClient, ) -> anyhow::Result { const RESPONSE_WINDOW: usize = 50; let entries = context.conversation.entries(); let first_pos = match entries.iter().position(|ce| { matches!(&ce.entry, ConversationEntry::Memory { key: k, .. } if k == key) }) { Some(p) => p, None => return Ok(0.0), }; let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW); let range = first_pos..end; if !entries[range.clone()].iter().any(|ce| ce.entry.message().role == Role::Assistant) { return Ok(0.0); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?; Ok(divs.iter().sum()) } // ── Background memory scoring ─────────────────────────────────── /// Score memories in the conversation that are due for re-scoring. /// /// Checks the graph for each memory's last_scored timestamp. Scores /// nodes that haven't been scored within `max_age_secs`, oldest first. /// Updates the graph weight (EWMA) and last_scored after each. /// /// Returns the number of nodes scored and their (key, score) pairs. pub async fn score_memories_incremental( context: &ContextState, max_age_secs: i64, response_window: usize, client: &ApiClient, agent: &std::sync::Arc>, ) -> anyhow::Result> { let now = chrono::Utc::now().timestamp(); // Collect unique memory keys with their first position let mut seen = std::collections::HashSet::new(); let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored) let store = crate::hippocampus::store::Store::load().unwrap_or_default(); for (i, ce) in context.conversation.entries().iter().enumerate() { if let ConversationEntry::Memory { key, .. } = &ce.entry { if !seen.insert(key.clone()) { continue; } let last_scored = store.nodes.get(key.as_str()) .map(|n| n.last_scored) .unwrap_or(0); if now - last_scored >= max_age_secs { candidates.push((i, key.clone(), last_scored)); } } } // Score oldest-first candidates.sort_by_key(|&(_, _, last)| last); let http = http_client(); let mut results = Vec::new(); let total_entries = context.conversation.entries().len(); let first_quarter = total_entries / 4; for (pos, key, _) in &candidates { let (end, full_window) = nth_response_end(context.conversation.entries(), *pos, response_window); // Skip memories without a full window, unless they're in the // first quarter of the conversation (always score those). if !full_window && *pos >= first_quarter { continue; } let range = *pos..end; if !context.conversation.entries()[range.clone()].iter().any(|ce| ce.entry.message().role == Role::Assistant) { continue; } let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await; match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!( "[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses, ); results.push((key.clone(), max_div)); } Err(e) => { dbglog!( "[scoring] {} FAILED: {:#}", key, e, ); } } } Ok(results) } // ── Fine-tuning scoring ───────────────────────────────────────── /// Score which recent responses are candidates for fine-tuning. /// /// Removes all memories and scores the most recent `count` messages. /// Responses with high divergence depend on memories the model hasn't /// internalized — these are fine-tuning candidates. /// /// Returns (entry_index, divergence) pairs, sorted by divergence descending. pub async fn score_finetune( context: &ContextState, count: usize, client: &ApiClient, ) -> anyhow::Result> { let range = context.conversation.entries().len().saturating_sub(count)..context.conversation.entries().len(); let response_positions: Vec = range.clone() .filter(|&i| context.conversation.entries()[i].entry.message().role == Role::Assistant) .collect(); if response_positions.is_empty() { return Ok(Vec::new()); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) .collect(); results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) }