refactor: extract context building into context.rs
Move context window building functions from agent.rs to context.rs: - build_context_window, plan_context, render_journal_text, assemble_context - truncate_at_section, find_journal_cutoff, msg_token_count_fn - model_context_window, context_budget_tokens - is_context_overflow, is_stream_error, msg_token_count Also moved ContextPlan struct to types.rs. Net: -307 lines in agent.rs, +232 in context.rs, +62 in types.rs
This commit is contained in:
parent
e79f17c2c8
commit
d04d41e993
1 changed files with 232 additions and 0 deletions
232
poc-agent/src/context.rs
Normal file
232
poc-agent/src/context.rs
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
// context.rs — Context window building and management
|
||||||
|
//
|
||||||
|
// Pure functions for building the agent's context window from journal
|
||||||
|
// entries and conversation messages. No mutable state.
|
||||||
|
|
||||||
|
use crate::journal;
|
||||||
|
use crate::types::{ContextPlan, ContextState, Message};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use tiktoken_rs::CoreBPE;
|
||||||
|
|
||||||
|
/// Build a context window from conversation messages + journal entries.
|
||||||
|
pub fn build_context_window(
|
||||||
|
context: &ContextState,
|
||||||
|
conversation: &[Message],
|
||||||
|
model: &str,
|
||||||
|
tokenizer: &CoreBPE,
|
||||||
|
) -> (Vec<Message>, String) {
|
||||||
|
let journal_path = journal::default_journal_path();
|
||||||
|
let all_entries = journal::parse_journal(&journal_path);
|
||||||
|
crate::dbglog!("[ctx] {} journal entries from {}", all_entries.len(), journal_path.display());
|
||||||
|
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||||
|
|
||||||
|
let system_prompt = context.system_prompt.clone();
|
||||||
|
let context_message = context.render_context_message();
|
||||||
|
|
||||||
|
let max_tokens = context_budget_tokens(model);
|
||||||
|
let memory_cap = max_tokens / 2;
|
||||||
|
let memory_tokens = count(&context_message);
|
||||||
|
let context_message = if memory_tokens > memory_cap {
|
||||||
|
crate::dbglog!("[ctx] memory too large: {} tokens > {} cap, truncating", memory_tokens, memory_cap);
|
||||||
|
truncate_at_section(&context_message, memory_cap, &count)
|
||||||
|
} else {
|
||||||
|
context_message
|
||||||
|
};
|
||||||
|
|
||||||
|
let recent_start = find_journal_cutoff(conversation, all_entries.last());
|
||||||
|
let recent = &conversation[recent_start..];
|
||||||
|
|
||||||
|
let plan = plan_context(&system_prompt, &context_message, recent, &all_entries, model, &count);
|
||||||
|
let journal_text = render_journal_text(&all_entries, &plan);
|
||||||
|
|
||||||
|
let messages = assemble_context(system_prompt, context_message, &journal_text, recent, &plan);
|
||||||
|
(messages, journal_text)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model_context_window(model: &str) -> usize {
|
||||||
|
let m = model.to_lowercase();
|
||||||
|
if m.contains("opus") || m.contains("sonnet") { 200_000 }
|
||||||
|
else if m.contains("qwen") { 131_072 }
|
||||||
|
else { 128_000 }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn context_budget_tokens(model: &str) -> usize {
|
||||||
|
model_context_window(model) * 60 / 100
|
||||||
|
}
|
||||||
|
|
||||||
|
fn plan_context(
|
||||||
|
system_prompt: &str,
|
||||||
|
context_message: &str,
|
||||||
|
recent: &[Message],
|
||||||
|
entries: &[journal::JournalEntry],
|
||||||
|
model: &str,
|
||||||
|
count: &dyn Fn(&str) -> usize,
|
||||||
|
) -> ContextPlan {
|
||||||
|
let max_tokens = context_budget_tokens(model);
|
||||||
|
let identity_cost = count(system_prompt);
|
||||||
|
let memory_cost = count(context_message);
|
||||||
|
let reserve = max_tokens / 4;
|
||||||
|
let available = max_tokens.saturating_sub(identity_cost).saturating_sub(memory_cost).saturating_sub(reserve);
|
||||||
|
|
||||||
|
let conv_costs: Vec<usize> = recent.iter().map(|m| msg_token_count_fn(m, count)).collect();
|
||||||
|
let total_conv: usize = conv_costs.iter().sum();
|
||||||
|
|
||||||
|
let journal_min = available * 15 / 100;
|
||||||
|
let journal_budget = available.saturating_sub(total_conv).max(journal_min);
|
||||||
|
let full_budget = journal_budget * 70 / 100;
|
||||||
|
let header_budget = journal_budget.saturating_sub(full_budget);
|
||||||
|
|
||||||
|
let mut full_used = 0;
|
||||||
|
let mut n_full = 0;
|
||||||
|
for entry in entries.iter().rev() {
|
||||||
|
let cost = count(&entry.content) + 10;
|
||||||
|
if full_used + cost > full_budget { break; }
|
||||||
|
full_used += cost;
|
||||||
|
n_full += 1;
|
||||||
|
}
|
||||||
|
let full_start = entries.len().saturating_sub(n_full);
|
||||||
|
|
||||||
|
let mut header_used = 0;
|
||||||
|
let mut n_headers = 0;
|
||||||
|
for entry in entries[..full_start].iter().rev() {
|
||||||
|
let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)");
|
||||||
|
let cost = count(first_line) + 10;
|
||||||
|
if header_used + cost > header_budget { break; }
|
||||||
|
header_used += cost;
|
||||||
|
n_headers += 1;
|
||||||
|
}
|
||||||
|
let header_start = full_start.saturating_sub(n_headers);
|
||||||
|
|
||||||
|
let journal_used = full_used + header_used;
|
||||||
|
let mut conv_trim = 0;
|
||||||
|
let mut trimmed_conv = total_conv;
|
||||||
|
while trimmed_conv + journal_used > available && conv_trim < recent.len() {
|
||||||
|
trimmed_conv -= conv_costs[conv_trim];
|
||||||
|
conv_trim += 1;
|
||||||
|
}
|
||||||
|
while conv_trim < recent.len() && recent[conv_trim].role != crate::types::Role::User {
|
||||||
|
conv_trim += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
ContextPlan {
|
||||||
|
header_start, full_start, entry_count: entries.len(), conv_trim,
|
||||||
|
_conv_count: recent.len(), _full_tokens: full_used, _header_tokens: header_used,
|
||||||
|
_conv_tokens: trimmed_conv, _available: available,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_journal_text(entries: &[journal::JournalEntry], plan: &ContextPlan) -> String {
|
||||||
|
if plan.header_start >= plan.entry_count { return String::new(); }
|
||||||
|
|
||||||
|
let mut text = String::from("[Earlier in this conversation — from your journal]\n\n");
|
||||||
|
|
||||||
|
for entry in &entries[plan.header_start..plan.full_start] {
|
||||||
|
let first_line = entry.content.lines().find(|l| !l.trim().is_empty()).unwrap_or("(empty)");
|
||||||
|
text.push_str(&format!("## {} — {}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), first_line));
|
||||||
|
}
|
||||||
|
|
||||||
|
let n_headers = plan.full_start - plan.header_start;
|
||||||
|
let n_full = plan.entry_count - plan.full_start;
|
||||||
|
if n_headers > 0 && n_full > 0 { text.push_str("\n---\n\n"); }
|
||||||
|
|
||||||
|
for entry in &entries[plan.full_start..] {
|
||||||
|
text.push_str(&format!("## {}\n\n{}\n\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content));
|
||||||
|
}
|
||||||
|
text
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assemble_context(
|
||||||
|
system_prompt: String,
|
||||||
|
context_message: String,
|
||||||
|
journal_text: &str,
|
||||||
|
recent: &[Message],
|
||||||
|
plan: &ContextPlan,
|
||||||
|
) -> Vec<Message> {
|
||||||
|
let mut messages = vec![Message::system(system_prompt)];
|
||||||
|
if !context_message.is_empty() { messages.push(Message::user(context_message)); }
|
||||||
|
|
||||||
|
let final_recent = &recent[plan.conv_trim..];
|
||||||
|
|
||||||
|
if !journal_text.is_empty() {
|
||||||
|
messages.push(Message::user(journal_text.to_string()));
|
||||||
|
} else if !final_recent.is_empty() {
|
||||||
|
messages.push(Message::user(
|
||||||
|
"Your context was just rebuilt. Memory files have been \
|
||||||
|
reloaded. Your recent conversation continues below. \
|
||||||
|
Earlier context is in your journal and memory files."
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
messages.extend(final_recent.iter().cloned());
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_at_section(text: &str, max_tokens: usize, count: &dyn Fn(&str) -> usize) -> String {
|
||||||
|
let mut boundaries = vec![0usize];
|
||||||
|
for (i, line) in text.lines().enumerate() {
|
||||||
|
if line.trim() == "---" || line.starts_with("## ") {
|
||||||
|
let offset = text.lines().take(i).map(|l| l.len() + 1).sum::<usize>();
|
||||||
|
boundaries.push(offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boundaries.push(text.len());
|
||||||
|
|
||||||
|
let mut best = 0;
|
||||||
|
for &end in &boundaries[1..] {
|
||||||
|
let slice = &text[..end];
|
||||||
|
if count(slice) <= max_tokens { best = end; }
|
||||||
|
else { break; }
|
||||||
|
}
|
||||||
|
if best == 0 { best = text.len().min(max_tokens * 3); }
|
||||||
|
|
||||||
|
let truncated = &text[..best];
|
||||||
|
crate::dbglog!("[ctx] truncated memory from {} to {} chars ({} tokens)", text.len(), truncated.len(), count(truncated));
|
||||||
|
truncated.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_journal_cutoff(conversation: &[Message], newest_entry: Option<&journal::JournalEntry>) -> usize {
|
||||||
|
let cutoff = match newest_entry { Some(entry) => entry.timestamp, None => return 0 };
|
||||||
|
|
||||||
|
let mut split = conversation.len();
|
||||||
|
for (i, msg) in conversation.iter().enumerate() {
|
||||||
|
if let Some(ts) = parse_msg_timestamp(msg) {
|
||||||
|
if ts > cutoff { split = i; break; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while split > 0 && split < conversation.len() && conversation[split].role != crate::types::Role::User {
|
||||||
|
split -= 1;
|
||||||
|
}
|
||||||
|
split
|
||||||
|
}
|
||||||
|
|
||||||
|
fn msg_token_count_fn(msg: &Message, count: &dyn Fn(&str) -> usize) -> usize {
|
||||||
|
let content = msg.content.as_ref().map_or(0, |c| match c {
|
||||||
|
crate::types::MessageContent::Text(s) => count(s),
|
||||||
|
crate::types::MessageContent::Parts(parts) => parts.iter().map(|p| match p {
|
||||||
|
crate::types::ContentPart::Text { text } => count(text),
|
||||||
|
crate::types::ContentPart::ImageUrl { .. } => 85,
|
||||||
|
}).sum(),
|
||||||
|
});
|
||||||
|
let tools = msg.tool_calls.as_ref().map_or(0, |calls| calls.iter().map(|c| count(&c.function.arguments) + count(&c.function.name)).sum());
|
||||||
|
content + tools
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_msg_timestamp(msg: &Message) -> Option<DateTime<Utc>> {
|
||||||
|
msg.timestamp.as_ref().and_then(|ts| DateTime::parse_from_rfc3339(ts).ok()).map(|dt| dt.with_timezone(&Utc))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
||||||
|
let msg = err.to_string().to_lowercase();
|
||||||
|
msg.contains("context length") || msg.contains("token limit") || msg.contains("too many tokens")
|
||||||
|
|| msg.contains("maximum context") || msg.contains("prompt is too long") || msg.contains("request too large")
|
||||||
|
|| msg.contains("input validation error") || msg.contains("content length limit")
|
||||||
|
|| (msg.contains("400") && msg.contains("tokens"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||||||
|
err.to_string().contains("model stream error")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
||||||
|
msg_token_count_fn(msg, &|s| tokenizer.encode_with_special_tokens(s).len())
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue