Simplify trim_entries, kill ContextBudget
trim_entries is now a simple loop: 1. Drop duplicate memories and DMN entries 2. While over budget: if memories > 50% of entry tokens, drop lowest-scored memory; otherwise drop oldest conversation entry 3. Snap to user message boundary ContextBudget is gone — sections already have cached token totals: - total_tokens() on ContextState replaces budget.total() - format_budget() on ContextState replaces budget.format() - trim() takes fixed_tokens: usize (system + identity + journal) Co-Authored-By: Proof of Concept <poc@bcachefs.org> Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
62996e27d7
commit
b892cae2be
4 changed files with 68 additions and 133 deletions
|
|
@ -99,8 +99,8 @@ impl ContextSection {
|
|||
}
|
||||
|
||||
/// Dedup and trim entries to fit within context budget.
|
||||
pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) {
|
||||
let result = trim_entries(&self.entries, tokenizer, budget);
|
||||
pub fn trim(&mut self, fixed_tokens: usize) {
|
||||
let result = trim_entries(&self.entries, fixed_tokens);
|
||||
self.entries = result;
|
||||
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
|
||||
}
|
||||
|
|
@ -130,6 +130,19 @@ impl ContextState {
|
|||
+ self.journal.tokens() + self.conversation.tokens()
|
||||
}
|
||||
|
||||
/// Budget status string for debug logging.
|
||||
pub fn format_budget(&self) -> String {
|
||||
let window = context_window();
|
||||
if window == 0 { return String::new(); }
|
||||
let used = self.total_tokens();
|
||||
let free = window.saturating_sub(used);
|
||||
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
|
||||
format!("sys:{}% id:{}% jnl:{}% conv:{}% free:{}%",
|
||||
pct(self.system.tokens()), pct(self.identity.tokens()),
|
||||
pct(self.journal.tokens()), pct(self.conversation.tokens()),
|
||||
pct(free))
|
||||
}
|
||||
|
||||
/// All sections as a slice for iteration.
|
||||
pub fn sections(&self) -> [&ContextSection; 4] {
|
||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||
|
|
@ -149,112 +162,76 @@ fn context_budget_tokens() -> usize {
|
|||
|
||||
/// Dedup and trim conversation entries to fit within the context budget.
|
||||
///
|
||||
/// 1. Dedup: if the same memory key appears multiple times, keep only
|
||||
/// the latest render (drop the earlier Memory entry and its
|
||||
/// corresponding assistant tool_call message).
|
||||
/// 2. Trim: drop oldest entries until the conversation fits, snapping
|
||||
/// to user message boundaries.
|
||||
fn trim_entries(
|
||||
entries: &[ContextEntry],
|
||||
_tokenizer: &CoreBPE,
|
||||
budget: &ContextBudget,
|
||||
) -> Vec<ContextEntry> {
|
||||
let fixed_tokens = budget.system + budget.identity + budget.journal;
|
||||
// --- Phase 1: dedup memory entries by key (keep last) ---
|
||||
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
|
||||
/// Phase 2: While over budget, drop lowest-scored memory (or if memories
|
||||
/// are under 50%, drop oldest conversation entry).
|
||||
fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec<ContextEntry> {
|
||||
let max_tokens = context_budget_tokens();
|
||||
|
||||
// Phase 1: dedup memories by key (keep last), drop DMN entries
|
||||
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||||
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||
|
||||
for (i, ce) in entries.iter().enumerate() {
|
||||
if let ConversationEntry::Memory { key, .. } = &ce.entry {
|
||||
if ce.entry.is_dmn() {
|
||||
drop_indices.insert(i);
|
||||
} else if let ConversationEntry::Memory { key, .. } = &ce.entry {
|
||||
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
||||
drop_indices.insert(prev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let deduped: Vec<ContextEntry> = entries.iter().enumerate()
|
||||
let mut result: Vec<ContextEntry> = entries.iter().enumerate()
|
||||
.filter(|(i, _)| !drop_indices.contains(i))
|
||||
.map(|(_, e)| e.clone())
|
||||
.collect();
|
||||
|
||||
// --- Phase 2: trim to fit context budget ---
|
||||
let max_tokens = context_budget_tokens();
|
||||
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::<usize>() };
|
||||
let mem_total = |r: &[ContextEntry]| -> usize {
|
||||
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum()
|
||||
};
|
||||
|
||||
let msg_costs: Vec<usize> = deduped.iter().map(|e| e.tokens).collect();
|
||||
let entry_total: usize = msg_costs.iter().sum();
|
||||
let total: usize = fixed_tokens + entry_total;
|
||||
dbglog!("[trim] max={} fixed={} total={} entries={}",
|
||||
max_tokens, fixed_tokens, fixed_tokens + entry_total(&result), result.len());
|
||||
|
||||
let mem_tokens: usize = deduped.iter()
|
||||
.filter(|ce| ce.entry.is_memory())
|
||||
.map(|ce| ce.tokens).sum();
|
||||
let conv_tokens: usize = entry_total - mem_tokens;
|
||||
// Phase 2: while over budget, evict
|
||||
while fixed_tokens + entry_total(&result) > max_tokens {
|
||||
let mt = mem_total(&result);
|
||||
let ct = entry_total(&result) - mt;
|
||||
|
||||
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
|
||||
max_tokens, fixed_tokens, mem_tokens, conv_tokens, total, deduped.len());
|
||||
|
||||
// Phase 2a: evict all DMN entries first — they're ephemeral
|
||||
let mut drop = vec![false; deduped.len()];
|
||||
let mut trimmed = total;
|
||||
let mut cur_mem = mem_tokens;
|
||||
|
||||
for i in 0..deduped.len() {
|
||||
if deduped[i].entry.is_dmn() {
|
||||
drop[i] = true;
|
||||
trimmed -= msg_costs[i];
|
||||
if mt > ct && let Some(i) = lowest_scored_memory(&result) {
|
||||
// If memories > 50% of entry tokens, drop lowest-scored memory
|
||||
result.remove(i);
|
||||
} else if let Some(i) = result.iter().position(|e| !e.entry.is_memory()) {
|
||||
// Otherwise drop oldest conversation entry
|
||||
result.remove(i);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2b: if memories > 50% of context, evict lowest-scored first
|
||||
if cur_mem > conv_tokens && trimmed > max_tokens {
|
||||
let mut mem_indices: Vec<usize> = (0..deduped.len())
|
||||
.filter(|&i| !drop[i] && deduped[i].entry.is_memory())
|
||||
.collect();
|
||||
mem_indices.sort_by(|&a, &b| {
|
||||
let sa = match &deduped[a].entry {
|
||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
};
|
||||
let sb = match &deduped[b].entry {
|
||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
};
|
||||
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
for i in mem_indices {
|
||||
if cur_mem <= conv_tokens { break; }
|
||||
if trimmed <= max_tokens { break; }
|
||||
drop[i] = true;
|
||||
trimmed -= msg_costs[i];
|
||||
cur_mem -= msg_costs[i];
|
||||
}
|
||||
// Snap to user message boundary at the start
|
||||
while let Some(first) = result.first() {
|
||||
if first.entry.message().role == Role::User { break; }
|
||||
result.remove(0);
|
||||
}
|
||||
|
||||
// Phase 2c: if still over, drop oldest conversation entries
|
||||
for i in 0..deduped.len() {
|
||||
if trimmed <= max_tokens { break; }
|
||||
if drop[i] { continue; }
|
||||
drop[i] = true;
|
||||
trimmed -= msg_costs[i];
|
||||
}
|
||||
|
||||
// Walk forward to include complete conversation boundaries
|
||||
let mut result: Vec<ContextEntry> = Vec::new();
|
||||
let mut skipping = true;
|
||||
for (i, ce) in deduped.into_iter().enumerate() {
|
||||
if skipping {
|
||||
if drop[i] { continue; }
|
||||
// Snap to user message boundary
|
||||
if ce.entry.message().role != Role::User { continue; }
|
||||
skipping = false;
|
||||
}
|
||||
result.push(ce);
|
||||
}
|
||||
|
||||
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
|
||||
|
||||
dbglog!("[trim] result={} total={}", result.len(), fixed_tokens + entry_total(&result));
|
||||
result
|
||||
}
|
||||
|
||||
fn lowest_scored_memory(entries: &[ContextEntry]) -> Option<usize> {
|
||||
entries.iter().enumerate()
|
||||
.filter(|(_, e)| e.entry.is_memory())
|
||||
.min_by(|(_, a), (_, b)| {
|
||||
let sa = match &a.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
|
||||
let sb = match &b.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
|
||||
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
}
|
||||
|
||||
/// Count the token footprint of a message using BPE tokenization.
|
||||
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||
|
|
@ -433,30 +410,3 @@ impl ContextState {
|
|||
parts.join("\n\n---\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
/// Token budget per context category — cheap to compute, no formatting.
|
||||
pub struct ContextBudget {
|
||||
pub system: usize,
|
||||
pub identity: usize,
|
||||
pub journal: usize,
|
||||
pub memory: usize,
|
||||
pub conversation: usize,
|
||||
}
|
||||
|
||||
impl ContextBudget {
|
||||
pub fn total(&self) -> usize {
|
||||
self.system + self.identity + self.journal + self.memory + self.conversation
|
||||
}
|
||||
|
||||
pub fn format(&self) -> String {
|
||||
let window = context_window();
|
||||
if window == 0 { return String::new(); }
|
||||
let used = self.total();
|
||||
let free = window.saturating_sub(used);
|
||||
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
|
||||
format!("sys:{}% id:{}% jnl:{}% mem:{}% conv:{}% free:{}%",
|
||||
pct(self.system), pct(self.identity), pct(self.journal),
|
||||
pct(self.memory), pct(self.conversation), pct(free))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -690,21 +690,6 @@ impl Agent {
|
|||
self.push_message(Message::tool_result(&call.id, &output));
|
||||
}
|
||||
|
||||
/// Token budget by category — just reads cached section totals.
|
||||
pub fn context_budget(&self) -> context::ContextBudget {
|
||||
let memory: usize = self.context.conversation.entries().iter()
|
||||
.filter(|e| e.entry.is_memory())
|
||||
.map(|e| e.tokens)
|
||||
.sum();
|
||||
let conv_total = self.context.conversation.tokens();
|
||||
context::ContextBudget {
|
||||
system: self.context.system.tokens(),
|
||||
identity: self.context.identity.tokens(),
|
||||
journal: self.context.journal.tokens(),
|
||||
memory,
|
||||
conversation: conv_total - memory,
|
||||
}
|
||||
}
|
||||
|
||||
/// Context state sections — just returns references to the live data.
|
||||
pub fn context_sections(&self) -> [&ContextSection; 4] {
|
||||
|
|
@ -907,8 +892,9 @@ impl Agent {
|
|||
self.load_startup_journal();
|
||||
|
||||
// Dedup memory, trim to budget
|
||||
let budget = self.context_budget();
|
||||
self.context.conversation.trim(&budget, &self.tokenizer);
|
||||
let fixed = self.context.system.tokens() + self.context.identity.tokens()
|
||||
+ self.context.journal.tokens();
|
||||
self.context.conversation.trim(fixed);
|
||||
|
||||
let after = self.context.conversation.len();
|
||||
let after_mem = self.context.conversation.entries().iter()
|
||||
|
|
@ -920,8 +906,7 @@ impl Agent {
|
|||
self.generation += 1;
|
||||
self.last_prompt_tokens = 0;
|
||||
|
||||
let budget = self.context_budget();
|
||||
dbglog!("[compact] budget: {}", budget.format());
|
||||
dbglog!("[compact] budget: {}", self.context.format_budget());
|
||||
}
|
||||
|
||||
/// Restore from the conversation log. Builds the context window
|
||||
|
|
@ -960,7 +945,7 @@ impl Agent {
|
|||
self.context.conversation.set_entries(all);
|
||||
self.compact();
|
||||
// Estimate prompt tokens so status bar isn't 0 on startup
|
||||
self.last_prompt_tokens = self.context_budget().total() as u32;
|
||||
self.last_prompt_tokens = self.context.total_tokens() as u32;
|
||||
true
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue