consciousness/src/thought/training.rs

255 lines
8.5 KiB
Rust
Raw Normal View History

// training.rs — Memory importance scoring via /v1/score
//
// Drops each memory from the context one at a time, calls the vLLM
// /v1/score endpoint to get logprobs for assistant responses.
// Produces a divergence matrix: memories × responses.
//
// Row sums = memory importance (for graph weight updates)
// Column sums = response memory-dependence (training candidates)
use crate::agent::api::ApiClient;
use crate::agent::types::*;
use crate::agent::ui_channel::{UiMessage, UiSender};
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>,
/// Conversation entry indices of the assistant responses
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
/// Get the most important memories for a given conversation entry index.
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation.
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None,
})
.collect();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memories.is_empty() || response_indices.is_empty() {
return Ok(MemoryScore {
memory_weights: Vec::new(),
response_scores: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring {} memories × {} responses via /v1/score",
memories.len(), response_indices.len(),
)));
let http = reqwest::Client::builder()
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default();
// Build the messages array from context
let all_messages = build_messages(context);
let roles: Vec<&str> = all_messages.iter()
.map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or("?"))
.collect();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] sending {} messages, roles: {:?}",
all_messages.len(), roles,
)));
// Baseline: score with all memories present
let baseline = call_score(&http, client, &all_messages, ui_tx).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} messages scored",
baseline.len(),
)));
// For each memory, drop it and measure divergence
let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Activity(format!(
"scoring {}/{}...", mem_idx + 1, memories.len(),
)));
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] scoring memory {}/{}: {}",
mem_idx + 1, memories.len(), key,
)));
// Build messages without this memory
let filtered_messages = build_messages_without(context, *entry_idx);
let without = call_score(&http, client, &filtered_messages, ui_tx).await?;
// Match scores by message index and compute divergence
let mut row = Vec::new();
for base_score in &baseline {
let base_lp = base_score.total_logprob;
let without_lp = without.iter()
.find(|s| s.message_index == base_score.message_index)
.map(|s| s.total_logprob)
.unwrap_or(base_lp);
// Positive = memory helped (logprob was higher with it)
let divergence = (base_lp - without_lp).max(0.0);
row.push(divergence);
}
matrix.push(row);
}
let _ = ui_tx.send(UiMessage::Activity(String::new()));
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let n_responses = response_indices.len();
let mut response_scores = vec![0.0; n_responses];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < n_responses {
response_scores[j] += v;
}
}
}
// Log summary
for (key, score) in &memory_weights {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} → importance {:.1}", key, score,
)));
}
Ok(MemoryScore {
memory_weights,
response_scores,
matrix,
memory_keys,
response_entry_indices: response_indices,
})
}
/// Score response from the /v1/score endpoint.
#[derive(serde::Deserialize)]
struct ScoreMessageResult {
message_index: usize,
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreApiResponse {
scores: Vec<ScoreMessageResult>,
}
/// Build the messages array for the /v1/score endpoint from ContextState.
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for entry in &context.entries {
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Build messages with one entry removed.
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for (i, entry) in context.entries.iter().enumerate() {
if i == skip_idx { continue; }
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Call the /v1/score endpoint and return per-message logprobs.
async fn call_score(
http: &reqwest::Client,
client: &ApiClient,
messages: &[serde_json::Value],
ui_tx: &UiSender,
) -> anyhow::Result<Vec<ScoreMessageResult>> {
let request = serde_json::json!({
"model": client.model,
"messages": messages,
"logprobs": 1,
});
let response = http
.post(format!("{}/score", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&request)
.send()
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.as_str())
.unwrap_or("unknown error");
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] score API error: {}", msg,
)));
anyhow::bail!("score API error: {}", msg);
}
let result: ScoreApiResponse = serde_json::from_value(body)?;
Ok(result.scores)
}