refactor: extract context building into context.rs

Move context window building functions from agent.rs to context.rs:
- build_context_window, plan_context, render_journal_text, assemble_context
- truncate_at_section, find_journal_cutoff, msg_token_count_fn
- model_context_window, context_budget_tokens
- is_context_overflow, is_stream_error, msg_token_count

Also moved ContextPlan struct to types.rs.

Net: -307 lines in agent.rs, +232 in context.rs, +62 in types.rs
This commit is contained in:
Kent Overstreet 2026-03-21 15:42:44 -04:00
parent e79f17c2c8
commit d04d41e993

232
poc-agent/src/context.rs Normal file
View file

@ -0,0 +1,232 @@
// context.rs — Context window building and management
//
// Pure functions for building the agent's context window from journal
// entries and conversation messages. No mutable state.
use crate::journal;
use crate::types::{ContextPlan, ContextState, Message};
use chrono::{DateTime, Utc};
use tiktoken_rs::CoreBPE;
/// Build a context window from conversation messages + journal entries.
pub fn build_context_window(
context: &ContextState,
conversation: &[Message],
model: &str,
tokenizer: &CoreBPE,
) -> (Vec<Message>, String) {
let journal_path = journal::default_journal_path();
let all_entries = journal::parse_journal(&journal_path);
crate::dbglog!("[ctx] {} journal entries from {}", all_entries.len(), journal_path.display());
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
let system_prompt = context.system_prompt.clone();
let context_message = context.render_context_message();
let max_tokens = context_budget_tokens(model);
let memory_cap = max_tokens / 2;
let memory_tokens = count(&context_message);
let context_message = if memory_tokens > memory_cap {
crate::dbglog!("[ctx] memory too large: {} tokens > {} cap, truncating", memory_tokens, memory_cap);
truncate_at_section(&context_message, memory_cap, &count)
} else {
context_message
};
let recent_start = find_journal_cutoff(conversation, all_entries.last());
let recent = &conversation[recent_start..];
let plan = plan_context(&system_prompt, &context_message, recent, &all_entries, model, &count);
let journal_text = render_journal_text(&all_entries, &plan);
let messages = assemble_context(system_prompt, context_message, &journal_text, recent, &plan);
(messages, journal_text)
}
pub fn model_context_window(model: &str) -> usize {
let m = model.to_lowercase();
if m.contains("opus") || m.contains("sonnet") { 200_000 }
else if m.contains("qwen") { 131_072 }
else { 128_000 }
}
fn context_budget_tokens(model: &str) -> usize {
model_context_window(model) * 60 / 100
}
fn plan_context(
system_prompt: &str,
context_message: &str,
recent: &[Message],
entries: &[journal::JournalEntry],
model: &str,
count: &dyn Fn(&str) -> usize,
) -> ContextPlan {
let max_tokens = context_budget_tokens(model);
let identity_cost = count(system_prompt);
let memory_cost = count(context_message);
let reserve = max_tokens / 4;
let available = max_tokens.saturating_sub(identity_cost).saturating_sub(memory_cost).saturating_sub(reserve);
let conv_costs: Vec<usize> = recent.iter().map(|m| msg_token_count_fn(m, count)).collect();
let total_conv: usize = conv_costs.iter().sum();
let journal_min = available * 15 / 100;
let journal_budget = available.saturating_sub(total_conv).max(journal_min);
let full_budget = journal_budget * 70 / 100;
let header_budget = journal_budget.saturating_sub(full_budget);
let mut full_used = 0;
let mut n_full = 0;
for entry in entries.iter().rev() {
let cost = count(&entry.content) + 10;
if full_used + cost > full_budget { break; }
full_used += cost;
n_full += 1;
}
let full_start = entries.len().saturating_sub(n_full);
let mut header_used = 0;
let mut n_headers = 0;
for entry in entries[..full_start].iter().rev() {
let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)");
let cost = count(first_line) + 10;
if header_used + cost > header_budget { break; }
header_used += cost;
n_headers += 1;
}
let header_start = full_start.saturating_sub(n_headers);
let journal_used = full_used + header_used;
let mut conv_trim = 0;
let mut trimmed_conv = total_conv;
while trimmed_conv + journal_used > available && conv_trim < recent.len() {
trimmed_conv -= conv_costs[conv_trim];
conv_trim += 1;
}
while conv_trim < recent.len() && recent[conv_trim].role != crate::types::Role::User {
conv_trim += 1;
}
ContextPlan {
header_start, full_start, entry_count: entries.len(), conv_trim,
_conv_count: recent.len(), _full_tokens: full_used, _header_tokens: header_used,
_conv_tokens: trimmed_conv, _available: available,
}
}
fn render_journal_text(entries: &[journal::JournalEntry], plan: &ContextPlan) -> String {
if plan.header_start >= plan.entry_count { return String::new(); }
let mut text = String::from("[Earlier in this conversation — from your journal]\n\n");
for entry in &entries[plan.header_start..plan.full_start] {
let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)");
text.push_str(&format!("## {}{}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), first_line));
}
let n_headers = plan.full_start - plan.header_start;
let n_full = plan.entry_count - plan.full_start;
if n_headers > 0 && n_full > 0 { text.push_str("\n---\n\n"); }
for entry in &entries[plan.full_start..] {
text.push_str(&format!("## {}\n\n{}\n\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content));
}
text
}
fn assemble_context(
system_prompt: String,
context_message: String,
journal_text: &str,
recent: &[Message],
plan: &ContextPlan,
) -> Vec<Message> {
let mut messages = vec![Message::system(system_prompt)];
if !context_message.is_empty() { messages.push(Message::user(context_message)); }
let final_recent = &recent[plan.conv_trim..];
if !journal_text.is_empty() {
messages.push(Message::user(journal_text.to_string()));
} else if !final_recent.is_empty() {
messages.push(Message::user(
"Your context was just rebuilt. Memory files have been \
reloaded. Your recent conversation continues below. \
Earlier context is in your journal and memory files."
.to_string(),
));
}
messages.extend(final_recent.iter().cloned());
messages
}
fn truncate_at_section(text: &str, max_tokens: usize, count: &dyn Fn(&str) -> usize) -> String {
let mut boundaries = vec![0usize];
for (i, line) in text.lines().enumerate() {
if line.trim() == "---" || line.starts_with("## ") {
let offset = text.lines().take(i).map(|l| l.len() + 1).sum::<usize>();
boundaries.push(offset);
}
}
boundaries.push(text.len());
let mut best = 0;
for &end in &boundaries[1..] {
let slice = &text[..end];
if count(slice) <= max_tokens { best = end; }
else { break; }
}
if best == 0 { best = text.len().min(max_tokens * 3); }
let truncated = &text[..best];
crate::dbglog!("[ctx] truncated memory from {} to {} chars ({} tokens)", text.len(), truncated.len(), count(truncated));
truncated.to_string()
}
fn find_journal_cutoff(conversation: &[Message], newest_entry: Option<&journal::JournalEntry>) -> usize {
let cutoff = match newest_entry { Some(entry) => entry.timestamp, None => return 0 };
let mut split = conversation.len();
for (i, msg) in conversation.iter().enumerate() {
if let Some(ts) = parse_msg_timestamp(msg) {
if ts > cutoff { split = i; break; }
}
}
while split > 0 && split < conversation.len() && conversation[split].role != crate::types::Role::User {
split -= 1;
}
split
}
fn msg_token_count_fn(msg: &Message, count: &dyn Fn(&str) -> usize) -> usize {
let content = msg.content.as_ref().map_or(0, |c| match c {
crate::types::MessageContent::Text(s) => count(s),
crate::types::MessageContent::Parts(parts) => parts.iter().map(|p| match p {
crate::types::ContentPart::Text { text } => count(text),
crate::types::ContentPart::ImageUrl { .. } => 85,
}).sum(),
});
let tools = msg.tool_calls.as_ref().map_or(0, |calls| calls.iter().map(|c| count(&c.function.arguments) + count(&c.function.name)).sum());
content + tools
}
fn parse_msg_timestamp(msg: &Message) -> Option<DateTime<Utc>> {
msg.timestamp.as_ref().and_then(|ts| DateTime::parse_from_rfc3339(ts).ok()).map(|dt| dt.with_timezone(&Utc))
}
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
let msg = err.to_string().to_lowercase();
msg.contains("context length") || msg.contains("token limit") || msg.contains("too many tokens")
|| msg.contains("maximum context") || msg.contains("prompt is too long") || msg.contains("request too large")
|| msg.contains("input validation error") || msg.contains("content length limit")
|| (msg.contains("400") && msg.contains("tokens"))
}
pub fn is_stream_error(err: &anyhow::Error) -> bool {
err.to_string().contains("model stream error")
}
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
msg_token_count_fn(msg, &|s| tokenizer.encode_with_special_tokens(s).len())
}