From ce045684545d4a62fb86d179303c75d6f93bb600 Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Sat, 4 Apr 2026 01:33:31 -0400 Subject: [PATCH] training: add memory_score() and finetune_score() Separate the scoring into two distinct functions: - memory_score(key): scores one memory's importance by measuring divergence in the 50 messages after it was surfaced. Two API calls (baseline vs without that memory). - finetune_score(count): scores recent messages with all memories stripped to identify fine-tuning candidates. Responses with high divergence depend on memories the model hasn't internalized yet. The existing score_memories() with the full NxM matrix is preserved for the debug screen. Co-Authored-By: Proof of Concept --- src/agent/training.rs | 419 +++++++++++++++++++++++------------------- 1 file changed, 225 insertions(+), 194 deletions(-) diff --git a/src/agent/training.rs b/src/agent/training.rs index 9d029f5..9fe419a 100644 --- a/src/agent/training.rs +++ b/src/agent/training.rs @@ -1,38 +1,166 @@ // training.rs — Memory importance scoring via /v1/score // -// Drops each memory from the context one at a time, calls the vLLM -// /v1/score endpoint to get logprobs for assistant responses. -// Produces a divergence matrix: memories × responses. +// Three scoring modes, all built on the same call_score() primitive: // -// Row sums = memory importance (for graph weight updates) -// Column sums = response memory-dependence (training candidates) - -use std::time::Instant; +// score_memories() — Full N×M matrix (memories × responses) for the +// debug screen. Expensive: N+1 API calls. +// +// memory_score() — Single memory importance. Scores the 50 messages +// after it was surfaced, with/without that memory. +// 2 API calls. +// +// finetune_score() — Identifies training candidates. Scores recent +// messages with all memories stripped. Responses +// with high divergence depend on memories the model +// hasn't internalized. 2 API calls. use super::api::ApiClient; use crate::agent::api::types::*; use crate::agent::context::{ConversationEntry, ContextState}; use crate::user::ui_channel::{UiMessage, UiSender}; -/// Timeout for individual /v1/score API calls. const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +// ── Message building ──────────────────────────────────────────── + +/// What to filter when building the message array for scoring. +enum Filter<'a> { + None, + SkipIndex(usize), + SkipKey(&'a str), + SkipAllMemories, +} + +/// Build the messages array for a scoring call. +/// +/// Always includes system prompt + context message as prefix, then +/// entries from `range` filtered by `filter`. +fn build_messages( + context: &ContextState, + range: std::ops::Range, + filter: Filter, +) -> Vec { + let mut msgs = vec![ + serde_json::json!({"role": "system", "content": &context.system_prompt}), + ]; + let ctx = context.render_context_message(); + if !ctx.is_empty() { + msgs.push(serde_json::json!({"role": "user", "content": ctx})); + } + for i in range { + let entry = &context.entries[i]; + let skip = match &filter { + Filter::None => false, + Filter::SkipIndex(idx) => i == *idx, + Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == key), + Filter::SkipAllMemories => entry.is_memory(), + }; + if skip { continue; } + let m = entry.api_message(); + msgs.push(serde_json::json!({ + "role": m.role_str(), + "content": m.content_text(), + })); + } + msgs +} + +// ── Score API ─────────────────────────────────────────────────── + +#[derive(serde::Deserialize)] +struct ScoreResult { + total_logprob: f64, +} + +#[derive(serde::Deserialize)] +struct ScoreResponse { + scores: Vec, +} + +fn http_client() -> reqwest::Client { + reqwest::Client::builder() + .timeout(SCORE_TIMEOUT) + .pool_max_idle_per_host(2) + .build() + .unwrap_or_default() +} + +async fn call_score( + http: &reqwest::Client, + client: &ApiClient, + messages: &[serde_json::Value], +) -> anyhow::Result> { + let response = http + .post(format!("{}/score", client.base_url())) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", client.api_key())) + .json(&serde_json::json!({ + "model": client.model, + "messages": messages, + "logprobs": 1, + })) + .send() + .await + .map_err(|e| if e.is_timeout() { + anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs()) + } else { + anyhow::anyhow!("score request failed: {}", e) + })?; + + let status = response.status(); + let body: serde_json::Value = response.json().await?; + + if !status.is_success() { + let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); + anyhow::bail!("score API HTTP {}: {}", status, msg); + } + if let Some(err) = body.get("error").and_then(|e| e.as_str()) { + anyhow::bail!("score API error: {}", err); + } + + let result: ScoreResponse = serde_json::from_value(body) + .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; + Ok(result.scores) +} + +/// Compute per-position logprob divergence: how much worse the model +/// scores each response without something vs with it. +fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { + baseline.iter().enumerate() + .map(|(i, base)| { + let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); + (base.total_logprob - without_lp).max(0.0) + }) + .collect() +} + +/// Score two message sets and return total divergence. +async fn score_divergence( + http: &reqwest::Client, + client: &ApiClient, + context: &ContextState, + range: std::ops::Range, + filter: Filter<'_>, +) -> anyhow::Result<(Vec, Vec)> { + let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?; + let without = call_score(http, client, &build_messages(context, range, filter)).await?; + let divs = divergence(&baseline, &without); + Ok((divs, baseline)) +} + +// ── Full matrix scoring (debug screen) ────────────────────────── + /// Result of scoring one conversation's memory usage. pub struct MemoryScore { - /// memory_key → importance score (sum of divergence across all responses) pub memory_weights: Vec<(String, f64)>, - /// response_index → memory-dependence score (sum of divergence across all memories) pub response_scores: Vec, /// Full matrix: divergence[memory_idx][response_idx] pub matrix: Vec>, - /// Keys of memories that were scored pub memory_keys: Vec, - /// Conversation entry indices of the assistant responses pub response_entry_indices: Vec, } impl MemoryScore { - /// Get the most important memories for a given conversation entry index. pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) else { return Vec::new() }; @@ -49,117 +177,57 @@ impl MemoryScore { } } -/// Score how important each memory is to the conversation. +/// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( context: &ContextState, client: &ApiClient, ui_tx: &UiSender, ) -> anyhow::Result { - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] in score_memories" - ))); - - let memories: Vec<(usize, String)> = context.entries.iter().enumerate() - .filter_map(|(i, e)| match e { - ConversationEntry::Memory { key, .. } => Some((i, key.clone())), + let mut memory_keys: Vec = context.entries.iter() + .filter_map(|e| match e { + ConversationEntry::Memory { key, .. } => Some(key.clone()), _ => None, }) .collect(); + memory_keys.dedup(); let response_indices: Vec = context.entries.iter().enumerate() .filter(|(_, e)| e.message().role == Role::Assistant) .map(|(i, _)| i) .collect(); - if memories.is_empty() || response_indices.is_empty() { - let _ = ui_tx.send(UiMessage::Debug( - "[training] nothing to score (no memories or no responses)".into() - )); + if memory_keys.is_empty() || response_indices.is_empty() { return Ok(MemoryScore { - memory_weights: Vec::new(), - response_scores: Vec::new(), - matrix: Vec::new(), - memory_keys: Vec::new(), + memory_weights: Vec::new(), response_scores: Vec::new(), + matrix: Vec::new(), memory_keys: Vec::new(), response_entry_indices: Vec::new(), }); } let _ = ui_tx.send(UiMessage::Info(format!( - "[scoring {} memories × {} responses]", - memories.len(), response_indices.len(), + "[scoring {} memories × {} responses]", memory_keys.len(), response_indices.len(), ))); - let http = reqwest::Client::builder() - .timeout(SCORE_TIMEOUT) - .pool_max_idle_per_host(2) - .build() - .unwrap_or_default(); + let http = http_client(); + let range = 0..context.entries.len(); - let all_messages = build_messages(context); - - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {} messages in context", - all_messages.len(), - ))); - - // Baseline: score with all memories present - let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into())); - let payload_size = serde_json::to_string(&all_messages) - .map(|s| s.len()).unwrap_or(0); - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] payload size: {}KB", - payload_size / 1024, - ))); let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into())); - let start = Instant::now(); - let baseline = call_score(&http, client, &all_messages).await?; - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] baseline: {} responses scored in {:.1}s", - baseline.len(), start.elapsed().as_secs_f64(), - ))); + let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?; - // For each memory, drop it and measure divergence + let total = memory_keys.len(); let mut matrix: Vec> = Vec::new(); - let memory_keys: Vec = memories.iter().map(|(_, k)| k.clone()).collect(); - let total = memories.len(); - for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { + for (mem_idx, key) in memory_keys.iter().enumerate() { let _ = ui_tx.send(UiMessage::Activity(format!( "scoring {}/{}: {}...", mem_idx + 1, total, key, ))); - - let start = Instant::now(); - let filtered_messages = build_messages_without(context, *entry_idx); - let without = call_score(&http, client, &filtered_messages).await; - - match without { - Ok(without) => { - let elapsed = start.elapsed().as_secs_f64(); - // Match scores by position (nth scored response), - // not message_index — indices shift when a memory - // is removed from the conversation. - let mut row = Vec::new(); - for (i, base_score) in baseline.iter().enumerate() { - let base_lp = base_score.total_logprob; - let without_lp = without.get(i) - .map(|s| s.total_logprob) - .unwrap_or(base_lp); - let divergence = (base_lp - without_lp).max(0.0); - row.push(divergence); - } - let importance: f64 = row.iter().sum(); - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {}/{} {} → {:.1} ({:.1}s)", - mem_idx + 1, total, key, importance, elapsed, - ))); - matrix.push(row); - } + let msgs = build_messages(context, range.clone(), Filter::SkipKey(key)); + match call_score(&http, client, &msgs).await { + Ok(without) => matrix.push(divergence(&baseline, &without)), Err(e) => { let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {}/{} {} FAILED: {:#}", - mem_idx + 1, total, key, e, + "[training] {} FAILED: {:#}", key, e, ))); - // Push zero row so matrix stays aligned matrix.push(vec![0.0; baseline.len()]); } } @@ -167,129 +235,92 @@ pub async fn score_memories( let _ = ui_tx.send(UiMessage::Activity(String::new())); - // Compute scores let memory_weights: Vec<(String, f64)> = memory_keys.iter() .zip(matrix.iter()) .map(|(key, row)| (key.clone(), row.iter().sum())) .collect(); - let n_responses = response_indices.len(); - let mut response_scores = vec![0.0; n_responses]; + let mut response_scores = vec![0.0; response_indices.len()]; for row in &matrix { for (j, &v) in row.iter().enumerate() { - if j < n_responses { - response_scores[j] += v; - } + if j < response_scores.len() { response_scores[j] += v; } } } - let _ = ui_tx.send(UiMessage::Info(format!( - "[scoring complete: {} memories scored]", - memory_keys.len(), - ))); - Ok(MemoryScore { - memory_weights, - response_scores, - matrix, - memory_keys, + memory_weights, response_scores, matrix, memory_keys, response_entry_indices: response_indices, }) } -/// Score response from the /v1/score endpoint. -#[derive(serde::Deserialize)] -struct ScoreMessageResult { - #[allow(dead_code)] - message_index: usize, - total_logprob: f64, -} +// ── Single memory scoring ─────────────────────────────────────── -#[derive(serde::Deserialize)] -struct ScoreApiResponse { - scores: Vec, -} - -/// Build the messages array for the /v1/score endpoint from ContextState. -fn build_messages(context: &ContextState) -> Vec { - let mut msgs = Vec::new(); - msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); - let ctx = context.render_context_message(); - if !ctx.is_empty() { - msgs.push(serde_json::json!({"role": "user", "content": ctx})); - } - for entry in &context.entries { - let m = entry.api_message(); - msgs.push(serde_json::json!({ - "role": m.role_str(), - "content": m.content_text(), - })); - } - msgs -} - -/// Build messages with one entry removed. -fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec { - let mut msgs = Vec::new(); - msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt})); - let ctx = context.render_context_message(); - if !ctx.is_empty() { - msgs.push(serde_json::json!({"role": "user", "content": ctx})); - } - for (i, entry) in context.entries.iter().enumerate() { - if i == skip_idx { continue; } - let m = entry.api_message(); - msgs.push(serde_json::json!({ - "role": m.role_str(), - "content": m.content_text(), - })); - } - msgs -} - -/// Call the /v1/score endpoint and return per-message logprobs. -async fn call_score( - http: &reqwest::Client, +/// Score how important a single memory is to the conversation. +/// +/// Scores the 50 messages after the memory was surfaced — the window +/// where it could have influenced responses. Returns the sum of +/// divergence, or 0.0 if the memory isn't in the conversation. +pub async fn score_memory( + context: &ContextState, + key: &str, client: &ApiClient, - messages: &[serde_json::Value], -) -> anyhow::Result> { - let request = serde_json::json!({ - "model": client.model, - "messages": messages, - "logprobs": 1, - }); + ui_tx: &UiSender, +) -> anyhow::Result { + const WINDOW: usize = 50; - let response = http - .post(format!("{}/score", client.base_url())) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", client.api_key())) - .json(&request) - .send() - .await - .map_err(|e| { - if e.is_timeout() { - anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs()) - } else { - anyhow::anyhow!("score request failed: {}", e) - } - })?; + let first_pos = match context.entries.iter().position(|e| { + matches!(e, ConversationEntry::Memory { key: k, .. } if k == key) + }) { + Some(p) => p, + None => return Ok(0.0), + }; - let status = response.status(); - let body: serde_json::Value = response.json().await?; - - if !status.is_success() { - let msg = body.get("error") - .and_then(|e| e.as_str()) - .unwrap_or("unknown error"); - anyhow::bail!("score API HTTP {}: {}", status, msg); + let range = first_pos..(first_pos + WINDOW).min(context.entries.len()); + if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) { + return Ok(0.0); } - // Check for error in body (score endpoint returns dict on error) - if let Some(err) = body.get("error").and_then(|e| e.as_str()) { - anyhow::bail!("score API error: {}", err); - } + let http = http_client(); + let _ = ui_tx.send(UiMessage::Activity(format!("scoring memory: {}...", key))); + let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?; + let _ = ui_tx.send(UiMessage::Activity(String::new())); - let result: ScoreApiResponse = serde_json::from_value(body) - .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; - Ok(result.scores) + Ok(divs.iter().sum()) +} + +// ── Fine-tuning scoring ───────────────────────────────────────── + +/// Score which recent responses are candidates for fine-tuning. +/// +/// Removes all memories and scores the most recent `count` messages. +/// Responses with high divergence depend on memories the model hasn't +/// internalized — these are fine-tuning candidates. +/// +/// Returns (entry_index, divergence) pairs, sorted by divergence descending. +pub async fn score_finetune( + context: &ContextState, + count: usize, + client: &ApiClient, + ui_tx: &UiSender, +) -> anyhow::Result> { + let range = context.entries.len().saturating_sub(count)..context.entries.len(); + + let response_positions: Vec = range.clone() + .filter(|&i| context.entries[i].message().role == Role::Assistant) + .collect(); + if response_positions.is_empty() { + return Ok(Vec::new()); + } + + let http = http_client(); + let _ = ui_tx.send(UiMessage::Activity("scoring for fine-tuning...".into())); + let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?; + let _ = ui_tx.send(UiMessage::Activity(String::new())); + + let mut results: Vec<(usize, f64)> = response_positions.iter() + .enumerate() + .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) + .collect(); + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + Ok(results) }