WIP: ContextEntry/ContextSection data structures for incremental token counting
New types — not yet wired to callers: - ContextEntry: wraps ConversationEntry with cached token count and timestamp - ContextSection: named group of entries with cached token total. Private entries/tokens, read via entries()/tokens(). Mutation via push(entry), set(index, entry), del(index). - ContextState: system/identity/journal/conversation sections + working_stack - ConversationEntry::System variant for system prompt entries Token counting happens once at push time. Sections maintain their totals incrementally via push/set/del. No more recomputing from scratch on every budget check. Does not compile — callers need updating. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
776ac527f1
commit
62996e27d7
10 changed files with 450 additions and 403 deletions
|
|
@ -10,20 +10,130 @@ use serde::{Deserialize, Serialize};
|
|||
use tiktoken_rs::CoreBPE;
|
||||
use crate::agent::tools::working_stack;
|
||||
|
||||
/// A section of the context window, possibly with children.
|
||||
// --- Context state types ---
|
||||
|
||||
/// Conversation entry — either a regular message or memory content.
|
||||
/// Memory entries preserve the original message for KV cache round-tripping.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ConversationEntry {
|
||||
/// System prompt or system-level instruction.
|
||||
System(Message),
|
||||
Message(Message),
|
||||
Memory { key: String, message: Message, score: Option<f64> },
|
||||
/// DMN heartbeat/autonomous prompt — evicted aggressively during compaction.
|
||||
Dmn(Message),
|
||||
/// Debug/status log line — written to conversation log for tracing,
|
||||
/// skipped on read-back.
|
||||
Log(String),
|
||||
}
|
||||
|
||||
/// Entry in the context window — wraps a ConversationEntry with cached metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextEntry {
|
||||
pub entry: ConversationEntry,
|
||||
/// Cached token count (0 for Log entries).
|
||||
pub tokens: usize,
|
||||
/// When this entry was added to the context.
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// A named section of the context window with cached token total.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextSection {
|
||||
pub name: String,
|
||||
pub tokens: usize,
|
||||
pub content: String,
|
||||
pub children: Vec<ContextSection>,
|
||||
/// Cached sum of entry tokens.
|
||||
tokens: usize,
|
||||
entries: Vec<ContextEntry>,
|
||||
}
|
||||
|
||||
/// A single journal entry with its timestamp and content.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JournalEntry {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub content: String,
|
||||
impl ContextSection {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self { name: name.into(), tokens: 0, entries: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn entries(&self) -> &[ContextEntry] { &self.entries }
|
||||
pub fn tokens(&self) -> usize { self.tokens }
|
||||
pub fn len(&self) -> usize { self.entries.len() }
|
||||
pub fn is_empty(&self) -> bool { self.entries.is_empty() }
|
||||
|
||||
/// Push an entry, updating the cached token total.
|
||||
pub fn push(&mut self, entry: ContextEntry) {
|
||||
self.tokens += entry.tokens;
|
||||
self.entries.push(entry);
|
||||
}
|
||||
|
||||
/// Replace an entry at `index`, adjusting the token total.
|
||||
pub fn set(&mut self, index: usize, entry: ContextEntry) {
|
||||
self.tokens -= self.entries[index].tokens;
|
||||
self.tokens += entry.tokens;
|
||||
self.entries[index] = entry;
|
||||
}
|
||||
|
||||
/// Remove an entry at `index`, adjusting the token total.
|
||||
pub fn del(&mut self, index: usize) -> ContextEntry {
|
||||
let removed = self.entries.remove(index);
|
||||
self.tokens -= removed.tokens;
|
||||
removed
|
||||
}
|
||||
|
||||
/// Replace the message inside an entry, recomputing its token count.
|
||||
pub fn set_message(&mut self, index: usize, tokenizer: &CoreBPE, msg: Message) {
|
||||
let old_tokens = self.entries[index].tokens;
|
||||
*self.entries[index].entry.message_mut() = msg;
|
||||
let new_tokens = msg_token_count(tokenizer, self.entries[index].entry.api_message());
|
||||
self.entries[index].tokens = new_tokens;
|
||||
self.tokens = self.tokens - old_tokens + new_tokens;
|
||||
}
|
||||
|
||||
/// Set the score on a Memory entry. No token change.
|
||||
pub fn set_score(&mut self, index: usize, score: Option<f64>) {
|
||||
if let ConversationEntry::Memory { score: s, .. } = &mut self.entries[index].entry {
|
||||
*s = score;
|
||||
}
|
||||
}
|
||||
|
||||
/// Bulk replace all entries, recomputing token total.
|
||||
pub fn set_entries(&mut self, entries: Vec<ContextEntry>) {
|
||||
self.tokens = entries.iter().map(|e| e.tokens).sum();
|
||||
self.entries = entries;
|
||||
}
|
||||
|
||||
/// Dedup and trim entries to fit within context budget.
|
||||
pub fn trim(&mut self, budget: &ContextBudget, tokenizer: &CoreBPE) {
|
||||
let result = trim_entries(&self.entries, tokenizer, budget);
|
||||
self.entries = result;
|
||||
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
|
||||
}
|
||||
|
||||
/// Clear all entries.
|
||||
pub fn clear(&mut self) {
|
||||
self.entries.clear();
|
||||
self.tokens = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ContextState {
|
||||
pub system: ContextSection,
|
||||
pub identity: ContextSection,
|
||||
pub journal: ContextSection,
|
||||
pub conversation: ContextSection,
|
||||
/// Working stack — separate from identity because it's managed
|
||||
/// by its own tool, not loaded from personality files.
|
||||
pub working_stack: Vec<String>,
|
||||
}
|
||||
|
||||
impl ContextState {
|
||||
/// Total tokens across all sections.
|
||||
pub fn total_tokens(&self) -> usize {
|
||||
self.system.tokens() + self.identity.tokens()
|
||||
+ self.journal.tokens() + self.conversation.tokens()
|
||||
}
|
||||
|
||||
/// All sections as a slice for iteration.
|
||||
pub fn sections(&self) -> [&ContextSection; 4] {
|
||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||
}
|
||||
}
|
||||
|
||||
/// Context window size in tokens (from config).
|
||||
|
|
@ -44,41 +154,39 @@ fn context_budget_tokens() -> usize {
|
|||
/// corresponding assistant tool_call message).
|
||||
/// 2. Trim: drop oldest entries until the conversation fits, snapping
|
||||
/// to user message boundaries.
|
||||
pub fn trim_entries(
|
||||
entries: &[ConversationEntry],
|
||||
tokenizer: &CoreBPE,
|
||||
fn trim_entries(
|
||||
entries: &[ContextEntry],
|
||||
_tokenizer: &CoreBPE,
|
||||
budget: &ContextBudget,
|
||||
) -> Vec<ConversationEntry> {
|
||||
) -> Vec<ContextEntry> {
|
||||
let fixed_tokens = budget.system + budget.identity + budget.journal;
|
||||
// --- Phase 1: dedup memory entries by key (keep last) ---
|
||||
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||||
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||
|
||||
for (i, entry) in entries.iter().enumerate() {
|
||||
if let ConversationEntry::Memory { key, .. } = entry {
|
||||
for (i, ce) in entries.iter().enumerate() {
|
||||
if let ConversationEntry::Memory { key, .. } = &ce.entry {
|
||||
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
||||
drop_indices.insert(prev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let deduped: Vec<ConversationEntry> = entries.iter().enumerate()
|
||||
let deduped: Vec<ContextEntry> = entries.iter().enumerate()
|
||||
.filter(|(i, _)| !drop_indices.contains(i))
|
||||
.map(|(_, e)| e.clone())
|
||||
.collect();
|
||||
|
||||
// --- Phase 2: trim to fit context budget ---
|
||||
let max_tokens = context_budget_tokens();
|
||||
let count_msg = |m: &Message| msg_token_count(tokenizer, m);
|
||||
|
||||
let msg_costs: Vec<usize> = deduped.iter()
|
||||
.map(|e| if e.is_log() { 0 } else { count_msg(e.api_message()) }).collect();
|
||||
let msg_costs: Vec<usize> = deduped.iter().map(|e| e.tokens).collect();
|
||||
let entry_total: usize = msg_costs.iter().sum();
|
||||
let total: usize = fixed_tokens + entry_total;
|
||||
|
||||
let mem_tokens: usize = deduped.iter().zip(&msg_costs)
|
||||
.filter(|(e, _)| e.is_memory())
|
||||
.map(|(_, &c)| c).sum();
|
||||
let mem_tokens: usize = deduped.iter()
|
||||
.filter(|ce| ce.entry.is_memory())
|
||||
.map(|ce| ce.tokens).sum();
|
||||
let conv_tokens: usize = entry_total - mem_tokens;
|
||||
|
||||
dbglog!("[trim] max_tokens={} fixed={} mem={} conv={} total={} entries={}",
|
||||
|
|
@ -90,7 +198,7 @@ pub fn trim_entries(
|
|||
let mut cur_mem = mem_tokens;
|
||||
|
||||
for i in 0..deduped.len() {
|
||||
if deduped[i].is_dmn() {
|
||||
if deduped[i].entry.is_dmn() {
|
||||
drop[i] = true;
|
||||
trimmed -= msg_costs[i];
|
||||
}
|
||||
|
|
@ -99,14 +207,14 @@ pub fn trim_entries(
|
|||
// Phase 2b: if memories > 50% of context, evict lowest-scored first
|
||||
if cur_mem > conv_tokens && trimmed > max_tokens {
|
||||
let mut mem_indices: Vec<usize> = (0..deduped.len())
|
||||
.filter(|&i| !drop[i] && deduped[i].is_memory())
|
||||
.filter(|&i| !drop[i] && deduped[i].entry.is_memory())
|
||||
.collect();
|
||||
mem_indices.sort_by(|&a, &b| {
|
||||
let sa = match &deduped[a] {
|
||||
let sa = match &deduped[a].entry {
|
||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
};
|
||||
let sb = match &deduped[b] {
|
||||
let sb = match &deduped[b].entry {
|
||||
ConversationEntry::Memory { score, .. } => score.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
};
|
||||
|
|
@ -130,16 +238,16 @@ pub fn trim_entries(
|
|||
}
|
||||
|
||||
// Walk forward to include complete conversation boundaries
|
||||
let mut result: Vec<ConversationEntry> = Vec::new();
|
||||
let mut result: Vec<ContextEntry> = Vec::new();
|
||||
let mut skipping = true;
|
||||
for (i, entry) in deduped.into_iter().enumerate() {
|
||||
for (i, ce) in deduped.into_iter().enumerate() {
|
||||
if skipping {
|
||||
if drop[i] { continue; }
|
||||
// Snap to user message boundary
|
||||
if entry.message().role != Role::User { continue; }
|
||||
if ce.entry.message().role != Role::User { continue; }
|
||||
skipping = false;
|
||||
}
|
||||
result.push(entry);
|
||||
result.push(ce);
|
||||
}
|
||||
|
||||
dbglog!("[trim] result={} trimmed_total={}", result.len(), trimmed);
|
||||
|
|
@ -186,28 +294,13 @@ pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
|||
err.to_string().contains("model stream error")
|
||||
}
|
||||
|
||||
// --- Context state types ---
|
||||
|
||||
/// Conversation entry — either a regular message or memory content.
|
||||
/// Memory entries preserve the original message for KV cache round-tripping.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ConversationEntry {
|
||||
Message(Message),
|
||||
Memory { key: String, message: Message, score: Option<f64> },
|
||||
/// DMN heartbeat/autonomous prompt — evicted aggressively during compaction.
|
||||
Dmn(Message),
|
||||
/// Debug/status log line — written to conversation log for tracing,
|
||||
/// skipped on read-back.
|
||||
Log(String),
|
||||
}
|
||||
|
||||
// Custom serde: serialize Memory with a "memory_key" field added to the message,
|
||||
// plain messages serialize as-is. This keeps the conversation log readable.
|
||||
impl Serialize for ConversationEntry {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
use serde::ser::SerializeMap;
|
||||
match self {
|
||||
Self::Message(m) | Self::Dmn(m) => m.serialize(s),
|
||||
Self::System(m) | Self::Message(m) | Self::Dmn(m) => m.serialize(s),
|
||||
Self::Memory { key, message, score } => {
|
||||
let json = serde_json::to_value(message).map_err(serde::ser::Error::custom)?;
|
||||
let mut map = s.serialize_map(None)?;
|
||||
|
|
@ -259,7 +352,7 @@ impl ConversationEntry {
|
|||
/// Panics on Log entries (which should be filtered before API calls).
|
||||
pub fn api_message(&self) -> &Message {
|
||||
match self {
|
||||
Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::System(m) | Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::Memory { message, .. } => message,
|
||||
Self::Log(_) => panic!("Log entries have no API message"),
|
||||
}
|
||||
|
|
@ -281,7 +374,7 @@ impl ConversationEntry {
|
|||
/// Panics on Log entries.
|
||||
pub fn message(&self) -> &Message {
|
||||
match self {
|
||||
Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::System(m) | Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::Memory { message, .. } => message,
|
||||
Self::Log(_) => panic!("Log entries have no message"),
|
||||
}
|
||||
|
|
@ -291,38 +384,36 @@ impl ConversationEntry {
|
|||
/// Panics on Log entries.
|
||||
pub fn message_mut(&mut self) -> &mut Message {
|
||||
match self {
|
||||
Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::System(m) | Self::Message(m) | Self::Dmn(m) => m,
|
||||
Self::Memory { message, .. } => message,
|
||||
Self::Log(_) => panic!("Log entries have no message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ContextState {
|
||||
pub system_prompt: String,
|
||||
pub personality: Vec<(String, String)>,
|
||||
pub journal: Vec<JournalEntry>,
|
||||
pub working_stack: Vec<String>,
|
||||
/// Conversation entries — messages and memory, interleaved in order.
|
||||
/// Does NOT include system prompt, personality, or journal.
|
||||
pub entries: Vec<ConversationEntry>,
|
||||
}
|
||||
|
||||
pub fn render_journal(entries: &[JournalEntry]) -> String {
|
||||
if entries.is_empty() { return String::new(); }
|
||||
let mut text = String::from("[Earlier — from your journal]\n\n");
|
||||
for entry in entries {
|
||||
use std::fmt::Write;
|
||||
writeln!(text, "## {}\n{}\n", entry.timestamp.format("%Y-%m-%dT%H:%M"), entry.content).ok();
|
||||
}
|
||||
text
|
||||
}
|
||||
|
||||
impl ContextState {
|
||||
/// Render journal entries into a single text block.
|
||||
pub fn render_journal(&self) -> String {
|
||||
if self.journal.is_empty() { return String::new(); }
|
||||
let mut text = String::from("[Earlier — from your journal]\n\n");
|
||||
for e in self.journal.entries() {
|
||||
use std::fmt::Write;
|
||||
if let Some(ts) = &e.timestamp {
|
||||
writeln!(text, "## {}\n{}\n",
|
||||
ts.format("%Y-%m-%dT%H:%M"),
|
||||
e.entry.message().content_text()).ok();
|
||||
} else {
|
||||
text.push_str(&e.entry.message().content_text());
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
}
|
||||
text
|
||||
}
|
||||
|
||||
/// Render identity files + working stack into a single user message.
|
||||
pub fn render_context_message(&self) -> String {
|
||||
let mut parts: Vec<String> = self.personality.iter()
|
||||
.map(|(name, content)| format!("## {}\n\n{}", name, content))
|
||||
let mut parts: Vec<String> = self.identity.entries().iter()
|
||||
.map(|e| e.entry.message().content_text().to_string())
|
||||
.collect();
|
||||
let instructions = std::fs::read_to_string(working_stack::instructions_path()).unwrap_or_default();
|
||||
let mut stack_section = instructions;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue