consciousness/src/thought/training.rs
Kent Overstreet df9b610c7f add memory importance scoring via prompt logprobs
score_memories() drops each memory from the context one at a time,
runs prompt_logprobs against the full conversation, and builds a
divergence matrix: memories × responses.

Row sums = memory importance (for graph weight updates)
Column sums = response memory-dependence (training candidates)

Uses vLLM's prompt_logprobs to check "would the model have said
this without this memory?" — one forward pass per memory, all
responses scored at once. ~3s per memory on B200.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-02 22:13:55 -04:00

226 lines
7.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via prompt logprobs
//
// Drops each memory from the context one at a time, runs prompt_logprobs
// to see how the model's confidence in its responses changes. Produces
// a divergence matrix: memories × responses.
//
// Row sums = memory importance (for graph weight updates)
// Column sums = response memory-dependence (training candidates)
use crate::agent::api::ApiClient;
use crate::agent::types::*;
use crate::agent::ui_channel::UiSender;
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>,
}
/// Score how important each memory is to the conversation.
///
/// For each Memory entry in the context, builds a version without it
/// and checks how the model's logprobs change for assistant responses.
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
use crate::agent::ui_channel::UiMessage;
// Identify memory entries and assistant response positions
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None,
})
.collect();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memories.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(),
response_scores: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
});
}
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring {} memories × {} responses",
memories.len(), response_indices.len(),
)));
// Baseline: logprobs with all memories present
let baseline = get_response_logprobs(context, &context.entries, client).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} response tokens scored",
baseline.iter().map(|r| r.len()).sum::<usize>(),
)));
// For each memory, drop it and measure divergence
let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring memory {}/{}: {}",
mem_idx + 1, memories.len(), key,
)));
// Build entries without this memory
let filtered: Vec<ConversationEntry> = context.entries.iter().enumerate()
.filter(|(i, _)| *i != *entry_idx)
.map(|(_, e)| e.clone())
.collect();
let without = get_response_logprobs(context, &filtered, client).await?;
// Compute per-response divergence
let mut row = Vec::new();
for (resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
// Sum of logprob drops across tokens in this response
// Positive = memory helped (logprob was higher with it)
let divergence: f64 = base_lps.iter().zip(without_lps.iter())
.map(|(b, w)| b - w) // positive when baseline was more confident
.filter(|d| *d > 0.0) // only count where memory helped
.sum();
row.push(divergence);
}
matrix.push(row);
}
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let n_responses = response_indices.len();
let mut response_scores = vec![0.0; n_responses];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < n_responses {
response_scores[j] += v;
}
}
}
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] done. top memory: {:?}",
memory_weights.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(k, v)| format!("{}: {:.1}", k, v)),
)));
Ok(MemoryScore {
memory_weights,
response_scores,
matrix,
memory_keys,
})
}
/// Get logprobs for all assistant response tokens in a conversation.
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
/// containing logprobs for each token in that response.
async fn get_response_logprobs(
context: &ContextState,
entries: &[ConversationEntry],
client: &ApiClient,
) -> anyhow::Result<Vec<Vec<f64>>> {
// Assemble messages the same way the runner does
let mut msgs = Vec::new();
msgs.push(Message::system(&context.system_prompt));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(Message::user(ctx));
}
msgs.extend(entries.iter().map(|e| e.api_message().clone()));
// Call the API with prompt_logprobs
let request = serde_json::json!({
"model": client.model,
"messages": msgs,
"max_tokens": 1,
"prompt_logprobs": 1,
"stream": false,
});
let response = reqwest::Client::new()
.post(format!("{}/chat/completions", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&request)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
if let Some(err) = body.get("error") {
anyhow::bail!("API error: {}", err);
}
let prompt_logprobs = body.get("prompt_logprobs")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?;
// Find assistant response boundaries using special tokens
// Pattern: <|im_start|> assistant \n [<think>...</think>] response <|im_end|>
let mut responses: Vec<Vec<f64>> = Vec::new();
let mut in_assistant = false;
let mut in_think = false;
let mut current_response: Vec<f64> = Vec::new();
for entry in prompt_logprobs {
let Some(obj) = entry.as_object() else { continue };
let first = obj.values().next();
let Some(info) = first.and_then(|v| v.as_object()) else { continue };
let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or("");
let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0);
match token {
"<|im_start|>" => {
in_assistant = false;
in_think = false;
}
"assistant" if !in_assistant => {
in_assistant = true;
in_think = false;
current_response.clear();
}
"<think>" if in_assistant => {
in_think = true;
}
"</think>" if in_assistant => {
in_think = false;
}
"<|im_end|>" if in_assistant => {
if !current_response.is_empty() {
responses.push(std::mem::take(&mut current_response));
}
in_assistant = false;
}
"\n" if in_assistant && current_response.is_empty() => {
// Skip the newline right after "assistant"
}
_ if in_assistant && !in_think => {
current_response.push(logprob);
}
_ => {}
}
}
Ok(responses)
}