diff --git a/src/agent/context.rs b/src/agent/context.rs index e3930a3..40a1054 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -99,8 +99,8 @@ impl ContextSection { } /// Dedup and trim entries to fit within context budget. - pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) { - let result = trim_entries(&self.entries, tokenizer, budget); + pub fn trim(&mut self, fixed_tokens: usize) { + let result = trim_entries(&self.entries, fixed_tokens); self.entries = result; self.tokens = self.entries.iter().map(|e| e.tokens).sum(); } @@ -130,6 +130,19 @@ impl ContextState { + self.journal.tokens() + self.conversation.tokens() } + /// Budget status string for debug logging. + pub fn format_budget(&self) -> String { + let window = context_window(); + if window == 0 { return String::new(); } + let used = self.total_tokens(); + let free = window.saturating_sub(used); + let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) }; + format!("sys:{}% id:{}% jnl:{}% conv:{}% free:{}%", + pct(self.system.tokens()), pct(self.identity.tokens()), + pct(self.journal.tokens()), pct(self.conversation.tokens()), + pct(free)) + } + /// All sections as a slice for iteration. pub fn sections(&self) -> [&ContextSection; 4] { [&self.system, &self.identity, &self.journal, &self.conversation] @@ -149,112 +162,76 @@ fn context_budget_tokens() -> usize { /// Dedup and trim conversation entries to fit within the context budget. /// -/// 1. Dedup: if the same memory key appears multiple times, keep only -/// the latest render (drop the earlier Memory entry and its -/// corresponding assistant tool_call message). -/// 2. Trim: drop oldest entries until the conversation fits, snapping -/// to user message boundaries. -fn trim_entries( - entries: &[ContextEntry], - _tokenizer: &CoreBPE, - budget: &ContextBudget, -) -> Vec { - let fixed_tokens = budget.system + budget.identity + budget.journal; - // --- Phase 1: dedup memory entries by key (keep last) --- +/// Phase 1: Drop duplicate memories (keep last) and DMN entries. +/// Phase 2: While over budget, drop lowest-scored memory (or if memories +/// are under 50%, drop oldest conversation entry). +fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec { + let max_tokens = context_budget_tokens(); + + // Phase 1: dedup memories by key (keep last), drop DMN entries let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new(); let mut drop_indices: std::collections::HashSet = std::collections::HashSet::new(); for (i, ce) in entries.iter().enumerate() { - if let ConversationEntry::Memory { key, .. } = &ce.entry { + if ce.entry.is_dmn() { + drop_indices.insert(i); + } else if let ConversationEntry::Memory { key, .. } = &ce.entry { if let Some(prev) = seen_keys.insert(key.as_str(), i) { drop_indices.insert(prev); } } } - let deduped: Vec = entries.iter().enumerate() + let mut result: Vec = entries.iter().enumerate() .filter(|(i, _)| !drop_indices.contains(i)) .map(|(_, e)| e.clone()) .collect(); - // --- Phase 2: trim to fit context budget --- - let max_tokens = context_budget_tokens(); + let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::() }; + let mem_total = |r: &[ContextEntry]| -> usize { + r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum() + }; - let msg_costs: Vec = deduped.iter().map(|e| e.tokens).collect(); - let entry_total: usize = msg_costs.iter().sum(); - let total: usize = fixed_tokens + entry_total; + dbglog!("[trim] max={} fixed={} total={} entries={}", + max_tokens, fixed_tokens, fixed_tokens + entry_total(&result), result.len()); - let mem_tokens: usize = deduped.iter() - .filter(|ce| ce.entry.is_memory()) - .map(|ce| ce.tokens).sum(); - let conv_tokens: usize = entry_total - mem_tokens; + // Phase 2: while over budget, evict + while fixed_tokens + entry_total(&result) > max_tokens { + let mt = mem_total(&result); + let ct = entry_total(&result) - mt; - dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}", - max_tokens, fixed_tokens, mem_tokens, conv_tokens, total, deduped.len()); - - // Phase 2a: evict all DMN entries first — they're ephemeral - let mut drop = vec![false; deduped.len()]; - let mut trimmed = total; - let mut cur_mem = mem_tokens; - - for i in 0..deduped.len() { - if deduped[i].entry.is_dmn() { - drop[i] = true; - trimmed -= msg_costs[i]; + if mt > ct && let Some(i) = lowest_scored_memory(&result) { + // If memories > 50% of entry tokens, drop lowest-scored memory + result.remove(i); + } else if let Some(i) = result.iter().position(|e| !e.entry.is_memory()) { + // Otherwise drop oldest conversation entry + result.remove(i); + } else { + break; } } - // Phase 2b: if memories > 50% of context, evict lowest-scored first - if cur_mem > conv_tokens && trimmed > max_tokens { - let mut mem_indices: Vec = (0..deduped.len()) - .filter(|&i| !drop[i] && deduped[i].entry.is_memory()) - .collect(); - mem_indices.sort_by(|&a, &b| { - let sa = match &deduped[a].entry { - ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), - _ => 0.0, - }; - let sb = match &deduped[b].entry { - ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), - _ => 0.0, - }; - sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal) - }); - for i in mem_indices { - if cur_mem <= conv_tokens { break; } - if trimmed <= max_tokens { break; } - drop[i] = true; - trimmed -= msg_costs[i]; - cur_mem -= msg_costs[i]; - } + // Snap to user message boundary at the start + while let Some(first) = result.first() { + if first.entry.message().role == Role::User { break; } + result.remove(0); } - // Phase 2c: if still over, drop oldest conversation entries - for i in 0..deduped.len() { - if trimmed <= max_tokens { break; } - if drop[i] { continue; } - drop[i] = true; - trimmed -= msg_costs[i]; - } - - // Walk forward to include complete conversation boundaries - let mut result: Vec = Vec::new(); - let mut skipping = true; - for (i, ce) in deduped.into_iter().enumerate() { - if skipping { - if drop[i] { continue; } - // Snap to user message boundary - if ce.entry.message().role != Role::User { continue; } - skipping = false; - } - result.push(ce); - } - - dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed); - + dbglog!("[trim] result={} total={}", result.len(), fixed_tokens + entry_total(&result)); result } +fn lowest_scored_memory(entries: &[ContextEntry]) -> Option { + entries.iter().enumerate() + .filter(|(_, e)| e.entry.is_memory()) + .min_by(|(_, a), (_, b)| { + let sa = match &a.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 }; + let sb = match &b.entry { ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0), _ => 0.0 }; + sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) +} + /// Count the token footprint of a message using BPE tokenization. pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize { let count = |s: &str| tokenizer.encode_with_special_tokens(s).len(); @@ -433,30 +410,3 @@ impl ContextState { parts.join("\n\n---\n\n") } } - -/// Token budget per context category — cheap to compute, no formatting. -pub struct ContextBudget { - pub system: usize, - pub identity: usize, - pub journal: usize, - pub memory: usize, - pub conversation: usize, -} - -impl ContextBudget { - pub fn total(&self) -> usize { - self.system + self.identity + self.journal + self.memory + self.conversation - } - - pub fn format(&self) -> String { - let window = context_window(); - if window == 0 { return String::new(); } - let used = self.total(); - let free = window.saturating_sub(used); - let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) }; - format!("sys:{}% id:{}% jnl:{}% mem:{}% conv:{}% free:{}%", - pct(self.system), pct(self.identity), pct(self.journal), - pct(self.memory), pct(self.conversation), pct(free)) - } -} - diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 0adbeba..8abea91 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -690,21 +690,6 @@ impl Agent { self.push_message(Message::tool_result(&call.id, &output)); } - /// Token budget by category — just reads cached section totals. - pub fn context_budget(&self) -> context::ContextBudget { - let memory: usize = self.context.conversation.entries().iter() - .filter(|e| e.entry.is_memory()) - .map(|e| e.tokens) - .sum(); - let conv_total = self.context.conversation.tokens(); - context::ContextBudget { - system: self.context.system.tokens(), - identity: self.context.identity.tokens(), - journal: self.context.journal.tokens(), - memory, - conversation: conv_total - memory, - } - } /// Context state sections — just returns references to the live data. pub fn context_sections(&self) -> [&ContextSection; 4] { @@ -907,8 +892,9 @@ impl Agent { self.load_startup_journal(); // Dedup memory, trim to budget - let budget = self.context_budget(); - self.context.conversation.trim(&budget, &self.tokenizer); + let fixed = self.context.system.tokens() + self.context.identity.tokens() + + self.context.journal.tokens(); + self.context.conversation.trim(fixed); let after = self.context.conversation.len(); let after_mem = self.context.conversation.entries().iter() @@ -920,8 +906,7 @@ impl Agent { self.generation += 1; self.last_prompt_tokens = 0; - let budget = self.context_budget(); - dbglog!("[compact] budget: {}", budget.format()); + dbglog!("[compact] budget: {}", self.context.format_budget()); } /// Restore from the conversation log. Builds the context window @@ -960,7 +945,7 @@ impl Agent { self.context.conversation.set_entries(all); self.compact(); // Estimate prompt tokens so status bar isn't 0 on startup - self.last_prompt_tokens = self.context_budget().total() as u32; + self.last_prompt_tokens = self.context.total_tokens() as u32; true } diff --git a/src/mind/mod.rs b/src/mind/mod.rs index 8a1cbca..90c711a 100644 --- a/src/mind/mod.rs +++ b/src/mind/mod.rs @@ -337,7 +337,7 @@ impl Mind { MindCommand::Compact => { let threshold = compaction_threshold(&self.config.app) as usize; let mut ag = self.agent.lock().await; - if ag.context_budget().total() > threshold { + if ag.context.total_tokens() > threshold { ag.compact(); ag.notify("compacted"); } @@ -437,7 +437,7 @@ impl Mind { // Compact if over budget before sending let threshold = compaction_threshold(&self.config.app) as usize; - if ag.context_budget().total() > threshold { + if ag.context.total_tokens() > threshold { ag.compact(); ag.notify("compacted"); } diff --git a/src/user/chat.rs b/src/user/chat.rs index 0efd370..92126cb 100644 --- a/src/user/chat.rs +++ b/src/user/chat.rs @@ -863,7 +863,7 @@ impl ScreenView for InteractScreen { agent.expire_activities(); app.status.prompt_tokens = agent.last_prompt_tokens(); app.status.model = agent.model().to_string(); - app.status.context_budget = agent.context_budget().format(); + app.status.context_budget = agent.context.format_budget(); app.activity = agent.activities.last() .map(|a| a.label.clone()) .unwrap_or_default();