diff --git a/src/bin/poc-agent.rs b/src/bin/poc-agent.rs index 745f21b..deb96f3 100644 --- a/src/bin/poc-agent.rs +++ b/src/bin/poc-agent.rs @@ -424,39 +424,39 @@ impl Session { Command::Handled } "/score" => { - { - let agent = self.agent.lock().await; + // Snapshot context+client while we have the lock, + // so the scoring task doesn't need to wait for turns. + let (context, client) = { + let mut agent = self.agent.lock().await; if agent.scoring_in_flight { let _ = self.ui_tx.send(UiMessage::Info( "(scoring already in progress)".into() )); return Command::Handled; } - } - self.agent.lock().await.scoring_in_flight = true; + agent.scoring_in_flight = true; + (agent.context.clone(), agent.client_clone()) + }; let agent = self.agent.clone(); let ui_tx = self.ui_tx.clone(); + let _ = self.ui_tx.send(UiMessage::Debug("[score] task spawning".into())); tokio::spawn(async move { - let (context, client) = { - let agent = agent.lock().await; - (agent.context.clone(), agent.client_clone()) - }; + let _ = ui_tx.send(UiMessage::Debug("[score] task started, calling score_memories".into())); let result = poc_memory::thought::training::score_memories( &context, &client, &ui_tx, ).await; + let _ = ui_tx.send(UiMessage::Debug("[score] score_memories returned, acquiring lock".into())); + // Store results — brief lock, just setting fields let mut agent = agent.lock().await; + let _ = ui_tx.send(UiMessage::Debug("[score] lock acquired, storing results".into())); agent.scoring_in_flight = false; match result { Ok(scores) => { - let _ = ui_tx.send(UiMessage::Info(format!( - "[memory scoring complete: {} memories scored]", - scores.memory_keys.len(), - ))); agent.memory_scores = Some(scores); } Err(e) => { let _ = ui_tx.send(UiMessage::Info(format!( - "[memory scoring failed: {:#}]", e, + "[scoring failed: {:#}]", e, ))); } } diff --git a/src/thought/training.rs b/src/thought/training.rs index 3399317..a7d9a63 100644 --- a/src/thought/training.rs +++ b/src/thought/training.rs @@ -7,10 +7,14 @@ // Row sums = memory importance (for graph weight updates) // Column sums = response memory-dependence (training candidates) +use std::time::Instant; use crate::agent::api::ApiClient; use crate::agent::types::*; use crate::agent::ui_channel::{UiMessage, UiSender}; +/// Timeout for individual /v1/score API calls. +const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); + /// Result of scoring one conversation's memory usage. pub struct MemoryScore { /// memory_key → importance score (sum of divergence across all responses) @@ -62,6 +66,9 @@ pub async fn score_memories( .collect(); if memories.is_empty() || response_indices.is_empty() { + let _ = ui_tx.send(UiMessage::Debug( + "[training] nothing to score (no memories or no responses)".into() + )); return Ok(MemoryScore { memory_weights: Vec::new(), response_scores: Vec::new(), @@ -71,65 +78,83 @@ pub async fn score_memories( }); } - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] scoring {} memories × {} responses via /v1/score", + let _ = ui_tx.send(UiMessage::Info(format!( + "[scoring {} memories × {} responses]", memories.len(), response_indices.len(), ))); let http = reqwest::Client::builder() + .timeout(SCORE_TIMEOUT) .pool_max_idle_per_host(2) .build() .unwrap_or_default(); - // Build the messages array from context let all_messages = build_messages(context); - let roles: Vec<&str> = all_messages.iter() - .map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or("?")) - .collect(); let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] sending {} messages, roles: {:?}", - all_messages.len(), roles, + "[training] {} messages in context", + all_messages.len(), ))); // Baseline: score with all memories present - let baseline = call_score(&http, client, &all_messages, ui_tx).await?; - + let payload_size = serde_json::to_string(&all_messages) + .map(|s| s.len()).unwrap_or(0); let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] baseline: {} messages scored", - baseline.len(), + "[training] payload size: {}KB", + payload_size / 1024, + ))); + let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into())); + let start = Instant::now(); + let baseline = call_score(&http, client, &all_messages).await?; + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] baseline: {} responses scored in {:.1}s", + baseline.len(), start.elapsed().as_secs_f64(), ))); // For each memory, drop it and measure divergence let mut matrix: Vec> = Vec::new(); let memory_keys: Vec = memories.iter().map(|(_, k)| k.clone()).collect(); + let total = memories.len(); for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() { let _ = ui_tx.send(UiMessage::Activity(format!( - "scoring {}/{}...", mem_idx + 1, memories.len(), - ))); - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] scoring memory {}/{}: {}", - mem_idx + 1, memories.len(), key, + "scoring {}/{}: {}...", mem_idx + 1, total, key, ))); - // Build messages without this memory + let start = Instant::now(); let filtered_messages = build_messages_without(context, *entry_idx); - let without = call_score(&http, client, &filtered_messages, ui_tx).await?; + let without = call_score(&http, client, &filtered_messages).await; - // Match scores by message index and compute divergence - let mut row = Vec::new(); - for base_score in &baseline { - let base_lp = base_score.total_logprob; - let without_lp = without.iter() - .find(|s| s.message_index == base_score.message_index) - .map(|s| s.total_logprob) - .unwrap_or(base_lp); - // Positive = memory helped (logprob was higher with it) - let divergence = (base_lp - without_lp).max(0.0); - row.push(divergence); + match without { + Ok(without) => { + let elapsed = start.elapsed().as_secs_f64(); + // Match scores by message index and compute divergence + let mut row = Vec::new(); + for base_score in &baseline { + let base_lp = base_score.total_logprob; + let without_lp = without.iter() + .find(|s| s.message_index == base_score.message_index) + .map(|s| s.total_logprob) + .unwrap_or(base_lp); + let divergence = (base_lp - without_lp).max(0.0); + row.push(divergence); + } + let importance: f64 = row.iter().sum(); + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] {}/{} {} → {:.1} ({:.1}s)", + mem_idx + 1, total, key, importance, elapsed, + ))); + matrix.push(row); + } + Err(e) => { + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] {}/{} {} FAILED: {:#}", + mem_idx + 1, total, key, e, + ))); + // Push zero row so matrix stays aligned + matrix.push(vec![0.0; baseline.len()]); + } } - matrix.push(row); } let _ = ui_tx.send(UiMessage::Activity(String::new())); @@ -150,12 +175,10 @@ pub async fn score_memories( } } - // Log summary - for (key, score) in &memory_weights { - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] {} → importance {:.1}", key, score, - ))); - } + let _ = ui_tx.send(UiMessage::Info(format!( + "[scoring complete: {} memories scored]", + memory_keys.len(), + ))); Ok(MemoryScore { memory_weights, @@ -220,7 +243,6 @@ async fn call_score( http: &reqwest::Client, client: &ApiClient, messages: &[serde_json::Value], - ui_tx: &UiSender, ) -> anyhow::Result> { let request = serde_json::json!({ "model": client.model, @@ -234,7 +256,14 @@ async fn call_score( .header("Authorization", format!("Bearer {}", client.api_key())) .json(&request) .send() - .await?; + .await + .map_err(|e| { + if e.is_timeout() { + anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs()) + } else { + anyhow::anyhow!("score request failed: {}", e) + } + })?; let status = response.status(); let body: serde_json::Value = response.json().await?; @@ -243,12 +272,15 @@ async fn call_score( let msg = body.get("error") .and_then(|e| e.as_str()) .unwrap_or("unknown error"); - let _ = ui_tx.send(UiMessage::Debug(format!( - "[training] score API error: {}", msg, - ))); - anyhow::bail!("score API error: {}", msg); + anyhow::bail!("score API HTTP {}: {}", status, msg); } - let result: ScoreApiResponse = serde_json::from_value(body)?; + // Check for error in body (score endpoint returns dict on error) + if let Some(err) = body.get("error").and_then(|e| e.as_str()) { + anyhow::bail!("score API error: {}", err); + } + + let result: ScoreApiResponse = serde_json::from_value(body) + .map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?; Ok(result.scores) }