// training.rs — Memory importance scoring via /v1/score // // Drops each memory from the context one at a time, calls the vLLM // /v1/score endpoint to get logprobs for assistant responses. // Produces a divergence matrix: memories × responses. // // Row sums = memory importance (for graph weight updates) // Column sums = response memory-dependence (training candidates) use std::time::Instant; use super::api::ApiClient; use crate::agent::api::types::*; use crate::agent::context::{ConversationEntry, ContextState}; use crate::user::ui_channel::{UiMessage, UiSender}; /// Timeout for individual /v1/score API calls. const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); /// Result of scoring one conversation's memory usage. pub struct MemoryScore { /// memory_key → importance score (sum of divergence across all responses) pub memory_weights: Vec<(String, f64)>, /// response_index → memory-dependence score (sum of divergence across all memories) pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, /// Keys of memories that were scored pub memory_keys: Vec, /// Conversation entry indices of the assistant responses pub response_entry_indices: Vec, } impl MemoryScore { /// Get the most important memories for a given conversation entry index. pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; let mut result: Vec<(&str, f64)> = self.memory_keys.iter() .zip(self.matrix.iter()) .filter_map(|(key, row)| { let score = row.get(resp_idx).copied().unwrap_or(0.0); if score > 0.01 { Some((key.as_str(), score)) } else { None } }) .collect(); result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); result } } /// Score how important each memory is to the conversation. pub async fn score_memories( context: &ContextState, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] in score_memories" ))); let memories: Vec<(usize, String)> = context.entries.iter().enumerate() .filter_map(|(i, e)| match e { ConversationEntry::Memory { key, .. } => Some((i, key.clone())), _ => None, }) .collect(); let response_indices: Vec = context.entries.iter().enumerate() .filter(|(_, e)| e.message().role == Role::Assistant) .map(|(i, _)| i) .collect(); if memories.is_empty() || response_indices.is_empty() { let _ = ui_tx.send(UiMessage::Debug( "[training] nothing to score (no memories or no responses)".into() )); return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let _ = ui_tx.send(UiMessage::Info(format!( "[scoring {} memories × {} responses]", memories.len(), response_indices.len(), ))); let http = reqwest::Client::builder() .timeout(SCORE_TIMEOUT) .pool_max_idle_per_host(2) .build() .unwrap_or_default(); let all_messages = build_messages(context); let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} messages in context", all_messages.len(), ))); // Baseline: score with all memories present let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into())); let payload_size = serde_json::to_string(&all_messages) .map(|s| s.len()).unwrap_or(0); let _ = ui_tx.send(UiMessage::Debug(format!( "[training] payload size: {}KB", payload_size / 1024, ))); let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into())); let start = Instant::now(); let baseline = call_score(&http, client, &all_messages).await?; let _ = ui_tx.send(UiMessage::Debug(format!( "[training] baseline: {} responses scored in {:.1}s", baseline.len(), start.elapsed().as_secs_f64(), ))); // For each memory, drop it and measure divergence let mut matrix: Vec> = Vec::new(); let memory_keys: Vec = memories.iter().map(|(_, k)| k.clone()).collect(); let total = memories.len(); for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { let _ = ui_tx.send(UiMessage::Activity(format!( "scoring {}/{}: {}...", mem_idx + 1, total, key, ))); let start = Instant::now(); let filtered_messages = build_messages_without(context, *entry_idx); let without = call_score(&http, client, &filtered_messages).await; match without { Ok(without) => { let elapsed = start.elapsed().as_secs_f64(); // Match scores by position (nth scored response), // not message_index — indices shift when a memory // is removed from the conversation. let mut row = Vec::new(); for (i, base_score) in baseline.iter().enumerate() { let base_lp = base_score.total_logprob; let without_lp = without.get(i) .map(|s| s.total_logprob) .unwrap_or(base_lp); let divergence = (base_lp - without_lp).max(0.0); row.push(divergence); } let importance: f64 = row.iter().sum(); let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {}/{} {} → {:.1} ({:.1}s)", mem_idx + 1, total, key, importance, elapsed, ))); matrix.push(row); } Err(e) => { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {}/{} {} FAILED: {:#}", mem_idx + 1, total, key, e, ))); // Push zero row so matrix stays aligned matrix.push(vec![0.0; baseline.len()]); } } } let _ = ui_tx.send(UiMessage::Activity(String::new())); // Compute scores let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); let n_responses = response_indices.len(); let mut response_scores = vec![0.0; n_responses]; for row in &matrix { for (j, &v) in row.iter().enumerate() { if j < n_responses { response_scores[j] += v; } } } let _ = ui_tx.send(UiMessage::Info(format!( "[scoring complete: {} memories scored]", memory_keys.len(), ))); Ok(MemoryScore { memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } /// Score response from the /v1/score endpoint. #[derive(serde::Deserialize)] struct ScoreMessageResult { #[allow(dead_code)] message_index: usize, total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreApiResponse { scores: Vec, } /// Build the messages array for the /v1/score endpoint from ContextState. fn build_messages(context: &ContextState) -> Vec { let mut msgs = Vec::new(); msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); let ctx = context.render_context_message(); if !ctx.is_empty() { msgs.push(serde_json::json!({"role": "user", "content": ctx})); } for entry in &context.entries { let m = entry.api_message(); msgs.push(serde_json::json!({ "role": m.role_str(), "content": m.content_text(), })); } msgs } /// Build messages with one entry removed. fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec { let mut msgs = Vec::new(); msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); let ctx = context.render_context_message(); if !ctx.is_empty() { msgs.push(serde_json::json!({"role": "user", "content": ctx})); } for (i, entry) in context.entries.iter().enumerate() { if i == skip_idx { continue; } let m = entry.api_message(); msgs.push(serde_json::json!({ "role": m.role_str(), "content": m.content_text(), })); } msgs } /// Call the /v1/score endpoint and return per-message logprobs. async fn call_score( http: &reqwest::Client, client: &ApiClient, messages: &[serde_json::Value], ) -> anyhow::Result> { let request = serde_json::json!({ "model": client.model, "messages": messages, "logprobs": 1, }); let response = http .post(format!("{}/score", client.base_url())) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", client.api_key())) .json(&request) .send() .await .map_err(|e| { if e.is_timeout() { anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs()) } else { anyhow::anyhow!("score request failed: {}", e) } })?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error") .and_then(|e| e.as_str()) .unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } // Check for error in body (score endpoint returns dict on error) if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreApiResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) }