add memory importance scoring via prompt logprobs
score_memories() drops each memory from the context one at a time, runs prompt_logprobs against the full conversation, and builds a divergence matrix: memories × responses. Row sums = memory importance (for graph weight updates) Column sums = response memory-dependence (training candidates) Uses vLLM's prompt_logprobs to check "would the model have said this without this memory?" — one forward pass per memory, all responses scored at once. ~3s per memory on B200. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
dae0cc8191
commit
df9b610c7f
3 changed files with 230 additions and 0 deletions
|
|
@ -166,6 +166,9 @@ impl ApiClient {
|
||||||
Ok((build_response_message(content, tool_calls), usage))
|
Ok((build_response_message(content, tool_calls), usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn base_url(&self) -> &str { &self.base_url }
|
||||||
|
pub fn api_key(&self) -> &str { &self.api_key }
|
||||||
|
|
||||||
/// Return a label for the active backend, used in startup info.
|
/// Return a label for the active backend, used in startup info.
|
||||||
pub fn backend_label(&self) -> &str {
|
pub fn backend_label(&self) -> &str {
|
||||||
if self.base_url.contains("openrouter") {
|
if self.base_url.contains("openrouter") {
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ pub mod glob_tool;
|
||||||
pub mod grep;
|
pub mod grep;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod read;
|
pub mod read;
|
||||||
|
pub mod training;
|
||||||
pub mod write;
|
pub mod write;
|
||||||
|
|
||||||
pub use bash::ProcessTracker;
|
pub use bash::ProcessTracker;
|
||||||
|
|
|
||||||
226
src/thought/training.rs
Normal file
226
src/thought/training.rs
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
// training.rs — Memory importance scoring via prompt logprobs
|
||||||
|
//
|
||||||
|
// Drops each memory from the context one at a time, runs prompt_logprobs
|
||||||
|
// to see how the model's confidence in its responses changes. Produces
|
||||||
|
// a divergence matrix: memories × responses.
|
||||||
|
//
|
||||||
|
// Row sums = memory importance (for graph weight updates)
|
||||||
|
// Column sums = response memory-dependence (training candidates)
|
||||||
|
|
||||||
|
use crate::agent::api::ApiClient;
|
||||||
|
use crate::agent::types::*;
|
||||||
|
use crate::agent::ui_channel::UiSender;
|
||||||
|
|
||||||
|
/// Result of scoring one conversation's memory usage.
|
||||||
|
pub struct MemoryScore {
|
||||||
|
/// memory_key → importance score (sum of divergence across all responses)
|
||||||
|
pub memory_weights: Vec<(String, f64)>,
|
||||||
|
/// response_index → memory-dependence score (sum of divergence across all memories)
|
||||||
|
pub response_scores: Vec<f64>,
|
||||||
|
/// Full matrix: divergence[memory_idx][response_idx]
|
||||||
|
pub matrix: Vec<Vec<f64>>,
|
||||||
|
/// Keys of memories that were scored
|
||||||
|
pub memory_keys: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Score how important each memory is to the conversation.
|
||||||
|
///
|
||||||
|
/// For each Memory entry in the context, builds a version without it
|
||||||
|
/// and checks how the model's logprobs change for assistant responses.
|
||||||
|
pub async fn score_memories(
|
||||||
|
context: &ContextState,
|
||||||
|
client: &ApiClient,
|
||||||
|
ui_tx: &UiSender,
|
||||||
|
) -> anyhow::Result<MemoryScore> {
|
||||||
|
use crate::agent::ui_channel::UiMessage;
|
||||||
|
|
||||||
|
// Identify memory entries and assistant response positions
|
||||||
|
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
|
||||||
|
.filter_map(|(i, e)| match e {
|
||||||
|
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let response_indices: Vec<usize> = context.entries.iter().enumerate()
|
||||||
|
.filter(|(_, e)| e.message().role == Role::Assistant)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if memories.is_empty() || response_indices.is_empty() {
|
||||||
|
return Ok(MemoryScore {
|
||||||
|
memory_weights: Vec::new(),
|
||||||
|
response_scores: Vec::new(),
|
||||||
|
matrix: Vec::new(),
|
||||||
|
memory_keys: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] scoring {} memories × {} responses",
|
||||||
|
memories.len(), response_indices.len(),
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Baseline: logprobs with all memories present
|
||||||
|
let baseline = get_response_logprobs(context, &context.entries, client).await?;
|
||||||
|
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] baseline: {} response tokens scored",
|
||||||
|
baseline.iter().map(|r| r.len()).sum::<usize>(),
|
||||||
|
)));
|
||||||
|
|
||||||
|
// For each memory, drop it and measure divergence
|
||||||
|
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
||||||
|
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
|
||||||
|
|
||||||
|
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] scoring memory {}/{}: {}",
|
||||||
|
mem_idx + 1, memories.len(), key,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Build entries without this memory
|
||||||
|
let filtered: Vec<ConversationEntry> = context.entries.iter().enumerate()
|
||||||
|
.filter(|(i, _)| *i != *entry_idx)
|
||||||
|
.map(|(_, e)| e.clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let without = get_response_logprobs(context, &filtered, client).await?;
|
||||||
|
|
||||||
|
// Compute per-response divergence
|
||||||
|
let mut row = Vec::new();
|
||||||
|
for (resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
|
||||||
|
// Sum of logprob drops across tokens in this response
|
||||||
|
// Positive = memory helped (logprob was higher with it)
|
||||||
|
let divergence: f64 = base_lps.iter().zip(without_lps.iter())
|
||||||
|
.map(|(b, w)| b - w) // positive when baseline was more confident
|
||||||
|
.filter(|d| *d > 0.0) // only count where memory helped
|
||||||
|
.sum();
|
||||||
|
row.push(divergence);
|
||||||
|
}
|
||||||
|
matrix.push(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute scores
|
||||||
|
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
||||||
|
.zip(matrix.iter())
|
||||||
|
.map(|(key, row)| (key.clone(), row.iter().sum()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let n_responses = response_indices.len();
|
||||||
|
let mut response_scores = vec![0.0; n_responses];
|
||||||
|
for row in &matrix {
|
||||||
|
for (j, &v) in row.iter().enumerate() {
|
||||||
|
if j < n_responses {
|
||||||
|
response_scores[j] += v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] done. top memory: {:?}",
|
||||||
|
memory_weights.iter()
|
||||||
|
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
|
||||||
|
.map(|(k, v)| format!("{}: {:.1}", k, v)),
|
||||||
|
)));
|
||||||
|
|
||||||
|
Ok(MemoryScore {
|
||||||
|
memory_weights,
|
||||||
|
response_scores,
|
||||||
|
matrix,
|
||||||
|
memory_keys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get logprobs for all assistant response tokens in a conversation.
|
||||||
|
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
|
||||||
|
/// containing logprobs for each token in that response.
|
||||||
|
async fn get_response_logprobs(
|
||||||
|
context: &ContextState,
|
||||||
|
entries: &[ConversationEntry],
|
||||||
|
client: &ApiClient,
|
||||||
|
) -> anyhow::Result<Vec<Vec<f64>>> {
|
||||||
|
// Assemble messages the same way the runner does
|
||||||
|
let mut msgs = Vec::new();
|
||||||
|
msgs.push(Message::system(&context.system_prompt));
|
||||||
|
let ctx = context.render_context_message();
|
||||||
|
if !ctx.is_empty() {
|
||||||
|
msgs.push(Message::user(ctx));
|
||||||
|
}
|
||||||
|
msgs.extend(entries.iter().map(|e| e.api_message().clone()));
|
||||||
|
|
||||||
|
// Call the API with prompt_logprobs
|
||||||
|
let request = serde_json::json!({
|
||||||
|
"model": client.model,
|
||||||
|
"messages": msgs,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"prompt_logprobs": 1,
|
||||||
|
"stream": false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = reqwest::Client::new()
|
||||||
|
.post(format!("{}/chat/completions", client.base_url()))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", client.api_key()))
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let body: serde_json::Value = response.json().await?;
|
||||||
|
|
||||||
|
if let Some(err) = body.get("error") {
|
||||||
|
anyhow::bail!("API error: {}", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt_logprobs = body.get("prompt_logprobs")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?;
|
||||||
|
|
||||||
|
// Find assistant response boundaries using special tokens
|
||||||
|
// Pattern: <|im_start|> assistant \n [<think>...</think>] response <|im_end|>
|
||||||
|
let mut responses: Vec<Vec<f64>> = Vec::new();
|
||||||
|
let mut in_assistant = false;
|
||||||
|
let mut in_think = false;
|
||||||
|
let mut current_response: Vec<f64> = Vec::new();
|
||||||
|
|
||||||
|
for entry in prompt_logprobs {
|
||||||
|
let Some(obj) = entry.as_object() else { continue };
|
||||||
|
let first = obj.values().next();
|
||||||
|
let Some(info) = first.and_then(|v| v.as_object()) else { continue };
|
||||||
|
let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||||
|
|
||||||
|
match token {
|
||||||
|
"<|im_start|>" => {
|
||||||
|
in_assistant = false;
|
||||||
|
in_think = false;
|
||||||
|
}
|
||||||
|
"assistant" if !in_assistant => {
|
||||||
|
in_assistant = true;
|
||||||
|
in_think = false;
|
||||||
|
current_response.clear();
|
||||||
|
}
|
||||||
|
"<think>" if in_assistant => {
|
||||||
|
in_think = true;
|
||||||
|
}
|
||||||
|
"</think>" if in_assistant => {
|
||||||
|
in_think = false;
|
||||||
|
}
|
||||||
|
"<|im_end|>" if in_assistant => {
|
||||||
|
if !current_response.is_empty() {
|
||||||
|
responses.push(std::mem::take(&mut current_response));
|
||||||
|
}
|
||||||
|
in_assistant = false;
|
||||||
|
}
|
||||||
|
"\n" if in_assistant && current_response.is_empty() => {
|
||||||
|
// Skip the newline right after "assistant"
|
||||||
|
}
|
||||||
|
_ if in_assistant && !in_think => {
|
||||||
|
current_response.push(logprob);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(responses)
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue