Simplify trim_entries, kill ContextBudget

trim_entries is now a simple loop:
1. Drop duplicate memories and DMN entries
2. While over budget: if memories > 50% of entry tokens, drop
   lowest-scored memory; otherwise drop oldest conversation entry
3. Snap to user message boundary

ContextBudget is gone — sections already have cached token totals:
- total_tokens() on ContextState replaces budget.total()
- format_budget() on ContextState replaces budget.format()
- trim() takes fixed_tokens: usize (system + identity + journal)

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
Kent Overstreet 2026-04-07 20:55:35 -04:00
parent 62996e27d7
commit b892cae2be
4 changed files with 68 additions and 133 deletions

View file

@ -99,8 +99,8 @@ impl ContextSection {
}
/// Dedup and trim entries to fit within context budget.
pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) {
let result = trim_entries(&self.entries, tokenizer, budget);
pub fn trim(&mut self, fixed_tokens: usize) {
let result = trim_entries(&self.entries, fixed_tokens);
self.entries = result;
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
}
@ -130,6 +130,19 @@ impl ContextState {
+ self.journal.tokens() + self.conversation.tokens()
}
/// Budget status string for debug logging.
pub fn format_budget(&self) -> String {
let window = context_window();
if window == 0 { return String::new(); }
let used = self.total_tokens();
let free = window.saturating_sub(used);
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
format!("sys:{}% id:{}% jnl:{}% conv:{}% free:{}%",
pct(self.system.tokens()), pct(self.identity.tokens()),
pct(self.journal.tokens()), pct(self.conversation.tokens()),
pct(free))
}
/// All sections as a slice for iteration.
pub fn sections(&self) -> [&ContextSection; 4] {
[&self.system, &self.identity, &self.journal, &self.conversation]
@ -149,112 +162,76 @@ fn context_budget_tokens() -> usize {
/// Dedup and trim conversation entries to fit within the context budget.
///
/// 1. Dedup: if the same memory key appears multiple times, keep only
/// the latest render (drop the earlier Memory entry and its
/// corresponding assistant tool_call message).
/// 2. Trim: drop oldest entries until the conversation fits, snapping
/// to user message boundaries.
fn trim_entries(
entries: &[ContextEntry],
_tokenizer: &CoreBPE,
budget: &ContextBudget,
) -> Vec<ContextEntry> {
let fixed_tokens = budget.system + budget.identity + budget.journal;
// --- Phase 1: dedup memory entries by key (keep last) ---
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
/// Phase 2: While over budget, drop lowest-scored memory (or if memories
/// are under 50%, drop oldest conversation entry).
fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec<ContextEntry> {
let max_tokens = context_budget_tokens();
// Phase 1: dedup memories by key (keep last), drop DMN entries
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
for (i, ce) in entries.iter().enumerate() {
if let ConversationEntry::Memory { key, .. } = &ce.entry {
if ce.entry.is_dmn() {
drop_indices.insert(i);
} else if let ConversationEntry::Memory { key, .. } = &ce.entry {
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
drop_indices.insert(prev);
}
}
}
let deduped: Vec<ContextEntry> = entries.iter().enumerate()
let mut result: Vec<ContextEntry> = entries.iter().enumerate()
.filter(|(i, _)| !drop_indices.contains(i))
.map(|(_, e)| e.clone())
.collect();
// --- Phase 2: trim to fit context budget ---
let max_tokens = context_budget_tokens();
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::<usize>() };
let mem_total = |r: &[ContextEntry]| -> usize {
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum()
};
let msg_costs: Vec<usize> = deduped.iter().map(|e| e.tokens).collect();
let entry_total: usize = msg_costs.iter().sum();
let total: usize = fixed_tokens + entry_total;
dbglog!("[trim] max={} fixed={} total={} entries={}",
max_tokens, fixed_tokens, fixed_tokens + entry_total(&result), result.len());
let mem_tokens: usize = deduped.iter()
.filter(|ce| ce.entry.is_memory())
.map(|ce| ce.tokens).sum();
let conv_tokens: usize = entry_total - mem_tokens;
// Phase 2: while over budget, evict
while fixed_tokens + entry_total(&result) > max_tokens {
let mt = mem_total(&result);
let ct = entry_total(&result) - mt;
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
max_tokens, fixed_tokens, mem_tokens, conv_tokens, total, deduped.len());
// Phase 2a: evict all DMN entries first — they're ephemeral
let mut drop = vec![false; deduped.len()];
let mut trimmed = total;
let mut cur_mem = mem_tokens;
for i in 0..deduped.len() {
if deduped[i].entry.is_dmn() {
drop[i] = true;
trimmed -= msg_costs[i];
if mt > ct && let Some(i) = lowest_scored_memory(&result) {
// If memories > 50% of entry tokens, drop lowest-scored memory
result.remove(i);
} else if let Some(i) = result.iter().position(|e| !e.entry.is_memory()) {
// Otherwise drop oldest conversation entry
result.remove(i);
} else {
break;
}
}
// Phase 2b: if memories > 50% of context, evict lowest-scored first
if cur_mem > conv_tokens && trimmed > max_tokens {
let mut mem_indices: Vec<usize> = (0..deduped.len())
.filter(|&i| !drop[i] && deduped[i].entry.is_memory())
.collect();
mem_indices.sort_by(|&a, &b| {
let sa = match &deduped[a].entry {
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
_ => 0.0,
};
let sb = match &deduped[b].entry {
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
_ => 0.0,
};
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
});
for i in mem_indices {
if cur_mem <= conv_tokens { break; }
if trimmed <= max_tokens { break; }
drop[i] = true;
trimmed -= msg_costs[i];
cur_mem -= msg_costs[i];
}
// Snap to user message boundary at the start
while let Some(first) = result.first() {
if first.entry.message().role == Role::User { break; }
result.remove(0);
}
// Phase 2c: if still over, drop oldest conversation entries
for i in 0..deduped.len() {
if trimmed <= max_tokens { break; }
if drop[i] { continue; }
drop[i] = true;
trimmed -= msg_costs[i];
}
// Walk forward to include complete conversation boundaries
let mut result: Vec<ContextEntry> = Vec::new();
let mut skipping = true;
for (i, ce) in deduped.into_iter().enumerate() {
if skipping {
if drop[i] { continue; }
// Snap to user message boundary
if ce.entry.message().role != Role::User { continue; }
skipping = false;
}
result.push(ce);
}
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
dbglog!("[trim] result={} total={}", result.len(), fixed_tokens + entry_total(&result));
result
}
fn lowest_scored_memory(entries: &[ContextEntry]) -> Option<usize> {
entries.iter().enumerate()
.filter(|(_, e)| e.entry.is_memory())
.min_by(|(_, a), (_, b)| {
let sa = match &a.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
let sb = match &b.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
/// Count the token footprint of a message using BPE tokenization.
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
@ -433,30 +410,3 @@ impl ContextState {
parts.join("\n\n---\n\n")
}
}
/// Token budget per context category — cheap to compute, no formatting.
pub struct ContextBudget {
pub system: usize,
pub identity: usize,
pub journal: usize,
pub memory: usize,
pub conversation: usize,
}
impl ContextBudget {
pub fn total(&self) -> usize {
self.system + self.identity + self.journal + self.memory + self.conversation
}
pub fn format(&self) -> String {
let window = context_window();
if window == 0 { return String::new(); }
let used = self.total();
let free = window.saturating_sub(used);
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
format!("sys:{}% id:{}% jnl:{}% mem:{}% conv:{}% free:{}%",
pct(self.system), pct(self.identity), pct(self.journal),
pct(self.memory), pct(self.conversation), pct(free))
}
}