// training.rs — Memory importance scoring via prompt logprobs // // Drops each memory from the context one at a time, runs prompt_logprobs // to see how the model's confidence in its responses changes. Produces // a divergence matrix: memories × responses. // // Row sums = memory importance (for graph weight updates) // Column sums = response memory-dependence (training candidates) use crate::agent::api::ApiClient; use crate::agent::types::*; use crate::agent::ui_channel::UiSender; /// Result of scoring one conversation's memory usage. pub struct MemoryScore { /// memory_key → importance score (sum of divergence across all responses) pub memory_weights: Vec<(String, f64)>, /// response_index → memory-dependence score (sum of divergence across all memories) pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, /// Keys of memories that were scored pub memory_keys: Vec, /// Conversation entry indices of the assistant responses (maps response_idx → entry_idx) pub response_entry_indices: Vec, } impl MemoryScore { /// Get the most important memories for a given conversation entry index. /// Returns (memory_key, divergence_score) sorted by importance. pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; let mut result: Vec<(&str, f64)> = self.memory_keys.iter() .zip(self.matrix.iter()) .filter_map(|(key, row)| { let score = row.get(resp_idx).copied().unwrap_or(0.0); if score > 0.01 { Some((key.as_str(), score)) } else { None } }) .collect(); result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); result } } /// Score how important each memory is to the conversation. /// /// For each Memory entry in the context, builds a version without it /// and checks how the model's logprobs change for assistant responses. pub async fn score_memories( context: &ContextState, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { use crate::agent::ui_channel::UiMessage; // Identify memory entries and assistant response positions let memories: Vec<(usize, String)> = context.entries.iter().enumerate() .filter_map(|(i, e)| match e { ConversationEntry::Memory { key, .. } => Some((i, key.clone())), _ => None, }) .collect(); let response_indices: Vec = context.entries.iter().enumerate() .filter(|(_, e)| e.message().role == Role::Assistant) .map(|(i, _)| i) .collect(); if memories.is_empty() || response_indices.is_empty() { return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let _ = ui_tx.send(UiMessage::Debug(format!( "[training] scoring {} memories × {} responses", memories.len(), response_indices.len(), ))); // Shared HTTP client for connection reuse across all scoring calls let http = reqwest::Client::builder() .pool_max_idle_per_host(2) .build() .unwrap_or_default(); // Baseline: logprobs with all memories present let baseline = get_response_logprobs(context, &context.entries, client, &http, ui_tx).await?; let _ = ui_tx.send(UiMessage::Debug(format!( "[training] baseline: {} response tokens scored", baseline.iter().map(|r| r.len()).sum::(), ))); // For each memory, drop it and measure divergence let mut matrix: Vec> = Vec::new(); let memory_keys: Vec = memories.iter().map(|(_, k)| k.clone()).collect(); for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { let _ = ui_tx.send(UiMessage::Activity(format!( "scoring {}/{}...", mem_idx + 1, memories.len(), ))); let _ = ui_tx.send(UiMessage::Debug(format!( "[training] scoring memory {}/{}: {}", mem_idx + 1, memories.len(), key, ))); // Build entries without this memory let filtered: Vec = context.entries.iter().enumerate() .filter(|(i, _)| *i != *entry_idx) .map(|(_, e)| e.clone()) .collect(); let without = get_response_logprobs(context, &filtered, client, &http, ui_tx).await?; // Compute per-response divergence let mut row = Vec::new(); for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() { // Sum of logprob drops across tokens in this response // Positive = memory helped (logprob was higher with it) let divergence: f64 = base_lps.iter().zip(without_lps.iter()) .map(|(b, w)| b - w) // positive when baseline was more confident .filter(|d| *d > 0.0) // only count where memory helped .sum(); row.push(divergence); } matrix.push(row); } // Compute scores let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); let n_responses = response_indices.len(); let mut response_scores = vec![0.0; n_responses]; for row in &matrix { for (j, &v) in row.iter().enumerate() { if j < n_responses { response_scores[j] += v; } } } let _ = ui_tx.send(UiMessage::Activity(String::new())); // Log summary per memory for (key, score) in &memory_weights { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} → importance {:.1}", key, score, ))); } // Log per-response breakdown for the most important memories let mut sorted_mems: Vec<(usize, &str, f64)> = memory_keys.iter().enumerate() .map(|(i, k)| (i, k.as_str(), memory_weights[i].1)) .collect(); sorted_mems.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)); for (mem_i, key, total) in sorted_mems.iter().take(5) { if *total <= 0.0 { continue; } let row = &matrix[*mem_i]; let top_responses: Vec = row.iter().enumerate() .filter(|(_, v)| **v > 0.1) .map(|(j, v)| format!("resp[{}]={:.1}", j, v)) .collect(); if !top_responses.is_empty() { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} ({:.1}): {}", key, total, top_responses.join(", "), ))); } } Ok(MemoryScore { memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } /// Rough token estimate: ~4 chars per token. const CHARS_PER_TOKEN: usize = 4; /// Get logprobs for all assistant response tokens in a conversation. /// Returns a Vec> — one inner vec per assistant response, /// containing logprobs for each token in that response. /// /// Chunks the conversation into ~50K token segments (rounded to /// assistant message boundaries) to avoid OOM from the logprobs /// tensor allocation. async fn get_response_logprobs( context: &ContextState, entries: &[ConversationEntry], client: &ApiClient, http: &reqwest::Client, ui_tx: &UiSender, ) -> anyhow::Result>> { // Build the fixed prefix (system prompt + personality) let mut prefix = Vec::new(); prefix.push(Message::system(&context.system_prompt)); let ctx = context.render_context_message(); if !ctx.is_empty() { prefix.push(Message::user(ctx)); } let prefix_chars: usize = prefix.iter() .map(|m| m.content_text().len()) .sum(); // Split entries into chunks that fit within the token budget, // each ending at an assistant message boundary. let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN; let budget = max_chunk_chars.saturating_sub(prefix_chars); let chunks = chunk_entries(entries, budget); let mut all_responses: Vec> = Vec::new(); use crate::agent::ui_channel::UiMessage; let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} chunks, prefix={}K chars, budget={}K chars", chunks.len(), prefix_chars / 1024, budget / 1024, ))); for (chunk_idx, chunk) in chunks.iter().enumerate() { let chunk_chars: usize = chunk.iter() .map(|e| e.message().content_text().len()) .sum(); let _ = ui_tx.send(UiMessage::Debug(format!( "[training] chunk {}/{}: {} entries, {}K chars", chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024, ))); } for chunk in &chunks { let mut msgs = prefix.clone(); msgs.extend(chunk.iter().map(|e| e.api_message().clone())); let result = call_prompt_logprobs(&msgs, client, http).await?; all_responses.extend(result); } Ok(all_responses) } /// Split entries into chunks of approximately `budget_chars` each, /// ending at assistant message boundaries. fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec> { let mut chunks = Vec::new(); let mut current = Vec::new(); let mut current_chars = 0; for entry in entries { let entry_chars = entry.message().content_text().len(); current_chars += entry_chars; current.push(entry.clone()); // If over budget and we just added an assistant message, cut here if current_chars >= budget_chars && entry.message().role == Role::Assistant { chunks.push(std::mem::take(&mut current)); current_chars = 0; } } if !current.is_empty() { chunks.push(current); } // If everything fit in one chunk, just return it if chunks.is_empty() { chunks.push(entries.to_vec()); } chunks } /// Make a single prompt_logprobs API call and extract response logprobs. async fn call_prompt_logprobs( msgs: &[Message], client: &ApiClient, http: &reqwest::Client, ) -> anyhow::Result>> { let request = serde_json::json!({ "model": client.model, "messages": msgs, "max_tokens": 1, "prompt_logprobs": 1, "stream": false, }); let response = http .post(format!("{}/chat/completions", client.base_url())) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", client.api_key())) .json(&request) .send() .await?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error") .and_then(|e| e.get("message")) .and_then(|m| m.as_str()) .unwrap_or("unknown error"); anyhow::bail!("HTTP {} from logprobs API: {}", status, msg); } let prompt_logprobs = body.get("prompt_logprobs") .and_then(|v| v.as_array()) .ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?; // Find assistant response boundaries using special tokens // Pattern: <|im_start|> assistant \n [...] response <|im_end|> let mut responses: Vec> = Vec::new(); let mut in_assistant = false; let mut in_think = false; let mut current_response: Vec = Vec::new(); for entry in prompt_logprobs { let Some(obj) = entry.as_object() else { continue }; let first = obj.values().next(); let Some(info) = first.and_then(|v| v.as_object()) else { continue }; let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or(""); let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0); match token { "<|im_start|>" => { in_assistant = false; in_think = false; } "assistant" if !in_assistant => { in_assistant = true; in_think = false; current_response.clear(); } "" if in_assistant => { in_think = true; } "" if in_assistant => { in_think = false; } "<|im_end|>" if in_assistant => { if !current_response.is_empty() { responses.push(std::mem::take(&mut current_response)); } in_assistant = false; } "\n" if in_assistant && current_response.is_empty() => { // Skip the newline right after "assistant" } _ if in_assistant && !in_think => { current_response.push(logprob); } _ => {} } } Ok(responses) }