wire up /score command and debug screen for memory importance

/score snapshots the context and client, releases the agent lock,
runs scoring in background. Only one score task at a time
(scoring_in_flight flag). Results stored on Agent and shown on
the F10 context debug screen with importance scores per memory.

ApiClient derives Clone. ContextState derives Clone.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-02 22:18:50 -04:00
parent df9b610c7f
commit c01d4a5b08
7 changed files with 64 additions and 4 deletions

View file

@ -57,6 +57,7 @@ pub enum StreamEvent {
Error(String), Error(String),
} }
#[derive(Clone)]
pub struct ApiClient { pub struct ApiClient {
client: Client, client: Client,
api_key: String, api_key: String,

View file

@ -76,6 +76,10 @@ pub struct Agent {
session_id: String, session_id: String,
/// Agent orchestration state (surface-observe, journal, reflect). /// Agent orchestration state (surface-observe, journal, reflect).
pub agent_cycles: crate::subconscious::subconscious::AgentCycleState, pub agent_cycles: crate::subconscious::subconscious::AgentCycleState,
/// Latest memory importance scores from training scorer.
pub memory_scores: Option<crate::thought::training::MemoryScore>,
/// Whether a /score task is currently running.
pub scoring_in_flight: bool,
} }
fn render_journal(entries: &[journal::JournalEntry]) -> String { fn render_journal(entries: &[journal::JournalEntry]) -> String {
@ -125,6 +129,8 @@ impl Agent {
prompt_file, prompt_file,
session_id, session_id,
agent_cycles, agent_cycles,
memory_scores: None,
scoring_in_flight: false,
}; };
agent.load_startup_journal(); agent.load_startup_journal();
@ -670,8 +676,16 @@ impl Agent {
_ => unreachable!(), _ => unreachable!(),
}; };
let text = entry.message().content_text(); let text = entry.message().content_text();
let score = self.memory_scores.as_ref()
.and_then(|s| s.memory_weights.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| *v));
let label = match score {
Some(v) => format!("{} (importance: {:.1})", key, v),
None => key.to_string(),
};
ContextSection { ContextSection {
name: key.to_string(), name: label,
tokens: count(text), tokens: count(text),
content: String::new(), content: String::new(),
children: Vec::new(), children: Vec::new(),
@ -970,6 +984,10 @@ impl Agent {
} }
/// Mutable access to conversation entries (for /retry). /// Mutable access to conversation entries (for /retry).
pub fn client_clone(&self) -> ApiClient {
self.client.clone()
}
pub fn entries_mut(&mut self) -> &mut Vec<ConversationEntry> { pub fn entries_mut(&mut self) -> &mut Vec<ConversationEntry> {
&mut self.context.entries &mut self.context.entries
} }

View file

@ -397,6 +397,7 @@ impl ConversationEntry {
} }
} }
#[derive(Clone)]
pub struct ContextState { pub struct ContextState {
pub system_prompt: String, pub system_prompt: String,
pub personality: Vec<(String, String)>, pub personality: Vec<(String, String)>,

View file

@ -354,6 +354,7 @@ impl Session {
("/save", "Save session to disk"), ("/save", "Save session to disk"),
("/retry", "Re-run last turn"), ("/retry", "Re-run last turn"),
("/model", "Show/switch model (/model <name>)"), ("/model", "Show/switch model (/model <name>)"),
("/score", "Score memory importance"),
("/dmn", "Show DMN state"), ("/dmn", "Show DMN state"),
("/sleep", "Put DMN to sleep"), ("/sleep", "Put DMN to sleep"),
("/wake", "Wake DMN to foraging"), ("/wake", "Wake DMN to foraging"),
@ -422,6 +423,46 @@ impl Session {
} }
Command::Handled Command::Handled
} }
"/score" => {
{
let agent = self.agent.lock().await;
if agent.scoring_in_flight {
let _ = self.ui_tx.send(UiMessage::Info(
"(scoring already in progress)".into()
));
return Command::Handled;
}
}
self.agent.lock().await.scoring_in_flight = true;
let agent = self.agent.clone();
let ui_tx = self.ui_tx.clone();
tokio::spawn(async move {
let (context, client) = {
let agent = agent.lock().await;
(agent.context.clone(), agent.client_clone())
};
let result = poc_memory::thought::training::score_memories(
&context, &client, &ui_tx,
).await;
let mut agent = agent.lock().await;
agent.scoring_in_flight = false;
match result {
Ok(scores) => {
let _ = ui_tx.send(UiMessage::Info(format!(
"[memory scoring complete: {} memories scored]",
scores.memory_keys.len(),
)));
agent.memory_scores = Some(scores);
}
Err(e) => {
let _ = ui_tx.send(UiMessage::Info(format!(
"[memory scoring failed: {:#}]", e,
)));
}
}
});
Command::Handled
}
"/dmn" => { "/dmn" => {
let _ = self let _ = self
.ui_tx .ui_tx

View file

@ -13,7 +13,6 @@
// Phase 2 will inline job logic; Phase 3 integrates into poc-agent. // Phase 2 will inline job logic; Phase 3 integrates into poc-agent.
use jobkit::{Choir, ExecutionContext, TaskError, TaskInfo, TaskStatus}; use jobkit::{Choir, ExecutionContext, TaskError, TaskInfo, TaskStatus};
use std::collections::{HashMap, HashSet};
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};

View file

@ -102,7 +102,7 @@ pub fn run_and_apply_excluded(
log: &(dyn Fn(&str) + Sync), log: &(dyn Fn(&str) + Sync),
exclude: &std::collections::HashSet<String>, exclude: &std::collections::HashSet<String>,
) -> Result<(), String> { ) -> Result<(), String> {
let result = run_one_agent_excluded(store, agent_name, batch_size, llm_tag, log, exclude)?; let _result = run_one_agent_excluded(store, agent_name, batch_size, llm_tag, log, exclude)?;
Ok(()) Ok(())
} }

View file

@ -89,7 +89,7 @@ pub async fn score_memories(
// Compute per-response divergence // Compute per-response divergence
let mut row = Vec::new(); let mut row = Vec::new();
for (resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() { for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
// Sum of logprob drops across tokens in this response // Sum of logprob drops across tokens in this response
// Positive = memory helped (logprob was higher with it) // Positive = memory helped (logprob was higher with it)
let divergence: f64 = base_lps.iter().zip(without_lps.iter()) let divergence: f64 = base_lps.iter().zip(without_lps.iter())