consciousness/src/subconscious/learn.rs
Kent Overstreet 4225294d16 replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
2026-04-25 15:35:14 -04:00

874 lines
31 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// learn.rs — Memory importance scoring over the salience gRPC protocol.
//
// Three scoring modes, all built on call_score():
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 sessions/calls.
//
// score_memory() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 calls.
//
// Each call opens an ephemeral gRPC session (reusing the shared
// tonic Channel on `ApiClient`), pushes the prompt through as
// interleaved tokens + AppendImage calls, runs Generate with
// max_tokens=0 + logprobs_ranges over the scored positions, collects
// each Token event's sampled_logprob, then drops the SessionHandle —
// which triggers a best-effort CloseSession over the shared channel.
use std::sync::Arc;
use crate::agent::api::ApiClient;
use crate::agent::api::salience::{SessionHandle, pb};
use crate::agent::context::{
Ast, AstNode, ContextState, Role, WireChunk, WireImage,
is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context,
};
use crate::agent::tokenizer;
use crate::mind::{MindState, MindTriggered, TaskHandle};
use crate::subconscious::generate::gen_continuation;
// ── Score API ───────────────────────────────────────────────────
#[derive(Debug, Clone)]
struct ScoreResult {
total_logprob: f64,
}
/// Find each <|vision_start|>...<|vision_end|> run in the flat prompt
/// and pair it with the matching entry in `images`. Returns a list
/// of `ImageAttachment` with absolute pad-range positions, ready
/// to drop into `GenerateRequest.images`.
fn pair_images_to_ranges(
prompt: &[u32],
images: &[WireImage],
) -> Vec<pb::ImageAttachment> {
let mut out: Vec<pb::ImageAttachment> = Vec::new();
let mut cur = 0;
let mut img_idx = 0;
while cur < prompt.len() {
if prompt[cur] == tokenizer::VISION_START {
let end_rel = prompt[cur..].iter()
.position(|&t| t == tokenizer::VISION_END)
.unwrap_or_else(|| panic!(
"unmatched VISION_START at position {} in prompt", cur));
let end = cur + end_rel + 1;
let img = images.get(img_idx)
.unwrap_or_else(|| panic!(
"image index {} out of range for {} images", img_idx, images.len()));
out.push(pb::ImageAttachment {
bytes: img.bytes.clone(),
mime: img.mime.clone(),
pad_range_start: cur as u32,
pad_range_end: end as u32,
});
img_idx += 1;
cur = end;
} else {
cur += 1;
}
}
out
}
async fn call_score(
client: &ApiClient,
prompt: &[u32],
images: &[WireImage],
ranges: &[(usize, usize)],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
use futures::StreamExt;
// Nothing to score — skip the round-trip.
if ranges.is_empty() {
return Ok(Vec::new());
}
let images_pb = pair_images_to_ranges(prompt, images);
let mut handle = SessionHandle::open(client).await?;
// Final Generate: max_tokens=0 so the server runs prefill of the
// full prompt and emits Token events for each position covered
// by logprobs_ranges, then Done. logprob_top_k=0 means "just
// the sampled (prompt) token's logprob" — no top-k alternatives,
// which is all call_score historically needed. Images attach
// inline via `images`; the prompt already contains their pre-
// expanded vision blocks at the declared ranges.
let logprobs_ranges: Vec<pb::PositionRange> = ranges.iter()
.map(|(s, e)| pb::PositionRange { start: *s as u32, end: *e as u32 })
.collect();
let req = pb::GenerateRequest {
session_id: handle.session_id.clone(),
append_tokens: prompt.to_vec(),
offset: handle.committed_len,
truncating: false,
max_tokens: 0,
logprobs_ranges,
logprob_top_k: 0,
readout_ranges: Vec::new(),
temperature: 0.0,
top_p: 0.0,
top_k: 0,
stop_token_ids: Vec::new(),
priority: priority.unwrap_or(0),
images: images_pb,
};
let mut stream = handle.generate(req).await?;
let mut totals = vec![0.0f64; ranges.len()];
while let Some(event) = stream.next().await {
let event = event
.map_err(|s| anyhow::anyhow!("score Generate stream: {}", s))?;
let Some(inner) = event.event else { continue };
match inner {
pb::generate_event::Event::Token(t) => {
if !t.has_sampled_logprob { continue; }
let pos = t.position as usize;
for (i, (start, end)) in ranges.iter().enumerate() {
if pos >= *start && pos < *end {
totals[i] += t.sampled_logprob as f64;
}
}
}
pb::generate_event::Event::Done(_) => break,
}
}
Ok(totals.into_iter()
.map(|total_logprob| ScoreResult { total_logprob })
.collect())
}
/// Compute per-position logprob divergence: how much worse the model
/// scores each response without something vs with it.
fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
baseline.iter().enumerate()
.map(|(i, base)| {
let without_lp = without.get(i).map(|s| s.total_logprob).unwrap_or(base.total_logprob);
(base.total_logprob - without_lp).max(0.0)
})
.collect()
}
/// Score two message sets and return total divergence.
async fn score_divergence<F>(
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
skip: F,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)>
where F: FnMut(&AstNode) -> bool,
{
let (baseline_tokens, baseline_images, baseline_ranges) =
context.wire_prompt(range.clone(), |_| false);
let (without_tokens, without_images, without_ranges) =
context.wire_prompt(range, skip);
let baseline = call_score(client, &baseline_tokens, &baseline_images,
&baseline_ranges, priority).await?;
let without = call_score(client, &without_tokens, &without_images,
&without_ranges, priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
// ── Full matrix scoring (debug screen) ──────────────────────────
/// Score how important each memory is to the conversation (full matrix).
pub async fn score_memories(
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
) -> anyhow::Result<()> {
// Collect memory keys and response indices under a brief lock
let (memory_keys, response_indices) = {
let ctx = agent.context.lock().await;
// Include identity nodes and conversation memories
let mut keys: Vec<String> = ctx.identity().iter()
.chain(ctx.conversation().iter())
.filter_map(|node| memory_key(node).map(String::from))
.collect();
keys.dedup();
let resp: Vec<usize> = ctx.conversation().iter().enumerate()
.filter(|(_, node)| is_assistant(node))
.map(|(i, _)| i)
.collect();
(keys, resp)
};
if memory_keys.is_empty() || response_indices.is_empty() {
return Ok(());
}
let total = memory_keys.len();
dbglog!("[scoring-full] starting: {} memories × {} responses",
total, response_indices.len());
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let (baseline_tokens, baseline_images, baseline_ranges) = {
let ctx = agent.context.lock().await;
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
};
let baseline = call_score(client, &baseline_tokens, &baseline_images,
&baseline_ranges, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let (tokens, images, ranges) = {
let ctx = agent.context.lock().await;
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
};
let row = match call_score(client, &tokens, &images, &ranges, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!("[scoring-full] {}/{}: {} max_div={:.3}",
mem_idx + 1, total, key, max_div);
divs
}
Err(e) => {
dbglog!("[scoring-full] {}/{}: {} FAILED: {:#}",
mem_idx + 1, total, key, e);
vec![0.0; baseline.len()]
}
};
// Write this memory's scores to the live AST nodes via the
// focused setter — keeps the AST mutation surface narrow.
{
let mut ctx = agent.context.lock().await;
let mut set_count = 0;
for (resp_idx, &idx) in response_indices.iter().enumerate() {
let Some(&score) = row.get(resp_idx) else { continue };
let normalized = if score > 0.01 { Some(score) } else { None };
ctx.set_branch_memory_score(
crate::agent::context::Section::Conversation,
idx,
&key,
normalized,
);
if normalized.is_some() {
set_count += 1;
}
}
dbglog!("[scoring-full] {}/{} AST: set={}", mem_idx + 1, total, set_count);
}
agent.state.lock().await.changed.notify_one();
}
Ok(())
}
/// Find the entry index after `start` that contains the Nth assistant response.
/// Returns (end_index, true) if N responses were found, (entries.len(), false) if not.
fn nth_response_end(entries: &[AstNode], start: usize, n: usize) -> (usize, bool) {
let mut count = 0;
for i in start..entries.len() {
if is_assistant(&entries[i]) {
count += 1;
if count >= n { return (i + 1, true); }
}
}
(entries.len(), false)
}
// ── Single memory scoring ───────────────────────────────────────
/// Score how important a single memory is to the conversation.
///
/// Scores the 50 messages after the memory was surfaced — the window
/// where it could have influenced responses. Returns the sum of
/// divergence, or 0.0 if the memory isn't in the conversation.
pub async fn score_memory(
context: &ContextState,
key: &str,
client: &ApiClient,
) -> anyhow::Result<f64> {
const RESPONSE_WINDOW: usize = 50;
let entries = context.conversation();
let first_pos = match entries.iter().position(|node| memory_key(node) == Some(key)) {
Some(p) => p,
None => return Ok(0.0),
};
let (end, _) = nth_response_end(entries, first_pos, RESPONSE_WINDOW);
let range = first_pos..end;
if !entries[range.clone()].iter().any(|node| is_assistant(node)) {
return Ok(0.0);
}
let (divs, _) = score_divergence(client, context, range,
|n| memory_key(n) == Some(key), Some(5)).await?;
Ok(divs.iter().sum())
}
// ── Background memory scoring ───────────────────────────────────
/// Score memories in the conversation that are due for re-scoring.
///
/// Checks the graph for each memory's last_scored timestamp. Scores
/// nodes that haven't been scored within `max_age_secs`, oldest first.
/// Updates the graph weight (EWMA) and last_scored after each.
///
/// Returns the number of nodes scored and their (key, score) pairs.
pub async fn score_memories_incremental<F, Fut>(
context: &ContextState,
max_age_secs: i64,
response_window: usize,
client: &ApiClient,
agent: &std::sync::Arc<crate::agent::Agent>,
mut on_score: F,
) -> anyhow::Result<usize>
where
F: FnMut(String, f64) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let now = chrono::Utc::now().timestamp();
// Collect unique memory keys with their first position
let mut seen = std::collections::HashSet::new();
let mut candidates: Vec<(usize, String, i64)> = Vec::new(); // (pos, key, last_scored)
let store_arc = crate::hippocampus::access_local()?;
{
let store = &*store_arc;
// Identity nodes always score at position 0; conversation nodes at their index
let identity_nodes = context.identity().iter().map(|n| (0, n));
let conv_nodes = context.conversation().iter().enumerate();
for (pos, node) in identity_nodes.chain(conv_nodes) {
if let Some(key) = memory_key(node) {
if !seen.insert(key.to_owned()) { continue; }
let last_scored = store.get_node(key)
.ok()
.flatten()
.map(|n| n.last_scored)
.unwrap_or(0);
if now - last_scored >= max_age_secs {
candidates.push((pos, key.to_owned(), last_scored));
}
}
}
}
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
let mut scored = 0;
let entries = context.conversation();
let total_tokens: usize = entries.iter().map(|n| n.tokens()).sum();
let token_cutoff = total_tokens * 60 / 100;
// Precompute cumulative token position for each entry
let mut cumulative: Vec<usize> = Vec::with_capacity(entries.len());
let mut running = 0;
for e in entries {
running += e.tokens();
cumulative.push(running);
}
let total = candidates.len();
dbglog!("[scoring] total_tokens={}, cutoff={}, {} candidates", total_tokens, token_cutoff, total);
let activity = crate::agent::start_activity(agent, format!("scoring: 0/{}", total)).await;
for (pos, key, _) in &candidates {
// Only score memories in the first 60% of the conversation by tokens —
// recent memories don't have enough responses to evaluate yet.
let cum = cumulative.get(*pos).copied().unwrap_or(total_tokens);
if cum > token_cutoff {
dbglog!("[scoring] skip {} (tokens {}/{} past cutoff)", key, cum, token_cutoff);
continue;
}
let (end, _) = nth_response_end(context.conversation(), *pos, response_window);
let range = *pos..end;
if !context.conversation()[range.clone()].iter().any(|node| is_assistant(node)) {
dbglog!("[scoring] skip {} (no assistant response in range {}..{})", key, pos, end);
continue;
}
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
match score_divergence(client, context, range,
|n| memory_key(n) == Some(key), Some(5)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
dbglog!(
"[scoring] {} max:{:.3} ({} responses)", key, max_div, n_responses,
);
on_score(key.clone(), max_div).await;
scored += 1;
}
Err(e) => {
dbglog!(
"[scoring] {} FAILED: {:#}", key, e,
);
}
}
}
Ok(scored)
}
/// Memory scoring — two modes sharing an in-flight handle (only one
/// runs at a time): `trigger()` for incremental, `trigger_full()` for
/// the N×M debug matrix.
pub struct MemoryScoring {
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
scores_path: std::path::PathBuf,
task: TaskHandle,
}
impl MemoryScoring {
pub fn new(
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
scores_path: std::path::PathBuf,
) -> Self {
Self { agent, shared, scores_path, task: TaskHandle::new() }
}
pub fn trigger_full(&self) {
self.task.trigger_if_idle(run_full(self.agent.clone(), self.shared.clone()));
}
}
impl MindTriggered for MemoryScoring {
fn trigger(&self) {
self.task.trigger_if_idle(run_incremental(
self.agent.clone(), self.shared.clone(), self.scores_path.clone(),
));
}
}
async fn run_incremental(
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
scores_path: std::path::PathBuf,
) {
shared.lock().unwrap().scoring_in_flight = true;
agent.state.lock().await.changed.notify_one();
let cfg = crate::config::get();
let max_age = cfg.scoring_interval_secs;
let response_window = cfg.scoring_response_window;
let (context, client) = {
let ctx = agent.context.lock().await.clone();
(ctx, agent.client.clone())
};
let _result = score_memories_incremental(
&context, max_age as i64, response_window, &client, &agent,
|key: String, score: f64| {
let agent = agent.clone();
let path = scores_path.clone();
async move {
let scores_snapshot = {
let mut ctx = agent.context.lock().await;
let found = crate::mind::find_memory_by_key(&ctx, &key);
match found {
Some((section, i)) => {
ctx.set_score(section, i, Some(score));
dbglog!("[scoring] persisted {} → {:.3} ({:?}[{}])",
key, score, section, i);
}
None => {
dbglog!(
"[scoring] DROP {}: find_memory_by_key None (id={}, cv={})",
key, ctx.identity().len(), ctx.conversation().len()
);
}
}
let snapshot = crate::mind::collect_memory_scores(&ctx);
drop(ctx);
agent.state.lock().await.changed.notify_one();
snapshot
};
crate::mind::save_memory_scores(&scores_snapshot, &path);
}
},
).await;
shared.lock().unwrap().scoring_in_flight = false;
agent.state.lock().await.changed.notify_one();
}
async fn run_full(
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
) {
shared.lock().unwrap().scoring_in_flight = true;
agent.state.lock().await.changed.notify_one();
let client = agent.client.clone();
match score_memories(&client, &agent).await {
Ok(()) => {},
Err(e) => { dbglog!("[scoring-full] FAILED: {:#}", e); }
}
shared.lock().unwrap().scoring_in_flight = false;
agent.state.lock().await.changed.notify_one();
}
// ── Fine-tuning scoring ─────────────────────────────────────────
/// Score which recent responses are candidates for fine-tuning.
///
/// Removes all memories and scores the most recent `count` messages.
/// Responses with high divergence depend on memories the model hasn't
/// internalized — these are fine-tuning candidates.
///
/// Returns (entry_index, divergence) pairs, sorted by divergence descending.
pub async fn score_finetune(
context: &ContextState,
count: usize,
client: &ApiClient,
) -> anyhow::Result<Vec<(usize, f64)>> {
let entries = context.conversation();
let range = entries.len().saturating_sub(count)..entries.len();
let response_positions: Vec<usize> = range.clone()
.filter(|&i| is_assistant(&entries[i]))
.collect();
if response_positions.is_empty() {
return Ok(Vec::new());
}
let (divs, _) = score_divergence(client, context, range, is_memory_node, Some(5)).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
.map(|(i, &entry_idx)| (entry_idx, divs.get(i).copied().unwrap_or(0.0)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
/// Enriched finetune candidate with context for review.
#[derive(Clone, Debug)]
pub struct FinetuneCandidate {
pub entry_idx: usize,
pub divergence: f64,
pub response_text: String,
/// Last couple of user/assistant messages before this response,
/// already rendered with role markers, for F6 display context.
pub prior_context: String,
/// Token IDs for context (everything before the response).
pub context_ids: Vec<u32>,
/// Token IDs for the response (what we're training on).
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
/// Score and enrich finetune candidates with full context.
///
/// Candidates are delivered via `on_candidate` one-at-a-time as they become
/// ready: scoring happens once (one /score call), then for each candidate
/// that passes the threshold we optionally generate an alternate response
/// and then emit it. The activity status is updated during the alternate
/// phase so the UI doesn't look stuck.
///
/// Returns (count_above_threshold, max_divergence).
pub async fn score_finetune_candidates(
context: &ContextState,
count: usize,
client: &ApiClient,
min_divergence: f64,
generate_alternates: bool,
activity: &crate::agent::ActivityGuard,
mut on_candidate: impl FnMut(FinetuneCandidate),
) -> anyhow::Result<(usize, f64)> {
let scores = score_finetune(context, count, client).await?;
let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max);
let entries = context.conversation();
let trained = load_trained();
let mut candidates: Vec<FinetuneCandidate> = Vec::new();
for (entry_idx, divergence) in scores {
if divergence < min_divergence {
continue;
}
let node = &entries[entry_idx];
// Skip if already trained on.
let timestamp_ns = node_timestamp_ns(node);
if trained.contains(&timestamp_ns) {
continue;
}
// Extract response text — content of the assistant turn.
let response_text = match node {
AstNode::Branch { children, .. } => render_branch_text(children),
_ => continue,
};
// Skip turns that produced nothing human-visible (e.g., a
// tool-only turn, or an interrupted generation). They'd show
// up as blank cards and we'd still burn alternate-gen on them.
if response_text.trim().is_empty() {
continue;
}
// Build the last couple of user/assistant exchanges for review.
let prior_context = render_prior_context(entries, entry_idx, 2);
// Build token IDs: context = everything before response, continuation = response.
let (context_ids, _, _) = context.wire_prompt(0..entry_idx, |_| false);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate {
entry_idx,
divergence,
response_text,
prior_context,
context_ids,
continuation_ids,
alternate_text: None,
timestamp_ns,
});
}
let total = candidates.len();
let gen_alternates = generate_alternates && total > 0;
for (i, mut candidate) in candidates.into_iter().enumerate() {
if gen_alternates {
activity.update(
format!("finetune: generating alternate {}/{}", i + 1, total)
).await;
match gen_continuation(context, candidate.entry_idx, is_memory_node, client).await {
Ok(text) => candidate.alternate_text = Some(text),
Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e),
}
}
on_candidate(candidate);
}
Ok((total, max_divergence))
}
/// Stats from a finetune scoring run. Stored on MindState for UI display.
#[derive(Clone, Debug)]
pub struct FinetuneScoringStats {
pub responses_considered: usize,
pub above_threshold: usize,
pub threshold: f64,
pub max_divergence: f64,
pub error: Option<String>,
}
/// Finetune scoring — `trigger()` aborts any in-flight run and starts
/// a fresh one, clearing the previous candidates.
pub struct FinetuneScoring {
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
task: TaskHandle,
}
impl FinetuneScoring {
pub fn new(
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
) -> Self {
Self { agent, shared, task: TaskHandle::new() }
}
}
impl MindTriggered for FinetuneScoring {
fn trigger(&self) {
self.task.trigger(run_finetune(self.agent.clone(), self.shared.clone()));
}
}
async fn run_finetune(
agent: Arc<crate::agent::Agent>,
shared: Arc<std::sync::Mutex<MindState>>,
) {
let (threshold, gen_alternates) = {
let app = crate::config::app();
(app.learn.threshold, app.learn.generate_alternates)
};
// Fresh run — clear previous candidates.
shared.lock().unwrap().finetune_candidates.clear();
agent.state.lock().await.changed.notify_one();
let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await;
let (context, client) = {
let ctx = agent.context.lock().await;
(ctx.clone(), agent.client.clone())
};
let entries = context.conversation();
let score_count = entries.len() / 2;
let range_start = entries.len() - score_count;
let responses_considered: usize = entries[range_start..].iter()
.filter(|n| matches!(n, AstNode::Branch { role: Role::Assistant, .. }))
.count();
activity.update(format!("finetune: scoring {} responses...", responses_considered)).await;
let stats = {
let shared = shared.clone();
let agent = agent.clone();
match score_finetune_candidates(
&context, score_count, &client, threshold,
gen_alternates, &activity,
move |c| {
shared.lock().unwrap().finetune_candidates.push(c);
{ let st = agent.state.lock_blocking(); st.changed.notify_one(); }
},
).await {
Ok((above_threshold, max_div)) => FinetuneScoringStats {
responses_considered,
above_threshold,
threshold,
max_divergence: max_div,
error: None,
},
Err(e) => FinetuneScoringStats {
responses_considered,
above_threshold: 0,
threshold,
max_divergence: 0.0,
error: Some(format!("{}", e)),
},
}
};
shared.lock().unwrap().finetune_last_run = Some(stats);
agent.state.lock().await.changed.notify_one();
}
// ── Finetune config and persistence ─────────────────────────────
use std::path::PathBuf;
use std::collections::HashSet;
const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json";
fn trained_path() -> PathBuf {
dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE)
}
/// Load set of trained response timestamps (nanos since epoch).
pub fn load_trained() -> HashSet<i64> {
let path = trained_path();
match std::fs::read_to_string(&path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
Err(_) => HashSet::new(),
}
}
/// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ns: i64) {
let mut trained = load_trained();
trained.insert(timestamp_ns);
let path = trained_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(json) = serde_json::to_string(&trained) {
let _ = std::fs::write(&path, json);
}
}
/// Get timestamp in nanoseconds from an AstNode.
/// i64-ns representation covers 1677..2262 via chrono; timestamps
/// outside that window would be bugs we'd want to surface, hence panic.
pub fn node_timestamp_ns(node: &AstNode) -> i64 {
let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { timestamp, .. } => *timestamp,
};
ts.timestamp_nanos_opt()
.expect("timestamp outside i64-ns representable range (1677..2262)")
}
// ── Training API ────────────────────────────────────────────────
/// Training sample for /train endpoint.
#[derive(serde::Serialize)]
struct TrainingSample {
context_ids: Vec<u32>,
continuation_ids: Vec<u32>,
}
/// Data needed to send a training sample.
pub struct TrainData {
pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>,
pub timestamp_ns: i64,
}
/// Send training samples to the server.
///
/// Returns job_id on success, marks each sample as trained.
pub async fn send_to_train(
samples: Vec<TrainData>,
client: &ApiClient,
) -> anyhow::Result<String> {
if samples.is_empty() {
anyhow::bail!("no samples to train");
}
let api_samples: Vec<TrainingSample> = samples.iter()
.map(|s| TrainingSample {
context_ids: s.context_ids.clone(),
continuation_ids: s.continuation_ids.clone(),
})
.collect();
let body = serde_json::json!({
"training_data": {
"samples": api_samples,
}
});
let url = format!("{}/train", client.base_url());
let http = crate::agent::api::http::HttpClient::builder()
.timeout(std::time::Duration::from_secs(300))
.build();
let response = http.send_json("POST", &url, &[], &body).await?;
let status = response.status();
let result: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("train API HTTP {}: {}", status, msg);
}
// Mark all samples as trained
for s in &samples {
mark_trained(s.timestamp_ns);
}
let job_id = result.get("job_id")
.and_then(|j| j.as_str())
.unwrap_or("unknown")
.to_string();
dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id);
Ok(job_id)
}