fix scoring: HTTP error checking, context refresh, chunk logging
Check HTTP status from logprobs API (was silently ignoring 500s). Call publish_context_state() after storing scores so F10 screen updates. Add chunk size logging for OOM debugging. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
29b3aeca57
commit
78abf90461
3 changed files with 28 additions and 5 deletions
|
|
@ -882,7 +882,7 @@ impl Agent {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Push the current context summary to the shared state for the TUI to read.
|
/// Push the current context summary to the shared state for the TUI to read.
|
||||||
fn publish_context_state(&self) {
|
pub fn publish_context_state(&self) {
|
||||||
let summary = self.context_state_summary();
|
let summary = self.context_state_summary();
|
||||||
if let Ok(mut dbg) = std::fs::OpenOptions::new().create(true).append(true)
|
if let Ok(mut dbg) = std::fs::OpenOptions::new().create(true).append(true)
|
||||||
.open("/tmp/poc-journal-debug.log") {
|
.open("/tmp/poc-journal-debug.log") {
|
||||||
|
|
|
||||||
|
|
@ -460,6 +460,7 @@ impl Session {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
agent.publish_context_state();
|
||||||
});
|
});
|
||||||
Command::Handled
|
Command::Handled
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ pub async fn score_memories(
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// Baseline: logprobs with all memories present
|
// Baseline: logprobs with all memories present
|
||||||
let baseline = get_response_logprobs(context, &context.entries, client).await?;
|
let baseline = get_response_logprobs(context, &context.entries, client, ui_tx).await?;
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] baseline: {} response tokens scored",
|
"[training] baseline: {} response tokens scored",
|
||||||
|
|
@ -110,7 +110,7 @@ pub async fn score_memories(
|
||||||
.map(|(_, e)| e.clone())
|
.map(|(_, e)| e.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let without = get_response_logprobs(context, &filtered, client).await?;
|
let without = get_response_logprobs(context, &filtered, client, ui_tx).await?;
|
||||||
|
|
||||||
// Compute per-response divergence
|
// Compute per-response divergence
|
||||||
let mut row = Vec::new();
|
let mut row = Vec::new();
|
||||||
|
|
@ -194,6 +194,7 @@ async fn get_response_logprobs(
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
entries: &[ConversationEntry],
|
entries: &[ConversationEntry],
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
|
ui_tx: &UiSender,
|
||||||
) -> anyhow::Result<Vec<Vec<f64>>> {
|
) -> anyhow::Result<Vec<Vec<f64>>> {
|
||||||
// Build the fixed prefix (system prompt + personality)
|
// Build the fixed prefix (system prompt + personality)
|
||||||
let mut prefix = Vec::new();
|
let mut prefix = Vec::new();
|
||||||
|
|
@ -214,6 +215,22 @@ async fn get_response_logprobs(
|
||||||
|
|
||||||
let mut all_responses: Vec<Vec<f64>> = Vec::new();
|
let mut all_responses: Vec<Vec<f64>> = Vec::new();
|
||||||
|
|
||||||
|
use crate::agent::ui_channel::UiMessage;
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] {} chunks, prefix={}K chars, budget={}K chars",
|
||||||
|
chunks.len(), prefix_chars / 1024, budget / 1024,
|
||||||
|
)));
|
||||||
|
|
||||||
|
for (chunk_idx, chunk) in chunks.iter().enumerate() {
|
||||||
|
let chunk_chars: usize = chunk.iter()
|
||||||
|
.map(|e| e.message().content_text().len())
|
||||||
|
.sum();
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] chunk {}/{}: {} entries, {}K chars",
|
||||||
|
chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
for chunk in &chunks {
|
for chunk in &chunks {
|
||||||
let mut msgs = prefix.clone();
|
let mut msgs = prefix.clone();
|
||||||
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
|
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
|
||||||
|
|
@ -277,10 +294,15 @@ async fn call_prompt_logprobs(
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
let body: serde_json::Value = response.json().await?;
|
let body: serde_json::Value = response.json().await?;
|
||||||
|
|
||||||
if let Some(err) = body.get("error") {
|
if !status.is_success() {
|
||||||
anyhow::bail!("API error: {}", err);
|
let msg = body.get("error")
|
||||||
|
.and_then(|e| e.get("message"))
|
||||||
|
.and_then(|m| m.as_str())
|
||||||
|
.unwrap_or("unknown error");
|
||||||
|
anyhow::bail!("HTTP {} from logprobs API: {}", status, msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt_logprobs = body.get("prompt_logprobs")
|
let prompt_logprobs = body.get("prompt_logprobs")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue