consciousness/src/thought/context.rs

105 lines
3.5 KiB
Rust
Raw Normal View History

// context.rs — Context window management
//
// Token counting, conversation trimming, and error classification.
// Journal entries are loaded from the memory graph store, not from
// a flat file — the parse functions are gone.
use crate::agent::types::*;
use chrono::{DateTime, Utc};
use tiktoken_rs::CoreBPE;
/// A single journal entry with its timestamp and content.
#[derive(Debug, Clone)]
pub struct JournalEntry {
pub timestamp: DateTime<Utc>,
pub content: String,
}
/// Look up a model's context window size in tokens.
pub fn model_context_window(_model: &str) -> usize {
crate::config::get().api_context_window
}
/// Context budget in tokens: 60% of the model's context window.
fn context_budget_tokens(model: &str) -> usize {
model_context_window(model) * 60 / 100
}
/// Trim conversation to fit within the context budget.
/// Returns the trimmed conversation messages (oldest dropped first).
pub fn trim_conversation(
context: &ContextState,
conversation: &[Message],
model: &str,
tokenizer: &CoreBPE,
) -> Vec<Message> {
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
let max_tokens = context_budget_tokens(model);
let identity_cost = count(&context.system_prompt)
+ context.personality.iter().map(|(_, c)| count(c)).sum::<usize>();
let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum();
let reserve = max_tokens / 4;
let available = max_tokens
.saturating_sub(identity_cost)
.saturating_sub(journal_cost)
.saturating_sub(reserve);
let msg_costs: Vec<usize> = conversation.iter()
.map(|m| msg_token_count(tokenizer, m)).collect();
let total: usize = msg_costs.iter().sum();
let mut skip = 0;
let mut trimmed = total;
while trimmed > available && skip < conversation.len() {
trimmed -= msg_costs[skip];
skip += 1;
}
// Walk forward to user message boundary
while skip < conversation.len() && conversation[skip].role != Role::User {
skip += 1;
}
conversation[skip..].to_vec()
}
/// Count the token footprint of a message using BPE tokenization.
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
let content = msg.content.as_ref().map_or(0, |c| match c {
MessageContent::Text(s) => count(s),
MessageContent::Parts(parts) => parts.iter()
.map(|p| match p {
ContentPart::Text { text } => count(text),
ContentPart::ImageUrl { .. } => 85,
})
.sum(),
});
let tools = msg.tool_calls.as_ref().map_or(0, |calls| {
calls.iter()
.map(|c| count(&c.function.arguments) + count(&c.function.name))
.sum()
});
content + tools
}
/// Detect context window overflow errors from the API.
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
let msg = err.to_string().to_lowercase();
msg.contains("context length")
|| msg.contains("token limit")
|| msg.contains("too many tokens")
|| msg.contains("maximum context")
|| msg.contains("prompt is too long")
|| msg.contains("request too large")
|| msg.contains("input validation error")
|| msg.contains("content length limit")
|| (msg.contains("400") && msg.contains("tokens"))
}
/// Detect model/provider errors delivered inside the SSE stream.
pub fn is_stream_error(err: &anyhow::Error) -> bool {
err.to_string().contains("model stream error")
}