From e8c3ed3d965008123a288e48a7b642697e6f41eb Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Fri, 3 Apr 2026 00:31:57 -0400 Subject: [PATCH] switch memory scoring to /v1/score endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace prompt_logprobs-based scoring with the new vLLM /v1/score endpoint. Much simpler: one API call per memory drop, returns per-message total_logprob directly. No chunking needed, no OOM risk — the endpoint only computes logits for scored tokens. Co-Authored-By: Proof of Concept --- src/agent/types.rs | 9 ++ src/thought/training.rs | 293 ++++++++++++---------------------------- 2 files changed, 99 insertions(+), 203 deletions(-) diff --git a/src/agent/types.rs b/src/agent/types.rs index a963556..0452a75 100644 --- a/src/agent/types.rs +++ b/src/agent/types.rs @@ -228,6 +228,15 @@ impl Message { self.content.as_ref().map_or("", |c| c.as_text()) } + pub fn role_str(&self) -> &str { + match self.role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + } + } + fn now() -> Option { Some(Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true)) } diff --git a/src/thought/training.rs b/src/thought/training.rs index 747bfcf..3399317 100644 --- a/src/thought/training.rs +++ b/src/thought/training.rs @@ -1,15 +1,15 @@ -// training.rs — Memory importance scoring via prompt logprobs +// training.rs — Memory importance scoring via /v1/score // -// Drops each memory from the context one at a time, runs prompt_logprobs -// to see how the model's confidence in its responses changes. Produces -// a divergence matrix: memories × responses. +// Drops each memory from the context one at a time, calls the vLLM +// /v1/score endpoint to get logprobs for assistant responses. +// Produces a divergence matrix: memories × responses. // // Row sums = memory importance (for graph weight updates) // Column sums = response memory-dependence (training candidates) use crate::agent::api::ApiClient; use crate::agent::types::*; -use crate::agent::ui_channel::UiSender; +use crate::agent::ui_channel::{UiMessage, UiSender}; /// Result of scoring one conversation's memory usage. pub struct MemoryScore { @@ -21,13 +21,12 @@ pub struct MemoryScore { pub matrix: Vec>, /// Keys of memories that were scored pub memory_keys: Vec, - /// Conversation entry indices of the assistant responses (maps response_idx → entry_idx) + /// Conversation entry indices of the assistant responses pub response_entry_indices: Vec, } impl MemoryScore { /// Get the most important memories for a given conversation entry index. - /// Returns (memory_key, divergence_score) sorted by importance. pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; @@ -45,17 +44,11 @@ impl MemoryScore { } /// Score how important each memory is to the conversation. -/// -/// For each Memory entry in the context, builds a version without it -/// and checks how the model's logprobs change for assistant responses. pub async fn score_memories( context: &ContextState, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { - use crate::agent::ui_channel::UiMessage; - - // Identify memory entries and assistant response positions let memories: Vec<(usize, String)> = context.entries.iter().enumerate() .filter_map(|(i, e)| match e { ConversationEntry::Memory { key, .. } => Some((i, key.clone())), @@ -79,22 +72,32 @@ pub async fn score_memories( } let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] scoring {} memories × {} responses", + "[training] scoring {} memories × {} responses via /v1/score", memories.len(), response_indices.len(), ))); - // Shared HTTP client for connection reuse across all scoring calls let http = reqwest::Client::builder() .pool_max_idle_per_host(2) .build() .unwrap_or_default(); - // Baseline: logprobs with all memories present - let baseline = get_response_logprobs(context, &context.entries, client, &http, ui_tx).await?; + // Build the messages array from context + let all_messages = build_messages(context); + + let roles: Vec<&str> = all_messages.iter() + .map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or("?")) + .collect(); + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] sending {} messages, roles: {:?}", + all_messages.len(), roles, + ))); + + // Baseline: score with all memories present + let baseline = call_score(&http, client, &all_messages, ui_tx).await?; let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] baseline: {} response tokens scored", - baseline.iter().map(|r| r.len()).sum::(), + "[training] baseline: {} messages scored", + baseline.len(), ))); // For each memory, drop it and measure divergence @@ -110,28 +113,27 @@ pub async fn score_memories( mem_idx + 1, memories.len(), key, ))); - // Build entries without this memory - let filtered: Vec = context.entries.iter().enumerate() - .filter(|(i, _)| *i != *entry_idx) - .map(|(_, e)| e.clone()) - .collect(); + // Build messages without this memory + let filtered_messages = build_messages_without(context, *entry_idx); + let without = call_score(&http, client, &filtered_messages, ui_tx).await?; - let without = get_response_logprobs(context, &filtered, client, &http, ui_tx).await?; - - // Compute per-response divergence + // Match scores by message index and compute divergence let mut row = Vec::new(); - for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() { - // Sum of logprob drops across tokens in this response + for base_score in &baseline { + let base_lp = base_score.total_logprob; + let without_lp = without.iter() + .find(|s| s.message_index == base_score.message_index) + .map(|s| s.total_logprob) + .unwrap_or(base_lp); // Positive = memory helped (logprob was higher with it) - let divergence: f64 = base_lps.iter().zip(without_lps.iter()) - .map(|(b, w)| b - w) // positive when baseline was more confident - .filter(|d| *d > 0.0) // only count where memory helped - .sum(); + let divergence = (base_lp - without_lp).max(0.0); row.push(divergence); } matrix.push(row); } + let _ = ui_tx.send(UiMessage::Activity(String::new())); + // Compute scores let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) @@ -148,35 +150,13 @@ pub async fn score_memories( } } - let _ = ui_tx.send(UiMessage::Activity(String::new())); - - // Log summary per memory + // Log summary for (key, score) in &memory_weights { let _ = ui_tx.send(UiMessage::Debug(format!( "[training] {} → importance {:.1}", key, score, ))); } - // Log per-response breakdown for the most important memories - let mut sorted_mems: Vec<(usize, &str, f64)> = memory_keys.iter().enumerate() - .map(|(i, k)| (i, k.as_str(), memory_weights[i].1)) - .collect(); - sorted_mems.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)); - - for (mem_i, key, total) in sorted_mems.iter().take(5) { - if *total <= 0.0 { continue; } - let row = &matrix[*mem_i]; - let top_responses: Vec = row.iter().enumerate() - .filter(|(_, v)| **v > 0.1) - .map(|(j, v)| format!("resp[{}]={:.1}", j, v)) - .collect(); - if !top_responses.is_empty() { - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {} ({:.1}): {}", key, total, top_responses.join(", "), - ))); - } - } - Ok(MemoryScore { memory_weights, response_scores, @@ -186,116 +166,70 @@ pub async fn score_memories( }) } -/// Rough token estimate: ~4 chars per token. -const CHARS_PER_TOKEN: usize = 4; +/// Score response from the /v1/score endpoint. +#[derive(serde::Deserialize)] +struct ScoreMessageResult { + message_index: usize, + total_logprob: f64, +} -/// Get logprobs for all assistant response tokens in a conversation. -/// Returns a Vec> — one inner vec per assistant response, -/// containing logprobs for each token in that response. -/// -/// Chunks the conversation into ~50K token segments (rounded to -/// assistant message boundaries) to avoid OOM from the logprobs -/// tensor allocation. -async fn get_response_logprobs( - context: &ContextState, - entries: &[ConversationEntry], - client: &ApiClient, - http: &reqwest::Client, - ui_tx: &UiSender, -) -> anyhow::Result>> { - // Build the fixed prefix (system prompt + personality) - let mut prefix = Vec::new(); - prefix.push(Message::system(&context.system_prompt)); +#[derive(serde::Deserialize)] +struct ScoreApiResponse { + scores: Vec, +} + +/// Build the messages array for the /v1/score endpoint from ContextState. +fn build_messages(context: &ContextState) -> Vec { + let mut msgs = Vec::new(); + msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); let ctx = context.render_context_message(); if !ctx.is_empty() { - prefix.push(Message::user(ctx)); + msgs.push(serde_json::json!({"role": "user", "content": ctx})); } - let prefix_chars: usize = prefix.iter() - .map(|m| m.content_text().len()) - .sum(); - - // Split entries into chunks that fit within the token budget, - // each ending at an assistant message boundary. - let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN; - let budget = max_chunk_chars.saturating_sub(prefix_chars); - let chunks = chunk_entries(entries, budget); - - let mut all_responses: Vec> = Vec::new(); - - use crate::agent::ui_channel::UiMessage; - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {} chunks, prefix={}K chars, budget={}K chars", - chunks.len(), prefix_chars / 1024, budget / 1024, - ))); - - for (chunk_idx, chunk) in chunks.iter().enumerate() { - let chunk_chars: usize = chunk.iter() - .map(|e| e.message().content_text().len()) - .sum(); - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] chunk {}/{}: {} entries, {}K chars", - chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024, - ))); + for entry in &context.entries { + let m = entry.api_message(); + msgs.push(serde_json::json!({ + "role": m.role_str(), + "content": m.content_text(), + })); } - - for chunk in &chunks { - let mut msgs = prefix.clone(); - msgs.extend(chunk.iter().map(|e| e.api_message().clone())); - - let result = call_prompt_logprobs(&msgs, client, http).await?; - all_responses.extend(result); - } - - Ok(all_responses) + msgs } -/// Split entries into chunks of approximately `budget_chars` each, -/// ending at assistant message boundaries. -fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec> { - let mut chunks = Vec::new(); - let mut current = Vec::new(); - let mut current_chars = 0; - - for entry in entries { - let entry_chars = entry.message().content_text().len(); - current_chars += entry_chars; - current.push(entry.clone()); - - // If over budget and we just added an assistant message, cut here - if current_chars >= budget_chars && entry.message().role == Role::Assistant { - chunks.push(std::mem::take(&mut current)); - current_chars = 0; - } +/// Build messages with one entry removed. +fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec { + let mut msgs = Vec::new(); + msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); + let ctx = context.render_context_message(); + if !ctx.is_empty() { + msgs.push(serde_json::json!({"role": "user", "content": ctx})); } - - if !current.is_empty() { - chunks.push(current); + for (i, entry) in context.entries.iter().enumerate() { + if i == skip_idx { continue; } + let m = entry.api_message(); + msgs.push(serde_json::json!({ + "role": m.role_str(), + "content": m.content_text(), + })); } - - // If everything fit in one chunk, just return it - if chunks.is_empty() { - chunks.push(entries.to_vec()); - } - - chunks + msgs } -/// Make a single prompt_logprobs API call and extract response logprobs. -async fn call_prompt_logprobs( - msgs: &[Message], - client: &ApiClient, +/// Call the /v1/score endpoint and return per-message logprobs. +async fn call_score( http: &reqwest::Client, -) -> anyhow::Result>> { + client: &ApiClient, + messages: &[serde_json::Value], + ui_tx: &UiSender, +) -> anyhow::Result> { let request = serde_json::json!({ "model": client.model, - "messages": msgs, - "max_tokens": 1, - "prompt_logprobs": 1, - "stream": false, + "messages": messages, + "logprobs": 1, }); let response = http - .post(format!("{}/chat/completions", client.base_url())) + .post(format!("{}/score", client.base_url())) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", client.api_key())) .json(&request) @@ -307,61 +241,14 @@ async fn call_prompt_logprobs( if !status.is_success() { let msg = body.get("error") - .and_then(|e| e.get("message")) - .and_then(|m| m.as_str()) + .and_then(|e| e.as_str()) .unwrap_or("unknown error"); - anyhow::bail!("HTTP {} from logprobs API: {}", status, msg); + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] score API error: {}", msg, + ))); + anyhow::bail!("score API error: {}", msg); } - let prompt_logprobs = body.get("prompt_logprobs") - .and_then(|v| v.as_array()) - .ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?; - - // Find assistant response boundaries using special tokens - // Pattern: <|im_start|> assistant \n [...] response <|im_end|> - let mut responses: Vec> = Vec::new(); - let mut in_assistant = false; - let mut in_think = false; - let mut current_response: Vec = Vec::new(); - - for entry in prompt_logprobs { - let Some(obj) = entry.as_object() else { continue }; - let first = obj.values().next(); - let Some(info) = first.and_then(|v| v.as_object()) else { continue }; - let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or(""); - let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0); - - match token { - "<|im_start|>" => { - in_assistant = false; - in_think = false; - } - "assistant" if !in_assistant => { - in_assistant = true; - in_think = false; - current_response.clear(); - } - "" if in_assistant => { - in_think = true; - } - "" if in_assistant => { - in_think = false; - } - "<|im_end|>" if in_assistant => { - if !current_response.is_empty() { - responses.push(std::mem::take(&mut current_response)); - } - in_assistant = false; - } - "\n" if in_assistant && current_response.is_empty() => { - // Skip the newline right after "assistant" - } - _ if in_assistant && !in_think => { - current_response.push(logprob); - } - _ => {} - } - } - - Ok(responses) + let result: ScoreApiResponse = serde_json::from_value(body)?; + Ok(result.scores) }