training: add memory_score() and finetune_score()

Separate the scoring into two distinct functions:

- memory_score(key): scores one memory's importance by measuring
  divergence in the 50 messages after it was surfaced. Two API calls
  (baseline vs without that memory).

- finetune_score(count): scores recent messages with all memories
  stripped to identify fine-tuning candidates. Responses with high
  divergence depend on memories the model hasn't internalized yet.

The existing score_memories() with the full NxM matrix is preserved
for the debug screen.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
ProofOfConcept 2026-04-04 01:33:31 -04:00 committed by Kent Overstreet
parent a32dff06f9
commit ce04568454

View file

@ -1,38 +1,166 @@
// training.rs — Memory importance scoring via /v1/score // training.rs — Memory importance scoring via /v1/score
// //
// Drops each memory from the context one at a time, calls the vLLM // Three scoring modes, all built on the same call_score() primitive:
// /v1/score endpoint to get logprobs for assistant responses.
// Produces a divergence matrix: memories × responses.
// //
// Row sums = memory importance (for graph weight updates) // score_memories() — Full N×M matrix (memories × responses) for the
// Column sums = response memory-dependence (training candidates) // debug screen. Expensive: N+1 API calls.
//
use std::time::Instant; // memory_score() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
use super::api::ApiClient; use super::api::ApiClient;
use crate::agent::api::types::*; use crate::agent::api::types::*;
use crate::agent::context::{ConversationEntry, ContextState}; use crate::agent::context::{ConversationEntry, ContextState};
use crate::user::ui_channel::{UiMessage, UiSender}; use crate::user::ui_channel::{UiMessage, UiSender};
/// Timeout for individual /v1/score API calls.
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
// ── Message building ────────────────────────────────────────────
/// What to filter when building the message array for scoring.
enum Filter<'a> {
None,
SkipIndex(usize),
SkipKey(&'a str),
SkipAllMemories,
}
/// Build the messages array for a scoring call.
///
/// Always includes system prompt + context message as prefix, then
/// entries from `range` filtered by `filter`.
fn build_messages(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<serde_json::Value> {
let mut msgs = vec![
serde_json::json!({"role": "system", "content": &context.system_prompt}),
];
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for i in range {
let entry = &context.entries[i];
let skip = match &filter {
Filter::None => false,
Filter::SkipIndex(idx) => i == *idx,
Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == key),
Filter::SkipAllMemories => entry.is_memory(),
};
if skip { continue; }
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(SCORE_TIMEOUT)
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default()
}
async fn call_score(
http: &reqwest::Client,
client: &ApiClient,
messages: &[serde_json::Value],
) -> anyhow::Result<Vec<ScoreResult>> {
let response = http
.post(format!("{}/score", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&serde_json::json!({
"model": client.model,
"messages": messages,
"logprobs": 1,
}))
.send()
.await
.map_err(|e| if e.is_timeout() {
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
} else {
anyhow::anyhow!("score request failed: {}", e)
})?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence(
http: &reqwest::Client,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter<'_>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let without = call_score(http, client, &build_messages(context, range, filter)).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Result of scoring one conversation's memory usage. /// Result of scoring one conversation's memory usage.
pub struct MemoryScore { pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>, pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>, pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx] /// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>, pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>, pub memory_keys: Vec<String>,
/// Conversation entry indices of the assistant responses
pub response_entry_indices: Vec<usize>, pub response_entry_indices: Vec<usize>,
} }
impl MemoryScore { impl MemoryScore {
/// Get the most important memories for a given conversation entry index.
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> { pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx) let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() }; else { return Vec::new() };
@ -49,117 +177,57 @@ impl MemoryScore {
} }
} }
/// Score how important each memory is to the conversation. /// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories( pub async fn score_memories(
context: &ContextState, context: &ContextState,
client: &ApiClient, client: &ApiClient,
ui_tx: &UiSender, ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> { ) -> anyhow::Result<MemoryScore> {
let _ = ui_tx.send(UiMessage::Debug(format!( let mut memory_keys: Vec<String> = context.entries.iter()
"[training] in score_memories" .filter_map(|e| match e {
))); ConversationEntry::Memory { key, .. } => Some(key.clone()),
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None, _ => None,
}) })
.collect(); .collect();
memory_keys.dedup();
let response_indices: Vec<usize> = context.entries.iter().enumerate() let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant) .filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i) .map(|(i, _)| i)
.collect(); .collect();
if memories.is_empty() || response_indices.is_empty() { if memory_keys.is_empty() || response_indices.is_empty() {
let _ = ui_tx.send(UiMessage::Debug(
"[training] nothing to score (no memories or no responses)".into()
));
return Ok(MemoryScore { return Ok(MemoryScore {
memory_weights: Vec::new(), memory_weights: Vec::new(), response_scores: Vec::new(),
response_scores: Vec::new(), matrix: Vec::new(), memory_keys: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
response_entry_indices: Vec::new(), response_entry_indices: Vec::new(),
}); });
} }
let _ = ui_tx.send(UiMessage::Info(format!( let _ = ui_tx.send(UiMessage::Info(format!(
"[scoring {} memories × {} responses]", "[scoring {} memories × {} responses]", memory_keys.len(), response_indices.len(),
memories.len(), response_indices.len(),
))); )));
let http = reqwest::Client::builder() let http = http_client();
.timeout(SCORE_TIMEOUT) let range = 0..context.entries.len();
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default();
let all_messages = build_messages(context);
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} messages in context",
all_messages.len(),
)));
// Baseline: score with all memories present
let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into()));
let payload_size = serde_json::to_string(&all_messages)
.map(|s| s.len()).unwrap_or(0);
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] payload size: {}KB",
payload_size / 1024,
)));
let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into())); let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into()));
let start = Instant::now(); let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let baseline = call_score(&http, client, &all_messages).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} responses scored in {:.1}s",
baseline.len(), start.elapsed().as_secs_f64(),
)));
// For each memory, drop it and measure divergence let total = memory_keys.len();
let mut matrix: Vec<Vec<f64>> = Vec::new(); let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
let total = memories.len();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { for (mem_idx, key) in memory_keys.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Activity(format!( let _ = ui_tx.send(UiMessage::Activity(format!(
"scoring {}/{}: {}...", mem_idx + 1, total, key, "scoring {}/{}: {}...", mem_idx + 1, total, key,
))); )));
let msgs = build_messages(context, range.clone(), Filter::SkipKey(key));
let start = Instant::now(); match call_score(&http, client, &msgs).await {
let filtered_messages = build_messages_without(context, *entry_idx); Ok(without) => matrix.push(divergence(&baseline, &without)),
let without = call_score(&http, client, &filtered_messages).await;
match without {
Ok(without) => {
let elapsed = start.elapsed().as_secs_f64();
// Match scores by position (nth scored response),
// not message_index — indices shift when a memory
// is removed from the conversation.
let mut row = Vec::new();
for (i, base_score) in baseline.iter().enumerate() {
let base_lp = base_score.total_logprob;
let without_lp = without.get(i)
.map(|s| s.total_logprob)
.unwrap_or(base_lp);
let divergence = (base_lp - without_lp).max(0.0);
row.push(divergence);
}
let importance: f64 = row.iter().sum();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {}/{} {} → {:.1} ({:.1}s)",
mem_idx + 1, total, key, importance, elapsed,
)));
matrix.push(row);
}
Err(e) => { Err(e) => {
let _ = ui_tx.send(UiMessage::Debug(format!( let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {}/{} {} FAILED: {:#}", "[training] {} FAILED: {:#}", key, e,
mem_idx + 1, total, key, e,
))); )));
// Push zero row so matrix stays aligned
matrix.push(vec![0.0; baseline.len()]); matrix.push(vec![0.0; baseline.len()]);
} }
} }
@ -167,129 +235,92 @@ pub async fn score_memories(
let _ = ui_tx.send(UiMessage::Activity(String::new())); let _ = ui_tx.send(UiMessage::Activity(String::new()));
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter() let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter()) .zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum())) .map(|(key, row)| (key.clone(), row.iter().sum()))
.collect(); .collect();
let n_responses = response_indices.len(); let mut response_scores = vec![0.0; response_indices.len()];
let mut response_scores = vec![0.0; n_responses];
for row in &matrix { for row in &matrix {
for (j, &v) in row.iter().enumerate() { for (j, &v) in row.iter().enumerate() {
if j < n_responses { if j < response_scores.len() { response_scores[j] += v; }
response_scores[j] += v;
}
} }
} }
let _ = ui_tx.send(UiMessage::Info(format!(
"[scoring complete: {} memories scored]",
memory_keys.len(),
)));
Ok(MemoryScore { Ok(MemoryScore {
memory_weights, memory_weights, response_scores, matrix, memory_keys,
response_scores,
matrix,
memory_keys,
response_entry_indices: response_indices, response_entry_indices: response_indices,
}) })
} }
/// Score response from the /v1/score endpoint. // ── Single memory scoring ───────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreMessageResult {
#[allow(dead_code)]
message_index: usize,
total_logprob: f64,
}
#[derive(serde::Deserialize)] /// Score how important a single memory is to the conversation.
struct ScoreApiResponse { ///
scores: Vec<ScoreMessageResult>, /// Scores the 50 messages after the memory was surfaced — the window
} /// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
/// Build the messages array for the /v1/score endpoint from ContextState. pub async fn score_memory(
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> { context: &ContextState,
let mut msgs = Vec::new(); key: &str,
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for entry in &context.entries {
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Build messages with one entry removed.
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for (i, entry) in context.entries.iter().enumerate() {
if i == skip_idx { continue; }
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Call the /v1/score endpoint and return per-message logprobs.
async fn call_score(
http: &reqwest::Client,
client: &ApiClient, client: &ApiClient,
messages: &[serde_json::Value], ui_tx: &UiSender,
) -> anyhow::Result<Vec<ScoreMessageResult>> { ) -> anyhow::Result<f64> {
let request = serde_json::json!({ const WINDOW: usize = 50;
"model": client.model,
"messages": messages,
"logprobs": 1,
});
let response = http let first_pos = match context.entries.iter().position(|e| {
.post(format!("{}/score", client.base_url())) matches!(e, ConversationEntry::Memory { key: k, .. } if k == key)
.header("Content-Type", "application/json") }) {
.header("Authorization", format!("Bearer {}", client.api_key())) Some(p) => p,
.json(&request) None => return Ok(0.0),
.send() };
.await
.map_err(|e| {
if e.is_timeout() {
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
} else {
anyhow::anyhow!("score request failed: {}", e)
}
})?;
let status = response.status(); let range = first_pos..(first_pos + WINDOW).min(context.entries.len());
let body: serde_json::Value = response.json().await?; if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) {
return Ok(0.0);
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.as_str())
.unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
} }
// Check for error in body (score endpoint returns dict on error) let http = http_client();
if let Some(err) = body.get("error").and_then(|e| e.as_str()) { let _ = ui_tx.send(UiMessage::Activity(format!("scoring memory: {}...", key)));
anyhow::bail!("score API error: {}", err); let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?;
} let _ = ui_tx.send(UiMessage::Activity(String::new()));
let result: ScoreApiResponse = serde_json::from_value(body) Ok(divs.iter().sum())
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; }
Ok(result.scores)
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<Vec<(usize, f64)>> {
let range = context.entries.len().saturating_sub(count)..context.entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| context.entries[i].message().role == Role::Assistant)
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let http = http_client();
let _ = ui_tx.send(UiMessage::Activity("scoring for fine-tuning...".into()));
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?;
let _ = ui_tx.send(UiMessage::Activity(String::new()));
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
} }