consciousness/src/subconscious/learn.rs
ProofOfConcept cfddb55ed9 Kill TextDelta, Info — UiMessage is dead. RAII ActivityGuards replace all status feedback
Streaming text now goes directly to agent entries via append_streaming().
sync_from_agent diffs the growing entry each tick. The streaming entry
is popped when the response completes; build_response_message pushes
the final version.

All status feedback uses RAII ActivityGuards:
- push_activity() for long-running work (thinking, streaming, scoring)
- notify() for instant feedback (compacted, DMN state changes, commands)
- Guards auto-remove on Drop, appending "(complete)" and lingering 5s
- expire_activities() cleans up timed-out notifications on render tick

UiMessage enum reduced to a single Info variant with zero sends.
The channel infrastructure remains for now (Mind/Agent still take
UiSender in signatures) — mechanical cleanup for a follow-up.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-05 22:18:07 -04:00

410 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via /v1/score
//
// Three scoring modes, all built on the same call_score() primitive:
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 API calls.
//
// memory_score() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
use crate::agent::api::ApiClient;
use crate::agent::api::types::*;
use crate::agent::context::{ConversationEntry, ContextState};
use crate::user::ui_channel::UiSender;
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
// ── Message building ────────────────────────────────────────────
/// What to filter when building the message array for scoring.
enum Filter<'a> {
None,
SkipIndex(usize),
SkipKey(&'a str),
SkipAllMemories,
}
/// Build the messages array for a scoring call.
///
/// Always includes system prompt + context message as prefix, then
/// entries from `range` filtered by `filter`.
fn build_messages(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<serde_json::Value> {
let mut msgs = vec![
serde_json::json!({"role": "system", "content": &context.system_prompt}),
];
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for i in range {
let entry = &context.entries[i];
let skip = match &filter {
Filter::None => false,
Filter::SkipIndex(idx) => i == *idx,
Filter::SkipKey(key) => matches!(entry, ConversationEntry::Memory { key: k, .. } if k == key),
Filter::SkipAllMemories => entry.is_memory(),
};
if skip { continue; }
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(SCORE_TIMEOUT)
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default()
}
async fn call_score(
http: &reqwest::Client,
client: &ApiClient,
messages: &[serde_json::Value],
) -> anyhow::Result<Vec<ScoreResult>> {
let response = http
.post(format!("{}/score", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&serde_json::json!({
"model": client.model,
"messages": messages,
"logprobs": 1,
}))
.send()
.await
.map_err(|e| if e.is_timeout() {
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
} else {
anyhow::anyhow!("score request failed: {}", e)
})?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence(
http: &reqwest::Client,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter<'_>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let without = call_score(http, client, &build_messages(context, range, filter)).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
pub memory_weights: Vec<(String, f64)>,
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
pub memory_keys: Vec<String>,
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
let mut memory_keys: Vec<String> = context.entries.iter()
.filter_map(|e| match e {
ConversationEntry::Memory { key, .. } => Some(key.clone()),
_ => None,
})
.collect();
memory_keys.dedup();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memory_keys.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(), response_scores: Vec::new(),
matrix: Vec::new(), memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let http = http_client();
let range = 0..context.entries.len();
let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?;
let total = memory_keys.len();
let mut matrix: Vec<Vec<f64>> = Vec::new();
for (mem_idx, key) in memory_keys.iter().enumerate() {
dbglog!(
"scoring {}/{}: {}...", mem_idx + 1, total, key,
);
let msgs = build_messages(context, range.clone(), Filter::SkipKey(key));
match call_score(&http, client, &msgs).await {
Ok(without) => matrix.push(divergence(&baseline, &without)),
Err(e) => {
dbglog!(
"[training] {} FAILED: {:#}", key, e,
);
matrix.push(vec![0.0; baseline.len()]);
}
}
}
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let mut response_scores = vec![0.0; response_indices.len()];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < response_scores.len() { response_scores[j] += v; }
}
}
Ok(MemoryScore {
memory_weights, response_scores, matrix, memory_keys,
response_entry_indices: response_indices,
})
}
/// Find the entry index after `start` that contains the Nth assistant response.
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
fn nth_response_end(entries: &[ConversationEntry], start: usize, n: usize) -> (usize, bool) {
let mut count = 0;
for i in start..entries.len() {
if entries[i].message().role == Role::Assistant {
count += 1;
if count >= n { return (i + 1, true); }
}
}
(entries.len(), false)
}
// ── Single memory scoring ───────────────────────────────────────
/// Score how important a single memory is to the conversation.
///
/// Scores the 50 messages after the memory was surfaced — the window
/// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
pub async fn score_memory(
context: &ContextState,
key: &str,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<f64> {
const RESPONSE_WINDOW: usize = 50;
let first_pos = match context.entries.iter().position(|e| {
matches!(e, ConversationEntry::Memory { key: k, .. } if k == key)
}) {
Some(p) => p,
None => return Ok(0.0),
};
let (end, _) = nth_response_end(&context.entries, first_pos, RESPONSE_WINDOW);
let range = first_pos..end;
if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) {
return Ok(0.0);
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?;
Ok(divs.iter().sum())
}
// ── Background memory scoring ───────────────────────────────────
/// Score memories in the conversation that are due for re-scoring.
///
/// Checks the graph for each memory's last_scored timestamp. Scores
/// nodes that haven't been scored within `max_age_secs`, oldest first.
/// Updates the graph weight (EWMA) and last_scored after each.
///
/// Returns the number of nodes scored and their (key, score) pairs.
pub async fn score_memories_incremental(
context: &ContextState,
max_age_secs: i64,
response_window: usize,
client: &ApiClient,
ui_tx: &UiSender,
agent: &std::sync::Arc<tokio::sync::Mutex<crate::agent::Agent>>,
) -> anyhow::Result<Vec<(String, f64)>> {
let now = chrono::Utc::now().timestamp();
// Collect unique memory keys with their first position
let mut seen = std::collections::HashSet::new();
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
let store = crate::hippocampus::store::Store::load().unwrap_or_default();
for (i, entry) in context.entries.iter().enumerate() {
if let ConversationEntry::Memory { key, .. } = entry {
if !seen.insert(key.clone()) { continue; }
let last_scored = store.nodes.get(key.as_str())
.map(|n| n.last_scored)
.unwrap_or(0);
if now - last_scored >= max_age_secs {
candidates.push((i, key.clone(), last_scored));
}
}
}
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
let http = http_client();
let mut results = Vec::new();
let total_entries = context.entries.len();
let first_quarter = total_entries / 4;
for (pos, key, _) in &candidates {
let (end, full_window) = nth_response_end(&context.entries, *pos, response_window);
// Skip memories without a full window, unless they're in the
// first quarter of the conversation (always score those).
if !full_window && *pos >= first_quarter {
continue;
}
let range = *pos..end;
if !context.entries[range.clone()].iter().any(|e| e.message().role == Role::Assistant) {
continue;
}
let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await;
match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!(
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
);
results.push((key.clone(), max_div));
}
Err(e) => {
dbglog!(
"[scoring] {} FAILED: {:#}", key, e,
);
}
}
}
Ok(results)
}
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<Vec<(usize, f64)>> {
let range = context.entries.len().saturating_sub(count)..context.entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| context.entries[i].message().role == Role::Assistant)
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}