Simplify trim_entries, kill ContextBudget

trim_entries is now a simple loop:
1. Drop duplicate memories and DMN entries
2. While over budget: if memories > 50% of entry tokens, drop
   lowest-scored memory; otherwise drop oldest conversation entry
3. Snap to user message boundary

ContextBudget is gone — sections already have cached token totals:
- total_tokens() on ContextState replaces budget.total()
- format_budget() on ContextState replaces budget.format()
- trim() takes fixed_tokens: usize (system + identity + journal)

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
Kent Overstreet 2026-04-07 20:55:35 -04:00
parent 62996e27d7
commit b892cae2be
4 changed files with 68 additions and 133 deletions

View file

@ -99,8 +99,8 @@ impl ContextSection {
}
/// Dedup and trim entries to fit within context budget.
pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) {
let result = trim_entries(&self.entries, tokenizer, budget);
pub fn trim(&mut self, fixed_tokens: usize) {
let result = trim_entries(&self.entries, fixed_tokens);
self.entries = result;
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
}
@ -130,6 +130,19 @@ impl ContextState {
+ self.journal.tokens() + self.conversation.tokens()
}
/// Budget status string for debug logging.
pub fn format_budget(&self) -> String {
let window = context_window();
if window == 0 { return String::new(); }
let used = self.total_tokens();
let free = window.saturating_sub(used);
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
format!("sys:{}% id:{}% jnl:{}% conv:{}% free:{}%",
pct(self.system.tokens()), pct(self.identity.tokens()),
pct(self.journal.tokens()), pct(self.conversation.tokens()),
pct(free))
}
/// All sections as a slice for iteration.
pub fn sections(&self) -> [&ContextSection; 4] {
[&self.system, &self.identity, &self.journal, &self.conversation]
@ -149,112 +162,76 @@ fn context_budget_tokens() -> usize {
/// Dedup and trim conversation entries to fit within the context budget.
///
/// 1. Dedup: if the same memory key appears multiple times, keep only
/// the latest render (drop the earlier Memory entry and its
/// corresponding assistant tool_call message).
/// 2. Trim: drop oldest entries until the conversation fits, snapping
/// to user message boundaries.
fn trim_entries(
entries: &[ContextEntry],
_tokenizer: &CoreBPE,
budget: &ContextBudget,
) -> Vec<ContextEntry> {
let fixed_tokens = budget.system + budget.identity + budget.journal;
// --- Phase 1: dedup memory entries by key (keep last) ---
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
/// Phase 2: While over budget, drop lowest-scored memory (or if memories
/// are under 50%, drop oldest conversation entry).
fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec<ContextEntry> {
let max_tokens = context_budget_tokens();
// Phase 1: dedup memories by key (keep last), drop DMN entries
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
for (i, ce) in entries.iter().enumerate() {
if let ConversationEntry::Memory { key, .. } = &ce.entry {
if ce.entry.is_dmn() {
drop_indices.insert(i);
} else if let ConversationEntry::Memory { key, .. } = &ce.entry {
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
drop_indices.insert(prev);
}
}
}
let deduped: Vec<ContextEntry> = entries.iter().enumerate()
let mut result: Vec<ContextEntry> = entries.iter().enumerate()
.filter(|(i, _)| !drop_indices.contains(i))
.map(|(_, e)| e.clone())
.collect();
// --- Phase 2: trim to fit context budget ---
let max_tokens = context_budget_tokens();
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::<usize>() };
let mem_total = |r: &[ContextEntry]| -> usize {
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum()
};
let msg_costs: Vec<usize> = deduped.iter().map(|e| e.tokens).collect();
let entry_total: usize = msg_costs.iter().sum();
let total: usize = fixed_tokens + entry_total;
dbglog!("[trim] max={} fixed={} total={} entries={}",
max_tokens, fixed_tokens, fixed_tokens + entry_total(&result), result.len());
let mem_tokens: usize = deduped.iter()
.filter(|ce| ce.entry.is_memory())
.map(|ce| ce.tokens).sum();
let conv_tokens: usize = entry_total - mem_tokens;
// Phase 2: while over budget, evict
while fixed_tokens + entry_total(&result) > max_tokens {
let mt = mem_total(&result);
let ct = entry_total(&result) - mt;
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
max_tokens, fixed_tokens, mem_tokens, conv_tokens, total, deduped.len());
// Phase 2a: evict all DMN entries first — they're ephemeral
let mut drop = vec![false; deduped.len()];
let mut trimmed = total;
let mut cur_mem = mem_tokens;
for i in 0..deduped.len() {
if deduped[i].entry.is_dmn() {
drop[i] = true;
trimmed -= msg_costs[i];
if mt > ct && let Some(i) = lowest_scored_memory(&result) {
// If memories > 50% of entry tokens, drop lowest-scored memory
result.remove(i);
} else if let Some(i) = result.iter().position(|e| !e.entry.is_memory()) {
// Otherwise drop oldest conversation entry
result.remove(i);
} else {
break;
}
}
// Phase 2b: if memories > 50% of context, evict lowest-scored first
if cur_mem > conv_tokens && trimmed > max_tokens {
let mut mem_indices: Vec<usize> = (0..deduped.len())
.filter(|&i| !drop[i] && deduped[i].entry.is_memory())
.collect();
mem_indices.sort_by(|&a, &b| {
let sa = match &deduped[a].entry {
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
_ => 0.0,
};
let sb = match &deduped[b].entry {
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
_ => 0.0,
};
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
});
for i in mem_indices {
if cur_mem <= conv_tokens { break; }
if trimmed <= max_tokens { break; }
drop[i] = true;
trimmed -= msg_costs[i];
cur_mem -= msg_costs[i];
}
// Snap to user message boundary at the start
while let Some(first) = result.first() {
if first.entry.message().role == Role::User { break; }
result.remove(0);
}
// Phase 2c: if still over, drop oldest conversation entries
for i in 0..deduped.len() {
if trimmed <= max_tokens { break; }
if drop[i] { continue; }
drop[i] = true;
trimmed -= msg_costs[i];
}
// Walk forward to include complete conversation boundaries
let mut result: Vec<ContextEntry> = Vec::new();
let mut skipping = true;
for (i, ce) in deduped.into_iter().enumerate() {
if skipping {
if drop[i] { continue; }
// Snap to user message boundary
if ce.entry.message().role != Role::User { continue; }
skipping = false;
}
result.push(ce);
}
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
dbglog!("[trim] result={} total={}", result.len(), fixed_tokens + entry_total(&result));
result
}
fn lowest_scored_memory(entries: &[ContextEntry]) -> Option<usize> {
entries.iter().enumerate()
.filter(|(_, e)| e.entry.is_memory())
.min_by(|(_, a), (_, b)| {
let sa = match &a.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
let sb = match &b.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
/// Count the token footprint of a message using BPE tokenization.
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
@ -433,30 +410,3 @@ impl ContextState {
parts.join("\n\n---\n\n")
}
}
/// Token budget per context category — cheap to compute, no formatting.
pub struct ContextBudget {
pub system: usize,
pub identity: usize,
pub journal: usize,
pub memory: usize,
pub conversation: usize,
}
impl ContextBudget {
pub fn total(&self) -> usize {
self.system + self.identity + self.journal + self.memory + self.conversation
}
pub fn format(&self) -> String {
let window = context_window();
if window == 0 { return String::new(); }
let used = self.total();
let free = window.saturating_sub(used);
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
format!("sys:{}% id:{}% jnl:{}% mem:{}% conv:{}% free:{}%",
pct(self.system), pct(self.identity), pct(self.journal),
pct(self.memory), pct(self.conversation), pct(free))
}
}

View file

@ -690,21 +690,6 @@ impl Agent {
self.push_message(Message::tool_result(&call.id, &output));
}
/// Token budget by category — just reads cached section totals.
pub fn context_budget(&self) -> context::ContextBudget {
let memory: usize = self.context.conversation.entries().iter()
.filter(|e| e.entry.is_memory())
.map(|e| e.tokens)
.sum();
let conv_total = self.context.conversation.tokens();
context::ContextBudget {
system: self.context.system.tokens(),
identity: self.context.identity.tokens(),
journal: self.context.journal.tokens(),
memory,
conversation: conv_total - memory,
}
}
/// Context state sections — just returns references to the live data.
pub fn context_sections(&self) -> [&ContextSection; 4] {
@ -907,8 +892,9 @@ impl Agent {
self.load_startup_journal();
// Dedup memory, trim to budget
let budget = self.context_budget();
self.context.conversation.trim(&budget, &self.tokenizer);
let fixed = self.context.system.tokens() + self.context.identity.tokens()
+ self.context.journal.tokens();
self.context.conversation.trim(fixed);
let after = self.context.conversation.len();
let after_mem = self.context.conversation.entries().iter()
@ -920,8 +906,7 @@ impl Agent {
self.generation += 1;
self.last_prompt_tokens = 0;
let budget = self.context_budget();
dbglog!("[compact] budget: {}", budget.format());
dbglog!("[compact] budget: {}", self.context.format_budget());
}
/// Restore from the conversation log. Builds the context window
@ -960,7 +945,7 @@ impl Agent {
self.context.conversation.set_entries(all);
self.compact();
// Estimate prompt tokens so status bar isn't 0 on startup
self.last_prompt_tokens = self.context_budget().total() as u32;
self.last_prompt_tokens = self.context.total_tokens() as u32;
true
}

View file

@ -337,7 +337,7 @@ impl Mind {
MindCommand::Compact => {
let threshold = compaction_threshold(&self.config.app) as usize;
let mut ag = self.agent.lock().await;
if ag.context_budget().total() > threshold {
if ag.context.total_tokens() > threshold {
ag.compact();
ag.notify("compacted");
}
@ -437,7 +437,7 @@ impl Mind {
// Compact if over budget before sending
let threshold = compaction_threshold(&self.config.app) as usize;
if ag.context_budget().total() > threshold {
if ag.context.total_tokens() > threshold {
ag.compact();
ag.notify("compacted");
}

View file

@ -863,7 +863,7 @@ impl ScreenView for InteractScreen {
agent.expire_activities();
app.status.prompt_tokens = agent.last_prompt_tokens();
app.status.model = agent.model().to_string();
app.status.context_budget = agent.context_budget().format();
app.status.context_budget = agent.context.format_budget();
app.activity = agent.activities.last()
.map(|a| a.label.clone())
.unwrap_or_default();