Mind's impl had accumulated ~50 lines of setup glue per scoring flow
(memory, memory-full, finetune): snapshot config, clone handles,
resolve context, spawn task, route results back through BgEvent,
write stats. The shape was identical; only the middle changed.
Introduce the MindTriggered trait:
pub trait MindTriggered {
fn trigger(&self);
}
Each flow becomes a struct next to its scoring code that owns its
dependencies and a JoinHandle (behind a sync Mutex for interior
mutability):
subconscious::learn::MemoryScoring (Score, ScoreFull)
subconscious::learn::FinetuneScoring (ScoreFinetune)
Mind holds one of each and dispatches in one line:
MindCommand::Score => self.memory_scoring.trigger(),
MindCommand::ScoreFull => self.memory_scoring.trigger_full(),
MindCommand::ScoreFinetune => self.finetune_scoring.trigger(),
Each struct picks its own trigger semantics — memory scoring is
no-op-if-running (!handle.is_finished()); finetune is abort-restart.
Falls out:
- BgEvent / bg_tx / bg_rx disappear entirely. Tasks write directly
to their slice of MindState and call agent.state.changed.notify_one()
to wake the UI. The bg_rx arm in Mind's select loop is gone.
- agent.state.memory_scoring_in_flight was duplicating
shared.scoring_in_flight via BgEvent routing; now the JoinHandle
alone tells us, and shared.scoring_in_flight is written directly
by the task for the UI.
- start_memory_scoring / start_full_scoring / start_finetune_scoring
methods on Mind are deleted; Mind no longer knows the setup shape
of any scoring flow.
- FinetuneScoringStats moves from mind/ to subconscious/learn.rs
next to the function that produces it.
No behavior change — same flows, same trigger points, same semantics.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
831 lines
29 KiB
Rust
831 lines
29 KiB
Rust
// training.rs — Memory importance scoring via /v1/score
|
||
//
|
||
// Three scoring modes, all built on the same call_score() primitive:
|
||
//
|
||
// score_memories() — Full N×M matrix (memories × responses) for the
|
||
// debug screen. Expensive: N+1 API calls.
|
||
//
|
||
// memory_score() — Single memory importance. Scores the 50 messages
|
||
// after it was surfaced, with/without that memory.
|
||
// 2 API calls.
|
||
//
|
||
// finetune_score() — Identifies training candidates. Scores recent
|
||
// messages with all memories stripped. Responses
|
||
// with high divergence depend on memories the model
|
||
// hasn't internalized. 2 API calls.
|
||
|
||
use std::sync::Arc;
|
||
|
||
use crate::agent::api::ApiClient;
|
||
use crate::agent::context::{
|
||
Ast, AstNode, ContextState, Role, WireImage,
|
||
is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context,
|
||
};
|
||
use crate::mind::{MindState, MindTriggered, TaskHandle};
|
||
use crate::subconscious::generate::gen_continuation;
|
||
|
||
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||
|
||
// ── Score API ───────────────────────────────────────────────────
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct ScoreResult {
|
||
total_logprob: f64,
|
||
}
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct ScoreResponse {
|
||
scores: Vec<ScoreResult>,
|
||
}
|
||
|
||
fn http_client() -> crate::agent::api::http::HttpClient {
|
||
crate::agent::api::http::HttpClient::builder()
|
||
.timeout(SCORE_TIMEOUT)
|
||
.build()
|
||
}
|
||
|
||
async fn call_score(
|
||
http: &crate::agent::api::http::HttpClient,
|
||
client: &ApiClient,
|
||
prompt: &[u32],
|
||
images: &[WireImage],
|
||
ranges: &[(usize, usize)],
|
||
priority: Option<i32>,
|
||
) -> anyhow::Result<Vec<ScoreResult>> {
|
||
// Nothing to score — skip the round-trip.
|
||
if ranges.is_empty() {
|
||
return Ok(Vec::new());
|
||
}
|
||
let url = format!("{}/score", client.base_url());
|
||
let auth = format!("Bearer {}", client.api_key());
|
||
let mut body = serde_json::json!({
|
||
"model": client.model,
|
||
"prompt": prompt,
|
||
"score_ranges": ranges,
|
||
"logprobs": 1,
|
||
});
|
||
if !images.is_empty() {
|
||
use base64::Engine;
|
||
let b64 = base64::engine::general_purpose::STANDARD;
|
||
let uris: Vec<String> = images.iter()
|
||
.map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes)))
|
||
.collect();
|
||
body["multi_modal_data"] = serde_json::json!({ "image": uris });
|
||
}
|
||
if let Some(p) = priority {
|
||
body["priority"] = serde_json::json!(p);
|
||
}
|
||
let response = http
|
||
.send_json("POST", &url, &[
|
||
("authorization", &auth),
|
||
], &body)
|
||
.await?;
|
||
|
||
let status = response.status();
|
||
let body: serde_json::Value = response.json().await?;
|
||
|
||
if !status.is_success() {
|
||
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
|
||
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
||
}
|
||
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
||
anyhow::bail!("score API error: {}", err);
|
||
}
|
||
|
||
let result: ScoreResponse = serde_json::from_value(body)
|
||
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
||
Ok(result.scores)
|
||
}
|
||
|
||
/// Compute per-position logprob divergence: how much worse the model
|
||
/// scores each response without something vs with it.
|
||
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
|
||
baseline.iter().enumerate()
|
||
.map(|(i, base)| {
|
||
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
|
||
(base.total_logprob - without_lp).max(0.0)
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// Score two message sets and return total divergence.
|
||
async fn score_divergence<F>(
|
||
http: &crate::agent::api::http::HttpClient,
|
||
client: &ApiClient,
|
||
context: &ContextState,
|
||
range: std::ops::Range<usize>,
|
||
skip: F,
|
||
priority: Option<i32>,
|
||
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)>
|
||
where F: FnMut(&AstNode) -> bool,
|
||
{
|
||
let (baseline_tokens, baseline_images, baseline_ranges) =
|
||
context.wire_prompt(range.clone(), |_| false);
|
||
let (without_tokens, without_images, without_ranges) =
|
||
context.wire_prompt(range, skip);
|
||
let baseline = call_score(http, client, &baseline_tokens, &baseline_images,
|
||
&baseline_ranges, priority).await?;
|
||
let without = call_score(http, client, &without_tokens, &without_images,
|
||
&without_ranges, priority).await?;
|
||
let divs = divergence(&baseline, &without);
|
||
Ok((divs, baseline))
|
||
}
|
||
|
||
// ── Full matrix scoring (debug screen) ──────────────────────────
|
||
|
||
/// Score how important each memory is to the conversation (full matrix).
|
||
pub async fn score_memories(
|
||
client: &ApiClient,
|
||
agent: &std::sync::Arc<crate::agent::Agent>,
|
||
) -> anyhow::Result<()> {
|
||
// Collect memory keys and response indices under a brief lock
|
||
let (memory_keys, response_indices) = {
|
||
let ctx = agent.context.lock().await;
|
||
// Include identity nodes and conversation memories
|
||
let mut keys: Vec<String> = ctx.identity().iter()
|
||
.chain(ctx.conversation().iter())
|
||
.filter_map(|node| memory_key(node).map(String::from))
|
||
.collect();
|
||
keys.dedup();
|
||
let resp: Vec<usize> = ctx.conversation().iter().enumerate()
|
||
.filter(|(_, node)| is_assistant(node))
|
||
.map(|(i, _)| i)
|
||
.collect();
|
||
(keys, resp)
|
||
};
|
||
|
||
if memory_keys.is_empty() || response_indices.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
let total = memory_keys.len();
|
||
dbglog!("[scoring-full] starting: {} memories × {} responses",
|
||
total, response_indices.len());
|
||
|
||
let http = http_client();
|
||
|
||
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
||
let (baseline_tokens, baseline_images, baseline_ranges) = {
|
||
let ctx = agent.context.lock().await;
|
||
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
|
||
};
|
||
let baseline = call_score(&http, client, &baseline_tokens, &baseline_images,
|
||
&baseline_ranges, Some(5)).await?;
|
||
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
||
|
||
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
||
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
||
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
||
let (tokens, images, ranges) = {
|
||
let ctx = agent.context.lock().await;
|
||
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
|
||
};
|
||
let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await {
|
||
Ok(without) => {
|
||
let divs = divergence(&baseline, &without);
|
||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||
dbglog!("[scoring-full] {}/{}: {} max_div={:.3}",
|
||
mem_idx + 1, total, key, max_div);
|
||
divs
|
||
}
|
||
Err(e) => {
|
||
dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}",
|
||
mem_idx + 1, total, key, e);
|
||
vec![0.0; baseline.len()]
|
||
}
|
||
};
|
||
// Write this memory's scores to the live AST nodes
|
||
{
|
||
let mut ctx = agent.context.lock().await;
|
||
let mut set_count = 0;
|
||
|
||
for (resp_idx, &idx) in response_indices.iter().enumerate() {
|
||
if idx >= ctx.conversation().len() { continue; }
|
||
let node = &mut ctx.conversation_mut()[idx];
|
||
if let AstNode::Branch {
|
||
role: Role::Assistant, memory_scores, ..
|
||
} = node {
|
||
if let Some(&score) = row.get(resp_idx) {
|
||
if score > 0.01 {
|
||
memory_scores.insert(key.clone(), score);
|
||
set_count += 1;
|
||
} else {
|
||
memory_scores.remove(key.as_str());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count);
|
||
}
|
||
agent.state.lock().await.changed.notify_one();
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// Find the entry index after `start` that contains the Nth assistant response.
|
||
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
|
||
fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) {
|
||
let mut count = 0;
|
||
for i in start..entries.len() {
|
||
if is_assistant(&entries[i]) {
|
||
count += 1;
|
||
if count >= n { return (i + 1, true); }
|
||
}
|
||
}
|
||
(entries.len(), false)
|
||
}
|
||
|
||
// ── Single memory scoring ───────────────────────────────────────
|
||
|
||
/// Score how important a single memory is to the conversation.
|
||
///
|
||
/// Scores the 50 messages after the memory was surfaced — the window
|
||
/// where it could have influenced responses. Returns the sum of
|
||
/// divergence, or 0.0 if the memory isn't in the conversation.
|
||
pub async fn score_memory(
|
||
context: &ContextState,
|
||
key: &str,
|
||
client: &ApiClient,
|
||
) -> anyhow::Result<f64> {
|
||
const RESPONSE_WINDOW: usize = 50;
|
||
|
||
let entries = context.conversation();
|
||
let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) {
|
||
Some(p) => p,
|
||
None => return Ok(0.0),
|
||
};
|
||
|
||
let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW);
|
||
let range = first_pos..end;
|
||
if !entries[range.clone()].iter().any(|node| is_assistant(node)) {
|
||
return Ok(0.0);
|
||
}
|
||
|
||
let http = http_client();
|
||
let (divs, _) = score_divergence(&http, client, context, range,
|
||
|n| memory_key(n) == Some(key), Some(5)).await?;
|
||
|
||
Ok(divs.iter().sum())
|
||
}
|
||
|
||
// ── Background memory scoring ───────────────────────────────────
|
||
|
||
/// Score memories in the conversation that are due for re-scoring.
|
||
///
|
||
/// Checks the graph for each memory's last_scored timestamp. Scores
|
||
/// nodes that haven't been scored within `max_age_secs`, oldest first.
|
||
/// Updates the graph weight (EWMA) and last_scored after each.
|
||
///
|
||
/// Returns the number of nodes scored and their (key, score) pairs.
|
||
pub async fn score_memories_incremental<F, Fut>(
|
||
context: &ContextState,
|
||
max_age_secs: i64,
|
||
response_window: usize,
|
||
client: &ApiClient,
|
||
agent: &std::sync::Arc<crate::agent::Agent>,
|
||
mut on_score: F,
|
||
) -> anyhow::Result<usize>
|
||
where
|
||
F: FnMut(String, f64) -> Fut,
|
||
Fut: std::future::Future<Output = ()>,
|
||
{
|
||
let now = chrono::Utc::now().timestamp();
|
||
|
||
// Collect unique memory keys with their first position
|
||
let mut seen = std::collections::HashSet::new();
|
||
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
|
||
|
||
let store_arc = crate::hippocampus::access_local()?;
|
||
|
||
{
|
||
let store = &*store_arc;
|
||
// Identity nodes always score at position 0; conversation nodes at their index
|
||
let identity_nodes = context.identity().iter().map(|n| (0, n));
|
||
let conv_nodes = context.conversation().iter().enumerate();
|
||
for (pos, node) in identity_nodes.chain(conv_nodes) {
|
||
if let Some(key) = memory_key(node) {
|
||
if !seen.insert(key.to_owned()) { continue; }
|
||
let last_scored = store.get_node(key)
|
||
.ok()
|
||
.flatten()
|
||
.map(|n| n.last_scored)
|
||
.unwrap_or(0);
|
||
if now - last_scored >= max_age_secs {
|
||
candidates.push((pos, key.to_owned(), last_scored));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Score oldest-first
|
||
candidates.sort_by_key(|&(_, _, last)| last);
|
||
|
||
let http = http_client();
|
||
let mut scored = 0;
|
||
|
||
let entries = context.conversation();
|
||
let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum();
|
||
let token_cutoff = total_tokens * 60 / 100;
|
||
|
||
// Precompute cumulative token position for each entry
|
||
let mut cumulative: Vec<usize> = Vec::with_capacity(entries.len());
|
||
let mut running = 0;
|
||
for e in entries {
|
||
running += e.tokens();
|
||
cumulative.push(running);
|
||
}
|
||
|
||
let total = candidates.len();
|
||
dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total);
|
||
let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await;
|
||
|
||
for (pos, key, _) in &candidates {
|
||
// Only score memories in the first 60% of the conversation by tokens —
|
||
// recent memories don't have enough responses to evaluate yet.
|
||
let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens);
|
||
if cum > token_cutoff {
|
||
dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff);
|
||
continue;
|
||
}
|
||
let (end, _) = nth_response_end(context.conversation(), *pos, response_window);
|
||
let range = *pos..end;
|
||
if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) {
|
||
dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end);
|
||
continue;
|
||
}
|
||
|
||
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
|
||
match score_divergence(&http, client, context, range,
|
||
|n| memory_key(n) == Some(key), Some(5)).await {
|
||
Ok((divs, _)) => {
|
||
let n_responses = divs.len();
|
||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||
dbglog!(
|
||
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
|
||
);
|
||
on_score(key.clone(), max_div).await;
|
||
scored += 1;
|
||
}
|
||
Err(e) => {
|
||
dbglog!(
|
||
"[scoring] {} FAILED: {:#}", key, e,
|
||
);
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(scored)
|
||
}
|
||
|
||
/// Memory scoring — two modes sharing an in-flight handle (only one
|
||
/// runs at a time): `trigger()` for incremental, `trigger_full()` for
|
||
/// the N×M debug matrix.
|
||
pub struct MemoryScoring {
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
scores_path: std::path::PathBuf,
|
||
task: TaskHandle,
|
||
}
|
||
|
||
impl MemoryScoring {
|
||
pub fn new(
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
scores_path: std::path::PathBuf,
|
||
) -> Self {
|
||
Self { agent, shared, scores_path, task: TaskHandle::new() }
|
||
}
|
||
|
||
pub fn trigger_full(&self) {
|
||
self.task.trigger_if_idle(run_full(self.agent.clone(), self.shared.clone()));
|
||
}
|
||
}
|
||
|
||
impl MindTriggered for MemoryScoring {
|
||
fn trigger(&self) {
|
||
self.task.trigger_if_idle(run_incremental(
|
||
self.agent.clone(), self.shared.clone(), self.scores_path.clone(),
|
||
));
|
||
}
|
||
}
|
||
|
||
async fn run_incremental(
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
scores_path: std::path::PathBuf,
|
||
) {
|
||
shared.lock().unwrap().scoring_in_flight = true;
|
||
agent.state.lock().await.changed.notify_one();
|
||
|
||
let cfg = crate::config::get();
|
||
let max_age = cfg.scoring_interval_secs;
|
||
let response_window = cfg.scoring_response_window;
|
||
|
||
let (context, client) = {
|
||
let ctx = agent.context.lock().await.clone();
|
||
(ctx, agent.client.clone())
|
||
};
|
||
|
||
let _result = score_memories_incremental(
|
||
&context, max_age as i64, response_window, &client, &agent,
|
||
|key: String, score: f64| {
|
||
let agent = agent.clone();
|
||
let path = scores_path.clone();
|
||
async move {
|
||
let scores_snapshot = {
|
||
let mut ctx = agent.context.lock().await;
|
||
let found = crate::mind::find_memory_by_key(&ctx, &key);
|
||
match found {
|
||
Some((section, i)) => {
|
||
ctx.set_score(section, i, Some(score));
|
||
dbglog!("[scoring] persisted {} → {:.3} ({:?}[{}])",
|
||
key, score, section, i);
|
||
}
|
||
None => {
|
||
dbglog!(
|
||
"[scoring] DROP {}: find_memory_by_key None (id={}, cv={})",
|
||
key, ctx.identity().len(), ctx.conversation().len()
|
||
);
|
||
}
|
||
}
|
||
let snapshot = crate::mind::collect_memory_scores(&ctx);
|
||
drop(ctx);
|
||
agent.state.lock().await.changed.notify_one();
|
||
snapshot
|
||
};
|
||
crate::mind::save_memory_scores(&scores_snapshot, &path);
|
||
}
|
||
},
|
||
).await;
|
||
|
||
shared.lock().unwrap().scoring_in_flight = false;
|
||
agent.state.lock().await.changed.notify_one();
|
||
}
|
||
|
||
async fn run_full(
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
) {
|
||
shared.lock().unwrap().scoring_in_flight = true;
|
||
agent.state.lock().await.changed.notify_one();
|
||
|
||
let client = agent.client.clone();
|
||
match score_memories(&client, &agent).await {
|
||
Ok(()) => {},
|
||
Err(e) => { dbglog!("[scoring-full] FAILED: {:#}", e); }
|
||
}
|
||
|
||
shared.lock().unwrap().scoring_in_flight = false;
|
||
agent.state.lock().await.changed.notify_one();
|
||
}
|
||
|
||
// ── Fine-tuning scoring ─────────────────────────────────────────
|
||
|
||
/// Score which recent responses are candidates for fine-tuning.
|
||
///
|
||
/// Removes all memories and scores the most recent `count` messages.
|
||
/// Responses with high divergence depend on memories the model hasn't
|
||
/// internalized — these are fine-tuning candidates.
|
||
///
|
||
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
|
||
pub async fn score_finetune(
|
||
context: &ContextState,
|
||
count: usize,
|
||
client: &ApiClient,
|
||
) -> anyhow::Result<Vec<(usize, f64)>> {
|
||
let entries = context.conversation();
|
||
let range = entries.len().saturating_sub(count)..entries.len();
|
||
|
||
let response_positions: Vec<usize> = range.clone()
|
||
.filter(|&i| is_assistant(&entries[i]))
|
||
.collect();
|
||
if response_positions.is_empty() {
|
||
return Ok(Vec::new());
|
||
}
|
||
|
||
let http = http_client();
|
||
let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?;
|
||
|
||
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
||
.enumerate()
|
||
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
|
||
.collect();
|
||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||
Ok(results)
|
||
}
|
||
|
||
/// Enriched finetune candidate with context for review.
|
||
#[derive(Clone, Debug)]
|
||
pub struct FinetuneCandidate {
|
||
pub entry_idx: usize,
|
||
pub divergence: f64,
|
||
pub response_text: String,
|
||
/// Last couple of user/assistant messages before this response,
|
||
/// already rendered with role markers, for F6 display context.
|
||
pub prior_context: String,
|
||
/// Token IDs for context (everything before the response).
|
||
pub context_ids: Vec<u32>,
|
||
/// Token IDs for the response (what we're training on).
|
||
pub continuation_ids: Vec<u32>,
|
||
/// What the model would have said without memories (if generated).
|
||
pub alternate_text: Option<String>,
|
||
/// Timestamp in nanos — used as unique key for trained-set dedup.
|
||
pub timestamp_ns: i64,
|
||
}
|
||
|
||
/// Score and enrich finetune candidates with full context.
|
||
///
|
||
/// Candidates are delivered via `on_candidate` one-at-a-time as they become
|
||
/// ready: scoring happens once (one /score call), then for each candidate
|
||
/// that passes the threshold we optionally generate an alternate response
|
||
/// and then emit it. The activity status is updated during the alternate
|
||
/// phase so the UI doesn't look stuck.
|
||
///
|
||
/// Returns (count_above_threshold, max_divergence).
|
||
pub async fn score_finetune_candidates(
|
||
context: &ContextState,
|
||
count: usize,
|
||
client: &ApiClient,
|
||
min_divergence: f64,
|
||
generate_alternates: bool,
|
||
activity: &crate::agent::ActivityGuard,
|
||
mut on_candidate: impl FnMut(FinetuneCandidate),
|
||
) -> anyhow::Result<(usize, f64)> {
|
||
let scores = score_finetune(context, count, client).await?;
|
||
|
||
let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max);
|
||
|
||
let entries = context.conversation();
|
||
let trained = load_trained();
|
||
let mut candidates: Vec<FinetuneCandidate> = Vec::new();
|
||
|
||
for (entry_idx, divergence) in scores {
|
||
if divergence < min_divergence {
|
||
continue;
|
||
}
|
||
|
||
let node = &entries[entry_idx];
|
||
|
||
// Skip if already trained on.
|
||
let timestamp_ns = node_timestamp_ns(node);
|
||
if trained.contains(×tamp_ns) {
|
||
continue;
|
||
}
|
||
|
||
// Extract response text — content of the assistant turn.
|
||
let response_text = match node {
|
||
AstNode::Branch { children, .. } => render_branch_text(children),
|
||
_ => continue,
|
||
};
|
||
|
||
// Skip turns that produced nothing human-visible (e.g., a
|
||
// tool-only turn, or an interrupted generation). They'd show
|
||
// up as blank cards and we'd still burn alternate-gen on them.
|
||
if response_text.trim().is_empty() {
|
||
continue;
|
||
}
|
||
|
||
// Build the last couple of user/assistant exchanges for review.
|
||
let prior_context = render_prior_context(entries, entry_idx, 2);
|
||
|
||
// Build token IDs: context = everything before response, continuation = response.
|
||
let (context_ids, _, _) = context.wire_prompt(0..entry_idx, |_| false);
|
||
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
|
||
|
||
candidates.push(FinetuneCandidate {
|
||
entry_idx,
|
||
divergence,
|
||
response_text,
|
||
prior_context,
|
||
context_ids,
|
||
continuation_ids,
|
||
alternate_text: None,
|
||
timestamp_ns,
|
||
});
|
||
}
|
||
|
||
let total = candidates.len();
|
||
let gen_alternates = generate_alternates && total > 0;
|
||
|
||
for (i, mut candidate) in candidates.into_iter().enumerate() {
|
||
if gen_alternates {
|
||
activity.update(
|
||
format!("finetune: generating alternate {}/{}", i + 1, total)
|
||
).await;
|
||
match gen_continuation(context, candidate.entry_idx, is_memory_node, client).await {
|
||
Ok(text) => candidate.alternate_text = Some(text),
|
||
Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e),
|
||
}
|
||
}
|
||
on_candidate(candidate);
|
||
}
|
||
|
||
Ok((total, max_divergence))
|
||
}
|
||
|
||
/// Stats from a finetune scoring run. Stored on MindState for UI display.
|
||
#[derive(Clone, Debug)]
|
||
pub struct FinetuneScoringStats {
|
||
pub responses_considered: usize,
|
||
pub above_threshold: usize,
|
||
pub threshold: f64,
|
||
pub max_divergence: f64,
|
||
pub error: Option<String>,
|
||
}
|
||
|
||
/// Finetune scoring — `trigger()` aborts any in-flight run and starts
|
||
/// a fresh one, clearing the previous candidates.
|
||
pub struct FinetuneScoring {
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
task: TaskHandle,
|
||
}
|
||
|
||
impl FinetuneScoring {
|
||
pub fn new(
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
) -> Self {
|
||
Self { agent, shared, task: TaskHandle::new() }
|
||
}
|
||
}
|
||
|
||
impl MindTriggered for FinetuneScoring {
|
||
fn trigger(&self) {
|
||
self.task.trigger(run_finetune(self.agent.clone(), self.shared.clone()));
|
||
}
|
||
}
|
||
|
||
async fn run_finetune(
|
||
agent: Arc<crate::agent::Agent>,
|
||
shared: Arc<std::sync::Mutex<MindState>>,
|
||
) {
|
||
let (threshold, gen_alternates) = {
|
||
let app = crate::config::app();
|
||
(app.learn.threshold, app.learn.generate_alternates)
|
||
};
|
||
|
||
// Fresh run — clear previous candidates.
|
||
shared.lock().unwrap().finetune_candidates.clear();
|
||
agent.state.lock().await.changed.notify_one();
|
||
|
||
let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await;
|
||
|
||
let (context, client) = {
|
||
let ctx = agent.context.lock().await;
|
||
(ctx.clone(), agent.client.clone())
|
||
};
|
||
|
||
let entries = context.conversation();
|
||
let score_count = entries.len() / 2;
|
||
let range_start = entries.len() - score_count;
|
||
let responses_considered: usize = entries[range_start..].iter()
|
||
.filter(|n| matches!(n, AstNode::Branch { role: Role::Assistant, .. }))
|
||
.count();
|
||
|
||
activity.update(format!("finetune: scoring {} responses...", responses_considered)).await;
|
||
|
||
let stats = {
|
||
let shared = shared.clone();
|
||
let agent = agent.clone();
|
||
match score_finetune_candidates(
|
||
&context, score_count, &client, threshold,
|
||
gen_alternates, &activity,
|
||
move |c| {
|
||
shared.lock().unwrap().finetune_candidates.push(c);
|
||
if let Ok(st) = agent.state.try_lock() { st.changed.notify_one(); }
|
||
},
|
||
).await {
|
||
Ok((above_threshold, max_div)) => FinetuneScoringStats {
|
||
responses_considered,
|
||
above_threshold,
|
||
threshold,
|
||
max_divergence: max_div,
|
||
error: None,
|
||
},
|
||
Err(e) => FinetuneScoringStats {
|
||
responses_considered,
|
||
above_threshold: 0,
|
||
threshold,
|
||
max_divergence: 0.0,
|
||
error: Some(format!("{}", e)),
|
||
},
|
||
}
|
||
};
|
||
|
||
shared.lock().unwrap().finetune_last_run = Some(stats);
|
||
agent.state.lock().await.changed.notify_one();
|
||
}
|
||
|
||
// ── Finetune config and persistence ─────────────────────────────
|
||
|
||
use std::path::PathBuf;
|
||
use std::collections::HashSet;
|
||
|
||
const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json";
|
||
|
||
fn trained_path() -> PathBuf {
|
||
dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE)
|
||
}
|
||
|
||
/// Load set of trained response timestamps (nanos since epoch).
|
||
pub fn load_trained() -> HashSet<i64> {
|
||
let path = trained_path();
|
||
match std::fs::read_to_string(&path) {
|
||
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
|
||
Err(_) => HashSet::new(),
|
||
}
|
||
}
|
||
|
||
/// Mark a response as trained by its timestamp.
|
||
pub fn mark_trained(timestamp_ns: i64) {
|
||
let mut trained = load_trained();
|
||
trained.insert(timestamp_ns);
|
||
let path = trained_path();
|
||
if let Some(parent) = path.parent() {
|
||
let _ = std::fs::create_dir_all(parent);
|
||
}
|
||
if let Ok(json) = serde_json::to_string(&trained) {
|
||
let _ = std::fs::write(&path, json);
|
||
}
|
||
}
|
||
|
||
/// Get timestamp in nanoseconds from an AstNode.
|
||
/// i64-ns representation covers 1677..2262 via chrono; timestamps
|
||
/// outside that window would be bugs we'd want to surface, hence panic.
|
||
pub fn node_timestamp_ns(node: &AstNode) -> i64 {
|
||
let ts = match node {
|
||
AstNode::Leaf(leaf) => leaf.timestamp(),
|
||
AstNode::Branch { timestamp, .. } => *timestamp,
|
||
};
|
||
ts.timestamp_nanos_opt()
|
||
.expect("timestamp outside i64-ns representable range (1677..2262)")
|
||
}
|
||
|
||
// ── Training API ────────────────────────────────────────────────
|
||
|
||
/// Training sample for /train endpoint.
|
||
#[derive(serde::Serialize)]
|
||
struct TrainingSample {
|
||
context_ids: Vec<u32>,
|
||
continuation_ids: Vec<u32>,
|
||
}
|
||
|
||
/// Data needed to send a training sample.
|
||
pub struct TrainData {
|
||
pub context_ids: Vec<u32>,
|
||
pub continuation_ids: Vec<u32>,
|
||
pub timestamp_ns: i64,
|
||
}
|
||
|
||
/// Send training samples to the server.
|
||
///
|
||
/// Returns job_id on success, marks each sample as trained.
|
||
pub async fn send_to_train(
|
||
samples: Vec<TrainData>,
|
||
client: &ApiClient,
|
||
) -> anyhow::Result<String> {
|
||
if samples.is_empty() {
|
||
anyhow::bail!("no samples to train");
|
||
}
|
||
|
||
let api_samples: Vec<TrainingSample> = samples.iter()
|
||
.map(|s| TrainingSample {
|
||
context_ids: s.context_ids.clone(),
|
||
continuation_ids: s.continuation_ids.clone(),
|
||
})
|
||
.collect();
|
||
|
||
let body = serde_json::json!({
|
||
"training_data": {
|
||
"samples": api_samples,
|
||
}
|
||
});
|
||
|
||
let http = http_client();
|
||
let url = format!("{}/train", client.base_url());
|
||
let response = http.send_json("POST", &url, &[], &body).await?;
|
||
|
||
let status = response.status();
|
||
let result: serde_json::Value = response.json().await?;
|
||
|
||
if !status.is_success() {
|
||
let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
|
||
anyhow::bail!("train API HTTP {}: {}", status, msg);
|
||
}
|
||
|
||
// Mark all samples as trained
|
||
for s in &samples {
|
||
mark_trained(s.timestamp_ns);
|
||
}
|
||
|
||
let job_id = result.get("job_id")
|
||
.and_then(|j| j.as_str())
|
||
.unwrap_or("unknown")
|
||
.to_string();
|
||
|
||
dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id);
|
||
Ok(job_id)
|
||
}
|