2026-04-03 00:31:57 -04:00
|
|
|
|
// training.rs — Memory importance scoring via /v1/score
|
2026-04-02 22:13:55 -04:00
|
|
|
|
//
|
2026-04-03 00:31:57 -04:00
|
|
|
|
// Drops each memory from the context one at a time, calls the vLLM
|
|
|
|
|
|
// /v1/score endpoint to get logprobs for assistant responses.
|
|
|
|
|
|
// Produces a divergence matrix: memories × responses.
|
2026-04-02 22:13:55 -04:00
|
|
|
|
//
|
|
|
|
|
|
// Row sums = memory importance (for graph weight updates)
|
|
|
|
|
|
// Column sums = response memory-dependence (training candidates)
|
|
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
use std::time::Instant;
|
2026-04-04 00:29:11 -04:00
|
|
|
|
|
|
|
|
|
|
use super::api::ApiClient;
|
|
|
|
|
|
use crate::agent::api::types::*;
|
|
|
|
|
|
use crate::agent::context::{ConversationEntry, ContextState};
|
2026-04-03 17:25:59 -04:00
|
|
|
|
use crate::user::ui_channel::{UiMessage, UiSender};
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
/// Timeout for individual /v1/score API calls.
|
|
|
|
|
|
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
|
|
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
/// Result of scoring one conversation's memory usage.
|
|
|
|
|
|
pub struct MemoryScore {
|
|
|
|
|
|
/// memory_key → importance score (sum of divergence across all responses)
|
|
|
|
|
|
pub memory_weights: Vec<(String, f64)>,
|
|
|
|
|
|
/// response_index → memory-dependence score (sum of divergence across all memories)
|
|
|
|
|
|
pub response_scores: Vec<f64>,
|
|
|
|
|
|
/// Full matrix: divergence[memory_idx][response_idx]
|
|
|
|
|
|
pub matrix: Vec<Vec<f64>>,
|
|
|
|
|
|
/// Keys of memories that were scored
|
|
|
|
|
|
pub memory_keys: Vec<String>,
|
2026-04-03 00:31:57 -04:00
|
|
|
|
/// Conversation entry indices of the assistant responses
|
2026-04-02 22:27:43 -04:00
|
|
|
|
pub response_entry_indices: Vec<usize>,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl MemoryScore {
|
|
|
|
|
|
/// Get the most important memories for a given conversation entry index.
|
|
|
|
|
|
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
|
|
|
|
|
|
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
|
|
|
|
|
|
else { return Vec::new() };
|
|
|
|
|
|
|
|
|
|
|
|
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
|
|
|
|
|
|
.zip(self.matrix.iter())
|
|
|
|
|
|
.filter_map(|(key, row)| {
|
|
|
|
|
|
let score = row.get(resp_idx).copied().unwrap_or(0.0);
|
|
|
|
|
|
if score > 0.01 { Some((key.as_str(), score)) } else { None }
|
|
|
|
|
|
})
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
|
|
result
|
|
|
|
|
|
}
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Score how important each memory is to the conversation.
|
|
|
|
|
|
pub async fn score_memories(
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
ui_tx: &UiSender,
|
|
|
|
|
|
) -> anyhow::Result<MemoryScore> {
|
2026-04-03 17:25:59 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
|
|
|
|
"[training] in score_memories"
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
|
|
|
|
|
|
.filter_map(|(i, e)| match e {
|
|
|
|
|
|
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
|
|
|
|
|
|
_ => None,
|
|
|
|
|
|
})
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
|
|
let response_indices: Vec<usize> = context.entries.iter().enumerate()
|
|
|
|
|
|
.filter(|(_, e)| e.message().role == Role::Assistant)
|
|
|
|
|
|
.map(|(i, _)| i)
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
|
|
if memories.is_empty() || response_indices.is_empty() {
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(
|
|
|
|
|
|
"[training] nothing to score (no memories or no responses)".into()
|
|
|
|
|
|
));
|
2026-04-02 22:13:55 -04:00
|
|
|
|
return Ok(MemoryScore {
|
|
|
|
|
|
memory_weights: Vec::new(),
|
|
|
|
|
|
response_scores: Vec::new(),
|
|
|
|
|
|
matrix: Vec::new(),
|
|
|
|
|
|
memory_keys: Vec::new(),
|
2026-04-02 22:27:43 -04:00
|
|
|
|
response_entry_indices: Vec::new(),
|
2026-04-02 22:13:55 -04:00
|
|
|
|
});
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
|
|
|
|
|
"[scoring {} memories × {} responses]",
|
2026-04-02 22:13:55 -04:00
|
|
|
|
memories.len(), response_indices.len(),
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
2026-04-02 23:11:40 -04:00
|
|
|
|
let http = reqwest::Client::builder()
|
2026-04-03 01:07:47 -04:00
|
|
|
|
.timeout(SCORE_TIMEOUT)
|
2026-04-02 23:11:40 -04:00
|
|
|
|
.pool_max_idle_per_host(2)
|
|
|
|
|
|
.build()
|
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
let all_messages = build_messages(context);
|
|
|
|
|
|
|
|
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
2026-04-03 01:07:47 -04:00
|
|
|
|
"[training] {} messages in context",
|
|
|
|
|
|
all_messages.len(),
|
2026-04-03 00:31:57 -04:00
|
|
|
|
)));
|
|
|
|
|
|
|
|
|
|
|
|
// Baseline: score with all memories present
|
2026-04-03 17:25:59 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into()));
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let payload_size = serde_json::to_string(&all_messages)
|
|
|
|
|
|
.map(|s| s.len()).unwrap_or(0);
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
2026-04-03 01:07:47 -04:00
|
|
|
|
"[training] payload size: {}KB",
|
|
|
|
|
|
payload_size / 1024,
|
|
|
|
|
|
)));
|
|
|
|
|
|
let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into()));
|
|
|
|
|
|
let start = Instant::now();
|
|
|
|
|
|
let baseline = call_score(&http, client, &all_messages).await?;
|
|
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
|
|
|
|
"[training] baseline: {} responses scored in {:.1}s",
|
|
|
|
|
|
baseline.len(), start.elapsed().as_secs_f64(),
|
2026-04-02 22:13:55 -04:00
|
|
|
|
)));
|
|
|
|
|
|
|
|
|
|
|
|
// For each memory, drop it and measure divergence
|
|
|
|
|
|
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
|
|
|
|
|
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let total = memories.len();
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
|
|
|
|
|
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
|
2026-04-02 22:27:43 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Activity(format!(
|
2026-04-03 01:07:47 -04:00
|
|
|
|
"scoring {}/{}: {}...", mem_idx + 1, total, key,
|
2026-04-02 22:13:55 -04:00
|
|
|
|
)));
|
|
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let start = Instant::now();
|
2026-04-03 00:31:57 -04:00
|
|
|
|
let filtered_messages = build_messages_without(context, *entry_idx);
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let without = call_score(&http, client, &filtered_messages).await;
|
|
|
|
|
|
|
|
|
|
|
|
match without {
|
|
|
|
|
|
Ok(without) => {
|
|
|
|
|
|
let elapsed = start.elapsed().as_secs_f64();
|
2026-04-03 17:25:59 -04:00
|
|
|
|
// Match scores by position (nth scored response),
|
|
|
|
|
|
// not message_index — indices shift when a memory
|
|
|
|
|
|
// is removed from the conversation.
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let mut row = Vec::new();
|
2026-04-03 17:25:59 -04:00
|
|
|
|
for (i, base_score) in baseline.iter().enumerate() {
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let base_lp = base_score.total_logprob;
|
2026-04-03 17:25:59 -04:00
|
|
|
|
let without_lp = without.get(i)
|
2026-04-03 01:07:47 -04:00
|
|
|
|
.map(|s| s.total_logprob)
|
|
|
|
|
|
.unwrap_or(base_lp);
|
|
|
|
|
|
let divergence = (base_lp - without_lp).max(0.0);
|
|
|
|
|
|
row.push(divergence);
|
|
|
|
|
|
}
|
|
|
|
|
|
let importance: f64 = row.iter().sum();
|
|
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
|
|
|
|
"[training] {}/{} {} → {:.1} ({:.1}s)",
|
|
|
|
|
|
mem_idx + 1, total, key, importance, elapsed,
|
|
|
|
|
|
)));
|
|
|
|
|
|
matrix.push(row);
|
|
|
|
|
|
}
|
|
|
|
|
|
Err(e) => {
|
|
|
|
|
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
|
|
|
|
"[training] {}/{} {} FAILED: {:#}",
|
|
|
|
|
|
mem_idx + 1, total, key, e,
|
|
|
|
|
|
)));
|
|
|
|
|
|
// Push zero row so matrix stays aligned
|
|
|
|
|
|
matrix.push(vec![0.0; baseline.len()]);
|
|
|
|
|
|
}
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
|
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
// Compute scores
|
|
|
|
|
|
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
|
|
|
|
|
.zip(matrix.iter())
|
|
|
|
|
|
.map(|(key, row)| (key.clone(), row.iter().sum()))
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
|
|
let n_responses = response_indices.len();
|
|
|
|
|
|
let mut response_scores = vec![0.0; n_responses];
|
|
|
|
|
|
for row in &matrix {
|
|
|
|
|
|
for (j, &v) in row.iter().enumerate() {
|
|
|
|
|
|
if j < n_responses {
|
|
|
|
|
|
response_scores[j] += v;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
|
|
|
|
|
"[scoring complete: {} memories scored]",
|
|
|
|
|
|
memory_keys.len(),
|
|
|
|
|
|
)));
|
2026-04-02 22:27:43 -04:00
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
Ok(MemoryScore {
|
|
|
|
|
|
memory_weights,
|
|
|
|
|
|
response_scores,
|
|
|
|
|
|
matrix,
|
|
|
|
|
|
memory_keys,
|
2026-04-02 22:27:43 -04:00
|
|
|
|
response_entry_indices: response_indices,
|
2026-04-02 22:13:55 -04:00
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
/// Score response from the /v1/score endpoint.
|
|
|
|
|
|
#[derive(serde::Deserialize)]
|
|
|
|
|
|
struct ScoreMessageResult {
|
2026-04-03 18:46:14 -04:00
|
|
|
|
#[allow(dead_code)]
|
2026-04-03 00:31:57 -04:00
|
|
|
|
message_index: usize,
|
|
|
|
|
|
total_logprob: f64,
|
|
|
|
|
|
}
|
2026-04-02 22:35:29 -04:00
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
#[derive(serde::Deserialize)]
|
|
|
|
|
|
struct ScoreApiResponse {
|
|
|
|
|
|
scores: Vec<ScoreMessageResult>,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Build the messages array for the /v1/score endpoint from ContextState.
|
|
|
|
|
|
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> {
|
|
|
|
|
|
let mut msgs = Vec::new();
|
|
|
|
|
|
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let ctx = context.render_context_message();
|
|
|
|
|
|
if !ctx.is_empty() {
|
2026-04-03 00:31:57 -04:00
|
|
|
|
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
2026-04-02 22:47:44 -04:00
|
|
|
|
}
|
2026-04-03 00:31:57 -04:00
|
|
|
|
for entry in &context.entries {
|
|
|
|
|
|
let m = entry.api_message();
|
|
|
|
|
|
msgs.push(serde_json::json!({
|
|
|
|
|
|
"role": m.role_str(),
|
|
|
|
|
|
"content": m.content_text(),
|
|
|
|
|
|
}));
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
2026-04-03 00:31:57 -04:00
|
|
|
|
msgs
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
/// Build messages with one entry removed.
|
|
|
|
|
|
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
|
|
|
|
|
|
let mut msgs = Vec::new();
|
|
|
|
|
|
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
|
|
|
|
|
let ctx = context.render_context_message();
|
|
|
|
|
|
if !ctx.is_empty() {
|
|
|
|
|
|
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
2026-04-03 00:31:57 -04:00
|
|
|
|
for (i, entry) in context.entries.iter().enumerate() {
|
|
|
|
|
|
if i == skip_idx { continue; }
|
|
|
|
|
|
let m = entry.api_message();
|
|
|
|
|
|
msgs.push(serde_json::json!({
|
|
|
|
|
|
"role": m.role_str(),
|
|
|
|
|
|
"content": m.content_text(),
|
|
|
|
|
|
}));
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
2026-04-03 00:31:57 -04:00
|
|
|
|
msgs
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
/// Call the /v1/score endpoint and return per-message logprobs.
|
|
|
|
|
|
async fn call_score(
|
2026-04-02 23:11:40 -04:00
|
|
|
|
http: &reqwest::Client,
|
2026-04-03 00:31:57 -04:00
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
messages: &[serde_json::Value],
|
|
|
|
|
|
) -> anyhow::Result<Vec<ScoreMessageResult>> {
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let request = serde_json::json!({
|
|
|
|
|
|
"model": client.model,
|
2026-04-03 00:31:57 -04:00
|
|
|
|
"messages": messages,
|
|
|
|
|
|
"logprobs": 1,
|
2026-04-02 22:13:55 -04:00
|
|
|
|
});
|
|
|
|
|
|
|
2026-04-02 23:11:40 -04:00
|
|
|
|
let response = http
|
2026-04-03 00:31:57 -04:00
|
|
|
|
.post(format!("{}/score", client.base_url()))
|
2026-04-02 22:13:55 -04:00
|
|
|
|
.header("Content-Type", "application/json")
|
|
|
|
|
|
.header("Authorization", format!("Bearer {}", client.api_key()))
|
|
|
|
|
|
.json(&request)
|
|
|
|
|
|
.send()
|
2026-04-03 01:07:47 -04:00
|
|
|
|
.await
|
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
|
if e.is_timeout() {
|
|
|
|
|
|
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
|
|
|
|
|
|
} else {
|
|
|
|
|
|
anyhow::anyhow!("score request failed: {}", e)
|
|
|
|
|
|
}
|
|
|
|
|
|
})?;
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-02 22:47:44 -04:00
|
|
|
|
let status = response.status();
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let body: serde_json::Value = response.json().await?;
|
|
|
|
|
|
|
2026-04-02 22:47:44 -04:00
|
|
|
|
if !status.is_success() {
|
|
|
|
|
|
let msg = body.get("error")
|
2026-04-03 00:31:57 -04:00
|
|
|
|
.and_then(|e| e.as_str())
|
2026-04-02 22:47:44 -04:00
|
|
|
|
.unwrap_or("unknown error");
|
2026-04-03 01:07:47 -04:00
|
|
|
|
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Check for error in body (score endpoint returns dict on error)
|
|
|
|
|
|
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
|
|
|
|
|
anyhow::bail!("score API error: {}", err);
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
let result: ScoreApiResponse = serde_json::from_value(body)
|
|
|
|
|
|
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
2026-04-03 00:31:57 -04:00
|
|
|
|
Ok(result.scores)
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|