Simplify trim_entries, kill ContextBudget
trim_entries is now a simple loop: 1. Drop duplicate memories and DMN entries 2. While over budget: if memories > 50% of entry tokens, drop lowest-scored memory; otherwise drop oldest conversation entry 3. Snap to user message boundary ContextBudget is gone — sections already have cached token totals: - total_tokens() on ContextState replaces budget.total() - format_budget() on ContextState replaces budget.format() - trim() takes fixed_tokens: usize (system + identity + journal) Co-Authored-By: Proof of Concept <poc@bcachefs.org> Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
62996e27d7
commit
b892cae2be
4 changed files with 68 additions and 133 deletions
|
|
@ -99,8 +99,8 @@ impl ContextSection {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dedup and trim entries to fit within context budget.
|
/// Dedup and trim entries to fit within context budget.
|
||||||
pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) {
|
pub fn trim(&mut self, fixed_tokens: usize) {
|
||||||
let result = trim_entries(&self.entries, tokenizer, budget);
|
let result = trim_entries(&self.entries, fixed_tokens);
|
||||||
self.entries = result;
|
self.entries = result;
|
||||||
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
|
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +130,19 @@ impl ContextState {
|
||||||
+ self.journal.tokens() + self.conversation.tokens()
|
+ self.journal.tokens() + self.conversation.tokens()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Budget status string for debug logging.
|
||||||
|
pub fn format_budget(&self) -> String {
|
||||||
|
let window = context_window();
|
||||||
|
if window == 0 { return String::new(); }
|
||||||
|
let used = self.total_tokens();
|
||||||
|
let free = window.saturating_sub(used);
|
||||||
|
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
|
||||||
|
format!("sys:{}% id:{}% jnl:{}% conv:{}% free:{}%",
|
||||||
|
pct(self.system.tokens()), pct(self.identity.tokens()),
|
||||||
|
pct(self.journal.tokens()), pct(self.conversation.tokens()),
|
||||||
|
pct(free))
|
||||||
|
}
|
||||||
|
|
||||||
/// All sections as a slice for iteration.
|
/// All sections as a slice for iteration.
|
||||||
pub fn sections(&self) -> [&ContextSection; 4] {
|
pub fn sections(&self) -> [&ContextSection; 4] {
|
||||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||||
|
|
@ -149,112 +162,76 @@ fn context_budget_tokens() -> usize {
|
||||||
|
|
||||||
/// Dedup and trim conversation entries to fit within the context budget.
|
/// Dedup and trim conversation entries to fit within the context budget.
|
||||||
///
|
///
|
||||||
/// 1. Dedup: if the same memory key appears multiple times, keep only
|
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
|
||||||
/// the latest render (drop the earlier Memory entry and its
|
/// Phase 2: While over budget, drop lowest-scored memory (or if memories
|
||||||
/// corresponding assistant tool_call message).
|
/// are under 50%, drop oldest conversation entry).
|
||||||
/// 2. Trim: drop oldest entries until the conversation fits, snapping
|
fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec<ContextEntry> {
|
||||||
/// to user message boundaries.
|
let max_tokens = context_budget_tokens();
|
||||||
fn trim_entries(
|
|
||||||
entries: &[ContextEntry],
|
// Phase 1: dedup memories by key (keep last), drop DMN entries
|
||||||
_tokenizer: &CoreBPE,
|
|
||||||
budget: &ContextBudget,
|
|
||||||
) -> Vec<ContextEntry> {
|
|
||||||
let fixed_tokens = budget.system + budget.identity + budget.journal;
|
|
||||||
// --- Phase 1: dedup memory entries by key (keep last) ---
|
|
||||||
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||||||
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||||
|
|
||||||
for (i, ce) in entries.iter().enumerate() {
|
for (i, ce) in entries.iter().enumerate() {
|
||||||
if let ConversationEntry::Memory { key, .. } = &ce.entry {
|
if ce.entry.is_dmn() {
|
||||||
|
drop_indices.insert(i);
|
||||||
|
} else if let ConversationEntry::Memory { key, .. } = &ce.entry {
|
||||||
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
||||||
drop_indices.insert(prev);
|
drop_indices.insert(prev);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let deduped: Vec<ContextEntry> = entries.iter().enumerate()
|
let mut result: Vec<ContextEntry> = entries.iter().enumerate()
|
||||||
.filter(|(i, _)| !drop_indices.contains(i))
|
.filter(|(i, _)| !drop_indices.contains(i))
|
||||||
.map(|(_, e)| e.clone())
|
.map(|(_, e)| e.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// --- Phase 2: trim to fit context budget ---
|
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::<usize>() };
|
||||||
let max_tokens = context_budget_tokens();
|
let mem_total = |r: &[ContextEntry]| -> usize {
|
||||||
|
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum()
|
||||||
|
};
|
||||||
|
|
||||||
let msg_costs: Vec<usize> = deduped.iter().map(|e| e.tokens).collect();
|
dbglog!("[trim] max={} fixed={} total={} entries={}",
|
||||||
let entry_total: usize = msg_costs.iter().sum();
|
max_tokens, fixed_tokens, fixed_tokens + entry_total(&result), result.len());
|
||||||
let total: usize = fixed_tokens + entry_total;
|
|
||||||
|
|
||||||
let mem_tokens: usize = deduped.iter()
|
// Phase 2: while over budget, evict
|
||||||
.filter(|ce| ce.entry.is_memory())
|
while fixed_tokens + entry_total(&result) > max_tokens {
|
||||||
.map(|ce| ce.tokens).sum();
|
let mt = mem_total(&result);
|
||||||
let conv_tokens: usize = entry_total - mem_tokens;
|
let ct = entry_total(&result) - mt;
|
||||||
|
|
||||||
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
|
if mt > ct && let Some(i) = lowest_scored_memory(&result) {
|
||||||
max_tokens, fixed_tokens, mem_tokens, conv_tokens, total, deduped.len());
|
// If memories > 50% of entry tokens, drop lowest-scored memory
|
||||||
|
result.remove(i);
|
||||||
// Phase 2a: evict all DMN entries first — they're ephemeral
|
} else if let Some(i) = result.iter().position(|e| !e.entry.is_memory()) {
|
||||||
let mut drop = vec![false; deduped.len()];
|
// Otherwise drop oldest conversation entry
|
||||||
let mut trimmed = total;
|
result.remove(i);
|
||||||
let mut cur_mem = mem_tokens;
|
} else {
|
||||||
|
break;
|
||||||
for i in 0..deduped.len() {
|
|
||||||
if deduped[i].entry.is_dmn() {
|
|
||||||
drop[i] = true;
|
|
||||||
trimmed -= msg_costs[i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 2b: if memories > 50% of context, evict lowest-scored first
|
// Snap to user message boundary at the start
|
||||||
if cur_mem > conv_tokens && trimmed > max_tokens {
|
while let Some(first) = result.first() {
|
||||||
let mut mem_indices: Vec<usize> = (0..deduped.len())
|
if first.entry.message().role == Role::User { break; }
|
||||||
.filter(|&i| !drop[i] && deduped[i].entry.is_memory())
|
result.remove(0);
|
||||||
.collect();
|
|
||||||
mem_indices.sort_by(|&a, &b| {
|
|
||||||
let sa = match &deduped[a].entry {
|
|
||||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
|
||||||
_ => 0.0,
|
|
||||||
};
|
|
||||||
let sb = match &deduped[b].entry {
|
|
||||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
|
||||||
_ => 0.0,
|
|
||||||
};
|
|
||||||
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
});
|
|
||||||
for i in mem_indices {
|
|
||||||
if cur_mem <= conv_tokens { break; }
|
|
||||||
if trimmed <= max_tokens { break; }
|
|
||||||
drop[i] = true;
|
|
||||||
trimmed -= msg_costs[i];
|
|
||||||
cur_mem -= msg_costs[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 2c: if still over, drop oldest conversation entries
|
dbglog!("[trim] result={} total={}", result.len(), fixed_tokens + entry_total(&result));
|
||||||
for i in 0..deduped.len() {
|
|
||||||
if trimmed <= max_tokens { break; }
|
|
||||||
if drop[i] { continue; }
|
|
||||||
drop[i] = true;
|
|
||||||
trimmed -= msg_costs[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk forward to include complete conversation boundaries
|
|
||||||
let mut result: Vec<ContextEntry> = Vec::new();
|
|
||||||
let mut skipping = true;
|
|
||||||
for (i, ce) in deduped.into_iter().enumerate() {
|
|
||||||
if skipping {
|
|
||||||
if drop[i] { continue; }
|
|
||||||
// Snap to user message boundary
|
|
||||||
if ce.entry.message().role != Role::User { continue; }
|
|
||||||
skipping = false;
|
|
||||||
}
|
|
||||||
result.push(ce);
|
|
||||||
}
|
|
||||||
|
|
||||||
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
|
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn lowest_scored_memory(entries: &[ContextEntry]) -> Option<usize> {
|
||||||
|
entries.iter().enumerate()
|
||||||
|
.filter(|(_, e)| e.entry.is_memory())
|
||||||
|
.min_by(|(_, a), (_, b)| {
|
||||||
|
let sa = match &a.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
|
||||||
|
let sb = match &b.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 };
|
||||||
|
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
}
|
||||||
|
|
||||||
/// Count the token footprint of a message using BPE tokenization.
|
/// Count the token footprint of a message using BPE tokenization.
|
||||||
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
||||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||||
|
|
@ -433,30 +410,3 @@ impl ContextState {
|
||||||
parts.join("\n\n---\n\n")
|
parts.join("\n\n---\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Token budget per context category — cheap to compute, no formatting.
|
|
||||||
pub struct ContextBudget {
|
|
||||||
pub system: usize,
|
|
||||||
pub identity: usize,
|
|
||||||
pub journal: usize,
|
|
||||||
pub memory: usize,
|
|
||||||
pub conversation: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ContextBudget {
|
|
||||||
pub fn total(&self) -> usize {
|
|
||||||
self.system + self.identity + self.journal + self.memory + self.conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn format(&self) -> String {
|
|
||||||
let window = context_window();
|
|
||||||
if window == 0 { return String::new(); }
|
|
||||||
let used = self.total();
|
|
||||||
let free = window.saturating_sub(used);
|
|
||||||
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
|
|
||||||
format!("sys:{}% id:{}% jnl:{}% mem:{}% conv:{}% free:{}%",
|
|
||||||
pct(self.system), pct(self.identity), pct(self.journal),
|
|
||||||
pct(self.memory), pct(self.conversation), pct(free))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -690,21 +690,6 @@ impl Agent {
|
||||||
self.push_message(Message::tool_result(&call.id, &output));
|
self.push_message(Message::tool_result(&call.id, &output));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Token budget by category — just reads cached section totals.
|
|
||||||
pub fn context_budget(&self) -> context::ContextBudget {
|
|
||||||
let memory: usize = self.context.conversation.entries().iter()
|
|
||||||
.filter(|e| e.entry.is_memory())
|
|
||||||
.map(|e| e.tokens)
|
|
||||||
.sum();
|
|
||||||
let conv_total = self.context.conversation.tokens();
|
|
||||||
context::ContextBudget {
|
|
||||||
system: self.context.system.tokens(),
|
|
||||||
identity: self.context.identity.tokens(),
|
|
||||||
journal: self.context.journal.tokens(),
|
|
||||||
memory,
|
|
||||||
conversation: conv_total - memory,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Context state sections — just returns references to the live data.
|
/// Context state sections — just returns references to the live data.
|
||||||
pub fn context_sections(&self) -> [&ContextSection; 4] {
|
pub fn context_sections(&self) -> [&ContextSection; 4] {
|
||||||
|
|
@ -907,8 +892,9 @@ impl Agent {
|
||||||
self.load_startup_journal();
|
self.load_startup_journal();
|
||||||
|
|
||||||
// Dedup memory, trim to budget
|
// Dedup memory, trim to budget
|
||||||
let budget = self.context_budget();
|
let fixed = self.context.system.tokens() + self.context.identity.tokens()
|
||||||
self.context.conversation.trim(&budget, &self.tokenizer);
|
+ self.context.journal.tokens();
|
||||||
|
self.context.conversation.trim(fixed);
|
||||||
|
|
||||||
let after = self.context.conversation.len();
|
let after = self.context.conversation.len();
|
||||||
let after_mem = self.context.conversation.entries().iter()
|
let after_mem = self.context.conversation.entries().iter()
|
||||||
|
|
@ -920,8 +906,7 @@ impl Agent {
|
||||||
self.generation += 1;
|
self.generation += 1;
|
||||||
self.last_prompt_tokens = 0;
|
self.last_prompt_tokens = 0;
|
||||||
|
|
||||||
let budget = self.context_budget();
|
dbglog!("[compact] budget: {}", self.context.format_budget());
|
||||||
dbglog!("[compact] budget: {}", budget.format());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Restore from the conversation log. Builds the context window
|
/// Restore from the conversation log. Builds the context window
|
||||||
|
|
@ -960,7 +945,7 @@ impl Agent {
|
||||||
self.context.conversation.set_entries(all);
|
self.context.conversation.set_entries(all);
|
||||||
self.compact();
|
self.compact();
|
||||||
// Estimate prompt tokens so status bar isn't 0 on startup
|
// Estimate prompt tokens so status bar isn't 0 on startup
|
||||||
self.last_prompt_tokens = self.context_budget().total() as u32;
|
self.last_prompt_tokens = self.context.total_tokens() as u32;
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -337,7 +337,7 @@ impl Mind {
|
||||||
MindCommand::Compact => {
|
MindCommand::Compact => {
|
||||||
let threshold = compaction_threshold(&self.config.app) as usize;
|
let threshold = compaction_threshold(&self.config.app) as usize;
|
||||||
let mut ag = self.agent.lock().await;
|
let mut ag = self.agent.lock().await;
|
||||||
if ag.context_budget().total() > threshold {
|
if ag.context.total_tokens() > threshold {
|
||||||
ag.compact();
|
ag.compact();
|
||||||
ag.notify("compacted");
|
ag.notify("compacted");
|
||||||
}
|
}
|
||||||
|
|
@ -437,7 +437,7 @@ impl Mind {
|
||||||
|
|
||||||
// Compact if over budget before sending
|
// Compact if over budget before sending
|
||||||
let threshold = compaction_threshold(&self.config.app) as usize;
|
let threshold = compaction_threshold(&self.config.app) as usize;
|
||||||
if ag.context_budget().total() > threshold {
|
if ag.context.total_tokens() > threshold {
|
||||||
ag.compact();
|
ag.compact();
|
||||||
ag.notify("compacted");
|
ag.notify("compacted");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -863,7 +863,7 @@ impl ScreenView for InteractScreen {
|
||||||
agent.expire_activities();
|
agent.expire_activities();
|
||||||
app.status.prompt_tokens = agent.last_prompt_tokens();
|
app.status.prompt_tokens = agent.last_prompt_tokens();
|
||||||
app.status.model = agent.model().to_string();
|
app.status.model = agent.model().to_string();
|
||||||
app.status.context_budget = agent.context_budget().format();
|
app.status.context_budget = agent.context.format_budget();
|
||||||
app.activity = agent.activities.last()
|
app.activity = agent.activities.last()
|
||||||
.map(|a| a.label.clone())
|
.map(|a| a.label.clone())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue