consciousness/src/subconscious/learn.rs
Kent Overstreet b3d0a3ab25 store: internal locking, remove Arc<Mutex<Store>> wrapper
Store now has internal Mutex for capnp appends and AtomicU64 for
size tracking. All methods take &self. The external Arc<Mutex<Store>>
is replaced with Arc<Store>.

- Store::append_lock protects file appends
- local.rs functions take &Store (not &mut Store)
- access_local() returns Arc<Store>
- All .lock().await calls removed from callers

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-13 21:49:54 -04:00

441 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via /v1/score
//
// Three scoring modes, all built on the same call_score() primitive:
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 API calls.
//
// memory_score() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
use crate::agent::api::ApiClient;
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
// ── Message building ────────────────────────────────────────────
/// What to filter when building the message array for scoring.
#[allow(dead_code)]
enum Filter<'a> {
None,
SkipIndex(usize),
SkipKey(&'a str),
SkipAllMemories,
}
fn is_memory(node: &AstNode) -> bool {
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
}
fn memory_key(node: &AstNode) -> Option<&str> {
match node {
AstNode::Leaf(leaf) => match leaf.body() {
NodeBody::Memory { key, .. } => Some(key),
_ => None,
},
_ => None,
}
}
fn is_assistant(node: &AstNode) -> bool {
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
}
/// Build a token ID array for a scoring call.
///
/// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries.
fn build_token_ids(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<u32> {
use crate::agent::context::Ast;
let mut ids = Vec::new();
for node in context.system() {
ids.extend(node.token_ids());
}
for node in context.identity() {
ids.extend(node.token_ids());
}
for node in context.journal() {
ids.extend(node.token_ids());
}
let entries = context.conversation();
for i in range {
let node = &entries[i];
let skip = match &filter {
Filter::None => false,
Filter::SkipIndex(idx) => i == *idx,
Filter::SkipKey(key) => memory_key(node) == Some(*key),
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
ids.extend(node.token_ids());
}
ids
}
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> crate::agent::api::http::HttpClient {
crate::agent::api::http::HttpClient::builder()
.timeout(SCORE_TIMEOUT)
.build()
}
async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url());
let auth = format!("Bearer {}", client.api_key());
let mut body = serde_json::json!({
"model": client.model,
"prompt": prompt,
"logprobs": 1,
});
if let Some(p) = priority {
body["priority"] = serde_json::json!(p);
}
let response = http
.send_json("POST", &url, &[
("authorization", &auth),
], &body)
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter<'_>,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories(
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
) -> anyhow::Result<()> {
// Collect memory keys and response indices under a brief lock
let (memory_keys, response_indices) = {
let ctx = agent.context.lock().await;
let mut keys: Vec<String> = ctx.conversation().iter()
.filter_map(|node| memory_key(node).map(String::from))
.collect();
keys.dedup();
let resp: Vec<usize> = ctx.conversation().iter().enumerate()
.filter(|(_, node)| is_assistant(node))
.map(|(i, _)| i)
.collect();
(keys, resp)
};
if memory_keys.is_empty() || response_indices.is_empty() {
return Ok(());
}
let total = memory_keys.len();
dbglog!("[scoring-full] starting: {} memories × {} responses",
total, response_indices.len());
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let baseline_tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
};
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
};
let row = match call_score(&http, client, &tokens, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!("[scoring-full] {}/{}: {} max_div={:.3}",
mem_idx + 1, total, key, max_div);
divs
}
Err(e) => {
dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}",
mem_idx + 1, total, key, e);
vec![0.0; baseline.len()]
}
};
// Write this memory's scores to the live AST nodes
{
let mut ctx = agent.context.lock().await;
let mut set_count = 0;
for (resp_idx, &idx) in response_indices.iter().enumerate() {
if idx >= ctx.conversation().len() { continue; }
let node = &mut ctx.conversation_mut()[idx];
if let AstNode::Branch {
role: Role::Assistant, memory_scores, ..
} = node {
if let Some(&score) = row.get(resp_idx) {
if score > 0.01 {
memory_scores.insert(key.clone(), score);
set_count += 1;
} else {
memory_scores.remove(key.as_str());
}
}
}
}
dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count);
}
agent.state.lock().await.changed.notify_one();
}
Ok(())
}
/// Find the entry index after `start` that contains the Nth assistant response.
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) {
let mut count = 0;
for i in start..entries.len() {
if is_assistant(&entries[i]) {
count += 1;
if count >= n { return (i + 1, true); }
}
}
(entries.len(), false)
}
// ── Single memory scoring ───────────────────────────────────────
/// Score how important a single memory is to the conversation.
///
/// Scores the 50 messages after the memory was surfaced — the window
/// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
pub async fn score_memory(
context: &ContextState,
key: &str,
client: &ApiClient,
) -> anyhow::Result<f64> {
const RESPONSE_WINDOW: usize = 50;
let entries = context.conversation();
let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) {
Some(p) => p,
None => return Ok(0.0),
};
let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW);
let range = first_pos..end;
if !entries[range.clone()].iter().any(|node| is_assistant(node)) {
return Ok(0.0);
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?;
Ok(divs.iter().sum())
}
// ── Background memory scoring ───────────────────────────────────
/// Score memories in the conversation that are due for re-scoring.
///
/// Checks the graph for each memory's last_scored timestamp. Scores
/// nodes that haven't been scored within `max_age_secs`, oldest first.
/// Updates the graph weight (EWMA) and last_scored after each.
///
/// Returns the number of nodes scored and their (key, score) pairs.
pub async fn score_memories_incremental<F, Fut>(
context: &ContextState,
max_age_secs: i64,
response_window: usize,
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
mut on_score: F,
) -> anyhow::Result<usize>
where
F: FnMut(String, f64) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let now = chrono::Utc::now().timestamp();
// Collect unique memory keys with their first position
let mut seen = std::collections::HashSet::new();
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
let store_arc = crate::hippocampus::access_local()?;
{
let store = &*store_arc;
for (i, node) in context.conversation().iter().enumerate() {
if let Some(key) = memory_key(node) {
if !seen.insert(key.to_owned()) { continue; }
let last_scored = store.get_node(key)
.ok()
.flatten()
.map(|n| n.last_scored)
.unwrap_or(0);
if now - last_scored >= max_age_secs {
candidates.push((i, key.to_owned(), last_scored));
}
}
}
}
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
let http = http_client();
let mut scored = 0;
let entries = context.conversation();
let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum();
let token_cutoff = total_tokens * 60 / 100;
// Precompute cumulative token position for each entry
let mut cumulative: Vec<usize> = Vec::with_capacity(entries.len());
let mut running = 0;
for e in entries {
running += e.tokens();
cumulative.push(running);
}
let total = candidates.len();
dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total);
let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await;
for (pos, key, _) in &candidates {
// Only score memories in the first 60% of the conversation by tokens —
// recent memories don't have enough responses to evaluate yet.
let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens);
if cum > token_cutoff {
dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff);
continue;
}
let (end, _) = nth_response_end(context.conversation(), *pos, response_window);
let range = *pos..end;
if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) {
dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end);
continue;
}
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!(
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
);
on_score(key.clone(), max_div).await;
scored += 1;
}
Err(e) => {
dbglog!(
"[scoring] {} FAILED: {:#}", key, e,
);
}
}
}
Ok(scored)
}
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
) -> anyhow::Result<Vec<(usize, f64)>> {
let entries = context.conversation();
let range = entries.len().saturating_sub(count)..entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| is_assistant(&entries[i]))
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}