consciousness/src/subconscious/learn.rs

447 lines
15 KiB
Rust
Raw Normal View History

// training.rs — Memory importance scoring via /v1/score
//
// Three scoring modes, all built on the same call_score() primitive:
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 API calls.
//
// memory_score() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
2026-04-05 01:48:11 -04:00
use crate::agent::api::ApiClient;
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
// ── Message building ────────────────────────────────────────────
/// What to filter when building the message array for scoring.
#[allow(dead_code)]
enum Filter<'a> {
None,
SkipIndex(usize),
SkipKey(&'a str),
SkipAllMemories,
}
fn is_memory(node: &AstNode) -> bool {
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
}
fn memory_key(node: &AstNode) -> Option<&str> {
match node {
AstNode::Leaf(leaf) => match leaf.body() {
NodeBody::Memory { key, .. } => Some(key),
_ => None,
},
_ => None,
}
}
fn is_assistant(node: &AstNode) -> bool {
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
}
/// Push an AstNode as one or more JSON messages for the scoring API.
fn push_api_message(node: &AstNode, msgs: &mut Vec<serde_json::Value>) {
match node {
AstNode::Branch { role, children } => {
let content: String = children.iter().map(|c| c.render()).collect();
msgs.push(serde_json::json!({
"role": role.as_str(),
"content": content,
}));
}
AstNode::Leaf(leaf) => {
let role = match leaf.body() {
NodeBody::ToolResult(_) => "tool",
_ => "user",
};
msgs.push(serde_json::json!({
"role": role,
"content": leaf.body().text(),
}));
}
}
}
/// Build the messages array for a scoring call.
///
/// Always includes system prompt as prefix, then entries from `range`
/// filtered by `filter`.
fn build_messages(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
for node in context.system() {
push_api_message(node, &mut msgs);
}
let entries = context.conversation();
for i in range {
let node = &entries[i];
let skip = match &filter {
Filter::None => false,
Filter::SkipIndex(idx) => i == *idx,
Filter::SkipKey(key) => memory_key(node) == Some(*key),
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
push_api_message(node, &mut msgs);
}
msgs
}
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> crate::agent::api::http::HttpClient {
crate::agent::api::http::HttpClient::builder()
.timeout(SCORE_TIMEOUT)
.build()
}
async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
messages: &[serde_json::Value],
) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url());
let auth = format!("Bearer {}", client.api_key());
let body = serde_json::json!({
"model": client.model,
"messages": messages,
"logprobs": 1,
});
let response = http
.send_json("POST", &url, &[
("authorization", &auth),
], &body)
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter<'_>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let without = call_score(http, client, &build_messages(context, range, filter)).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
pub memory_weights: Vec<(String, f64)>,
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
pub memory_keys: Vec<String>,
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
) -> anyhow::Result<MemoryScore> {
let mut memory_keys: Vec<String> = context.conversation().iter()
.filter_map(|node| memory_key(node).map(String::from))
.collect();
memory_keys.dedup();
let response_indices: Vec<usize> = context.conversation().iter().enumerate()
.filter(|(_, node)| is_assistant(node))
.map(|(i, _)| i)
.collect();
if memory_keys.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(), response_scores: Vec::new(),
matrix: Vec::new(), memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let http = http_client();
let range = 0..context.conversation().len();
let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let total = memory_keys.len();
let mut matrix: Vec<Vec<f64>> = Vec::new();
for (mem_idx, key) in memory_keys.iter().enumerate() {
dbglog!(
"scoring {}/{}: {}...", mem_idx + 1, total, key,
);
let msgs = build_messages(context, range.clone(), Filter::SkipKey(key));
match call_score(&http, client, &msgs).await {
Ok(without) => matrix.push(divergence(&baseline, &without)),
Err(e) => {
dbglog!(
"[training] {} FAILED: {:#}", key, e,
);
matrix.push(vec![0.0; baseline.len()]);
}
}
}
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let mut response_scores = vec![0.0; response_indices.len()];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < response_scores.len() { response_scores[j] += v; }
}
}
Ok(MemoryScore {
memory_weights, response_scores, matrix, memory_keys,
response_entry_indices: response_indices,
})
}
/// Find the entry index after `start` that contains the Nth assistant response.
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) {
let mut count = 0;
for i in start..entries.len() {
if is_assistant(&entries[i]) {
count += 1;
if count >= n { return (i + 1, true); }
}
}
(entries.len(), false)
}
// ── Single memory scoring ───────────────────────────────────────
/// Score how important a single memory is to the conversation.
///
/// Scores the 50 messages after the memory was surfaced — the window
/// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
pub async fn score_memory(
context: &ContextState,
key: &str,
client: &ApiClient,
) -> anyhow::Result<f64> {
const RESPONSE_WINDOW: usize = 50;
let entries = context.conversation();
let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) {
Some(p) => p,
None => return Ok(0.0),
};
let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW);
let range = first_pos..end;
if !entries[range.clone()].iter().any(|node| is_assistant(node)) {
return Ok(0.0);
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?;
Ok(divs.iter().sum())
}
2026-04-04 02:46:32 -04:00
// ── Background memory scoring ───────────────────────────────────
/// Score memories in the conversation that are due for re-scoring.
2026-04-04 02:46:32 -04:00
///
/// Checks the graph for each memory's last_scored timestamp. Scores
/// nodes that haven't been scored within `max_age_secs`, oldest first.
/// Updates the graph weight (EWMA) and last_scored after each.
2026-04-04 02:46:32 -04:00
///
/// Returns the number of nodes scored and their (key, score) pairs.
pub async fn score_memories_incremental<F, Fut>(
2026-04-04 02:46:32 -04:00
context: &ContextState,
max_age_secs: i64,
response_window: usize,
2026-04-04 02:46:32 -04:00
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
mut on_score: F,
) -> anyhow::Result<usize>
where
F: FnMut(String, f64) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let now = chrono::Utc::now().timestamp();
2026-04-04 02:46:32 -04:00
// Collect unique memory keys with their first position
2026-04-04 02:46:32 -04:00
let mut seen = std::collections::HashSet::new();
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
2026-04-04 02:46:32 -04:00
let store = crate::hippocampus::store::Store::load().unwrap_or_default();
for (i, node) in context.conversation().iter().enumerate() {
if let Some(key) = memory_key(node) {
if !seen.insert(key.to_owned()) { continue; }
let last_scored = store.nodes.get(key)
.map(|n| n.last_scored)
.unwrap_or(0);
if now - last_scored >= max_age_secs {
candidates.push((i, key.to_owned(), last_scored));
2026-04-04 02:46:32 -04:00
}
}
}
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
2026-04-04 02:46:32 -04:00
let http = http_client();
let mut scored = 0;
2026-04-04 02:46:32 -04:00
let entries = context.conversation();
let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum();
let token_cutoff = total_tokens * 60 / 100;
// Precompute cumulative token position for each entry
let mut cumulative: Vec<usize> = Vec::with_capacity(entries.len());
let mut running = 0;
for e in entries {
running += e.tokens();
cumulative.push(running);
}
2026-04-04 02:46:32 -04:00
for (pos, key, _) in &candidates {
// Only score memories in the first 70% of the conversation by tokens —
// recent memories don't have enough responses to evaluate yet.
if cumulative.get(*pos).copied().unwrap_or(total_tokens) > token_cutoff {
continue;
2026-04-04 02:46:32 -04:00
}
let (end, _) = nth_response_end(context.conversation(), *pos, response_window);
2026-04-04 02:46:32 -04:00
let range = *pos..end;
if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) {
2026-04-04 02:46:32 -04:00
continue;
}
let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await;
2026-04-04 02:46:32 -04:00
match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!(
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
);
on_score(key.clone(), max_div).await;
scored += 1;
2026-04-04 02:46:32 -04:00
}
Err(e) => {
dbglog!(
2026-04-04 02:46:32 -04:00
"[scoring] {} FAILED: {:#}", key, e,
);
2026-04-04 02:46:32 -04:00
}
}
}
Ok(scored)
2026-04-04 02:46:32 -04:00
}
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
) -> anyhow::Result<Vec<(usize, f64)>> {
let entries = context.conversation();
let range = entries.len().saturating_sub(count)..entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| is_assistant(&entries[i]))
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}