From 214806cb9060663377e53e1384fdd2cf308b5994 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Thu, 2 Apr 2026 15:28:00 -0400 Subject: [PATCH] move context functions from agent/context.rs to thought/context.rs trim_conversation moved to thought/context.rs where model_context_window, msg_token_count, is_context_overflow, is_stream_error already lived. Delete the duplicate agent/context.rs (94 lines). Co-Authored-By: Proof of Concept --- src/agent/context.rs | 94 ------------------------------------------ src/agent/mod.rs | 1 - src/agent/runner.rs | 14 +++---- src/bin/poc-agent.rs | 4 +- src/thought/context.rs | 39 ++++++++++++++++++ 5 files changed, 48 insertions(+), 104 deletions(-) delete mode 100644 src/agent/context.rs diff --git a/src/agent/context.rs b/src/agent/context.rs deleted file mode 100644 index eeca48f..0000000 --- a/src/agent/context.rs +++ /dev/null @@ -1,94 +0,0 @@ -// context.rs — Context window management -// -// Token counting and conversation trimming for the context window. - -use crate::agent::types::*; -use tiktoken_rs::CoreBPE; - -/// Look up a model's context window size in tokens. -pub fn model_context_window(_model: &str) -> usize { - crate::config::get().api_context_window -} - -/// Context budget in tokens: 60% of the model's context window. -fn context_budget_tokens(model: &str) -> usize { - model_context_window(model) * 60 / 100 -} - -/// Trim conversation to fit within the context budget. -/// Returns the trimmed conversation messages (oldest dropped first). -pub fn trim_conversation( - context: &ContextState, - conversation: &[Message], - model: &str, - tokenizer: &CoreBPE, -) -> Vec { - let count = |s: &str| tokenizer.encode_with_special_tokens(s).len(); - let max_tokens = context_budget_tokens(model); - - let identity_cost = count(&context.system_prompt) - + context.personality.iter().map(|(_, c)| count(c)).sum::(); - let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum(); - let reserve = max_tokens / 4; - let available = max_tokens - .saturating_sub(identity_cost) - .saturating_sub(journal_cost) - .saturating_sub(reserve); - - let msg_costs: Vec = conversation.iter() - .map(|m| msg_token_count(tokenizer, m)).collect(); - let total: usize = msg_costs.iter().sum(); - - let mut skip = 0; - let mut trimmed = total; - while trimmed > available && skip < conversation.len() { - trimmed -= msg_costs[skip]; - skip += 1; - } - - // Walk forward to user message boundary - while skip < conversation.len() && conversation[skip].role != Role::User { - skip += 1; - } - - conversation[skip..].to_vec() -} - -/// Count the token footprint of a message using BPE tokenization. -pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize { - let content = msg.content.as_ref().map_or(0, |c| match c { - MessageContent::Text(s) => tokenizer.encode_with_special_tokens(s).len(), - MessageContent::Parts(parts) => parts.iter() - .map(|p| match p { - ContentPart::Text { text } => tokenizer.encode_with_special_tokens(text).len(), - ContentPart::ImageUrl { .. } => 85, - }) - .sum(), - }); - let tools = msg.tool_calls.as_ref().map_or(0, |calls| { - calls.iter() - .map(|c| tokenizer.encode_with_special_tokens(&c.function.arguments).len() - + tokenizer.encode_with_special_tokens(&c.function.name).len()) - .sum() - }); - content + tools -} - -/// Detect context window overflow errors from the API. -pub fn is_context_overflow(err: &anyhow::Error) -> bool { - let msg = err.to_string().to_lowercase(); - msg.contains("context length") - || msg.contains("token limit") - || msg.contains("too many tokens") - || msg.contains("maximum context") - || msg.contains("prompt is too long") - || msg.contains("request too large") - || msg.contains("input validation error") - || msg.contains("content length limit") - || (msg.contains("400") && msg.contains("tokens")) -} - -/// Detect model/provider errors delivered inside the SSE stream. -pub fn is_stream_error(err: &anyhow::Error) -> bool { - err.to_string().contains("model stream error") -} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 6c9a6dc..300b9e0 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -15,7 +15,6 @@ pub mod tools; pub mod ui_channel; pub mod runner; pub mod cli; -pub mod context; pub mod dmn; pub mod identity; pub mod log; diff --git a/src/agent/runner.rs b/src/agent/runner.rs index 2b18074..42a7cf6 100644 --- a/src/agent/runner.rs +++ b/src/agent/runner.rs @@ -180,8 +180,8 @@ impl Agent { /// every startup/compaction. pub fn budget(&self) -> ContextBudget { let count_str = |s: &str| self.tokenizer.encode_with_special_tokens(s).len(); - let count_msg = |m: &Message| crate::agent::context::msg_token_count(&self.tokenizer, m); - let window = crate::agent::context::model_context_window(&self.client.model); + let count_msg = |m: &Message| crate::thought::context::msg_token_count(&self.tokenizer, m); + let window = crate::thought::context::model_context_window(&self.client.model); self.context.budget(&count_str, &count_msg, window) } @@ -326,7 +326,7 @@ impl Agent { // Handle stream errors with retry logic if let Some(e) = stream_error { let err = anyhow::anyhow!("{}", e); - if crate::agent::context::is_context_overflow(&err) && overflow_retries < 2 { + if crate::thought::context::is_context_overflow(&err) && overflow_retries < 2 { overflow_retries += 1; let _ = ui_tx.send(UiMessage::Info(format!( "[context overflow — compacting and retrying ({}/2)]", @@ -335,7 +335,7 @@ impl Agent { self.emergency_compact(); continue; } - if crate::agent::context::is_stream_error(&err) && empty_retries < 2 { + if crate::thought::context::is_stream_error(&err) && empty_retries < 2 { empty_retries += 1; let _ = ui_tx.send(UiMessage::Info(format!( "[stream error: {} — retrying ({}/2)]", @@ -790,7 +790,7 @@ impl Agent { // Walk backwards from cutoff, accumulating entries within 5% of context let count = |s: &str| self.tokenizer.encode_with_special_tokens(s).len(); - let context_window = crate::agent::context::model_context_window(&self.client.model); + let context_window = crate::thought::context::model_context_window(&self.client.model); let journal_budget = context_window * 5 / 100; dbg_log!("[journal] budget={} tokens ({}*5%)", journal_budget, context_window); @@ -976,7 +976,7 @@ impl Agent { fn do_compact(&mut self) { let conversation: Vec = self.context.entries.iter() .map(|e| e.api_message().clone()).collect(); - let messages = crate::agent::context::trim_conversation( + let messages = crate::thought::context::trim_conversation( &self.context, &conversation, &self.client.model, @@ -1040,7 +1040,7 @@ impl Agent { let n = entries.len(); let conversation: Vec = entries.iter() .map(|e| e.api_message().clone()).collect(); - let trimmed = crate::agent::context::trim_conversation( + let trimmed = crate::thought::context::trim_conversation( &self.context, &conversation, &self.client.model, diff --git a/src/bin/poc-agent.rs b/src/bin/poc-agent.rs index bbf390d..6fb9061 100644 --- a/src/bin/poc-agent.rs +++ b/src/bin/poc-agent.rs @@ -41,13 +41,13 @@ use poc_memory::agent::ui_channel::{ContextInfo, StatusInfo, StreamTarget, UiMes /// Hard compaction threshold — context is rebuilt immediately. /// Uses config percentage of model context window. fn compaction_threshold(model: &str, app: &AppConfig) -> u32 { - (context::model_context_window(model) as u32) * app.compaction.hard_threshold_pct / 100 + (poc_memory::thought::context::model_context_window(model) as u32) * app.compaction.hard_threshold_pct / 100 } /// Soft threshold — nudge the model to journal before compaction. /// Fires once; the hard threshold handles the actual rebuild. fn pre_compaction_threshold(model: &str, app: &AppConfig) -> u32 { - (context::model_context_window(model) as u32) * app.compaction.soft_threshold_pct / 100 + (poc_memory::thought::context::model_context_window(model) as u32) * app.compaction.soft_threshold_pct / 100 } #[tokio::main] diff --git a/src/thought/context.rs b/src/thought/context.rs index 32ccb90..f2266ec 100644 --- a/src/thought/context.rs +++ b/src/thought/context.rs @@ -423,6 +423,45 @@ pub fn is_stream_error(err: &anyhow::Error) -> bool { err.to_string().contains("model stream error") } +/// Trim conversation to fit within the context budget. +/// Returns the trimmed conversation messages (oldest dropped first). +pub fn trim_conversation( + context: &ContextState, + conversation: &[Message], + model: &str, + tokenizer: &CoreBPE, +) -> Vec { + let count = |s: &str| tokenizer.encode_with_special_tokens(s).len(); + let max_tokens = context_budget_tokens(model); + + let identity_cost = count(&context.system_prompt) + + context.personality.iter().map(|(_, c)| count(c)).sum::(); + let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum(); + let reserve = max_tokens / 4; + let available = max_tokens + .saturating_sub(identity_cost) + .saturating_sub(journal_cost) + .saturating_sub(reserve); + + let msg_costs: Vec = conversation.iter() + .map(|m| msg_token_count(tokenizer, m)).collect(); + let total: usize = msg_costs.iter().sum(); + + let mut skip = 0; + let mut trimmed = total; + while trimmed > available && skip < conversation.len() { + trimmed -= msg_costs[skip]; + skip += 1; + } + + // Walk forward to user message boundary + while skip < conversation.len() && conversation[skip].role != Role::User { + skip += 1; + } + + conversation[skip..].to_vec() +} + fn parse_msg_timestamp(msg: &Message) -> Option> { msg.timestamp .as_ref()