consciousness/src/agent/training.rs
Kent Overstreet 9bebbcb635 Move API code from user/ to agent/
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
2026-04-04 00:34:48 -04:00

295 lines
10 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// training.rs — Memory importance scoring via /v1/score
//
// Drops each memory from the context one at a time, calls the vLLM
// /v1/score endpoint to get logprobs for assistant responses.
// Produces a divergence matrix: memories × responses.
//
// Row sums = memory importance (for graph weight updates)
// Column sums = response memory-dependence (training candidates)
use std::time::Instant;
use super::api::ApiClient;
use crate::agent::api::types::*;
use crate::agent::context::{ConversationEntry, ContextState};
use crate::user::ui_channel::{UiMessage, UiSender};
/// Timeout for individual /v1/score API calls.
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
/// Result of scoring one conversation's memory usage.
pub struct MemoryScore {
/// memory_key → importance score (sum of divergence across all responses)
pub memory_weights: Vec<(String, f64)>,
/// response_index → memory-dependence score (sum of divergence across all memories)
pub response_scores: Vec<f64>,
/// Full matrix: divergence[memory_idx][response_idx]
pub matrix: Vec<Vec<f64>>,
/// Keys of memories that were scored
pub memory_keys: Vec<String>,
/// Conversation entry indices of the assistant responses
pub response_entry_indices: Vec<usize>,
}
impl MemoryScore {
/// Get the most important memories for a given conversation entry index.
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
else { return Vec::new() };
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
.zip(self.matrix.iter())
.filter_map(|(key, row)| {
let score = row.get(resp_idx).copied().unwrap_or(0.0);
if score > 0.01 { Some((key.as_str(), score)) } else { None }
})
.collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
}
/// Score how important each memory is to the conversation.
pub async fn score_memories(
context: &ContextState,
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<MemoryScore> {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] in score_memories"
)));
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
.filter_map(|(i, e)| match e {
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
_ => None,
})
.collect();
let response_indices: Vec<usize> = context.entries.iter().enumerate()
.filter(|(_, e)| e.message().role == Role::Assistant)
.map(|(i, _)| i)
.collect();
if memories.is_empty() || response_indices.is_empty() {
let _ = ui_tx.send(UiMessage::Debug(
"[training] nothing to score (no memories or no responses)".into()
));
return Ok(MemoryScore {
memory_weights: Vec::new(),
response_scores: Vec::new(),
matrix: Vec::new(),
memory_keys: Vec::new(),
response_entry_indices: Vec::new(),
});
}
let _ = ui_tx.send(UiMessage::Info(format!(
"[scoring {} memories × {} responses]",
memories.len(), response_indices.len(),
)));
let http = reqwest::Client::builder()
.timeout(SCORE_TIMEOUT)
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default();
let all_messages = build_messages(context);
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} messages in context",
all_messages.len(),
)));
// Baseline: score with all memories present
let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into()));
let payload_size = serde_json::to_string(&all_messages)
.map(|s| s.len()).unwrap_or(0);
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] payload size: {}KB",
payload_size / 1024,
)));
let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into()));
let start = Instant::now();
let baseline = call_score(&http, client, &all_messages).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} responses scored in {:.1}s",
baseline.len(), start.elapsed().as_secs_f64(),
)));
// For each memory, drop it and measure divergence
let mut matrix: Vec<Vec<f64>> = Vec::new();
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
let total = memories.len();
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
let _ = ui_tx.send(UiMessage::Activity(format!(
"scoring {}/{}: {}...", mem_idx + 1, total, key,
)));
let start = Instant::now();
let filtered_messages = build_messages_without(context, *entry_idx);
let without = call_score(&http, client, &filtered_messages).await;
match without {
Ok(without) => {
let elapsed = start.elapsed().as_secs_f64();
// Match scores by position (nth scored response),
// not message_index — indices shift when a memory
// is removed from the conversation.
let mut row = Vec::new();
for (i, base_score) in baseline.iter().enumerate() {
let base_lp = base_score.total_logprob;
let without_lp = without.get(i)
.map(|s| s.total_logprob)
.unwrap_or(base_lp);
let divergence = (base_lp - without_lp).max(0.0);
row.push(divergence);
}
let importance: f64 = row.iter().sum();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {}/{} {}{:.1} ({:.1}s)",
mem_idx + 1, total, key, importance, elapsed,
)));
matrix.push(row);
}
Err(e) => {
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {}/{} {} FAILED: {:#}",
mem_idx + 1, total, key, e,
)));
// Push zero row so matrix stays aligned
matrix.push(vec![0.0; baseline.len()]);
}
}
}
let _ = ui_tx.send(UiMessage::Activity(String::new()));
// Compute scores
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
.zip(matrix.iter())
.map(|(key, row)| (key.clone(), row.iter().sum()))
.collect();
let n_responses = response_indices.len();
let mut response_scores = vec![0.0; n_responses];
for row in &matrix {
for (j, &v) in row.iter().enumerate() {
if j < n_responses {
response_scores[j] += v;
}
}
}
let _ = ui_tx.send(UiMessage::Info(format!(
"[scoring complete: {} memories scored]",
memory_keys.len(),
)));
Ok(MemoryScore {
memory_weights,
response_scores,
matrix,
memory_keys,
response_entry_indices: response_indices,
})
}
/// Score response from the /v1/score endpoint.
#[derive(serde::Deserialize)]
struct ScoreMessageResult {
#[allow(dead_code)]
message_index: usize,
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreApiResponse {
scores: Vec<ScoreMessageResult>,
}
/// Build the messages array for the /v1/score endpoint from ContextState.
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for entry in &context.entries {
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Build messages with one entry removed.
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
let mut msgs = Vec::new();
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
let ctx = context.render_context_message();
if !ctx.is_empty() {
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
}
for (i, entry) in context.entries.iter().enumerate() {
if i == skip_idx { continue; }
let m = entry.api_message();
msgs.push(serde_json::json!({
"role": m.role_str(),
"content": m.content_text(),
}));
}
msgs
}
/// Call the /v1/score endpoint and return per-message logprobs.
async fn call_score(
http: &reqwest::Client,
client: &ApiClient,
messages: &[serde_json::Value],
) -> anyhow::Result<Vec<ScoreMessageResult>> {
let request = serde_json::json!({
"model": client.model,
"messages": messages,
"logprobs": 1,
});
let response = http
.post(format!("{}/score", client.base_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key()))
.json(&request)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
} else {
anyhow::anyhow!("score request failed: {}", e)
}
})?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.as_str())
.unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
// Check for error in body (score endpoint returns dict on error)
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
}
let result: ScoreApiResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
}