2026-04-03 00:31:57 -04:00
|
|
|
|
// training.rs — Memory importance scoring via /v1/score
|
2026-04-02 22:13:55 -04:00
|
|
|
|
//
|
2026-04-04 01:33:31 -04:00
|
|
|
|
// Three scoring modes, all built on the same call_score() primitive:
|
2026-04-02 22:13:55 -04:00
|
|
|
|
//
|
2026-04-04 01:33:31 -04:00
|
|
|
|
// score_memories() — Full N×M matrix (memories × responses) for the
|
|
|
|
|
|
// debug screen. Expensive: N+1 API calls.
|
|
|
|
|
|
//
|
|
|
|
|
|
// memory_score() — Single memory importance. Scores the 50 messages
|
|
|
|
|
|
// after it was surfaced, with/without that memory.
|
|
|
|
|
|
// 2 API calls.
|
|
|
|
|
|
//
|
|
|
|
|
|
// finetune_score() — Identifies training candidates. Scores recent
|
|
|
|
|
|
// messages with all memories stripped. Responses
|
|
|
|
|
|
// with high divergence depend on memories the model
|
|
|
|
|
|
// hasn't internalized. 2 API calls.
|
2026-04-04 00:29:11 -04:00
|
|
|
|
|
2026-04-05 01:48:11 -04:00
|
|
|
|
use crate::agent::api::ApiClient;
|
2026-04-04 00:29:11 -04:00
|
|
|
|
use crate::agent::api::types::*;
|
|
|
|
|
|
use crate::agent::context::{ConversationEntry, ContextState};
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-03 01:07:47 -04:00
|
|
|
|
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
// ── Message building ────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
/// What to filter when building the message array for scoring.
|
|
|
|
|
|
enum Filter<'a> {
|
|
|
|
|
|
None,
|
|
|
|
|
|
SkipIndex(usize),
|
|
|
|
|
|
SkipKey(&'a str),
|
|
|
|
|
|
SkipAllMemories,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Build the messages array for a scoring call.
|
|
|
|
|
|
///
|
|
|
|
|
|
/// Always includes system prompt + context message as prefix, then
|
|
|
|
|
|
/// entries from `range` filtered by `filter`.
|
|
|
|
|
|
fn build_messages(
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
range: std::ops::Range<usize>,
|
|
|
|
|
|
filter: Filter,
|
|
|
|
|
|
) -> Vec<serde_json::Value> {
|
|
|
|
|
|
let mut msgs = vec![
|
|
|
|
|
|
serde_json::json!({"role": "system", "content": &context.system_prompt}),
|
|
|
|
|
|
];
|
|
|
|
|
|
let ctx = context.render_context_message();
|
|
|
|
|
|
if !ctx.is_empty() {
|
|
|
|
|
|
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
|
|
|
|
|
}
|
|
|
|
|
|
for i in range {
|
|
|
|
|
|
let entry = &context.entries[i];
|
|
|
|
|
|
let skip = match &filter {
|
|
|
|
|
|
Filter::None => false,
|
|
|
|
|
|
Filter::SkipIndex(idx) => i == *idx,
|
|
|
|
|
|
Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == key),
|
|
|
|
|
|
Filter::SkipAllMemories => entry.is_memory(),
|
|
|
|
|
|
};
|
|
|
|
|
|
if skip { continue; }
|
|
|
|
|
|
let m = entry.api_message();
|
|
|
|
|
|
msgs.push(serde_json::json!({
|
|
|
|
|
|
"role": m.role_str(),
|
|
|
|
|
|
"content": m.content_text(),
|
|
|
|
|
|
}));
|
|
|
|
|
|
}
|
|
|
|
|
|
msgs
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ── Score API ───────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde::Deserialize)]
|
|
|
|
|
|
struct ScoreResult {
|
|
|
|
|
|
total_logprob: f64,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde::Deserialize)]
|
|
|
|
|
|
struct ScoreResponse {
|
|
|
|
|
|
scores: Vec<ScoreResult>,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn http_client() -> reqwest::Client {
|
|
|
|
|
|
reqwest::Client::builder()
|
|
|
|
|
|
.timeout(SCORE_TIMEOUT)
|
|
|
|
|
|
.pool_max_idle_per_host(2)
|
|
|
|
|
|
.build()
|
|
|
|
|
|
.unwrap_or_default()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn call_score(
|
|
|
|
|
|
http: &reqwest::Client,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
messages: &[serde_json::Value],
|
|
|
|
|
|
) -> anyhow::Result<Vec<ScoreResult>> {
|
|
|
|
|
|
let response = http
|
|
|
|
|
|
.post(format!("{}/score", client.base_url()))
|
|
|
|
|
|
.header("Content-Type", "application/json")
|
|
|
|
|
|
.header("Authorization", format!("Bearer {}", client.api_key()))
|
|
|
|
|
|
.json(&serde_json::json!({
|
|
|
|
|
|
"model": client.model,
|
|
|
|
|
|
"messages": messages,
|
|
|
|
|
|
"logprobs": 1,
|
|
|
|
|
|
}))
|
|
|
|
|
|
.send()
|
|
|
|
|
|
.await
|
|
|
|
|
|
.map_err(|e| if e.is_timeout() {
|
|
|
|
|
|
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
|
|
|
|
|
|
} else {
|
|
|
|
|
|
anyhow::anyhow!("score request failed: {}", e)
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
let status = response.status();
|
|
|
|
|
|
let body: serde_json::Value = response.json().await?;
|
|
|
|
|
|
|
|
|
|
|
|
if !status.is_success() {
|
|
|
|
|
|
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
|
|
|
|
|
|
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
|
|
|
|
|
}
|
|
|
|
|
|
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
|
|
|
|
|
anyhow::bail!("score API error: {}", err);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let result: ScoreResponse = serde_json::from_value(body)
|
|
|
|
|
|
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
|
|
|
|
|
Ok(result.scores)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Compute per-position logprob divergence: how much worse the model
|
|
|
|
|
|
/// scores each response without something vs with it.
|
|
|
|
|
|
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
|
|
|
|
|
|
baseline.iter().enumerate()
|
|
|
|
|
|
.map(|(i, base)| {
|
|
|
|
|
|
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
|
|
|
|
|
|
(base.total_logprob - without_lp).max(0.0)
|
|
|
|
|
|
})
|
|
|
|
|
|
.collect()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Score two message sets and return total divergence.
|
|
|
|
|
|
async fn score_divergence(
|
|
|
|
|
|
http: &reqwest::Client,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
range: std::ops::Range<usize>,
|
|
|
|
|
|
filter: Filter<'_>,
|
|
|
|
|
|
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
|
|
|
|
|
|
let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?;
|
|
|
|
|
|
let without = call_score(http, client, &build_messages(context, range, filter)).await?;
|
|
|
|
|
|
let divs = divergence(&baseline, &without);
|
|
|
|
|
|
Ok((divs, baseline))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ── Full matrix scoring (debug screen) ──────────────────────────
|
|
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
/// Result of scoring one conversation's memory usage.
|
|
|
|
|
|
pub struct MemoryScore {
|
|
|
|
|
|
pub memory_weights: Vec<(String, f64)>,
|
|
|
|
|
|
pub response_scores: Vec<f64>,
|
|
|
|
|
|
/// Full matrix: divergence[memory_idx][response_idx]
|
|
|
|
|
|
pub matrix: Vec<Vec<f64>>,
|
|
|
|
|
|
pub memory_keys: Vec<String>,
|
2026-04-02 22:27:43 -04:00
|
|
|
|
pub response_entry_indices: Vec<usize>,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl MemoryScore {
|
|
|
|
|
|
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
|
|
|
|
|
|
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
|
|
|
|
|
|
else { return Vec::new() };
|
|
|
|
|
|
|
|
|
|
|
|
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
|
|
|
|
|
|
.zip(self.matrix.iter())
|
|
|
|
|
|
.filter_map(|(key, row)| {
|
|
|
|
|
|
let score = row.get(resp_idx).copied().unwrap_or(0.0);
|
|
|
|
|
|
if score > 0.01 { Some((key.as_str(), score)) } else { None }
|
|
|
|
|
|
})
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
|
|
result
|
|
|
|
|
|
}
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
/// Score how important each memory is to the conversation (full matrix).
|
2026-04-02 22:13:55 -04:00
|
|
|
|
pub async fn score_memories(
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
) -> anyhow::Result<MemoryScore> {
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let mut memory_keys: Vec<String> = context.entries.iter()
|
|
|
|
|
|
.filter_map(|e| match e {
|
|
|
|
|
|
ConversationEntry::Memory { key, .. } => Some(key.clone()),
|
2026-04-02 22:13:55 -04:00
|
|
|
|
_ => None,
|
|
|
|
|
|
})
|
|
|
|
|
|
.collect();
|
2026-04-04 01:33:31 -04:00
|
|
|
|
memory_keys.dedup();
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
|
|
|
|
|
let response_indices: Vec<usize> = context.entries.iter().enumerate()
|
|
|
|
|
|
.filter(|(_, e)| e.message().role == Role::Assistant)
|
|
|
|
|
|
.map(|(i, _)| i)
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
if memory_keys.is_empty() || response_indices.is_empty() {
|
2026-04-02 22:13:55 -04:00
|
|
|
|
return Ok(MemoryScore {
|
2026-04-04 01:33:31 -04:00
|
|
|
|
memory_weights: Vec::new(), response_scores: Vec::new(),
|
|
|
|
|
|
matrix: Vec::new(), memory_keys: Vec::new(),
|
2026-04-02 22:27:43 -04:00
|
|
|
|
response_entry_indices: Vec::new(),
|
2026-04-02 22:13:55 -04:00
|
|
|
|
});
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let http = http_client();
|
|
|
|
|
|
let range = 0..context.entries.len();
|
2026-04-03 00:31:57 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?;
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let total = memory_keys.len();
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
2026-04-05 21:45:55 -04:00
|
|
|
|
dbglog!(
|
2026-04-03 01:07:47 -04:00
|
|
|
|
"scoring {}/{}: {}...", mem_idx + 1, total, key,
|
2026-04-05 21:45:55 -04:00
|
|
|
|
);
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let msgs = build_messages(context, range.clone(), Filter::SkipKey(key));
|
|
|
|
|
|
match call_score(&http, client, &msgs).await {
|
|
|
|
|
|
Ok(without) => matrix.push(divergence(&baseline, &without)),
|
2026-04-03 01:07:47 -04:00
|
|
|
|
Err(e) => {
|
2026-04-05 21:45:55 -04:00
|
|
|
|
dbglog!(
|
2026-04-04 01:33:31 -04:00
|
|
|
|
"[training] {} FAILED: {:#}", key, e,
|
2026-04-05 21:45:55 -04:00
|
|
|
|
);
|
2026-04-03 01:07:47 -04:00
|
|
|
|
matrix.push(vec![0.0; baseline.len()]);
|
|
|
|
|
|
}
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 00:31:57 -04:00
|
|
|
|
|
2026-04-02 22:13:55 -04:00
|
|
|
|
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
|
|
|
|
|
.zip(matrix.iter())
|
|
|
|
|
|
.map(|(key, row)| (key.clone(), row.iter().sum()))
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let mut response_scores = vec![0.0; response_indices.len()];
|
2026-04-02 22:13:55 -04:00
|
|
|
|
for row in &matrix {
|
|
|
|
|
|
for (j, &v) in row.iter().enumerate() {
|
2026-04-04 01:33:31 -04:00
|
|
|
|
if j < response_scores.len() { response_scores[j] += v; }
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Ok(MemoryScore {
|
2026-04-04 01:33:31 -04:00
|
|
|
|
memory_weights, response_scores, matrix, memory_keys,
|
2026-04-02 22:27:43 -04:00
|
|
|
|
response_entry_indices: response_indices,
|
2026-04-02 22:13:55 -04:00
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
/// Find the entry index after `start` that contains the Nth assistant response.
|
|
|
|
|
|
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
|
|
|
|
|
|
fn nth_response_end(entries: &[ConversationEntry], start: usize, n: usize) -> (usize, bool) {
|
|
|
|
|
|
let mut count = 0;
|
|
|
|
|
|
for i in start..entries.len() {
|
|
|
|
|
|
if entries[i].message().role == Role::Assistant {
|
|
|
|
|
|
count += 1;
|
|
|
|
|
|
if count >= n { return (i + 1, true); }
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
(entries.len(), false)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
// ── Single memory scoring ───────────────────────────────────────
|
2026-04-02 22:35:29 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
/// Score how important a single memory is to the conversation.
|
|
|
|
|
|
///
|
|
|
|
|
|
/// Scores the 50 messages after the memory was surfaced — the window
|
|
|
|
|
|
/// where it could have influenced responses. Returns the sum of
|
|
|
|
|
|
/// divergence, or 0.0 if the memory isn't in the conversation.
|
|
|
|
|
|
pub async fn score_memory(
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
key: &str,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
) -> anyhow::Result<f64> {
|
2026-04-04 05:01:49 -04:00
|
|
|
|
const RESPONSE_WINDOW: usize = 50;
|
2026-04-04 01:33:31 -04:00
|
|
|
|
|
|
|
|
|
|
let first_pos = match context.entries.iter().position(|e| {
|
|
|
|
|
|
matches!(e, ConversationEntry::Memory { key: k, .. } if k == key)
|
|
|
|
|
|
}) {
|
|
|
|
|
|
Some(p) => p,
|
|
|
|
|
|
None => return Ok(0.0),
|
|
|
|
|
|
};
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
let (end, _) = nth_response_end(&context.entries, first_pos, RESPONSE_WINDOW);
|
|
|
|
|
|
let range = first_pos..end;
|
2026-04-04 01:33:31 -04:00
|
|
|
|
if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) {
|
|
|
|
|
|
return Ok(0.0);
|
2026-04-02 22:35:29 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let http = http_client();
|
|
|
|
|
|
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?;
|
2026-04-02 22:35:29 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
Ok(divs.iter().sum())
|
|
|
|
|
|
}
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-04 02:46:32 -04:00
|
|
|
|
// ── Background memory scoring ───────────────────────────────────
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
/// Score memories in the conversation that are due for re-scoring.
|
2026-04-04 02:46:32 -04:00
|
|
|
|
///
|
2026-04-04 05:01:49 -04:00
|
|
|
|
/// Checks the graph for each memory's last_scored timestamp. Scores
|
|
|
|
|
|
/// nodes that haven't been scored within `max_age_secs`, oldest first.
|
|
|
|
|
|
/// Updates the graph weight (EWMA) and last_scored after each.
|
2026-04-04 02:46:32 -04:00
|
|
|
|
///
|
2026-04-04 05:01:49 -04:00
|
|
|
|
/// Returns the number of nodes scored and their (key, score) pairs.
|
2026-04-04 02:46:32 -04:00
|
|
|
|
pub async fn score_memories_incremental(
|
|
|
|
|
|
context: &ContextState,
|
2026-04-04 05:01:49 -04:00
|
|
|
|
max_age_secs: i64,
|
|
|
|
|
|
response_window: usize,
|
2026-04-04 02:46:32 -04:00
|
|
|
|
client: &ApiClient,
|
Kill StatusUpdate, Activity, DmnAnnotation, ContextInfoUpdate, AgentUpdate
Status bar reads directly from Agent and MindState on each render tick.
Activity is now a field on Agent — set by agent code directly, read by
UI via try_lock. DmnAnnotation, ContextInfoUpdate, AgentUpdate were
already dead (no senders).
UiMessage down to 4 variants: TextDelta, Reasoning, Debug, Info.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-05 21:34:27 -04:00
|
|
|
|
agent: &std::sync::Arc<tokio::sync::Mutex<crate::agent::Agent>>,
|
2026-04-04 05:01:49 -04:00
|
|
|
|
) -> anyhow::Result<Vec<(String, f64)>> {
|
|
|
|
|
|
let now = chrono::Utc::now().timestamp();
|
2026-04-04 02:46:32 -04:00
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
// Collect unique memory keys with their first position
|
2026-04-04 02:46:32 -04:00
|
|
|
|
let mut seen = std::collections::HashSet::new();
|
2026-04-04 05:01:49 -04:00
|
|
|
|
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
|
2026-04-04 02:46:32 -04:00
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
let store = crate::hippocampus::store::Store::load().unwrap_or_default();
|
|
|
|
|
|
|
|
|
|
|
|
for (i, entry) in context.entries.iter().enumerate() {
|
2026-04-04 02:46:32 -04:00
|
|
|
|
if let ConversationEntry::Memory { key, .. } = entry {
|
2026-04-04 05:01:49 -04:00
|
|
|
|
if !seen.insert(key.clone()) { continue; }
|
|
|
|
|
|
let last_scored = store.nodes.get(key.as_str())
|
|
|
|
|
|
.map(|n| n.last_scored)
|
|
|
|
|
|
.unwrap_or(0);
|
|
|
|
|
|
if now - last_scored >= max_age_secs {
|
|
|
|
|
|
candidates.push((i, key.clone(), last_scored));
|
2026-04-04 02:46:32 -04:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
// Score oldest-first
|
|
|
|
|
|
candidates.sort_by_key(|&(_, _, last)| last);
|
|
|
|
|
|
|
2026-04-04 02:46:32 -04:00
|
|
|
|
let http = http_client();
|
|
|
|
|
|
let mut results = Vec::new();
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
let total_entries = context.entries.len();
|
|
|
|
|
|
let first_quarter = total_entries / 4;
|
2026-04-04 02:46:32 -04:00
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
for (pos, key, _) in &candidates {
|
|
|
|
|
|
let (end, full_window) = nth_response_end(&context.entries, *pos, response_window);
|
|
|
|
|
|
// Skip memories without a full window, unless they're in the
|
|
|
|
|
|
// first quarter of the conversation (always score those).
|
|
|
|
|
|
if !full_window && *pos >= first_quarter {
|
|
|
|
|
|
continue;
|
2026-04-04 02:46:32 -04:00
|
|
|
|
}
|
|
|
|
|
|
let range = *pos..end;
|
|
|
|
|
|
if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-05 22:18:07 -04:00
|
|
|
|
let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await;
|
2026-04-04 02:46:32 -04:00
|
|
|
|
match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await {
|
|
|
|
|
|
Ok((divs, _)) => {
|
2026-04-04 05:01:49 -04:00
|
|
|
|
let n_responses = divs.len();
|
|
|
|
|
|
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
2026-04-05 21:45:55 -04:00
|
|
|
|
dbglog!(
|
2026-04-04 05:01:49 -04:00
|
|
|
|
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
|
2026-04-05 21:45:55 -04:00
|
|
|
|
);
|
2026-04-04 05:01:49 -04:00
|
|
|
|
results.push((key.clone(), max_div));
|
2026-04-04 02:46:32 -04:00
|
|
|
|
}
|
|
|
|
|
|
Err(e) => {
|
2026-04-05 21:45:55 -04:00
|
|
|
|
dbglog!(
|
2026-04-04 02:46:32 -04:00
|
|
|
|
"[scoring] {} FAILED: {:#}", key, e,
|
2026-04-05 21:45:55 -04:00
|
|
|
|
);
|
2026-04-04 02:46:32 -04:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 05:01:49 -04:00
|
|
|
|
Ok(results)
|
2026-04-04 02:46:32 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
// ── Fine-tuning scoring ─────────────────────────────────────────
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
/// Score which recent responses are candidates for fine-tuning.
|
|
|
|
|
|
///
|
|
|
|
|
|
/// Removes all memories and scores the most recent `count` messages.
|
|
|
|
|
|
/// Responses with high divergence depend on memories the model hasn't
|
|
|
|
|
|
/// internalized — these are fine-tuning candidates.
|
|
|
|
|
|
///
|
|
|
|
|
|
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
|
|
|
|
|
|
pub async fn score_finetune(
|
|
|
|
|
|
context: &ContextState,
|
|
|
|
|
|
count: usize,
|
|
|
|
|
|
client: &ApiClient,
|
|
|
|
|
|
) -> anyhow::Result<Vec<(usize, f64)>> {
|
|
|
|
|
|
let range = context.entries.len().saturating_sub(count)..context.entries.len();
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let response_positions: Vec<usize> = range.clone()
|
|
|
|
|
|
.filter(|&i| context.entries[i].message().role == Role::Assistant)
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
if response_positions.is_empty() {
|
|
|
|
|
|
return Ok(Vec::new());
|
2026-04-03 01:07:47 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let http = http_client();
|
|
|
|
|
|
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?;
|
2026-04-02 22:13:55 -04:00
|
|
|
|
|
2026-04-04 01:33:31 -04:00
|
|
|
|
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
|
|
|
|
|
.enumerate()
|
|
|
|
|
|
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
|
|
Ok(results)
|
2026-04-02 22:13:55 -04:00
|
|
|
|
}
|