scoring: add timeouts, progress feedback, error resilience
- 120s timeout on individual /v1/score HTTP calls - Activity bar shows "scoring 3/24: memory-key..." - Info messages at start and completion - Per-memory timing and importance in debug pane - Failed individual memories log error but don't abort (zero row) - Removed duplicate completion message (info from score_memories) Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
e8c3ed3d96
commit
beb49ec477
2 changed files with 89 additions and 57 deletions
|
|
@ -424,39 +424,39 @@ impl Session {
|
||||||
Command::Handled
|
Command::Handled
|
||||||
}
|
}
|
||||||
"/score" => {
|
"/score" => {
|
||||||
{
|
// Snapshot context+client while we have the lock,
|
||||||
let agent = self.agent.lock().await;
|
// so the scoring task doesn't need to wait for turns.
|
||||||
|
let (context, client) = {
|
||||||
|
let mut agent = self.agent.lock().await;
|
||||||
if agent.scoring_in_flight {
|
if agent.scoring_in_flight {
|
||||||
let _ = self.ui_tx.send(UiMessage::Info(
|
let _ = self.ui_tx.send(UiMessage::Info(
|
||||||
"(scoring already in progress)".into()
|
"(scoring already in progress)".into()
|
||||||
));
|
));
|
||||||
return Command::Handled;
|
return Command::Handled;
|
||||||
}
|
}
|
||||||
}
|
agent.scoring_in_flight = true;
|
||||||
self.agent.lock().await.scoring_in_flight = true;
|
|
||||||
let agent = self.agent.clone();
|
|
||||||
let ui_tx = self.ui_tx.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let (context, client) = {
|
|
||||||
let agent = agent.lock().await;
|
|
||||||
(agent.context.clone(), agent.client_clone())
|
(agent.context.clone(), agent.client_clone())
|
||||||
};
|
};
|
||||||
|
let agent = self.agent.clone();
|
||||||
|
let ui_tx = self.ui_tx.clone();
|
||||||
|
let _ = self.ui_tx.send(UiMessage::Debug("[score] task spawning".into()));
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug("[score] task started, calling score_memories".into()));
|
||||||
let result = poc_memory::thought::training::score_memories(
|
let result = poc_memory::thought::training::score_memories(
|
||||||
&context, &client, &ui_tx,
|
&context, &client, &ui_tx,
|
||||||
).await;
|
).await;
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug("[score] score_memories returned, acquiring lock".into()));
|
||||||
|
// Store results — brief lock, just setting fields
|
||||||
let mut agent = agent.lock().await;
|
let mut agent = agent.lock().await;
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug("[score] lock acquired, storing results".into()));
|
||||||
agent.scoring_in_flight = false;
|
agent.scoring_in_flight = false;
|
||||||
match result {
|
match result {
|
||||||
Ok(scores) => {
|
Ok(scores) => {
|
||||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
|
||||||
"[memory scoring complete: {} memories scored]",
|
|
||||||
scores.memory_keys.len(),
|
|
||||||
)));
|
|
||||||
agent.memory_scores = Some(scores);
|
agent.memory_scores = Some(scores);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||||
"[memory scoring failed: {:#}]", e,
|
"[scoring failed: {:#}]", e,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,14 @@
|
||||||
// Row sums = memory importance (for graph weight updates)
|
// Row sums = memory importance (for graph weight updates)
|
||||||
// Column sums = response memory-dependence (training candidates)
|
// Column sums = response memory-dependence (training candidates)
|
||||||
|
|
||||||
|
use std::time::Instant;
|
||||||
use crate::agent::api::ApiClient;
|
use crate::agent::api::ApiClient;
|
||||||
use crate::agent::types::*;
|
use crate::agent::types::*;
|
||||||
use crate::agent::ui_channel::{UiMessage, UiSender};
|
use crate::agent::ui_channel::{UiMessage, UiSender};
|
||||||
|
|
||||||
|
/// Timeout for individual /v1/score API calls.
|
||||||
|
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
|
||||||
|
|
||||||
/// Result of scoring one conversation's memory usage.
|
/// Result of scoring one conversation's memory usage.
|
||||||
pub struct MemoryScore {
|
pub struct MemoryScore {
|
||||||
/// memory_key → importance score (sum of divergence across all responses)
|
/// memory_key → importance score (sum of divergence across all responses)
|
||||||
|
|
@ -62,6 +66,9 @@ pub async fn score_memories(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if memories.is_empty() || response_indices.is_empty() {
|
if memories.is_empty() || response_indices.is_empty() {
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(
|
||||||
|
"[training] nothing to score (no memories or no responses)".into()
|
||||||
|
));
|
||||||
return Ok(MemoryScore {
|
return Ok(MemoryScore {
|
||||||
memory_weights: Vec::new(),
|
memory_weights: Vec::new(),
|
||||||
response_scores: Vec::new(),
|
response_scores: Vec::new(),
|
||||||
|
|
@ -71,52 +78,56 @@ pub async fn score_memories(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||||
"[training] scoring {} memories × {} responses via /v1/score",
|
"[scoring {} memories × {} responses]",
|
||||||
memories.len(), response_indices.len(),
|
memories.len(), response_indices.len(),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
let http = reqwest::Client::builder()
|
let http = reqwest::Client::builder()
|
||||||
|
.timeout(SCORE_TIMEOUT)
|
||||||
.pool_max_idle_per_host(2)
|
.pool_max_idle_per_host(2)
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Build the messages array from context
|
|
||||||
let all_messages = build_messages(context);
|
let all_messages = build_messages(context);
|
||||||
|
|
||||||
let roles: Vec<&str> = all_messages.iter()
|
|
||||||
.map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or("?"))
|
|
||||||
.collect();
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] sending {} messages, roles: {:?}",
|
"[training] {} messages in context",
|
||||||
all_messages.len(), roles,
|
all_messages.len(),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// Baseline: score with all memories present
|
// Baseline: score with all memories present
|
||||||
let baseline = call_score(&http, client, &all_messages, ui_tx).await?;
|
let payload_size = serde_json::to_string(&all_messages)
|
||||||
|
.map(|s| s.len()).unwrap_or(0);
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] baseline: {} messages scored",
|
"[training] payload size: {}KB",
|
||||||
baseline.len(),
|
payload_size / 1024,
|
||||||
|
)));
|
||||||
|
let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into()));
|
||||||
|
let start = Instant::now();
|
||||||
|
let baseline = call_score(&http, client, &all_messages).await?;
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] baseline: {} responses scored in {:.1}s",
|
||||||
|
baseline.len(), start.elapsed().as_secs_f64(),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// For each memory, drop it and measure divergence
|
// For each memory, drop it and measure divergence
|
||||||
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
||||||
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
|
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
|
||||||
|
let total = memories.len();
|
||||||
|
|
||||||
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
|
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
|
||||||
let _ = ui_tx.send(UiMessage::Activity(format!(
|
let _ = ui_tx.send(UiMessage::Activity(format!(
|
||||||
"scoring {}/{}...", mem_idx + 1, memories.len(),
|
"scoring {}/{}: {}...", mem_idx + 1, total, key,
|
||||||
)));
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
||||||
"[training] scoring memory {}/{}: {}",
|
|
||||||
mem_idx + 1, memories.len(), key,
|
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// Build messages without this memory
|
let start = Instant::now();
|
||||||
let filtered_messages = build_messages_without(context, *entry_idx);
|
let filtered_messages = build_messages_without(context, *entry_idx);
|
||||||
let without = call_score(&http, client, &filtered_messages, ui_tx).await?;
|
let without = call_score(&http, client, &filtered_messages).await;
|
||||||
|
|
||||||
|
match without {
|
||||||
|
Ok(without) => {
|
||||||
|
let elapsed = start.elapsed().as_secs_f64();
|
||||||
// Match scores by message index and compute divergence
|
// Match scores by message index and compute divergence
|
||||||
let mut row = Vec::new();
|
let mut row = Vec::new();
|
||||||
for base_score in &baseline {
|
for base_score in &baseline {
|
||||||
|
|
@ -125,12 +136,26 @@ pub async fn score_memories(
|
||||||
.find(|s| s.message_index == base_score.message_index)
|
.find(|s| s.message_index == base_score.message_index)
|
||||||
.map(|s| s.total_logprob)
|
.map(|s| s.total_logprob)
|
||||||
.unwrap_or(base_lp);
|
.unwrap_or(base_lp);
|
||||||
// Positive = memory helped (logprob was higher with it)
|
|
||||||
let divergence = (base_lp - without_lp).max(0.0);
|
let divergence = (base_lp - without_lp).max(0.0);
|
||||||
row.push(divergence);
|
row.push(divergence);
|
||||||
}
|
}
|
||||||
|
let importance: f64 = row.iter().sum();
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] {}/{} {} → {:.1} ({:.1}s)",
|
||||||
|
mem_idx + 1, total, key, importance, elapsed,
|
||||||
|
)));
|
||||||
matrix.push(row);
|
matrix.push(row);
|
||||||
}
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] {}/{} {} FAILED: {:#}",
|
||||||
|
mem_idx + 1, total, key, e,
|
||||||
|
)));
|
||||||
|
// Push zero row so matrix stays aligned
|
||||||
|
matrix.push(vec![0.0; baseline.len()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
||||||
|
|
||||||
|
|
@ -150,12 +175,10 @@ pub async fn score_memories(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log summary
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||||
for (key, score) in &memory_weights {
|
"[scoring complete: {} memories scored]",
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
memory_keys.len(),
|
||||||
"[training] {} → importance {:.1}", key, score,
|
|
||||||
)));
|
)));
|
||||||
}
|
|
||||||
|
|
||||||
Ok(MemoryScore {
|
Ok(MemoryScore {
|
||||||
memory_weights,
|
memory_weights,
|
||||||
|
|
@ -220,7 +243,6 @@ async fn call_score(
|
||||||
http: &reqwest::Client,
|
http: &reqwest::Client,
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
messages: &[serde_json::Value],
|
messages: &[serde_json::Value],
|
||||||
ui_tx: &UiSender,
|
|
||||||
) -> anyhow::Result<Vec<ScoreMessageResult>> {
|
) -> anyhow::Result<Vec<ScoreMessageResult>> {
|
||||||
let request = serde_json::json!({
|
let request = serde_json::json!({
|
||||||
"model": client.model,
|
"model": client.model,
|
||||||
|
|
@ -234,7 +256,14 @@ async fn call_score(
|
||||||
.header("Authorization", format!("Bearer {}", client.api_key()))
|
.header("Authorization", format!("Bearer {}", client.api_key()))
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
if e.is_timeout() {
|
||||||
|
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
|
||||||
|
} else {
|
||||||
|
anyhow::anyhow!("score request failed: {}", e)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let body: serde_json::Value = response.json().await?;
|
let body: serde_json::Value = response.json().await?;
|
||||||
|
|
@ -243,12 +272,15 @@ async fn call_score(
|
||||||
let msg = body.get("error")
|
let msg = body.get("error")
|
||||||
.and_then(|e| e.as_str())
|
.and_then(|e| e.as_str())
|
||||||
.unwrap_or("unknown error");
|
.unwrap_or("unknown error");
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
||||||
"[training] score API error: {}", msg,
|
|
||||||
)));
|
|
||||||
anyhow::bail!("score API error: {}", msg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ScoreApiResponse = serde_json::from_value(body)?;
|
// Check for error in body (score endpoint returns dict on error)
|
||||||
|
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
||||||
|
anyhow::bail!("score API error: {}", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ScoreApiResponse = serde_json::from_value(body)
|
||||||
|
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
||||||
Ok(result.scores)
|
Ok(result.scores)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue