move context functions from agent/context.rs to thought/context.rs
trim_conversation moved to thought/context.rs where model_context_window, msg_token_count, is_context_overflow, is_stream_error already lived. Delete the duplicate agent/context.rs (94 lines). Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
01bfbc0dad
commit
214806cb90
5 changed files with 48 additions and 104 deletions
|
|
@ -1,94 +0,0 @@
|
||||||
// context.rs — Context window management
|
|
||||||
//
|
|
||||||
// Token counting and conversation trimming for the context window.
|
|
||||||
|
|
||||||
use crate::agent::types::*;
|
|
||||||
use tiktoken_rs::CoreBPE;
|
|
||||||
|
|
||||||
/// Look up a model's context window size in tokens.
|
|
||||||
pub fn model_context_window(_model: &str) -> usize {
|
|
||||||
crate::config::get().api_context_window
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Context budget in tokens: 60% of the model's context window.
|
|
||||||
fn context_budget_tokens(model: &str) -> usize {
|
|
||||||
model_context_window(model) * 60 / 100
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trim conversation to fit within the context budget.
|
|
||||||
/// Returns the trimmed conversation messages (oldest dropped first).
|
|
||||||
pub fn trim_conversation(
|
|
||||||
context: &ContextState,
|
|
||||||
conversation: &[Message],
|
|
||||||
model: &str,
|
|
||||||
tokenizer: &CoreBPE,
|
|
||||||
) -> Vec<Message> {
|
|
||||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
|
||||||
let max_tokens = context_budget_tokens(model);
|
|
||||||
|
|
||||||
let identity_cost = count(&context.system_prompt)
|
|
||||||
+ context.personality.iter().map(|(_, c)| count(c)).sum::<usize>();
|
|
||||||
let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum();
|
|
||||||
let reserve = max_tokens / 4;
|
|
||||||
let available = max_tokens
|
|
||||||
.saturating_sub(identity_cost)
|
|
||||||
.saturating_sub(journal_cost)
|
|
||||||
.saturating_sub(reserve);
|
|
||||||
|
|
||||||
let msg_costs: Vec<usize> = conversation.iter()
|
|
||||||
.map(|m| msg_token_count(tokenizer, m)).collect();
|
|
||||||
let total: usize = msg_costs.iter().sum();
|
|
||||||
|
|
||||||
let mut skip = 0;
|
|
||||||
let mut trimmed = total;
|
|
||||||
while trimmed > available && skip < conversation.len() {
|
|
||||||
trimmed -= msg_costs[skip];
|
|
||||||
skip += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk forward to user message boundary
|
|
||||||
while skip < conversation.len() && conversation[skip].role != Role::User {
|
|
||||||
skip += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation[skip..].to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Count the token footprint of a message using BPE tokenization.
|
|
||||||
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
|
||||||
let content = msg.content.as_ref().map_or(0, |c| match c {
|
|
||||||
MessageContent::Text(s) => tokenizer.encode_with_special_tokens(s).len(),
|
|
||||||
MessageContent::Parts(parts) => parts.iter()
|
|
||||||
.map(|p| match p {
|
|
||||||
ContentPart::Text { text } => tokenizer.encode_with_special_tokens(text).len(),
|
|
||||||
ContentPart::ImageUrl { .. } => 85,
|
|
||||||
})
|
|
||||||
.sum(),
|
|
||||||
});
|
|
||||||
let tools = msg.tool_calls.as_ref().map_or(0, |calls| {
|
|
||||||
calls.iter()
|
|
||||||
.map(|c| tokenizer.encode_with_special_tokens(&c.function.arguments).len()
|
|
||||||
+ tokenizer.encode_with_special_tokens(&c.function.name).len())
|
|
||||||
.sum()
|
|
||||||
});
|
|
||||||
content + tools
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Detect context window overflow errors from the API.
|
|
||||||
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
|
||||||
let msg = err.to_string().to_lowercase();
|
|
||||||
msg.contains("context length")
|
|
||||||
|| msg.contains("token limit")
|
|
||||||
|| msg.contains("too many tokens")
|
|
||||||
|| msg.contains("maximum context")
|
|
||||||
|| msg.contains("prompt is too long")
|
|
||||||
|| msg.contains("request too large")
|
|
||||||
|| msg.contains("input validation error")
|
|
||||||
|| msg.contains("content length limit")
|
|
||||||
|| (msg.contains("400") && msg.contains("tokens"))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Detect model/provider errors delivered inside the SSE stream.
|
|
||||||
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
|
||||||
err.to_string().contains("model stream error")
|
|
||||||
}
|
|
||||||
|
|
@ -15,7 +15,6 @@ pub mod tools;
|
||||||
pub mod ui_channel;
|
pub mod ui_channel;
|
||||||
pub mod runner;
|
pub mod runner;
|
||||||
pub mod cli;
|
pub mod cli;
|
||||||
pub mod context;
|
|
||||||
pub mod dmn;
|
pub mod dmn;
|
||||||
pub mod identity;
|
pub mod identity;
|
||||||
pub mod log;
|
pub mod log;
|
||||||
|
|
|
||||||
|
|
@ -180,8 +180,8 @@ impl Agent {
|
||||||
/// every startup/compaction.
|
/// every startup/compaction.
|
||||||
pub fn budget(&self) -> ContextBudget {
|
pub fn budget(&self) -> ContextBudget {
|
||||||
let count_str = |s: &str| self.tokenizer.encode_with_special_tokens(s).len();
|
let count_str = |s: &str| self.tokenizer.encode_with_special_tokens(s).len();
|
||||||
let count_msg = |m: &Message| crate::agent::context::msg_token_count(&self.tokenizer, m);
|
let count_msg = |m: &Message| crate::thought::context::msg_token_count(&self.tokenizer, m);
|
||||||
let window = crate::agent::context::model_context_window(&self.client.model);
|
let window = crate::thought::context::model_context_window(&self.client.model);
|
||||||
self.context.budget(&count_str, &count_msg, window)
|
self.context.budget(&count_str, &count_msg, window)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -326,7 +326,7 @@ impl Agent {
|
||||||
// Handle stream errors with retry logic
|
// Handle stream errors with retry logic
|
||||||
if let Some(e) = stream_error {
|
if let Some(e) = stream_error {
|
||||||
let err = anyhow::anyhow!("{}", e);
|
let err = anyhow::anyhow!("{}", e);
|
||||||
if crate::agent::context::is_context_overflow(&err) && overflow_retries < 2 {
|
if crate::thought::context::is_context_overflow(&err) && overflow_retries < 2 {
|
||||||
overflow_retries += 1;
|
overflow_retries += 1;
|
||||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||||
"[context overflow — compacting and retrying ({}/2)]",
|
"[context overflow — compacting and retrying ({}/2)]",
|
||||||
|
|
@ -335,7 +335,7 @@ impl Agent {
|
||||||
self.emergency_compact();
|
self.emergency_compact();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if crate::agent::context::is_stream_error(&err) && empty_retries < 2 {
|
if crate::thought::context::is_stream_error(&err) && empty_retries < 2 {
|
||||||
empty_retries += 1;
|
empty_retries += 1;
|
||||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||||
"[stream error: {} — retrying ({}/2)]",
|
"[stream error: {} — retrying ({}/2)]",
|
||||||
|
|
@ -790,7 +790,7 @@ impl Agent {
|
||||||
|
|
||||||
// Walk backwards from cutoff, accumulating entries within 5% of context
|
// Walk backwards from cutoff, accumulating entries within 5% of context
|
||||||
let count = |s: &str| self.tokenizer.encode_with_special_tokens(s).len();
|
let count = |s: &str| self.tokenizer.encode_with_special_tokens(s).len();
|
||||||
let context_window = crate::agent::context::model_context_window(&self.client.model);
|
let context_window = crate::thought::context::model_context_window(&self.client.model);
|
||||||
let journal_budget = context_window * 5 / 100;
|
let journal_budget = context_window * 5 / 100;
|
||||||
dbg_log!("[journal] budget={} tokens ({}*5%)", journal_budget, context_window);
|
dbg_log!("[journal] budget={} tokens ({}*5%)", journal_budget, context_window);
|
||||||
|
|
||||||
|
|
@ -976,7 +976,7 @@ impl Agent {
|
||||||
fn do_compact(&mut self) {
|
fn do_compact(&mut self) {
|
||||||
let conversation: Vec<Message> = self.context.entries.iter()
|
let conversation: Vec<Message> = self.context.entries.iter()
|
||||||
.map(|e| e.api_message().clone()).collect();
|
.map(|e| e.api_message().clone()).collect();
|
||||||
let messages = crate::agent::context::trim_conversation(
|
let messages = crate::thought::context::trim_conversation(
|
||||||
&self.context,
|
&self.context,
|
||||||
&conversation,
|
&conversation,
|
||||||
&self.client.model,
|
&self.client.model,
|
||||||
|
|
@ -1040,7 +1040,7 @@ impl Agent {
|
||||||
let n = entries.len();
|
let n = entries.len();
|
||||||
let conversation: Vec<Message> = entries.iter()
|
let conversation: Vec<Message> = entries.iter()
|
||||||
.map(|e| e.api_message().clone()).collect();
|
.map(|e| e.api_message().clone()).collect();
|
||||||
let trimmed = crate::agent::context::trim_conversation(
|
let trimmed = crate::thought::context::trim_conversation(
|
||||||
&self.context,
|
&self.context,
|
||||||
&conversation,
|
&conversation,
|
||||||
&self.client.model,
|
&self.client.model,
|
||||||
|
|
|
||||||
|
|
@ -41,13 +41,13 @@ use poc_memory::agent::ui_channel::{ContextInfo, StatusInfo, StreamTarget, UiMes
|
||||||
/// Hard compaction threshold — context is rebuilt immediately.
|
/// Hard compaction threshold — context is rebuilt immediately.
|
||||||
/// Uses config percentage of model context window.
|
/// Uses config percentage of model context window.
|
||||||
fn compaction_threshold(model: &str, app: &AppConfig) -> u32 {
|
fn compaction_threshold(model: &str, app: &AppConfig) -> u32 {
|
||||||
(context::model_context_window(model) as u32) * app.compaction.hard_threshold_pct / 100
|
(poc_memory::thought::context::model_context_window(model) as u32) * app.compaction.hard_threshold_pct / 100
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Soft threshold — nudge the model to journal before compaction.
|
/// Soft threshold — nudge the model to journal before compaction.
|
||||||
/// Fires once; the hard threshold handles the actual rebuild.
|
/// Fires once; the hard threshold handles the actual rebuild.
|
||||||
fn pre_compaction_threshold(model: &str, app: &AppConfig) -> u32 {
|
fn pre_compaction_threshold(model: &str, app: &AppConfig) -> u32 {
|
||||||
(context::model_context_window(model) as u32) * app.compaction.soft_threshold_pct / 100
|
(poc_memory::thought::context::model_context_window(model) as u32) * app.compaction.soft_threshold_pct / 100
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|
|
||||||
|
|
@ -423,6 +423,45 @@ pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||||||
err.to_string().contains("model stream error")
|
err.to_string().contains("model stream error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trim conversation to fit within the context budget.
|
||||||
|
/// Returns the trimmed conversation messages (oldest dropped first).
|
||||||
|
pub fn trim_conversation(
|
||||||
|
context: &ContextState,
|
||||||
|
conversation: &[Message],
|
||||||
|
model: &str,
|
||||||
|
tokenizer: &CoreBPE,
|
||||||
|
) -> Vec<Message> {
|
||||||
|
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||||
|
let max_tokens = context_budget_tokens(model);
|
||||||
|
|
||||||
|
let identity_cost = count(&context.system_prompt)
|
||||||
|
+ context.personality.iter().map(|(_, c)| count(c)).sum::<usize>();
|
||||||
|
let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum();
|
||||||
|
let reserve = max_tokens / 4;
|
||||||
|
let available = max_tokens
|
||||||
|
.saturating_sub(identity_cost)
|
||||||
|
.saturating_sub(journal_cost)
|
||||||
|
.saturating_sub(reserve);
|
||||||
|
|
||||||
|
let msg_costs: Vec<usize> = conversation.iter()
|
||||||
|
.map(|m| msg_token_count(tokenizer, m)).collect();
|
||||||
|
let total: usize = msg_costs.iter().sum();
|
||||||
|
|
||||||
|
let mut skip = 0;
|
||||||
|
let mut trimmed = total;
|
||||||
|
while trimmed > available && skip < conversation.len() {
|
||||||
|
trimmed -= msg_costs[skip];
|
||||||
|
skip += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk forward to user message boundary
|
||||||
|
while skip < conversation.len() && conversation[skip].role != Role::User {
|
||||||
|
skip += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation[skip..].to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_msg_timestamp(msg: &Message) -> Option<DateTime<Utc>> {
|
fn parse_msg_timestamp(msg: &Message) -> Option<DateTime<Utc>> {
|
||||||
msg.timestamp
|
msg.timestamp
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue