consciousness/src/thought/training.rs

368 lines
13 KiB
Rust
Raw Normal View History

// training.rs — Memory importance scoring via prompt logprobs
//
// Drops each memory from the context one at a time, runs prompt_logprobs
// to see how the model's confidence in its responses changes. Produces
// a divergence matrix: memories × responses.
//
// Row sums = memory importance (for graph weight updates)
// Column sums = response memory-dependence (training candidates)
use crate::agent::api::ApiClient;
use crate::agent::types::*;
use crate::agent::ui_channel::UiSender;
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>,
/// Conversation entry indices of the assistant responses (maps response_idx → entry_idx)
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
/// Get the most important memories for a given conversation entry index.
/// Returns (memory_key, divergence_score) sorted by importance.
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation.
///
/// For each Memory entry in the context, builds a version without it
/// and checks how the model's logprobs change for assistant responses.
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
use crate::agent::ui_channel::UiMessage;
// Identify memory entries and assistant response positions
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None,
})
.collect();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memories.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(),
response_scores: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring {} memories × {} responses",
memories.len(), response_indices.len(),
)));
// Shared HTTP client for connection reuse across all scoring calls
let http = reqwest::Client::builder()
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default();
// Baseline: logprobs with all memories present
let baseline = get_response_logprobs(context, &context.entries, client, &http, ui_tx).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} response tokens scored",
baseline.iter().map(|r| r.len()).sum::<usize>(),
)));
// For each memory, drop it and measure divergence
let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Activity(format!(
"scoring {}/{}...", mem_idx + 1, memories.len(),
)));
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring memory {}/{}: {}",
mem_idx + 1, memories.len(), key,
)));
// Build entries without this memory
let filtered: Vec<ConversationEntry> = context.entries.iter().enumerate()
.filter(|(i, _)| *i != *entry_idx)
.map(|(_, e)| e.clone())
.collect();
let without = get_response_logprobs(context, &filtered, client, &http, ui_tx).await?;
// Compute per-response divergence
let mut row = Vec::new();
for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
// Sum of logprob drops across tokens in this response
// Positive = memory helped (logprob was higher with it)
let divergence: f64 = base_lps.iter().zip(without_lps.iter())
.map(|(b, w)| b - w) // positive when baseline was more confident
.filter(|d| *d > 0.0) // only count where memory helped
.sum();
row.push(divergence);
}
matrix.push(row);
}
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let n_responses = response_indices.len();
let mut response_scores = vec![0.0; n_responses];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < n_responses {
response_scores[j] += v;
}
}
}
let _ = ui_tx.send(UiMessage::Activity(String::new()));
// Log summary per memory
for (key, score) in &memory_weights {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} → importance {:.1}", key, score,
)));
}
// Log per-response breakdown for the most important memories
let mut sorted_mems: Vec<(usize, &str, f64)> = memory_keys.iter().enumerate()
.map(|(i, k)| (i, k.as_str(), memory_weights[i].1))
.collect();
sorted_mems.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
for (mem_i, key, total) in sorted_mems.iter().take(5) {
if *total <= 0.0 { continue; }
let row = &matrix[*mem_i];
let top_responses: Vec<String> = row.iter().enumerate()
.filter(|(_, v)| **v > 0.1)
.map(|(j, v)| format!("resp[{}]={:.1}", j, v))
.collect();
if !top_responses.is_empty() {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} ({:.1}): {}", key, total, top_responses.join(", "),
)));
}
}
Ok(MemoryScore {
memory_weights,
response_scores,
matrix,
memory_keys,
response_entry_indices: response_indices,
})
}
/// Rough token estimate: ~4 chars per token.
const CHARS_PER_TOKEN: usize = 4;
/// Get logprobs for all assistant response tokens in a conversation.
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
/// containing logprobs for each token in that response.
///
/// Chunks the conversation into ~50K token segments (rounded to
/// assistant message boundaries) to avoid OOM from the logprobs
/// tensor allocation.
async fn get_response_logprobs(
context: &ContextState,
entries: &[ConversationEntry],
client: &ApiClient,
http: &reqwest::Client,
ui_tx: &UiSender,
) -> anyhow::Result<Vec<Vec<f64>>> {
// Build the fixed prefix (system prompt + personality)
let mut prefix = Vec::new();
prefix.push(Message::system(&context.system_prompt));
let ctx = context.render_context_message();
if !ctx.is_empty() {
prefix.push(Message::user(ctx));
}
let prefix_chars: usize = prefix.iter()
.map(|m| m.content_text().len())
.sum();
// Split entries into chunks that fit within the token budget,
// each ending at an assistant message boundary.
let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN;
let budget = max_chunk_chars.saturating_sub(prefix_chars);
let chunks = chunk_entries(entries, budget);
let mut all_responses: Vec<Vec<f64>> = Vec::new();
use crate::agent::ui_channel::UiMessage;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} chunks, prefix={}K chars, budget={}K chars",
chunks.len(), prefix_chars / 1024, budget / 1024,
)));
for (chunk_idx, chunk) in chunks.iter().enumerate() {
let chunk_chars: usize = chunk.iter()
.map(|e| e.message().content_text().len())
.sum();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] chunk {}/{}: {} entries, {}K chars",
chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024,
)));
}
for chunk in &chunks {
let mut msgs = prefix.clone();
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
let result = call_prompt_logprobs(&msgs, client, http).await?;
all_responses.extend(result);
}
Ok(all_responses)
}
/// Split entries into chunks of approximately `budget_chars` each,
/// ending at assistant message boundaries.
fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec<Vec<ConversationEntry>> {
let mut chunks = Vec::new();
let mut current = Vec::new();
let mut current_chars = 0;
for entry in entries {
let entry_chars = entry.message().content_text().len();
current_chars += entry_chars;
current.push(entry.clone());
// If over budget and we just added an assistant message, cut here
if current_chars >= budget_chars && entry.message().role == Role::Assistant {
chunks.push(std::mem::take(&mut current));
current_chars = 0;
}
}
if !current.is_empty() {
chunks.push(current);
}
// If everything fit in one chunk, just return it
if chunks.is_empty() {
chunks.push(entries.to_vec());
}
chunks
}
/// Make a single prompt_logprobs API call and extract response logprobs.
async fn call_prompt_logprobs(
msgs: &[Message],
client: &ApiClient,
http: &reqwest::Client,
) -> anyhow::Result<Vec<Vec<f64>>> {
let request = serde_json::json!({
"model": client.model,
"messages": msgs,
"max_tokens": 1,
"prompt_logprobs": 1,
"stream": false,
});
let response = http
.post(format!("{}/chat/completions", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&request)
.send()
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("unknown error");
anyhow::bail!("HTTP {} from logprobs API: {}", status, msg);
}
let prompt_logprobs = body.get("prompt_logprobs")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?;
// Find assistant response boundaries using special tokens
// Pattern: <|im_start|> assistant \n [<think>...</think>] response <|im_end|>
let mut responses: Vec<Vec<f64>> = Vec::new();
let mut in_assistant = false;
let mut in_think = false;
let mut current_response: Vec<f64> = Vec::new();
for entry in prompt_logprobs {
let Some(obj) = entry.as_object() else { continue };
let first = obj.values().next();
let Some(info) = first.and_then(|v| v.as_object()) else { continue };
let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or("");
let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0);
match token {
"<|im_start|>" => {
in_assistant = false;
in_think = false;
}
"assistant" if !in_assistant => {
in_assistant = true;
in_think = false;
current_response.clear();
}
"<think>" if in_assistant => {
in_think = true;
}
"</think>" if in_assistant => {
in_think = false;
}
"<|im_end|>" if in_assistant => {
if !current_response.is_empty() {
responses.push(std::mem::take(&mut current_response));
}
in_assistant = false;
}
"\n" if in_assistant && current_response.is_empty() => {
// Skip the newline right after "assistant"
}
_ if in_assistant && !in_think => {
current_response.push(logprob);
}
_ => {}
}
}
Ok(responses)
}