// training.rs — Memory importance scoring via /v1/score // // Three scoring modes, all built on the same call_score() primitive: // // score_memories() — Full N×M matrix (memories × responses) for the // debug screen. Expensive: N+1 API calls. // // memory_score() — Single memory importance. Scores the 50 messages // after it was surfaced, with/without that memory. // 2 API calls. // // finetune_score() — Identifies training candidates. Scores recent // messages with all memories stripped. Responses // with high divergence depend on memories the model // hasn't internalized. 2 API calls. use crate::agent::api::ApiClient; use crate::agent::context::{ Ast, AstNode, ContextState, Role, WireImage, is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context, }; use crate::subconscious::generate::gen_continuation; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] struct ScoreResult { total_logprob: f64, } #[derive(serde::Deserialize)] struct ScoreResponse { scores: Vec, } fn http_client() -> crate::agent::api::http::HttpClient { crate::agent::api::http::HttpClient::builder() .timeout(SCORE_TIMEOUT) .build() } async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, prompt: &[u32], images: &[WireImage], ranges: &[(usize, usize)], priority: Option, ) -> anyhow::Result> { // Nothing to score — skip the round-trip. if ranges.is_empty() { return Ok(Vec::new()); } let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); let mut body = serde_json::json!({ "model": client.model, "prompt": prompt, "score_ranges": ranges, "logprobs": 1, }); if !images.is_empty() { use base64::Engine; let b64 = base64::engine::general_purpose::STANDARD; let uris: Vec = images.iter() .map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes))) .collect(); body["multi_modal_data"] = serde_json::json!({ "image": uris }); } if let Some(p) = priority { body["priority"] = serde_json::json!(p); } let response = http .send_json("POST", &url, &[ ("authorization", &auth), ], &body) .await?; let status = response.status(); let body: serde_json::Value = response.json().await?; if !status.is_success() { let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("score API HTTP {}: {}", status, msg); } if let Some(err) = body.get("error").and_then(|e| e.as_str()) { anyhow::bail!("score API error: {}", err); } let result: ScoreResponse = serde_json::from_value(body) .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) } /// Compute per-position logprob divergence: how much worse the model /// scores each response without something vs with it. fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { baseline.iter().enumerate() .map(|(i, base)| { let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob); (base.total_logprob - without_lp).max(0.0) }) .collect() } /// Score two message sets and return total divergence. async fn score_divergence( http: &crate::agent::api::http::HttpClient, client: &ApiClient, context: &ContextState, range: std::ops::Range, skip: F, priority: Option, ) -> anyhow::Result<(Vec, Vec)> where F: FnMut(&AstNode) -> bool, { let (baseline_tokens, baseline_images, baseline_ranges) = context.wire_prompt(range.clone(), |_| false); let (without_tokens, without_images, without_ranges) = context.wire_prompt(range, skip); let baseline = call_score(http, client, &baseline_tokens, &baseline_images, &baseline_ranges, priority).await?; let without = call_score(http, client, &without_tokens, &without_images, &without_ranges, priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } // ── Full matrix scoring (debug screen) ────────────────────────── /// Score how important each memory is to the conversation (full matrix). pub async fn score_memories( client: &ApiClient, agent: &std::sync::Arc, ) -> anyhow::Result<()> { // Collect memory keys and response indices under a brief lock let (memory_keys, response_indices) = { let ctx = agent.context.lock().await; // Include identity nodes and conversation memories let mut keys: Vec = ctx.identity().iter() .chain(ctx.conversation().iter()) .filter_map(|node| memory_key(node).map(String::from)) .collect(); keys.dedup(); let resp: Vec = ctx.conversation().iter().enumerate() .filter(|(_, node)| is_assistant(node)) .map(|(i, _)| i) .collect(); (keys, resp) }; if memory_keys.is_empty() || response_indices.is_empty() { return Ok(()); } let total = memory_keys.len(); dbglog!("[scoring-full] starting: {} memories × {} responses", total, response_indices.len()); let http = http_client(); let activity = crate::agent::start_activity(agent, "scoring: baseline").await; let (baseline_tokens, baseline_images, baseline_ranges) = { let ctx = agent.context.lock().await; ctx.wire_prompt(0..ctx.conversation().len(), |_| false) }; let baseline = call_score(&http, client, &baseline_tokens, &baseline_images, &baseline_ranges, Some(5)).await?; dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); for (mem_idx, key) in memory_keys.iter().enumerate() { activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); let (tokens, images, ranges) = { let ctx = agent.context.lock().await; ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str())) }; let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await { Ok(without) => { let divs = divergence(&baseline, &without); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!("[scoring-full] {}/{}: {} max_div={:.3}", mem_idx + 1, total, key, max_div); divs } Err(e) => { dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}", mem_idx + 1, total, key, e); vec![0.0; baseline.len()] } }; // Write this memory's scores to the live AST nodes { let mut ctx = agent.context.lock().await; let mut set_count = 0; for (resp_idx, &idx) in response_indices.iter().enumerate() { if idx >= ctx.conversation().len() { continue; } let node = &mut ctx.conversation_mut()[idx]; if let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node { if let Some(&score) = row.get(resp_idx) { if score > 0.01 { memory_scores.insert(key.clone(), score); set_count += 1; } else { memory_scores.remove(key.as_str()); } } } } dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count); } agent.state.lock().await.changed.notify_one(); } Ok(()) } /// Find the entry index after `start` that contains the Nth assistant response. /// Returns (end_index, true) if N responses were found, (entries.len(), false) if not. fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) { let mut count = 0; for i in start..entries.len() { if is_assistant(&entries[i]) { count += 1; if count >= n { return (i + 1, true); } } } (entries.len(), false) } // ── Single memory scoring ─────────────────────────────────────── /// Score how important a single memory is to the conversation. /// /// Scores the 50 messages after the memory was surfaced — the window /// where it could have influenced responses. Returns the sum of /// divergence, or 0.0 if the memory isn't in the conversation. pub async fn score_memory( context: &ContextState, key: &str, client: &ApiClient, ) -> anyhow::Result { const RESPONSE_WINDOW: usize = 50; let entries = context.conversation(); let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) { Some(p) => p, None => return Ok(0.0), }; let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW); let range = first_pos..end; if !entries[range.clone()].iter().any(|node| is_assistant(node)) { return Ok(0.0); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, |n| memory_key(n) == Some(key), Some(5)).await?; Ok(divs.iter().sum()) } // ── Background memory scoring ─────────────────────────────────── /// Score memories in the conversation that are due for re-scoring. /// /// Checks the graph for each memory's last_scored timestamp. Scores /// nodes that haven't been scored within `max_age_secs`, oldest first. /// Updates the graph weight (EWMA) and last_scored after each. /// /// Returns the number of nodes scored and their (key, score) pairs. pub async fn score_memories_incremental( context: &ContextState, max_age_secs: i64, response_window: usize, client: &ApiClient, agent: &std::sync::Arc, mut on_score: F, ) -> anyhow::Result where F: FnMut(String, f64) -> Fut, Fut: std::future::Future, { let now = chrono::Utc::now().timestamp(); // Collect unique memory keys with their first position let mut seen = std::collections::HashSet::new(); let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored) let store_arc = crate::hippocampus::access_local()?; { let store = &*store_arc; // Identity nodes always score at position 0; conversation nodes at their index let identity_nodes = context.identity().iter().map(|n| (0, n)); let conv_nodes = context.conversation().iter().enumerate(); for (pos, node) in identity_nodes.chain(conv_nodes) { if let Some(key) = memory_key(node) { if !seen.insert(key.to_owned()) { continue; } let last_scored = store.get_node(key) .ok() .flatten() .map(|n| n.last_scored) .unwrap_or(0); if now - last_scored >= max_age_secs { candidates.push((pos, key.to_owned(), last_scored)); } } } } // Score oldest-first candidates.sort_by_key(|&(_, _, last)| last); let http = http_client(); let mut scored = 0; let entries = context.conversation(); let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum(); let token_cutoff = total_tokens * 60 / 100; // Precompute cumulative token position for each entry let mut cumulative: Vec = Vec::with_capacity(entries.len()); let mut running = 0; for e in entries { running += e.tokens(); cumulative.push(running); } let total = candidates.len(); dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total); let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await; for (pos, key, _) in &candidates { // Only score memories in the first 60% of the conversation by tokens — // recent memories don't have enough responses to evaluate yet. let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens); if cum > token_cutoff { dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff); continue; } let (end, _) = nth_response_end(context.conversation(), *pos, response_window); let range = *pos..end; if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) { dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end); continue; } activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await; match score_divergence(&http, client, context, range, |n| memory_key(n) == Some(key), Some(5)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); dbglog!( "[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses, ); on_score(key.clone(), max_div).await; scored += 1; } Err(e) => { dbglog!( "[scoring] {} FAILED: {:#}", key, e, ); } } } Ok(scored) } // ── Fine-tuning scoring ───────────────────────────────────────── /// Score which recent responses are candidates for fine-tuning. /// /// Removes all memories and scores the most recent `count` messages. /// Responses with high divergence depend on memories the model hasn't /// internalized — these are fine-tuning candidates. /// /// Returns (entry_index, divergence) pairs, sorted by divergence descending. pub async fn score_finetune( context: &ContextState, count: usize, client: &ApiClient, ) -> anyhow::Result> { let entries = context.conversation(); let range = entries.len().saturating_sub(count)..entries.len(); let response_positions: Vec = range.clone() .filter(|&i| is_assistant(&entries[i])) .collect(); if response_positions.is_empty() { return Ok(Vec::new()); } let http = http_client(); let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() .map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0))) .collect(); results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(results) } /// Enriched finetune candidate with context for review. #[derive(Clone, Debug)] pub struct FinetuneCandidate { pub entry_idx: usize, pub divergence: f64, pub response_text: String, /// Last couple of user/assistant messages before this response, /// already rendered with role markers, for F6 display context. pub prior_context: String, /// Token IDs for context (everything before the response). pub context_ids: Vec, /// Token IDs for the response (what we're training on). pub continuation_ids: Vec, /// What the model would have said without memories (if generated). pub alternate_text: Option, /// Timestamp in nanos — used as unique key for trained-set dedup. pub timestamp_ns: i64, } /// Score and enrich finetune candidates with full context. /// /// Candidates are delivered via `on_candidate` one-at-a-time as they become /// ready: scoring happens once (one /score call), then for each candidate /// that passes the threshold we optionally generate an alternate response /// and then emit it. The activity status is updated during the alternate /// phase so the UI doesn't look stuck. /// /// Returns (count_above_threshold, max_divergence). pub async fn score_finetune_candidates( context: &ContextState, count: usize, client: &ApiClient, min_divergence: f64, generate_alternates: bool, activity: &crate::agent::ActivityGuard, mut on_candidate: impl FnMut(FinetuneCandidate), ) -> anyhow::Result<(usize, f64)> { let scores = score_finetune(context, count, client).await?; let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max); let entries = context.conversation(); let trained = load_trained(); let mut candidates: Vec = Vec::new(); for (entry_idx, divergence) in scores { if divergence < min_divergence { continue; } let node = &entries[entry_idx]; // Skip if already trained on. let timestamp_ns = node_timestamp_ns(node); if trained.contains(×tamp_ns) { continue; } // Extract response text — content of the assistant turn. let response_text = match node { AstNode::Branch { children, .. } => render_branch_text(children), _ => continue, }; // Skip turns that produced nothing human-visible (e.g., a // tool-only turn, or an interrupted generation). They'd show // up as blank cards and we'd still burn alternate-gen on them. if response_text.trim().is_empty() { continue; } // Build the last couple of user/assistant exchanges for review. let prior_context = render_prior_context(entries, entry_idx, 2); // Build token IDs: context = everything before response, continuation = response. let (context_ids, _, _) = context.wire_prompt(0..entry_idx, |_| false); let continuation_ids: Vec = node.token_ids().into_iter().collect(); candidates.push(FinetuneCandidate { entry_idx, divergence, response_text, prior_context, context_ids, continuation_ids, alternate_text: None, timestamp_ns, }); } let total = candidates.len(); let gen_alternates = generate_alternates && total > 0; for (i, mut candidate) in candidates.into_iter().enumerate() { if gen_alternates { activity.update( format!("finetune: generating alternate {}/{}", i + 1, total) ).await; match gen_continuation(context, candidate.entry_idx, is_memory_node, client).await { Ok(text) => candidate.alternate_text = Some(text), Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e), } } on_candidate(candidate); } Ok((total, max_divergence)) } // ── Finetune config and persistence ───────────────────────────── use std::path::PathBuf; use std::collections::HashSet; const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json"; fn trained_path() -> PathBuf { dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE) } /// Load set of trained response timestamps (nanos since epoch). pub fn load_trained() -> HashSet { let path = trained_path(); match std::fs::read_to_string(&path) { Ok(content) => serde_json::from_str(&content).unwrap_or_default(), Err(_) => HashSet::new(), } } /// Mark a response as trained by its timestamp. pub fn mark_trained(timestamp_ns: i64) { let mut trained = load_trained(); trained.insert(timestamp_ns); let path = trained_path(); if let Some(parent) = path.parent() { let _ = std::fs::create_dir_all(parent); } if let Ok(json) = serde_json::to_string(&trained) { let _ = std::fs::write(&path, json); } } /// Get timestamp in nanoseconds from an AstNode. /// i64-ns representation covers 1677..2262 via chrono; timestamps /// outside that window would be bugs we'd want to surface, hence panic. pub fn node_timestamp_ns(node: &AstNode) -> i64 { let ts = match node { AstNode::Leaf(leaf) => leaf.timestamp(), AstNode::Branch { timestamp, .. } => *timestamp, }; ts.timestamp_nanos_opt() .expect("timestamp outside i64-ns representable range (1677..2262)") } // ── Training API ──────────────────────────────────────────────── /// Training sample for /train endpoint. #[derive(serde::Serialize)] struct TrainingSample { context_ids: Vec, continuation_ids: Vec, } /// Data needed to send a training sample. pub struct TrainData { pub context_ids: Vec, pub continuation_ids: Vec, pub timestamp_ns: i64, } /// Send training samples to the server. /// /// Returns job_id on success, marks each sample as trained. pub async fn send_to_train( samples: Vec, client: &ApiClient, ) -> anyhow::Result { if samples.is_empty() { anyhow::bail!("no samples to train"); } let api_samples: Vec = samples.iter() .map(|s| TrainingSample { context_ids: s.context_ids.clone(), continuation_ids: s.continuation_ids.clone(), }) .collect(); let body = serde_json::json!({ "training_data": { "samples": api_samples, } }); let http = http_client(); let url = format!("{}/train", client.base_url()); let response = http.send_json("POST", &url, &[], &body).await?; let status = response.status(); let result: serde_json::Value = response.json().await?; if !status.is_success() { let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error"); anyhow::bail!("train API HTTP {}: {}", status, msg); } // Mark all samples as trained for s in &samples { mark_trained(s.timestamp_ns); } let job_id = result.get("job_id") .and_then(|j| j.as_str()) .unwrap_or("unknown") .to_string(); dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id); Ok(job_id) }