diff --git a/poc-agent/src/context.rs b/poc-agent/src/context.rs new file mode 100644 index 0000000..3c7e1c9 --- /dev/null +++ b/poc-agent/src/context.rs @@ -0,0 +1,232 @@ +// context.rs — Context window building and management +// +// Pure functions for building the agent's context window from journal +// entries and conversation messages. No mutable state. + +use crate::journal; +use crate::types::{ContextPlan, ContextState, Message}; +use chrono::{DateTime, Utc}; +use tiktoken_rs::CoreBPE; + +/// Build a context window from conversation messages + journal entries. +pub fn build_context_window( + context: &ContextState, + conversation: &[Message], + model: &str, + tokenizer: &CoreBPE, +) -> (Vec, String) { + let journal_path = journal::default_journal_path(); + let all_entries = journal::parse_journal(&journal_path); + crate::dbglog!("[ctx] {} journal entries from {}", all_entries.len(), journal_path.display()); + let count = |s: &str| tokenizer.encode_with_special_tokens(s).len(); + + let system_prompt = context.system_prompt.clone(); + let context_message = context.render_context_message(); + + let max_tokens = context_budget_tokens(model); + let memory_cap = max_tokens / 2; + let memory_tokens = count(&context_message); + let context_message = if memory_tokens > memory_cap { + crate::dbglog!("[ctx] memory too large: {} tokens > {} cap, truncating", memory_tokens, memory_cap); + truncate_at_section(&context_message, memory_cap, &count) + } else { + context_message + }; + + let recent_start = find_journal_cutoff(conversation, all_entries.last()); + let recent = &conversation[recent_start..]; + + let plan = plan_context(&system_prompt, &context_message, recent, &all_entries, model, &count); + let journal_text = render_journal_text(&all_entries, &plan); + + let messages = assemble_context(system_prompt, context_message, &journal_text, recent, &plan); + (messages, journal_text) +} + +pub fn model_context_window(model: &str) -> usize { + let m = model.to_lowercase(); + if m.contains("opus") || m.contains("sonnet") { 200_000 } + else if m.contains("qwen") { 131_072 } + else { 128_000 } +} + +fn context_budget_tokens(model: &str) -> usize { + model_context_window(model) * 60 / 100 +} + +fn plan_context( + system_prompt: &str, + context_message: &str, + recent: &[Message], + entries: &[journal::JournalEntry], + model: &str, + count: &dyn Fn(&str) -> usize, +) -> ContextPlan { + let max_tokens = context_budget_tokens(model); + let identity_cost = count(system_prompt); + let memory_cost = count(context_message); + let reserve = max_tokens / 4; + let available = max_tokens.saturating_sub(identity_cost).saturating_sub(memory_cost).saturating_sub(reserve); + + let conv_costs: Vec = recent.iter().map(|m| msg_token_count_fn(m, count)).collect(); + let total_conv: usize = conv_costs.iter().sum(); + + let journal_min = available * 15 / 100; + let journal_budget = available.saturating_sub(total_conv).max(journal_min); + let full_budget = journal_budget * 70 / 100; + let header_budget = journal_budget.saturating_sub(full_budget); + + let mut full_used = 0; + let mut n_full = 0; + for entry in entries.iter().rev() { + let cost = count(&entry.content) + 10; + if full_used + cost > full_budget { break; } + full_used += cost; + n_full += 1; + } + let full_start = entries.len().saturating_sub(n_full); + + let mut header_used = 0; + let mut n_headers = 0; + for entry in entries[..full_start].iter().rev() { + let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)"); + let cost = count(first_line) + 10; + if header_used + cost > header_budget { break; } + header_used += cost; + n_headers += 1; + } + let header_start = full_start.saturating_sub(n_headers); + + let journal_used = full_used + header_used; + let mut conv_trim = 0; + let mut trimmed_conv = total_conv; + while trimmed_conv + journal_used > available && conv_trim < recent.len() { + trimmed_conv -= conv_costs[conv_trim]; + conv_trim += 1; + } + while conv_trim < recent.len() && recent[conv_trim].role != crate::types::Role::User { + conv_trim += 1; + } + + ContextPlan { + header_start, full_start, entry_count: entries.len(), conv_trim, + _conv_count: recent.len(), _full_tokens: full_used, _header_tokens: header_used, + _conv_tokens: trimmed_conv, _available: available, + } +} + +fn render_journal_text(entries: &[journal::JournalEntry], plan: &ContextPlan) -> String { + if plan.header_start >= plan.entry_count { return String::new(); } + + let mut text = String::from("[Earlier in this conversation — from your journal]\n\n"); + + for entry in &entries[plan.header_start..plan.full_start] { + let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)"); + text.push_str(&format!("## {} — {}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), first_line)); + } + + let n_headers = plan.full_start - plan.header_start; + let n_full = plan.entry_count - plan.full_start; + if n_headers > 0 && n_full > 0 { text.push_str("\n---\n\n"); } + + for entry in &entries[plan.full_start..] { + text.push_str(&format!("## {}\n\n{}\n\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content)); + } + text +} + +fn assemble_context( + system_prompt: String, + context_message: String, + journal_text: &str, + recent: &[Message], + plan: &ContextPlan, +) -> Vec { + let mut messages = vec![Message::system(system_prompt)]; + if !context_message.is_empty() { messages.push(Message::user(context_message)); } + + let final_recent = &recent[plan.conv_trim..]; + + if !journal_text.is_empty() { + messages.push(Message::user(journal_text.to_string())); + } else if !final_recent.is_empty() { + messages.push(Message::user( + "Your context was just rebuilt. Memory files have been \ + reloaded. Your recent conversation continues below. \ + Earlier context is in your journal and memory files." + .to_string(), + )); + } + messages.extend(final_recent.iter().cloned()); + messages +} + +fn truncate_at_section(text: &str, max_tokens: usize, count: &dyn Fn(&str) -> usize) -> String { + let mut boundaries = vec![0usize]; + for (i, line) in text.lines().enumerate() { + if line.trim() == "---" || line.starts_with("## ") { + let offset = text.lines().take(i).map(|l| l.len() + 1).sum::(); + boundaries.push(offset); + } + } + boundaries.push(text.len()); + + let mut best = 0; + for &end in &boundaries[1..] { + let slice = &text[..end]; + if count(slice) <= max_tokens { best = end; } + else { break; } + } + if best == 0 { best = text.len().min(max_tokens * 3); } + + let truncated = &text[..best]; + crate::dbglog!("[ctx] truncated memory from {} to {} chars ({} tokens)", text.len(), truncated.len(), count(truncated)); + truncated.to_string() +} + +fn find_journal_cutoff(conversation: &[Message], newest_entry: Option<&journal::JournalEntry>) -> usize { + let cutoff = match newest_entry { Some(entry) => entry.timestamp, None => return 0 }; + + let mut split = conversation.len(); + for (i, msg) in conversation.iter().enumerate() { + if let Some(ts) = parse_msg_timestamp(msg) { + if ts > cutoff { split = i; break; } + } + } + while split > 0 && split < conversation.len() && conversation[split].role != crate::types::Role::User { + split -= 1; + } + split +} + +fn msg_token_count_fn(msg: &Message, count: &dyn Fn(&str) -> usize) -> usize { + let content = msg.content.as_ref().map_or(0, |c| match c { + crate::types::MessageContent::Text(s) => count(s), + crate::types::MessageContent::Parts(parts) => parts.iter().map(|p| match p { + crate::types::ContentPart::Text { text } => count(text), + crate::types::ContentPart::ImageUrl { .. } => 85, + }).sum(), + }); + let tools = msg.tool_calls.as_ref().map_or(0, |calls| calls.iter().map(|c| count(&c.function.arguments) + count(&c.function.name)).sum()); + content + tools +} + +fn parse_msg_timestamp(msg: &Message) -> Option> { + msg.timestamp.as_ref().and_then(|ts| DateTime::parse_from_rfc3339(ts).ok()).map(|dt| dt.with_timezone(&Utc)) +} + +pub fn is_context_overflow(err: &anyhow::Error) -> bool { + let msg = err.to_string().to_lowercase(); + msg.contains("context length") || msg.contains("token limit") || msg.contains("too many tokens") + || msg.contains("maximum context") || msg.contains("prompt is too long") || msg.contains("request too large") + || msg.contains("input validation error") || msg.contains("content length limit") + || (msg.contains("400") && msg.contains("tokens")) +} + +pub fn is_stream_error(err: &anyhow::Error) -> bool { + err.to_string().contains("model stream error") +} + +pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize { + msg_token_count_fn(msg, &|s| tokenizer.encode_with_special_tokens(s).len()) +}