diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 951557b..b70140d 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -166,6 +166,9 @@ impl ApiClient { Ok((build_response_message(content, tool_calls), usage)) } + pub fn base_url(&self) -> &str { &self.base_url } + pub fn api_key(&self) -> &str { &self.api_key } + /// Return a label for the active backend, used in startup info. pub fn backend_label(&self) -> &str { if self.base_url.contains("openrouter") { diff --git a/src/thought/mod.rs b/src/thought/mod.rs index d19caab..7a2e722 100644 --- a/src/thought/mod.rs +++ b/src/thought/mod.rs @@ -15,6 +15,7 @@ pub mod glob_tool; pub mod grep; pub mod memory; pub mod read; +pub mod training; pub mod write; pub use bash::ProcessTracker; diff --git a/src/thought/training.rs b/src/thought/training.rs new file mode 100644 index 0000000..35f1d27 --- /dev/null +++ b/src/thought/training.rs @@ -0,0 +1,226 @@ +// training.rs — Memory importance scoring via prompt logprobs +// +// Drops each memory from the context one at a time, runs prompt_logprobs +// to see how the model's confidence in its responses changes. Produces +// a divergence matrix: memories × responses. +// +// Row sums = memory importance (for graph weight updates) +// Column sums = response memory-dependence (training candidates) + +use crate::agent::api::ApiClient; +use crate::agent::types::*; +use crate::agent::ui_channel::UiSender; + +/// Result of scoring one conversation's memory usage. +pub struct MemoryScore { + /// memory_key → importance score (sum of divergence across all responses) + pub memory_weights: Vec<(String, f64)>, + /// response_index → memory-dependence score (sum of divergence across all memories) + pub response_scores: Vec, + /// Full matrix: divergence[memory_idx][response_idx] + pub matrix: Vec>, + /// Keys of memories that were scored + pub memory_keys: Vec, +} + +/// Score how important each memory is to the conversation. +/// +/// For each Memory entry in the context, builds a version without it +/// and checks how the model's logprobs change for assistant responses. +pub async fn score_memories( + context: &ContextState, + client: &ApiClient, + ui_tx: &UiSender, +) -> anyhow::Result { + use crate::agent::ui_channel::UiMessage; + + // Identify memory entries and assistant response positions + let memories: Vec<(usize, String)> = context.entries.iter().enumerate() + .filter_map(|(i, e)| match e { + ConversationEntry::Memory { key, .. } => Some((i, key.clone())), + _ => None, + }) + .collect(); + + let response_indices: Vec = context.entries.iter().enumerate() + .filter(|(_, e)| e.message().role == Role::Assistant) + .map(|(i, _)| i) + .collect(); + + if memories.is_empty() || response_indices.is_empty() { + return Ok(MemoryScore { + memory_weights: Vec::new(), + response_scores: Vec::new(), + matrix: Vec::new(), + memory_keys: Vec::new(), + }); + } + + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] scoring {} memories × {} responses", + memories.len(), response_indices.len(), + ))); + + // Baseline: logprobs with all memories present + let baseline = get_response_logprobs(context, &context.entries, client).await?; + + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] baseline: {} response tokens scored", + baseline.iter().map(|r| r.len()).sum::(), + ))); + + // For each memory, drop it and measure divergence + let mut matrix: Vec> = Vec::new(); + let memory_keys: Vec = memories.iter().map(|(_, k)| k.clone()).collect(); + + for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] scoring memory {}/{}: {}", + mem_idx + 1, memories.len(), key, + ))); + + // Build entries without this memory + let filtered: Vec = context.entries.iter().enumerate() + .filter(|(i, _)| *i != *entry_idx) + .map(|(_, e)| e.clone()) + .collect(); + + let without = get_response_logprobs(context, &filtered, client).await?; + + // Compute per-response divergence + let mut row = Vec::new(); + for (resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() { + // Sum of logprob drops across tokens in this response + // Positive = memory helped (logprob was higher with it) + let divergence: f64 = base_lps.iter().zip(without_lps.iter()) + .map(|(b, w)| b - w) // positive when baseline was more confident + .filter(|d| *d > 0.0) // only count where memory helped + .sum(); + row.push(divergence); + } + matrix.push(row); + } + + // Compute scores + let memory_weights: Vec<(String, f64)> = memory_keys.iter() + .zip(matrix.iter()) + .map(|(key, row)| (key.clone(), row.iter().sum())) + .collect(); + + let n_responses = response_indices.len(); + let mut response_scores = vec![0.0; n_responses]; + for row in &matrix { + for (j, &v) in row.iter().enumerate() { + if j < n_responses { + response_scores[j] += v; + } + } + } + + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] done. top memory: {:?}", + memory_weights.iter() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(k, v)| format!("{}: {:.1}", k, v)), + ))); + + Ok(MemoryScore { + memory_weights, + response_scores, + matrix, + memory_keys, + }) +} + +/// Get logprobs for all assistant response tokens in a conversation. +/// Returns a Vec> — one inner vec per assistant response, +/// containing logprobs for each token in that response. +async fn get_response_logprobs( + context: &ContextState, + entries: &[ConversationEntry], + client: &ApiClient, +) -> anyhow::Result>> { + // Assemble messages the same way the runner does + let mut msgs = Vec::new(); + msgs.push(Message::system(&context.system_prompt)); + let ctx = context.render_context_message(); + if !ctx.is_empty() { + msgs.push(Message::user(ctx)); + } + msgs.extend(entries.iter().map(|e| e.api_message().clone())); + + // Call the API with prompt_logprobs + let request = serde_json::json!({ + "model": client.model, + "messages": msgs, + "max_tokens": 1, + "prompt_logprobs": 1, + "stream": false, + }); + + let response = reqwest::Client::new() + .post(format!("{}/chat/completions", client.base_url())) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", client.api_key())) + .json(&request) + .send() + .await?; + + let body: serde_json::Value = response.json().await?; + + if let Some(err) = body.get("error") { + anyhow::bail!("API error: {}", err); + } + + let prompt_logprobs = body.get("prompt_logprobs") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?; + + // Find assistant response boundaries using special tokens + // Pattern: <|im_start|> assistant \n [...] response <|im_end|> + let mut responses: Vec> = Vec::new(); + let mut in_assistant = false; + let mut in_think = false; + let mut current_response: Vec = Vec::new(); + + for entry in prompt_logprobs { + let Some(obj) = entry.as_object() else { continue }; + let first = obj.values().next(); + let Some(info) = first.and_then(|v| v.as_object()) else { continue }; + let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or(""); + let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0); + + match token { + "<|im_start|>" => { + in_assistant = false; + in_think = false; + } + "assistant" if !in_assistant => { + in_assistant = true; + in_think = false; + current_response.clear(); + } + "" if in_assistant => { + in_think = true; + } + "" if in_assistant => { + in_think = false; + } + "<|im_end|>" if in_assistant => { + if !current_response.is_empty() { + responses.push(std::mem::take(&mut current_response)); + } + in_assistant = false; + } + "\n" if in_assistant && current_response.is_empty() => { + // Skip the newline right after "assistant" + } + _ if in_assistant && !in_think => { + current_response.push(logprob); + } + _ => {} + } + } + + Ok(responses) +}