consciousness/src/thought/training.rs
Kent Overstreet 78abf90461 fix scoring: HTTP error checking, context refresh, chunk logging
Check HTTP status from logprobs API (was silently ignoring 500s).
Call publish_context_state() after storing scores so F10 screen
updates. Add chunk size logging for OOM debugging.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-02 22:47:44 -04:00

359 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via prompt logprobs
//
// Drops each memory from the context one at a time, runs prompt_logprobs
// to see how the model's confidence in its responses changes. Produces
// a divergence matrix: memories × responses.
//
// Row sums = memory importance (for graph weight updates)
// Column sums = response memory-dependence (training candidates)
use crate::agent::api::ApiClient;
use crate::agent::types::*;
use crate::agent::ui_channel::UiSender;
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>,
/// Conversation entry indices of the assistant responses (maps response_idx → entry_idx)
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
/// Get the most important memories for a given conversation entry index.
/// Returns (memory_key, divergence_score) sorted by importance.
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation.
///
/// For each Memory entry in the context, builds a version without it
/// and checks how the model's logprobs change for assistant responses.
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
use crate::agent::ui_channel::UiMessage;
// Identify memory entries and assistant response positions
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None,
})
.collect();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memories.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(),
response_scores: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring {} memories × {} responses",
memories.len(), response_indices.len(),
)));
// Baseline: logprobs with all memories present
let baseline = get_response_logprobs(context, &context.entries, client, ui_tx).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} response tokens scored",
baseline.iter().map(|r| r.len()).sum::<usize>(),
)));
// For each memory, drop it and measure divergence
let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Activity(format!(
"scoring {}/{}...", mem_idx + 1, memories.len(),
)));
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring memory {}/{}: {}",
mem_idx + 1, memories.len(), key,
)));
// Build entries without this memory
let filtered: Vec<ConversationEntry> = context.entries.iter().enumerate()
.filter(|(i, _)| *i != *entry_idx)
.map(|(_, e)| e.clone())
.collect();
let without = get_response_logprobs(context, &filtered, client, ui_tx).await?;
// Compute per-response divergence
let mut row = Vec::new();
for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
// Sum of logprob drops across tokens in this response
// Positive = memory helped (logprob was higher with it)
let divergence: f64 = base_lps.iter().zip(without_lps.iter())
.map(|(b, w)| b - w) // positive when baseline was more confident
.filter(|d| *d > 0.0) // only count where memory helped
.sum();
row.push(divergence);
}
matrix.push(row);
}
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let n_responses = response_indices.len();
let mut response_scores = vec![0.0; n_responses];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < n_responses {
response_scores[j] += v;
}
}
}
let _ = ui_tx.send(UiMessage::Activity(String::new()));
// Log summary per memory
for (key, score) in &memory_weights {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} → importance {:.1}", key, score,
)));
}
// Log per-response breakdown for the most important memories
let mut sorted_mems: Vec<(usize, &str, f64)> = memory_keys.iter().enumerate()
.map(|(i, k)| (i, k.as_str(), memory_weights[i].1))
.collect();
sorted_mems.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
for (mem_i, key, total) in sorted_mems.iter().take(5) {
if *total <= 0.0 { continue; }
let row = &matrix[*mem_i];
let top_responses: Vec<String> = row.iter().enumerate()
.filter(|(_, v)| **v > 0.1)
.map(|(j, v)| format!("resp[{}]={:.1}", j, v))
.collect();
if !top_responses.is_empty() {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} ({:.1}): {}", key, total, top_responses.join(", "),
)));
}
}
Ok(MemoryScore {
memory_weights,
response_scores,
matrix,
memory_keys,
response_entry_indices: response_indices,
})
}
/// Rough token estimate: ~4 chars per token.
const CHARS_PER_TOKEN: usize = 4;
/// Get logprobs for all assistant response tokens in a conversation.
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
/// containing logprobs for each token in that response.
///
/// Chunks the conversation into ~50K token segments (rounded to
/// assistant message boundaries) to avoid OOM from the logprobs
/// tensor allocation.
async fn get_response_logprobs(
context: &ContextState,
entries: &[ConversationEntry],
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<Vec<Vec<f64>>> {
// Build the fixed prefix (system prompt + personality)
let mut prefix = Vec::new();
prefix.push(Message::system(&context.system_prompt));
let ctx = context.render_context_message();
if !ctx.is_empty() {
prefix.push(Message::user(ctx));
}
let prefix_chars: usize = prefix.iter()
.map(|m| m.content_text().len())
.sum();
// Split entries into chunks that fit within the token budget,
// each ending at an assistant message boundary.
let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN;
let budget = max_chunk_chars.saturating_sub(prefix_chars);
let chunks = chunk_entries(entries, budget);
let mut all_responses: Vec<Vec<f64>> = Vec::new();
use crate::agent::ui_channel::UiMessage;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} chunks, prefix={}K chars, budget={}K chars",
chunks.len(), prefix_chars / 1024, budget / 1024,
)));
for (chunk_idx, chunk) in chunks.iter().enumerate() {
let chunk_chars: usize = chunk.iter()
.map(|e| e.message().content_text().len())
.sum();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] chunk {}/{}: {} entries, {}K chars",
chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024,
)));
}
for chunk in &chunks {
let mut msgs = prefix.clone();
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
let result = call_prompt_logprobs(&msgs, client).await?;
all_responses.extend(result);
}
Ok(all_responses)
}
/// Split entries into chunks of approximately `budget_chars` each,
/// ending at assistant message boundaries.
fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec<Vec<ConversationEntry>> {
let mut chunks = Vec::new();
let mut current = Vec::new();
let mut current_chars = 0;
for entry in entries {
let entry_chars = entry.message().content_text().len();
current_chars += entry_chars;
current.push(entry.clone());
// If over budget and we just added an assistant message, cut here
if current_chars >= budget_chars && entry.message().role == Role::Assistant {
chunks.push(std::mem::take(&mut current));
current_chars = 0;
}
}
if !current.is_empty() {
chunks.push(current);
}
// If everything fit in one chunk, just return it
if chunks.is_empty() {
chunks.push(entries.to_vec());
}
chunks
}
/// Make a single prompt_logprobs API call and extract response logprobs.
async fn call_prompt_logprobs(
msgs: &[Message],
client: &ApiClient,
) -> anyhow::Result<Vec<Vec<f64>>> {
let request = serde_json::json!({
"model": client.model,
"messages": msgs,
"max_tokens": 1,
"prompt_logprobs": 1,
"stream": false,
});
let response = reqwest::Client::new()
.post(format!("{}/chat/completions", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&request)
.send()
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("unknown error");
anyhow::bail!("HTTP {} from logprobs API: {}", status, msg);
}
let prompt_logprobs = body.get("prompt_logprobs")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?;
// Find assistant response boundaries using special tokens
// Pattern: <|im_start|> assistant \n [<think>...</think>] response <|im_end|>
let mut responses: Vec<Vec<f64>> = Vec::new();
let mut in_assistant = false;
let mut in_think = false;
let mut current_response: Vec<f64> = Vec::new();
for entry in prompt_logprobs {
let Some(obj) = entry.as_object() else { continue };
let first = obj.values().next();
let Some(info) = first.and_then(|v| v.as_object()) else { continue };
let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or("");
let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0);
match token {
"<|im_start|>" => {
in_assistant = false;
in_think = false;
}
"assistant" if !in_assistant => {
in_assistant = true;
in_think = false;
current_response.clear();
}
"<think>" if in_assistant => {
in_think = true;
}
"</think>" if in_assistant => {
in_think = false;
}
"<|im_end|>" if in_assistant => {
if !current_response.is_empty() {
responses.push(std::mem::take(&mut current_response));
}
in_assistant = false;
}
"\n" if in_assistant && current_response.is_empty() => {
// Skip the newline right after "assistant"
}
_ if in_assistant && !in_think => {
current_response.push(logprob);
}
_ => {}
}
}
Ok(responses)
}