consciousness/src/subconscious/learn.rs
Kent Overstreet 5d9d3ffc5b learn: wire up /train endpoint for approved candidates
When 's' is pressed on the learn screen, approved candidates are now
sent to the inference server's /train endpoint.

Samples are marked as sent immediately in the UI, and mark_trained()
is called after successful API response to prevent re-scoring.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-16 02:04:26 -04:00

716 lines
24 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via /v1/score
//
// Three scoring modes, all built on the same call_score() primitive:
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 API calls.
//
// memory_score() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
use crate::agent::api::ApiClient;
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
use crate::agent::tokenizer;
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
// ── Message building ────────────────────────────────────────────
/// What to filter when building the message array for scoring.
#[allow(dead_code)]
enum Filter<'a> {
None,
SkipIndex(usize),
SkipKey(&'a str),
SkipAllMemories,
}
fn is_memory(node: &AstNode) -> bool {
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
}
fn memory_key(node: &AstNode) -> Option<&str> {
match node {
AstNode::Leaf(leaf) => match leaf.body() {
NodeBody::Memory { key, .. } => Some(key),
_ => None,
},
_ => None,
}
}
fn is_assistant(node: &AstNode) -> bool {
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
}
/// Build a token ID array for a scoring call.
///
/// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries.
fn build_token_ids(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<u32> {
use crate::agent::context::Ast;
let mut ids = Vec::new();
for node in context.system() {
ids.extend(node.token_ids());
}
// Identity nodes can be filtered by key for scoring
for node in context.identity() {
let skip = match &filter {
Filter::SkipKey(key) => memory_key(node) == Some(*key),
Filter::SkipAllMemories => is_memory(node),
_ => false,
};
if !skip {
ids.extend(node.token_ids());
}
}
for node in context.journal() {
ids.extend(node.token_ids());
}
let entries = context.conversation();
for i in range {
let node = &entries[i];
let skip = match &filter {
Filter::None => false,
Filter::SkipIndex(idx) => i == *idx,
Filter::SkipKey(key) => memory_key(node) == Some(*key),
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
ids.extend(node.token_ids());
}
ids
}
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> crate::agent::api::http::HttpClient {
crate::agent::api::http::HttpClient::builder()
.timeout(SCORE_TIMEOUT)
.build()
}
async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url());
let auth = format!("Bearer {}", client.api_key());
let mut body = serde_json::json!({
"model": client.model,
"prompt": prompt,
"logprobs": 1,
});
if let Some(p) = priority {
body["priority"] = serde_json::json!(p);
}
let response = http
.send_json("POST", &url, &[
("authorization", &auth),
], &body)
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter<'_>,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories(
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
) -> anyhow::Result<()> {
// Collect memory keys and response indices under a brief lock
let (memory_keys, response_indices) = {
let ctx = agent.context.lock().await;
// Include identity nodes and conversation memories
let mut keys: Vec<String> = ctx.identity().iter()
.chain(ctx.conversation().iter())
.filter_map(|node| memory_key(node).map(String::from))
.collect();
keys.dedup();
let resp: Vec<usize> = ctx.conversation().iter().enumerate()
.filter(|(_, node)| is_assistant(node))
.map(|(i, _)| i)
.collect();
(keys, resp)
};
if memory_keys.is_empty() || response_indices.is_empty() {
return Ok(());
}
let total = memory_keys.len();
dbglog!("[scoring-full] starting: {} memories × {} responses",
total, response_indices.len());
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let baseline_tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
};
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let tokens = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
};
let row = match call_score(&http, client, &tokens, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!("[scoring-full] {}/{}: {} max_div={:.3}",
mem_idx + 1, total, key, max_div);
divs
}
Err(e) => {
dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}",
mem_idx + 1, total, key, e);
vec![0.0; baseline.len()]
}
};
// Write this memory's scores to the live AST nodes
{
let mut ctx = agent.context.lock().await;
let mut set_count = 0;
for (resp_idx, &idx) in response_indices.iter().enumerate() {
if idx >= ctx.conversation().len() { continue; }
let node = &mut ctx.conversation_mut()[idx];
if let AstNode::Branch {
role: Role::Assistant, memory_scores, ..
} = node {
if let Some(&score) = row.get(resp_idx) {
if score > 0.01 {
memory_scores.insert(key.clone(), score);
set_count += 1;
} else {
memory_scores.remove(key.as_str());
}
}
}
}
dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count);
}
agent.state.lock().await.changed.notify_one();
}
Ok(())
}
/// Find the entry index after `start` that contains the Nth assistant response.
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) {
let mut count = 0;
for i in start..entries.len() {
if is_assistant(&entries[i]) {
count += 1;
if count >= n { return (i + 1, true); }
}
}
(entries.len(), false)
}
// ── Single memory scoring ───────────────────────────────────────
/// Score how important a single memory is to the conversation.
///
/// Scores the 50 messages after the memory was surfaced — the window
/// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
pub async fn score_memory(
context: &ContextState,
key: &str,
client: &ApiClient,
) -> anyhow::Result<f64> {
const RESPONSE_WINDOW: usize = 50;
let entries = context.conversation();
let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) {
Some(p) => p,
None => return Ok(0.0),
};
let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW);
let range = first_pos..end;
if !entries[range.clone()].iter().any(|node| is_assistant(node)) {
return Ok(0.0);
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?;
Ok(divs.iter().sum())
}
// ── Background memory scoring ───────────────────────────────────
/// Score memories in the conversation that are due for re-scoring.
///
/// Checks the graph for each memory's last_scored timestamp. Scores
/// nodes that haven't been scored within `max_age_secs`, oldest first.
/// Updates the graph weight (EWMA) and last_scored after each.
///
/// Returns the number of nodes scored and their (key, score) pairs.
pub async fn score_memories_incremental<F, Fut>(
context: &ContextState,
max_age_secs: i64,
response_window: usize,
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
mut on_score: F,
) -> anyhow::Result<usize>
where
F: FnMut(String, f64) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let now = chrono::Utc::now().timestamp();
// Collect unique memory keys with their first position
let mut seen = std::collections::HashSet::new();
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
let store_arc = crate::hippocampus::access_local()?;
{
let store = &*store_arc;
// Identity nodes always score at position 0; conversation nodes at their index
let identity_nodes = context.identity().iter().map(|n| (0, n));
let conv_nodes = context.conversation().iter().enumerate();
for (pos, node) in identity_nodes.chain(conv_nodes) {
if let Some(key) = memory_key(node) {
if !seen.insert(key.to_owned()) { continue; }
let last_scored = store.get_node(key)
.ok()
.flatten()
.map(|n| n.last_scored)
.unwrap_or(0);
if now - last_scored >= max_age_secs {
candidates.push((pos, key.to_owned(), last_scored));
}
}
}
}
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
let http = http_client();
let mut scored = 0;
let entries = context.conversation();
let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum();
let token_cutoff = total_tokens * 60 / 100;
// Precompute cumulative token position for each entry
let mut cumulative: Vec<usize> = Vec::with_capacity(entries.len());
let mut running = 0;
for e in entries {
running += e.tokens();
cumulative.push(running);
}
let total = candidates.len();
dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total);
let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await;
for (pos, key, _) in &candidates {
// Only score memories in the first 60% of the conversation by tokens —
// recent memories don't have enough responses to evaluate yet.
let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens);
if cum > token_cutoff {
dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff);
continue;
}
let (end, _) = nth_response_end(context.conversation(), *pos, response_window);
let range = *pos..end;
if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) {
dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end);
continue;
}
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!(
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
);
on_score(key.clone(), max_div).await;
scored += 1;
}
Err(e) => {
dbglog!(
"[scoring] {} FAILED: {:#}", key, e,
);
}
}
}
Ok(scored)
}
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
) -> anyhow::Result<Vec<(usize, f64)>> {
let entries = context.conversation();
let range = entries.len().saturating_sub(count)..entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| is_assistant(&entries[i]))
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
/// Enriched finetune candidate with context for review.
#[derive(Clone, Debug)]
pub struct FinetuneCandidate {
pub entry_idx: usize,
pub divergence: f64,
pub response_text: String,
/// Token IDs for context (everything before the response).
pub context_ids: Vec<u32>,
/// Token IDs for the response (what we're training on).
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status.
pub timestamp_ms: i64,
}
/// Score and enrich finetune candidates with full context.
///
/// Returns candidates ready for review, with context/continuation token IDs
/// already computed for sending to /finetune.
pub async fn score_finetune_candidates(
context: &ContextState,
count: usize,
client: &ApiClient,
min_divergence: f64,
) -> anyhow::Result<Vec<FinetuneCandidate>> {
let scores = score_finetune(context, count, client).await?;
let entries = context.conversation();
let mut candidates = Vec::new();
let trained = load_trained();
for (entry_idx, divergence) in scores {
if divergence < min_divergence {
continue;
}
let node = &entries[entry_idx];
// Get timestamp and skip if already trained
let timestamp_ms = match node_timestamp_ms(node) {
Some(ts) => {
if trained.contains(&ts) {
continue; // Already trained, skip
}
ts
}
None => continue, // No timestamp, skip
};
// Extract response text
let response_text = match node {
AstNode::Branch { children, .. } => {
children.iter()
.filter_map(|c| match c {
AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
_ => continue,
};
// Build token IDs: context = everything before response, continuation = response
let context_ids = build_token_ids(context, 0..entry_idx, Filter::None);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate {
entry_idx,
divergence,
response_text,
context_ids,
continuation_ids,
alternate_text: None,
timestamp_ms,
});
}
// Generate alternates if enabled
if alternates_enabled() && !candidates.is_empty() {
for candidate in &mut candidates {
match generate_alternate(context, candidate.entry_idx, client).await {
Ok(text) => candidate.alternate_text = Some(text),
Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e),
}
}
}
Ok(candidates)
}
/// Generate what the model would say without memories for a given entry.
async fn generate_alternate(
context: &ContextState,
entry_idx: usize,
client: &ApiClient,
) -> anyhow::Result<String> {
use crate::agent::api::{SamplingParams, StreamToken};
// Build context tokens without memories, up to the response
let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
// Add assistant turn start
prompt.push(tokenizer::IM_START);
prompt.extend(tokenizer::encode("assistant\n"));
// Generate completion
let sampling = SamplingParams {
temperature: 0.6,
top_p: 0.95,
top_k: 20,
};
let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5));
let mut tokens = Vec::new();
while let Some(tok) = rx.recv().await {
match tok {
StreamToken::Token(id) => tokens.push(id),
StreamToken::Done { .. } => break,
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
}
}
Ok(tokenizer::decode(&tokens))
}
// ── Finetune config and persistence ─────────────────────────────
use std::path::PathBuf;
use std::collections::HashSet;
const FINETUNE_ALTERNATES_FILE: &str = ".consciousness/cache/finetune-alternates";
const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json";
fn alternates_path() -> PathBuf {
dirs::home_dir().unwrap_or_default().join(FINETUNE_ALTERNATES_FILE)
}
fn trained_path() -> PathBuf {
dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE)
}
/// Check if alternate response generation is enabled.
pub fn alternates_enabled() -> bool {
alternates_path().exists()
}
/// Toggle alternate response generation and persist the setting.
pub fn set_alternates(enabled: bool) {
let path = alternates_path();
if enabled {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(&path, "");
} else {
let _ = std::fs::remove_file(&path);
}
}
/// Load set of trained response timestamps (millis since epoch).
pub fn load_trained() -> HashSet<i64> {
let path = trained_path();
match std::fs::read_to_string(&path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
Err(_) => HashSet::new(),
}
}
/// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ms: i64) {
let mut trained = load_trained();
trained.insert(timestamp_ms);
let path = trained_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(json) = serde_json::to_string(&trained) {
let _ = std::fs::write(&path, json);
}
}
/// Get timestamp in millis from an AstNode (for Branch, uses first child).
pub fn node_timestamp_ms(node: &AstNode) -> Option<i64> {
let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { children, .. } => {
children.first()?.leaf()?.timestamp()
}
}?;
Some(ts.timestamp_millis())
}
// ── Training API ────────────────────────────────────────────────
/// Training sample for /train endpoint.
#[derive(serde::Serialize)]
struct TrainingSample {
context_ids: Vec<u32>,
continuation_ids: Vec<u32>,
}
/// Data needed to send a training sample.
pub struct TrainData {
pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>,
pub timestamp_ms: i64,
}
/// Send training samples to the server.
///
/// Returns job_id on success, marks each sample as trained.
pub async fn send_to_train(
samples: Vec<TrainData>,
client: &ApiClient,
) -> anyhow::Result<String> {
if samples.is_empty() {
anyhow::bail!("no samples to train");
}
let api_samples: Vec<TrainingSample> = samples.iter()
.map(|s| TrainingSample {
context_ids: s.context_ids.clone(),
continuation_ids: s.continuation_ids.clone(),
})
.collect();
let body = serde_json::json!({
"training_data": {
"samples": api_samples,
}
});
let http = http_client();
let url = format!("{}/train", client.base_url());
let response = http.send_json("POST", &url, &[], &body).await?;
let status = response.status();
let result: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("train API HTTP {}: {}", status, msg);
}
// Mark all samples as trained
for s in &samples {
mark_trained(s.timestamp_ms);
}
let job_id = result.get("job_id")
.and_then(|j| j.as_str())
.unwrap_or("unknown")
.to_string();
dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id);
Ok(job_id)
}