diff --git a/src/agent/runner.rs b/src/agent/runner.rs index a17756b..6259db3 100644 --- a/src/agent/runner.rs +++ b/src/agent/runner.rs @@ -882,7 +882,7 @@ impl Agent { } /// Push the current context summary to the shared state for the TUI to read. - fn publish_context_state(&self) { + pub fn publish_context_state(&self) { let summary = self.context_state_summary(); if let Ok(mut dbg) = std::fs::OpenOptions::new().create(true).append(true) .open("/tmp/poc-journal-debug.log") { diff --git a/src/bin/poc-agent.rs b/src/bin/poc-agent.rs index 0497ce4..745f21b 100644 --- a/src/bin/poc-agent.rs +++ b/src/bin/poc-agent.rs @@ -460,6 +460,7 @@ impl Session { ))); } } + agent.publish_context_state(); }); Command::Handled } diff --git a/src/thought/training.rs b/src/thought/training.rs index 45deced..c67b8ca 100644 --- a/src/thought/training.rs +++ b/src/thought/training.rs @@ -84,7 +84,7 @@ pub async fn score_memories( ))); // Baseline: logprobs with all memories present - let baseline = get_response_logprobs(context, &context.entries, client).await?; + let baseline = get_response_logprobs(context, &context.entries, client, ui_tx).await?; let _ = ui_tx.send(UiMessage::Debug(format!( "[training] baseline: {} response tokens scored", @@ -110,7 +110,7 @@ pub async fn score_memories( .map(|(_, e)| e.clone()) .collect(); - let without = get_response_logprobs(context, &filtered, client).await?; + let without = get_response_logprobs(context, &filtered, client, ui_tx).await?; // Compute per-response divergence let mut row = Vec::new(); @@ -194,6 +194,7 @@ async fn get_response_logprobs( context: &ContextState, entries: &[ConversationEntry], client: &ApiClient, + ui_tx: &UiSender, ) -> anyhow::Result>> { // Build the fixed prefix (system prompt + personality) let mut prefix = Vec::new(); @@ -214,6 +215,22 @@ async fn get_response_logprobs( let mut all_responses: Vec> = Vec::new(); + use crate::agent::ui_channel::UiMessage; + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] {} chunks, prefix={}K chars, budget={}K chars", + chunks.len(), prefix_chars / 1024, budget / 1024, + ))); + + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let chunk_chars: usize = chunk.iter() + .map(|e| e.message().content_text().len()) + .sum(); + let _ = ui_tx.send(UiMessage::Debug(format!( + "[training] chunk {}/{}: {} entries, {}K chars", + chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024, + ))); + } + for chunk in &chunks { let mut msgs = prefix.clone(); msgs.extend(chunk.iter().map(|e| e.api_message().clone())); @@ -277,10 +294,15 @@ async fn call_prompt_logprobs( .send() .await?; + let status = response.status(); let body: serde_json::Value = response.json().await?; - if let Some(err) = body.get("error") { - anyhow::bail!("API error: {}", err); + if !status.is_success() { + let msg = body.get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .unwrap_or("unknown error"); + anyhow::bail!("HTTP {} from logprobs API: {}", status, msg); } let prompt_logprobs = body.get("prompt_logprobs")