// context.rs — Context window management // // Token counting, conversation trimming, and error classification. // Journal entries are loaded from the memory graph store, not from // a flat file — the parse functions are gone. use std::sync::{Arc, RwLock}; use crate::agent::api::types::*; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use tiktoken_rs::CoreBPE; use crate::agent::tools::working_stack; /// A section of the context window, possibly with children. #[derive(Debug, Clone)] pub struct ContextSection { pub name: String, pub tokens: usize, pub content: String, pub children: Vec, } /// Shared, live context state — agent writes, TUI reads for the debug screen. pub type SharedContextState = Arc>>; /// Create a new shared context state. pub fn shared_context_state() -> SharedContextState { Arc::new(RwLock::new(Vec::new())) } /// A single journal entry with its timestamp and content. #[derive(Debug, Clone)] pub struct JournalEntry { pub timestamp: DateTime, pub content: String, } /// Context window size in tokens (from config). pub fn context_window() -> usize { crate::config::get().api_context_window } /// Context budget in tokens: 80% of the model's context window. /// The remaining 20% is reserved for model output. fn context_budget_tokens() -> usize { context_window() * 80 / 100 } /// Dedup and trim conversation entries to fit within the context budget. /// /// 1. Dedup: if the same memory key appears multiple times, keep only /// the latest render (drop the earlier Memory entry and its /// corresponding assistant tool_call message). /// 2. Trim: drop oldest entries until the conversation fits, snapping /// to user message boundaries. pub fn trim_entries( context: &ContextState, entries: &[ConversationEntry], tokenizer: &CoreBPE, ) -> Vec { // --- Phase 1: dedup memory entries by key (keep last) --- let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new(); let mut drop_indices: std::collections::HashSet = std::collections::HashSet::new(); for (i, entry) in entries.iter().enumerate() { if let ConversationEntry::Memory { key, .. } = entry { if let Some(prev) = seen_keys.insert(key.as_str(), i) { drop_indices.insert(prev); } } } let deduped: Vec = entries.iter().enumerate() .filter(|(i, _)| !drop_indices.contains(i)) .map(|(_, e)| e.clone()) .collect(); // --- Phase 2: trim to fit context budget --- // Everything in the context window is a message. Count them all, // trim entries until the total fits. let max_tokens = context_budget_tokens(); let count_msg = |m: &Message| msg_token_count(tokenizer, m); let fixed_cost = count_msg(&Message::system(&context.system_prompt)) + count_msg(&Message::user(context.render_context_message())) + count_msg(&Message::user(render_journal(&context.journal))); let msg_costs: Vec = deduped.iter() .map(|e| count_msg(e.api_message())).collect(); let entry_total: usize = msg_costs.iter().sum(); let total: usize = fixed_cost + entry_total; let mem_tokens: usize = deduped.iter().zip(&msg_costs) .filter(|(e, _)| e.is_memory()) .map(|(_, &c)| c).sum(); let conv_tokens: usize = entry_total - mem_tokens; dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}", max_tokens, fixed_cost, mem_tokens, conv_tokens, total, deduped.len()); // Phase 2a: evict all DMN entries first — they're ephemeral let mut drop = vec![false; deduped.len()]; let mut trimmed = total; let mut cur_mem = mem_tokens; for i in 0..deduped.len() { if deduped[i].is_dmn() { drop[i] = true; trimmed -= msg_costs[i]; } } // Phase 2b: if memories > 50% of entries, evict oldest memories if cur_mem > conv_tokens && trimmed > max_tokens { for i in 0..deduped.len() { if drop[i] { continue; } if !deduped[i].is_memory() { continue; } if cur_mem <= conv_tokens { break; } if trimmed <= max_tokens { break; } drop[i] = true; trimmed -= msg_costs[i]; cur_mem -= msg_costs[i]; } } // Phase 2b: drop oldest entries until under budget for i in 0..deduped.len() { if trimmed <= max_tokens { break; } if drop[i] { continue; } drop[i] = true; trimmed -= msg_costs[i]; } // Walk forward to include complete conversation boundaries let mut result: Vec = Vec::new(); let mut skipping = true; for (i, entry) in deduped.into_iter().enumerate() { if skipping { if drop[i] { continue; } // Snap to user message boundary if entry.message().role != Role::User { continue; } skipping = false; } result.push(entry); } dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed); result } /// Count the token footprint of a message using BPE tokenization. pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize { let count = |s: &str| tokenizer.encode_with_special_tokens(s).len(); let content = msg.content.as_ref().map_or(0, |c| match c { MessageContent::Text(s) => count(s), MessageContent::Parts(parts) => parts.iter() .map(|p| match p { ContentPart::Text { text } => count(text), ContentPart::ImageUrl { .. } => 85, }) .sum(), }); let tools = msg.tool_calls.as_ref().map_or(0, |calls| { calls.iter() .map(|c| count(&c.function.arguments) + count(&c.function.name)) .sum() }); content + tools } /// Detect context window overflow errors from the API. pub fn is_context_overflow(err: &anyhow::Error) -> bool { let msg = err.to_string().to_lowercase(); msg.contains("context length") || msg.contains("token limit") || msg.contains("too many tokens") || msg.contains("maximum context") || msg.contains("prompt is too long") || msg.contains("request too large") || msg.contains("input validation error") || msg.contains("content length limit") || (msg.contains("400") && msg.contains("tokens")) } /// Detect model/provider errors delivered inside the SSE stream. pub fn is_stream_error(err: &anyhow::Error) -> bool { err.to_string().contains("model stream error") } // --- Context state types --- /// Conversation entry — either a regular message or memory content. /// Memory entries preserve the original message for KV cache round-tripping. #[derive(Debug, Clone, PartialEq)] pub enum ConversationEntry { Message(Message), Memory { key: String, message: Message }, /// DMN heartbeat/autonomous prompt — evicted aggressively during compaction. Dmn(Message), } // Custom serde: serialize Memory with a "memory_key" field added to the message, // plain messages serialize as-is. This keeps the conversation log readable. impl Serialize for ConversationEntry { fn serialize(&self, s: S) -> Result { use serde::ser::SerializeMap; match self { Self::Message(m) | Self::Dmn(m) => m.serialize(s), Self::Memory { key, message } => { let json = serde_json::to_value(message).map_err(serde::ser::Error::custom)?; let mut map = s.serialize_map(None)?; if let serde_json::Value::Object(obj) = json { for (k, v) in obj { map.serialize_entry(&k, &v)?; } } map.serialize_entry("memory_key", key)?; map.end() } } } } impl<'de> Deserialize<'de> for ConversationEntry { fn deserialize>(d: D) -> Result { let mut json: serde_json::Value = serde_json::Value::deserialize(d)?; if let Some(key) = json.as_object_mut().and_then(|o| o.remove("memory_key")) { let key = key.as_str().unwrap_or("").to_string(); let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?; Ok(Self::Memory { key, message }) } else { let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?; Ok(Self::Message(message)) } } } impl ConversationEntry { /// Get the API message for sending to the model. pub fn api_message(&self) -> &Message { match self { Self::Message(m) | Self::Dmn(m) => m, Self::Memory { message, .. } => message, } } pub fn is_memory(&self) -> bool { matches!(self, Self::Memory { .. }) } pub fn is_dmn(&self) -> bool { matches!(self, Self::Dmn(_)) } /// Get a reference to the inner message. pub fn message(&self) -> &Message { match self { Self::Message(m) | Self::Dmn(m) => m, Self::Memory { message, .. } => message, } } /// Get a mutable reference to the inner message. pub fn message_mut(&mut self) -> &mut Message { match self { Self::Message(m) | Self::Dmn(m) => m, Self::Memory { message, .. } => message, } } } #[derive(Clone)] pub struct ContextState { pub system_prompt: String, pub personality: Vec<(String, String)>, pub journal: Vec, pub working_stack: Vec, /// Conversation entries — messages and memory, interleaved in order. /// Does NOT include system prompt, personality, or journal. pub entries: Vec, } pub fn render_journal(entries: &[JournalEntry]) -> String { if entries.is_empty() { return String::new(); } let mut text = String::from("[Earlier — from your journal]\n\n"); for entry in entries { use std::fmt::Write; writeln!(text, "## {}\n{}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content).ok(); } text } impl ContextState { pub fn render_context_message(&self) -> String { let mut parts: Vec = self.personality.iter() .map(|(name, content)| format!("## {}\n\n{}", name, content)) .collect(); let instructions = std::fs::read_to_string(working_stack::instructions_path()).unwrap_or_default(); let mut stack_section = instructions; if self.working_stack.is_empty() { stack_section.push_str("\n## Current stack\n\n(empty)\n"); } else { stack_section.push_str("\n## Current stack\n\n"); for (i, item) in self.working_stack.iter().enumerate() { if i == self.working_stack.len() - 1 { stack_section.push_str(&format!("→ {}\n", item)); } else { stack_section.push_str(&format!(" [{}] {}\n", i, item)); } } } parts.push(stack_section); parts.join("\n\n---\n\n") } } /// Total tokens used across all context sections. pub fn sections_used(sections: &[ContextSection]) -> usize { sections.iter().map(|s| s.tokens).sum() } /// Budget status string derived from context sections. pub fn sections_budget_string(sections: &[ContextSection]) -> String { let window = context_window(); if window == 0 { return String::new(); } let used: usize = sections.iter().map(|s| s.tokens).sum(); let free = window.saturating_sub(used); let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) }; let parts: Vec = sections.iter() .map(|s| { // Short label from section name let label = match s.name.as_str() { n if n.starts_with("System") => "sys", n if n.starts_with("Personality") => "id", n if n.starts_with("Journal") => "jnl", n if n.starts_with("Working") => "stack", n if n.starts_with("Memory") => "mem", n if n.starts_with("Conversation") => "conv", _ => return String::new(), }; format!("{}:{}%", label, pct(s.tokens)) }) .filter(|s| !s.is_empty()) .collect(); format!("{} free:{}%", parts.join(" "), pct(free)) }