src/thought -> src/agent
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
39d6ca3fe0
commit
2f0c7ce5c2
21 changed files with 57 additions and 141 deletions
125
src/agent/context.rs
Normal file
125
src/agent/context.rs
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// context.rs — Context window management
|
||||
//
|
||||
// Token counting, conversation trimming, and error classification.
|
||||
// Journal entries are loaded from the memory graph store, not from
|
||||
// a flat file — the parse functions are gone.
|
||||
|
||||
use crate::user::types::*;
|
||||
use chrono::{DateTime, Utc};
|
||||
use tiktoken_rs::CoreBPE;
|
||||
|
||||
/// A single journal entry with its timestamp and content.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JournalEntry {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Context window size in tokens (from config).
|
||||
pub fn context_window() -> usize {
|
||||
crate::config::get().api_context_window
|
||||
}
|
||||
|
||||
/// Context budget in tokens: 80% of the model's context window.
|
||||
/// The remaining 20% is reserved for model output.
|
||||
fn context_budget_tokens() -> usize {
|
||||
context_window() * 80 / 100
|
||||
}
|
||||
|
||||
/// Dedup and trim conversation entries to fit within the context budget.
|
||||
///
|
||||
/// 1. Dedup: if the same memory key appears multiple times, keep only
|
||||
/// the latest render (drop the earlier Memory entry and its
|
||||
/// corresponding assistant tool_call message).
|
||||
/// 2. Trim: drop oldest entries until the conversation fits, snapping
|
||||
/// to user message boundaries.
|
||||
pub fn trim_entries(
|
||||
context: &ContextState,
|
||||
entries: &[ConversationEntry],
|
||||
tokenizer: &CoreBPE,
|
||||
) -> Vec<ConversationEntry> {
|
||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||
|
||||
// --- Phase 1: dedup memory entries by key (keep last) ---
|
||||
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||||
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||
|
||||
for (i, entry) in entries.iter().enumerate() {
|
||||
if let ConversationEntry::Memory { key, .. } = entry {
|
||||
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
||||
drop_indices.insert(prev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let deduped: Vec<ConversationEntry> = entries.iter().enumerate()
|
||||
.filter(|(i, _)| !drop_indices.contains(i))
|
||||
.map(|(_, e)| e.clone())
|
||||
.collect();
|
||||
|
||||
// --- Phase 2: trim to fit context budget ---
|
||||
let max_tokens = context_budget_tokens();
|
||||
let identity_cost = count(&context.system_prompt)
|
||||
+ context.personality.iter().map(|(_, c)| count(c)).sum::<usize>();
|
||||
let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum();
|
||||
let available = max_tokens
|
||||
.saturating_sub(identity_cost)
|
||||
.saturating_sub(journal_cost);
|
||||
|
||||
let msg_costs: Vec<usize> = deduped.iter()
|
||||
.map(|e| msg_token_count(tokenizer, e.message())).collect();
|
||||
let total: usize = msg_costs.iter().sum();
|
||||
|
||||
let mut skip = 0;
|
||||
let mut trimmed = total;
|
||||
while trimmed > available && skip < deduped.len() {
|
||||
trimmed -= msg_costs[skip];
|
||||
skip += 1;
|
||||
}
|
||||
|
||||
// Walk forward to user message boundary
|
||||
while skip < deduped.len() && deduped[skip].message().role != Role::User {
|
||||
skip += 1;
|
||||
}
|
||||
|
||||
deduped[skip..].to_vec()
|
||||
}
|
||||
|
||||
/// Count the token footprint of a message using BPE tokenization.
|
||||
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||
let content = msg.content.as_ref().map_or(0, |c| match c {
|
||||
MessageContent::Text(s) => count(s),
|
||||
MessageContent::Parts(parts) => parts.iter()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => count(text),
|
||||
ContentPart::ImageUrl { .. } => 85,
|
||||
})
|
||||
.sum(),
|
||||
});
|
||||
let tools = msg.tool_calls.as_ref().map_or(0, |calls| {
|
||||
calls.iter()
|
||||
.map(|c| count(&c.function.arguments) + count(&c.function.name))
|
||||
.sum()
|
||||
});
|
||||
content + tools
|
||||
}
|
||||
|
||||
/// Detect context window overflow errors from the API.
|
||||
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
||||
let msg = err.to_string().to_lowercase();
|
||||
msg.contains("context length")
|
||||
|| msg.contains("token limit")
|
||||
|| msg.contains("too many tokens")
|
||||
|| msg.contains("maximum context")
|
||||
|| msg.contains("prompt is too long")
|
||||
|| msg.contains("request too large")
|
||||
|| msg.contains("input validation error")
|
||||
|| msg.contains("content length limit")
|
||||
|| (msg.contains("400") && msg.contains("tokens"))
|
||||
}
|
||||
|
||||
/// Detect model/provider errors delivered inside the SSE stream.
|
||||
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||||
err.to_string().contains("model stream error")
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue