// training.rs — Memory importance scoring via /v1/score // // Three scoring modes, all built on the same call_score() primitive: // // score_memories() — Full N×M matrix (memories × responses) for the // debug screen. Expensive: N+1 API calls. // // memory_score() — Single memory importance. Scores the 50 messages // after it was surfaced, with/without that memory. // 2 API calls. // // finetune_score() — Identifies training candidates. Scores recent // messages with all memories stripped. Responses // with high divergence depend on memories the model // hasn't internalized. 2 API calls. use super::api::ApiClient; use crate::agent::api::types::*; use crate::agent::context::{ConversationEntry, ContextState}; use crate::user::ui_channel::{UiMessage, UiSender}; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); // ── Message building ──────────────────────────────────────────── /// What to filter when building the message array for scoring. enum Filter<'a> { None, SkipIndex(usize), SkipKey(&'a str), SkipAllMemories, } /// Build the messages array for a scoring call. /// /// Always includes system prompt + context message as prefix, then /// entries from `range` filtered by `filter`. fn build_messages( context: &ContextState, range: std::ops::Range, filter: Filter, ) -> Vec { let mut msgs = vec![ serde_json::json!({"role": "system", "content": &context.system_prompt}), ]; let ctx = context.render_context_message(); if !ctx.is_empty() { msgs.push(serde_json::json!({"role": "user", "content": ctx})); } for i in range { let entry = &context.entries[i]; let skip = match &filter { Filter::None => false, Filter::SkipIndex(idx) => i == *idx, Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == key), Filter::SkipAllMemories => entry.is_memory(), }; if skip { continue; } let m = entry.api_message(); msgs.push(serde_json::json!({ "role": m.role_str(), "content": m.content_text(), })); } msgs } // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] struct ScoreResult { total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreResponse { scores: Vec, } fn http_client() -> reqwest::Client { reqwest::Client::builder() .timeout(SCORE_TIMEOUT) .pool_max_idle_per_host(2) .build() .unwrap_or_default() } async fn call_score( http: &reqwest::Client, client: &ApiClient, messages: &[serde_json::Value], ) -> anyhow::Result> { let response = http .post(format!("{}/score", client.base_url())) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", client.api_key())) .json(&serde_json::json!({ "model": client.model, "messages": messages, "logprobs": 1, })) .send() .await .map_err(|e| if e.is_timeout() { anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs()) } else { anyhow::anyhow!("score request failed: {}", e) })?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) } /// Compute per-position logprob divergence: how much worse the model /// scores each response without something vs with it. fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { baseline.iter().enumerate() .map(|(i, base)| { let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); (base.total_logprob - without_lp).max(0.0) }) .collect() } /// Score two message sets and return total divergence. async fn score_divergence( http: &reqwest::Client, client: &ApiClient, context: &ContextState, range: std::ops::Range, filter: Filter<'_>, ) -> anyhow::Result<(Vec, Vec)> { let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?; let without = call_score(http, client, &build_messages(context, range, filter)).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } // ── Full matrix scoring (debug screen) ────────────────────────── /// Result of scoring one conversation's memory usage. pub struct MemoryScore { pub memory_weights: Vec<(String, f64)>, pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, pub memory_keys: Vec, pub response_entry_indices: Vec, } impl MemoryScore { pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; let mut result: Vec<(&str, f64)> = self.memory_keys.iter() .zip(self.matrix.iter()) .filter_map(|(key, row)| { let score = row.get(resp_idx).copied().unwrap_or(0.0); if score > 0.01 { Some((key.as_str(), score)) } else { None } }) .collect(); result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); result } } /// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( context: &ContextState, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { let mut memory_keys: Vec = context.entries.iter() .filter_map(|e| match e { ConversationEntry::Memory { key, .. } => Some(key.clone()), _ => None, }) .collect(); memory_keys.dedup(); let response_indices: Vec = context.entries.iter().enumerate() .filter(|(_, e)| e.message().role == Role::Assistant) .map(|(i, _)| i) .collect(); if memory_keys.is_empty() || response_indices.is_empty() { return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let _ = ui_tx.send(UiMessage::Info(format!( "[scoring {} memories × {} responses]", memory_keys.len(), response_indices.len(), ))); let http = http_client(); let range = 0..context.entries.len(); let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into())); let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?; let total = memory_keys.len(); let mut matrix: Vec> = Vec::new(); for (mem_idx, key) in memory_keys.iter().enumerate() { let _ = ui_tx.send(UiMessage::Activity(format!( "scoring {}/{}: {}...", mem_idx + 1, total, key, ))); let msgs = build_messages(context, range.clone(), Filter::SkipKey(key)); match call_score(&http, client, &msgs).await { Ok(without) => matrix.push(divergence(&baseline, &without)), Err(e) => { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} FAILED: {:#}", key, e, ))); matrix.push(vec![0.0; baseline.len()]); } } } let _ = ui_tx.send(UiMessage::Activity(String::new())); let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); let mut response_scores = vec![0.0; response_indices.len()]; for row in &matrix { for (j, &v) in row.iter().enumerate() { if j < response_scores.len() { response_scores[j] += v; } } } Ok(MemoryScore { memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } // ── Single memory scoring ─────────────────────────────────────── /// Score how important a single memory is to the conversation. /// /// Scores the 50 messages after the memory was surfaced — the window /// where it could have influenced responses. Returns the sum of /// divergence, or 0.0 if the memory isn't in the conversation. pub async fn score_memory( context: &ContextState, key: &str, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { const WINDOW: usize = 50; let first_pos = match context.entries.iter().position(|e| { matches!(e, ConversationEntry::Memory { key: k, .. } if k == key) }) { Some(p) => p, None => return Ok(0.0), }; let range = first_pos..(first_pos + WINDOW).min(context.entries.len()); if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) { return Ok(0.0); } let http = http_client(); let _ = ui_tx.send(UiMessage::Activity(format!("scoring memory: {}...", key))); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?; let _ = ui_tx.send(UiMessage::Activity(String::new())); Ok(divs.iter().sum()) } // ── Background memory scoring ─────────────────────────────────── /// Incrementally score memories through the conversation. /// /// Walks memory entries in conversation order starting from `cursor`. /// For each memory with a full WINDOW after it, calls score_memory() /// and yields the result. Stops at the first memory that doesn't have /// enough messages yet — the conversation needs to grow before we can /// score it. /// /// Returns the updated cursor (entry index to resume from next time) /// and the scores for each memory that was scored this round. pub async fn score_memories_incremental( context: &ContextState, cursor: usize, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result<(usize, Vec<(String, f64)>)> { const WINDOW: usize = 50; // Collect unique memory keys with their first position, starting from cursor let mut seen = std::collections::HashSet::new(); let mut to_score: Vec<(usize, String)> = Vec::new(); for (i, entry) in context.entries.iter().enumerate().skip(cursor) { if let ConversationEntry::Memory { key, .. } = entry { if seen.insert(key.clone()) { to_score.push((i, key.clone())); } } } let http = http_client(); let mut new_cursor = cursor; let mut results = Vec::new(); for (pos, key) in &to_score { let end = pos + WINDOW; // Not enough conversation after this memory yet — stop here if end > context.entries.len() { break; } // Need at least one assistant response in the window let range = *pos..end; if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) { new_cursor = end; continue; } let _ = ui_tx.send(UiMessage::Activity(format!("scoring memory: {}...", key))); match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await { Ok((divs, _)) => { let importance: f64 = divs.iter().sum(); let _ = ui_tx.send(UiMessage::Debug(format!( "[scoring] {} → {:.2}", key, importance, ))); results.push((key.clone(), importance)); } Err(e) => { let _ = ui_tx.send(UiMessage::Debug(format!( "[scoring] {} FAILED: {:#}", key, e, ))); } } new_cursor = end; } let _ = ui_tx.send(UiMessage::Activity(String::new())); Ok((new_cursor, results)) } // ── Fine-tuning scoring ───────────────────────────────────────── /// Score which recent responses are candidates for fine-tuning. /// /// Removes all memories and scores the most recent `count` messages. /// Responses with high divergence depend on memories the model hasn't /// internalized — these are fine-tuning candidates. /// /// Returns (entry_index, divergence) pairs, sorted by divergence descending. pub async fn score_finetune( context: &ContextState, count: usize, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result> { let range = context.entries.len().saturating_sub(count)..context.entries.len(); let response_positions: Vec = range.clone() .filter(|&i| context.entries[i].message().role == Role::Assistant) .collect(); if response_positions.is_empty() { return Ok(Vec::new()); } let http = http_client(); let _ = ui_tx.send(UiMessage::Activity("scoring for fine-tuning...".into())); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?; let _ = ui_tx.send(UiMessage::Activity(String::new())); let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) .collect(); results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) }