consciousness/src/agent/context.rs
Kent Overstreet b37b6d7495 Kill log callback — use ConversationEntry::Log for debug traces
Add Log variant to ConversationEntry that serializes to the
conversation log but is filtered out on read-back and API calls.
AutoAgent writes debug/status info (turns, tokens, tool calls)
through the conversation log instead of a callback parameter.

Removes the log callback from run_one_agent, call_api_with_tools,
and all callers.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-07 01:23:22 -04:00

372 lines
13 KiB
Rust

// context.rs — Context window management
//
// Token counting, conversation trimming, and error classification.
// Journal entries are loaded from the memory graph store, not from
// a flat file — the parse functions are gone.
use std::sync::{Arc, RwLock};
use crate::agent::api::types::*;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tiktoken_rs::CoreBPE;
use crate::agent::tools::working_stack;
/// A section of the context window, possibly with children.
#[derive(Debug, Clone)]
pub struct ContextSection {
pub name: String,
pub tokens: usize,
pub content: String,
pub children: Vec<ContextSection>,
}
/// Shared, live context state — agent writes, TUI reads for the debug screen.
pub type SharedContextState = Arc<RwLock<Vec<ContextSection>>>;
/// Create a new shared context state.
pub fn shared_context_state() -> SharedContextState {
Arc::new(RwLock::new(Vec::new()))
}
/// A single journal entry with its timestamp and content.
#[derive(Debug, Clone)]
pub struct JournalEntry {
pub timestamp: DateTime<Utc>,
pub content: String,
}
/// Context window size in tokens (from config).
pub fn context_window() -> usize {
crate::config::get().api_context_window
}
/// Context budget in tokens: 80% of the model's context window.
/// The remaining 20% is reserved for model output.
fn context_budget_tokens() -> usize {
context_window() * 80 / 100
}
/// Dedup and trim conversation entries to fit within the context budget.
///
/// 1. Dedup: if the same memory key appears multiple times, keep only
/// the latest render (drop the earlier Memory entry and its
/// corresponding assistant tool_call message).
/// 2. Trim: drop oldest entries until the conversation fits, snapping
/// to user message boundaries.
pub fn trim_entries(
context: &ContextState,
entries: &[ConversationEntry],
tokenizer: &CoreBPE,
) -> Vec<ConversationEntry> {
// --- Phase 1: dedup memory entries by key (keep last) ---
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
for (i, entry) in entries.iter().enumerate() {
if let ConversationEntry::Memory { key, .. } = entry {
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
drop_indices.insert(prev);
}
}
}
let deduped: Vec<ConversationEntry> = entries.iter().enumerate()
.filter(|(i, _)| !drop_indices.contains(i))
.map(|(_, e)| e.clone())
.collect();
// --- Phase 2: trim to fit context budget ---
// Everything in the context window is a message. Count them all,
// trim entries until the total fits.
let max_tokens = context_budget_tokens();
let count_msg = |m: &Message| msg_token_count(tokenizer, m);
let fixed_cost = count_msg(&Message::system(&context.system_prompt))
+ count_msg(&Message::user(context.render_context_message()))
+ count_msg(&Message::user(render_journal(&context.journal)));
let msg_costs: Vec<usize> = deduped.iter()
.map(|e| if e.is_log() { 0 } else { count_msg(e.api_message()) }).collect();
let entry_total: usize = msg_costs.iter().sum();
let total: usize = fixed_cost + entry_total;
let mem_tokens: usize = deduped.iter().zip(&msg_costs)
.filter(|(e, _)| e.is_memory())
.map(|(_, &c)| c).sum();
let conv_tokens: usize = entry_total - mem_tokens;
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
max_tokens, fixed_cost, mem_tokens, conv_tokens, total, deduped.len());
// Phase 2a: evict all DMN entries first — they're ephemeral
let mut drop = vec![false; deduped.len()];
let mut trimmed = total;
let mut cur_mem = mem_tokens;
for i in 0..deduped.len() {
if deduped[i].is_dmn() {
drop[i] = true;
trimmed -= msg_costs[i];
}
}
// Phase 2b: if memories > 50% of entries, evict oldest memories
if cur_mem > conv_tokens && trimmed > max_tokens {
for i in 0..deduped.len() {
if drop[i] { continue; }
if !deduped[i].is_memory() { continue; }
if cur_mem <= conv_tokens { break; }
if trimmed <= max_tokens { break; }
drop[i] = true;
trimmed -= msg_costs[i];
cur_mem -= msg_costs[i];
}
}
// Phase 2b: drop oldest entries until under budget
for i in 0..deduped.len() {
if trimmed <= max_tokens { break; }
if drop[i] { continue; }
drop[i] = true;
trimmed -= msg_costs[i];
}
// Walk forward to include complete conversation boundaries
let mut result: Vec<ConversationEntry> = Vec::new();
let mut skipping = true;
for (i, entry) in deduped.into_iter().enumerate() {
if skipping {
if drop[i] { continue; }
// Snap to user message boundary
if entry.message().role != Role::User { continue; }
skipping = false;
}
result.push(entry);
}
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
result
}
/// Count the token footprint of a message using BPE tokenization.
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
let content = msg.content.as_ref().map_or(0, |c| match c {
MessageContent::Text(s) => count(s),
MessageContent::Parts(parts) => parts.iter()
.map(|p| match p {
ContentPart::Text { text } => count(text),
ContentPart::ImageUrl { .. } => 85,
})
.sum(),
});
let tools = msg.tool_calls.as_ref().map_or(0, |calls| {
calls.iter()
.map(|c| count(&c.function.arguments) + count(&c.function.name))
.sum()
});
content + tools
}
/// Detect context window overflow errors from the API.
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
let msg = err.to_string().to_lowercase();
msg.contains("context length")
|| msg.contains("token limit")
|| msg.contains("too many tokens")
|| msg.contains("maximum context")
|| msg.contains("prompt is too long")
|| msg.contains("request too large")
|| msg.contains("input validation error")
|| msg.contains("content length limit")
|| (msg.contains("400") && msg.contains("tokens"))
}
/// Detect model/provider errors delivered inside the SSE stream.
pub fn is_stream_error(err: &anyhow::Error) -> bool {
err.to_string().contains("model stream error")
}
// --- Context state types ---
/// Conversation entry — either a regular message or memory content.
/// Memory entries preserve the original message for KV cache round-tripping.
#[derive(Debug, Clone, PartialEq)]
pub enum ConversationEntry {
Message(Message),
Memory { key: String, message: Message },
/// DMN heartbeat/autonomous prompt — evicted aggressively during compaction.
Dmn(Message),
/// Debug/status log line — written to conversation log for tracing,
/// skipped on read-back.
Log(String),
}
// Custom serde: serialize Memory with a "memory_key" field added to the message,
// plain messages serialize as-is. This keeps the conversation log readable.
impl Serialize for ConversationEntry {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
match self {
Self::Message(m) | Self::Dmn(m) => m.serialize(s),
Self::Memory { key, message } => {
let json = serde_json::to_value(message).map_err(serde::ser::Error::custom)?;
let mut map = s.serialize_map(None)?;
if let serde_json::Value::Object(obj) = json {
for (k, v) in obj {
map.serialize_entry(&k, &v)?;
}
}
map.serialize_entry("memory_key", key)?;
map.end()
}
Self::Log(text) => {
use serde::ser::SerializeMap;
let mut map = s.serialize_map(Some(1))?;
map.serialize_entry("log", text)?;
map.end()
}
}
}
}
impl<'de> Deserialize<'de> for ConversationEntry {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut json: serde_json::Value = serde_json::Value::deserialize(d)?;
// Log entries — skip on read-back
if json.get("log").is_some() {
let text = json["log"].as_str().unwrap_or("").to_string();
return Ok(Self::Log(text));
}
if let Some(key) = json.as_object_mut().and_then(|o| o.remove("memory_key")) {
let key = key.as_str().unwrap_or("").to_string();
let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?;
Ok(Self::Memory { key, message })
} else {
let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?;
Ok(Self::Message(message))
}
}
}
impl ConversationEntry {
/// Get the API message for sending to the model.
/// Panics on Log entries (which should be filtered before API calls).
pub fn api_message(&self) -> &Message {
match self {
Self::Message(m) | Self::Dmn(m) => m,
Self::Memory { message, .. } => message,
Self::Log(_) => panic!("Log entries have no API message"),
}
}
pub fn is_memory(&self) -> bool {
matches!(self, Self::Memory { .. })
}
pub fn is_dmn(&self) -> bool {
matches!(self, Self::Dmn(_))
}
pub fn is_log(&self) -> bool {
matches!(self, Self::Log(_))
}
/// Get a reference to the inner message.
/// Panics on Log entries.
pub fn message(&self) -> &Message {
match self {
Self::Message(m) | Self::Dmn(m) => m,
Self::Memory { message, .. } => message,
Self::Log(_) => panic!("Log entries have no message"),
}
}
/// Get a mutable reference to the inner message.
/// Panics on Log entries.
pub fn message_mut(&mut self) -> &mut Message {
match self {
Self::Message(m) | Self::Dmn(m) => m,
Self::Memory { message, .. } => message,
Self::Log(_) => panic!("Log entries have no message"),
}
}
}
#[derive(Clone)]
pub struct ContextState {
pub system_prompt: String,
pub personality: Vec<(String, String)>,
pub journal: Vec<JournalEntry>,
pub working_stack: Vec<String>,
/// Conversation entries — messages and memory, interleaved in order.
/// Does NOT include system prompt, personality, or journal.
pub entries: Vec<ConversationEntry>,
}
pub fn render_journal(entries: &[JournalEntry]) -> String {
if entries.is_empty() { return String::new(); }
let mut text = String::from("[Earlier — from your journal]\n\n");
for entry in entries {
use std::fmt::Write;
writeln!(text, "## {}\n{}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content).ok();
}
text
}
impl ContextState {
pub fn render_context_message(&self) -> String {
let mut parts: Vec<String> = self.personality.iter()
.map(|(name, content)| format!("## {}\n\n{}", name, content))
.collect();
let instructions = std::fs::read_to_string(working_stack::instructions_path()).unwrap_or_default();
let mut stack_section = instructions;
if self.working_stack.is_empty() {
stack_section.push_str("\n## Current stack\n\n(empty)\n");
} else {
stack_section.push_str("\n## Current stack\n\n");
for (i, item) in self.working_stack.iter().enumerate() {
if i == self.working_stack.len() - 1 {
stack_section.push_str(&format!("{}\n", item));
} else {
stack_section.push_str(&format!(" [{}] {}\n", i, item));
}
}
}
parts.push(stack_section);
parts.join("\n\n---\n\n")
}
}
/// Total tokens used across all context sections.
pub fn sections_used(sections: &[ContextSection]) -> usize {
sections.iter().map(|s| s.tokens).sum()
}
/// Budget status string derived from context sections.
pub fn sections_budget_string(sections: &[ContextSection]) -> String {
let window = context_window();
if window == 0 { return String::new(); }
let used: usize = sections.iter().map(|s| s.tokens).sum();
let free = window.saturating_sub(used);
let pct = |n: usize| if n == 0 { 0 } else { ((n * 100) / window).max(1) };
let parts: Vec<String> = sections.iter()
.map(|s| {
// Short label from section name
let label = match s.name.as_str() {
n if n.starts_with("System") => "sys",
n if n.starts_with("Personality") => "id",
n if n.starts_with("Journal") => "jnl",
n if n.starts_with("Working") => "stack",
n if n.starts_with("Memory") => "mem",
n if n.starts_with("Conversation") => "conv",
_ => return String::new(),
};
format!("{}:{}%", label, pct(s.tokens))
})
.filter(|s| !s.is_empty())
.collect();
format!("{} free:{}%", parts.join(" "), pct(free))
}