fix scoring: HTTP error checking, context refresh, chunk logging

Check HTTP status from logprobs API (was silently ignoring 500s).
Call publish_context_state() after storing scores so F10 screen
updates. Add chunk size logging for OOM debugging.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-02 22:47:44 -04:00
parent 29b3aeca57
commit 78abf90461
3 changed files with 28 additions and 5 deletions

View file

@ -84,7 +84,7 @@ pub async fn score_memories(
)));
// Baseline: logprobs with all memories present
let baseline = get_response_logprobs(context, &context.entries, client).await?;
let baseline = get_response_logprobs(context, &context.entries, client, ui_tx).await?;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} response tokens scored",
@ -110,7 +110,7 @@ pub async fn score_memories(
.map(|(_, e)| e.clone())
.collect();
let without = get_response_logprobs(context, &filtered, client).await?;
let without = get_response_logprobs(context, &filtered, client, ui_tx).await?;
// Compute per-response divergence
let mut row = Vec::new();
@ -194,6 +194,7 @@ async fn get_response_logprobs(
context: &ContextState,
entries: &[ConversationEntry],
client: &ApiClient,
ui_tx: &UiSender,
) -> anyhow::Result<Vec<Vec<f64>>> {
// Build the fixed prefix (system prompt + personality)
let mut prefix = Vec::new();
@ -214,6 +215,22 @@ async fn get_response_logprobs(
let mut all_responses: Vec<Vec<f64>> = Vec::new();
use crate::agent::ui_channel::UiMessage;
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] {} chunks, prefix={}K chars, budget={}K chars",
chunks.len(), prefix_chars / 1024, budget / 1024,
)));
for (chunk_idx, chunk) in chunks.iter().enumerate() {
let chunk_chars: usize = chunk.iter()
.map(|e| e.message().content_text().len())
.sum();
let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] chunk {}/{}: {} entries, {}K chars",
chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024,
)));
}
for chunk in &chunks {
let mut msgs = prefix.clone();
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
@ -277,10 +294,15 @@ async fn call_prompt_logprobs(
.send()
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
if let Some(err) = body.get("error") {
anyhow::bail!("API error: {}", err);
if !status.is_success() {
let msg = body.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("unknown error");
anyhow::bail!("HTTP {} from logprobs API: {}", status, msg);
}
let prompt_logprobs = body.get("prompt_logprobs")