ConversationEntry enum: typed memory vs conversation messages

Replace untyped message list with ConversationEntry enum:
- Message(Message) — regular conversation turn
- Memory { key, message } — memory content with preserved message
  for KV cache round-tripping

Budget counts memory vs conversation by matching on enum variant.
Debug screen labels memory entries with [memory: key]. No heuristic
tool-name scanning.

Custom serde: Memory serializes with a memory_key field alongside
the message fields, deserializes by checking for the field.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-02 03:26:00 -04:00
parent eb4dae04cb
commit b9e3568385
3 changed files with 153 additions and 93 deletions

View file

@ -103,7 +103,7 @@ impl Agent {
journal: Vec::new(), journal: Vec::new(),
working_stack: Vec::new(), working_stack: Vec::new(),
loaded_nodes: Vec::new(), loaded_nodes: Vec::new(),
messages: Vec::new(), entries: Vec::new(),
}; };
let session_id = format!("poc-agent-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")); let session_id = format!("poc-agent-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S"));
let agent_cycles = crate::subconscious::subconscious::AgentCycleState::new(&session_id); let agent_cycles = crate::subconscious::subconscious::AgentCycleState::new(&session_id);
@ -140,7 +140,7 @@ impl Agent {
if !jnl.is_empty() { if !jnl.is_empty() {
msgs.push(Message::user(jnl)); msgs.push(Message::user(jnl));
} }
msgs.extend(self.context.messages.iter().cloned()); msgs.extend(self.context.entries.iter().map(|e| e.api_message().clone()));
msgs msgs
} }
@ -173,7 +173,7 @@ impl Agent {
eprintln!("warning: failed to log message: {:#}", e); eprintln!("warning: failed to log message: {:#}", e);
} }
} }
self.context.messages.push(msg); self.context.entries.push(ConversationEntry::Message(msg));
} }
/// Push a context-only message (system prompt, identity context, /// Push a context-only message (system prompt, identity context,
@ -673,32 +673,41 @@ impl Agent {
} }
// Conversation — each message as a child // Conversation — each message as a child
let conv_messages = &self.context.messages; let conv_messages = &self.context.entries;
let conv_children: Vec<ContextSection> = conv_messages.iter().enumerate() let conv_children: Vec<ContextSection> = conv_messages.iter().enumerate()
.map(|(i, msg)| { .map(|(i, entry)| {
let text = msg.content.as_ref() let m = entry.message();
let text = m.content.as_ref()
.map(|c| c.as_text().to_string()) .map(|c| c.as_text().to_string())
.unwrap_or_default(); .unwrap_or_default();
let tool_info = msg.tool_calls.as_ref().map(|tc| { let tool_info = m.tool_calls.as_ref().map(|tc| {
tc.iter() tc.iter()
.map(|c| c.function.name.clone()) .map(|c| c.function.name.clone())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
}); });
let label = match (&msg.role, &tool_info) { let label = if entry.is_memory() {
(_, Some(tools)) => format!("[tool_call: {}]", tools), if let ConversationEntry::Memory { key, .. } = entry {
_ => { format!("[memory: {}]", key)
let preview: String = text.chars().take(60).collect(); } else { unreachable!() }
let preview = preview.replace('\n', " "); } else {
if text.len() > 60 { format!("{}...", preview) } else { preview } match &tool_info {
Some(tools) => format!("[tool_call: {}]", tools),
None => {
let preview: String = text.chars().take(60).collect();
let preview = preview.replace('\n', " ");
if text.len() > 60 { format!("{}...", preview) } else { preview }
}
} }
}; };
let tokens = count(&text); let tokens = count(&text);
let role_name = match msg.role { let role_name = if entry.is_memory() { "mem" } else {
Role::Assistant => "PoC", match m.role {
Role::User => "Kent", Role::Assistant => "PoC",
Role::Tool => "tool", Role::User => "Kent",
Role::System => "system", Role::Tool => "tool",
Role::System => "system",
}
}; };
ContextSection { ContextSection {
name: format!("[{}] {}: {}", i, role_name, label), name: format!("[{}] {}: {}", i, role_name, label),
@ -846,7 +855,8 @@ impl Agent {
/// all previous ones. The tool result message (right before each image /// all previous ones. The tool result message (right before each image
/// message) already records what was loaded, so no info is lost. /// message) already records what was loaded, so no info is lost.
fn age_out_images(&mut self) { fn age_out_images(&mut self) {
for msg in &mut self.context.messages { for entry in &mut self.context.entries {
let msg = entry.message_mut();
if let Some(MessageContent::Parts(parts)) = &msg.content { if let Some(MessageContent::Parts(parts)) = &msg.content {
let has_images = parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })); let has_images = parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. }));
if !has_images { if !has_images {
@ -891,7 +901,8 @@ impl Agent {
let mut strip_ids: Vec<String> = Vec::new(); let mut strip_ids: Vec<String> = Vec::new();
let mut strip_msg_indices: Vec<usize> = Vec::new(); let mut strip_msg_indices: Vec<usize> = Vec::new();
for (i, msg) in self.context.messages.iter().enumerate() { for (i, entry) in self.context.entries.iter().enumerate() {
let msg = entry.message();
if msg.role != Role::Assistant { if msg.role != Role::Assistant {
continue; continue;
} }
@ -917,8 +928,8 @@ impl Agent {
} }
// Remove in reverse order to preserve indices // Remove in reverse order to preserve indices
self.context.messages.retain(|msg| { self.context.entries.retain(|entry| {
// Strip the assistant messages we identified let msg = entry.message();
if msg.role == Role::Assistant { if msg.role == Role::Assistant {
if let Some(calls) = &msg.tool_calls { if let Some(calls) = &msg.tool_calls {
if calls.iter().all(|c| strip_ids.contains(&c.id)) { if calls.iter().all(|c| strip_ids.contains(&c.id)) {
@ -926,7 +937,6 @@ impl Agent {
} }
} }
} }
// Strip matching tool results
if msg.role == Role::Tool { if msg.role == Role::Tool {
if let Some(ref id) = msg.tool_call_id { if let Some(ref id) = msg.tool_call_id {
if strip_ids.contains(id) { if strip_ids.contains(id) {
@ -955,7 +965,8 @@ impl Agent {
/// Internal compaction — rebuilds context window from current messages. /// Internal compaction — rebuilds context window from current messages.
fn do_compact(&mut self) { fn do_compact(&mut self) {
let conversation: Vec<Message> = self.context.messages.clone(); let conversation: Vec<Message> = self.context.entries.iter()
.map(|e| e.api_message().clone()).collect();
let (messages, journal) = crate::agent::context::build_context_window( let (messages, journal) = crate::agent::context::build_context_window(
&self.context, &self.context,
&conversation, &conversation,
@ -963,7 +974,8 @@ impl Agent {
&self.tokenizer, &self.tokenizer,
); );
self.context.journal = journal::parse_journal_text(&journal); self.context.journal = journal::parse_journal_text(&journal);
self.context.messages = messages; self.context.entries = messages.into_iter()
.map(ConversationEntry::Message).collect();
self.last_prompt_tokens = 0; self.last_prompt_tokens = 0;
self.publish_context_state(); self.publish_context_state();
@ -1025,8 +1037,9 @@ impl Agent {
dbglog!("[restore] journal text: {} chars, {} lines", dbglog!("[restore] journal text: {} chars, {} lines",
journal.len(), journal.lines().count()); journal.len(), journal.lines().count());
self.context.journal = journal::parse_journal_text(&journal); self.context.journal = journal::parse_journal_text(&journal);
self.context.messages = messages; self.context.entries = messages.into_iter()
dbglog!("[restore] built context window: {} messages", self.context.messages.len()); .map(ConversationEntry::Message).collect();
dbglog!("[restore] built context window: {} entries", self.context.entries.len());
self.last_prompt_tokens = 0; self.last_prompt_tokens = 0;
self.publish_context_state(); self.publish_context_state();
@ -1043,19 +1056,19 @@ impl Agent {
&self.client.model &self.client.model
} }
/// Get the conversation history for persistence. /// Get the conversation entries for persistence.
pub fn messages(&self) -> &[Message] { pub fn entries(&self) -> &[ConversationEntry] {
&self.context.messages &self.context.entries
} }
/// Mutable access to conversation history (for /retry). /// Mutable access to conversation entries (for /retry).
pub fn messages_mut(&mut self) -> &mut Vec<Message> { pub fn entries_mut(&mut self) -> &mut Vec<ConversationEntry> {
&mut self.context.messages &mut self.context.entries
} }
/// Restore from a saved conversation. /// Restore from saved conversation entries.
pub fn restore(&mut self, messages: Vec<Message>) { pub fn restore(&mut self, entries: Vec<ConversationEntry>) {
self.context.messages = messages; self.context.entries = entries;
} }
} }

View file

@ -322,19 +322,93 @@ impl ToolDef {
} }
/// Mutable context state — the structured regions of the context window. /// Mutable context state — the structured regions of the context window.
/// Conversation entry — either a regular message or memory content.
/// Memory entries preserve the original message for KV cache round-tripping.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ConversationEntry {
Message(Message),
Memory { key: String, message: Message },
}
// Custom serde: serialize Memory with a "memory_key" field added to the message,
// plain messages serialize as-is. This keeps the conversation log readable.
impl Serialize for ConversationEntry {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
match self {
Self::Message(m) => m.serialize(s),
Self::Memory { key, message } => {
// Serialize message fields + memory_key
let json = serde_json::to_value(message).map_err(serde::ser::Error::custom)?;
let mut map = s.serialize_map(None)?;
if let serde_json::Value::Object(obj) = json {
for (k, v) in obj {
map.serialize_entry(&k, &v)?;
}
}
map.serialize_entry("memory_key", key)?;
map.end()
}
}
}
}
impl<'de> Deserialize<'de> for ConversationEntry {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut json: serde_json::Value = serde_json::Value::deserialize(d)?;
if let Some(key) = json.as_object_mut().and_then(|o| o.remove("memory_key")) {
let key = key.as_str().unwrap_or("").to_string();
let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?;
Ok(Self::Memory { key, message })
} else {
let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?;
Ok(Self::Message(message))
}
}
}
impl ConversationEntry {
/// Get the API message for sending to the model.
pub fn api_message(&self) -> &Message {
match self {
Self::Message(m) => m,
Self::Memory { message, .. } => message,
}
}
pub fn is_memory(&self) -> bool {
matches!(self, Self::Memory { .. })
}
/// Get a reference to the inner message.
pub fn message(&self) -> &Message {
match self {
Self::Message(m) => m,
Self::Memory { message, .. } => message,
}
}
/// Get a mutable reference to the inner message.
pub fn message_mut(&mut self) -> &mut Message {
match self {
Self::Message(m) => m,
Self::Memory { message, .. } => message,
}
}
}
pub struct ContextState { pub struct ContextState {
pub system_prompt: String, pub system_prompt: String,
pub personality: Vec<(String, String)>, pub personality: Vec<(String, String)>,
pub journal: Vec<crate::agent::journal::JournalEntry>, pub journal: Vec<crate::agent::journal::JournalEntry>,
pub working_stack: Vec<String>, pub working_stack: Vec<String>,
/// Memory nodes currently loaded in the context window. /// Memory nodes currently loaded — for debug display and refresh.
/// Content is NOT duplicated here; the actual content is in entries
/// as ConversationEntry::Memory.
pub loaded_nodes: Vec<crate::hippocampus::memory::MemoryNode>, pub loaded_nodes: Vec<crate::hippocampus::memory::MemoryNode>,
/// Conversation messages (user, assistant, tool turns). /// Conversation entries — messages and memory, interleaved in order.
/// Does NOT include system prompt, personality, or journal — /// Does NOT include system prompt, personality, or journal.
/// those are rendered from their typed sources when assembling pub entries: Vec<ConversationEntry>,
/// the API call.
pub messages: Vec<Message>,
} }
// TODO: these should not be hardcoded absolute paths // TODO: these should not be hardcoded absolute paths
@ -349,7 +423,12 @@ impl ContextState {
let id = count_str(&self.system_prompt) let id = count_str(&self.system_prompt)
+ self.personality.iter().map(|(_, c)| count_str(c)).sum::<usize>(); + self.personality.iter().map(|(_, c)| count_str(c)).sum::<usize>();
let jnl: usize = self.journal.iter().map(|e| count_str(&e.content)).sum(); let jnl: usize = self.journal.iter().map(|e| count_str(&e.content)).sum();
let (mem, conv) = self.split_memory_conversation(count_msg); let mut mem = 0;
let mut conv = 0;
for entry in &self.entries {
let tokens = count_msg(entry.api_message());
if entry.is_memory() { mem += tokens } else { conv += tokens }
}
ContextBudget { ContextBudget {
identity_tokens: id, identity_tokens: id,
memory_tokens: mem, memory_tokens: mem,
@ -359,40 +438,6 @@ impl ContextState {
} }
} }
/// Split conversation messages into memory tool interactions and
/// everything else. Returns (memory_tokens, conversation_tokens).
pub fn split_memory_conversation(&self, count: &dyn Fn(&Message) -> usize) -> (usize, usize) {
// Collect tool_call_ids that belong to memory tools
let mut memory_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in &self.messages {
if let Some(ref calls) = msg.tool_calls {
for call in calls {
if call.function.name.starts_with("memory_")
|| call.function.name.starts_with("journal_") {
memory_call_ids.insert(call.id.clone());
}
}
}
}
let mut mem_tokens = 0;
let mut conv_tokens = 0;
for msg in &self.messages {
let tokens = count(msg);
let is_memory = match &msg.tool_call_id {
Some(id) => memory_call_ids.contains(id),
None => msg.tool_calls.as_ref().map_or(false, |calls|
calls.iter().all(|c| memory_call_ids.contains(&c.id))),
};
if is_memory {
mem_tokens += tokens;
} else {
conv_tokens += tokens;
}
}
(mem_tokens, conv_tokens)
}
pub fn render_context_message(&self) -> String { pub fn render_context_message(&self) -> String {
let mut parts: Vec<String> = self.personality.iter() let mut parts: Vec<String> = self.personality.iter()
.map(|(name, content)| format!("## {}\n\n{}", name, content)) .map(|(name, content)| format!("## {}\n\n{}", name, content))

View file

@ -464,9 +464,9 @@ impl Session {
} }
"/context" => { "/context" => {
if let Ok(agent) = self.agent.try_lock() { if let Ok(agent) = self.agent.try_lock() {
let msgs = agent.messages(); let msgs = agent.entries();
let total_chars: usize = let total_chars: usize =
msgs.iter().map(|m| m.content_text().len()).sum(); msgs.iter().map(|e| e.message().content_text().len()).sum();
let prompt_tokens = agent.last_prompt_tokens(); let prompt_tokens = agent.last_prompt_tokens();
let threshold = compaction_threshold(agent.model(), &self.config.app); let threshold = compaction_threshold(agent.model(), &self.config.app);
let _ = self.ui_tx.send(UiMessage::Info(format!( let _ = self.ui_tx.send(UiMessage::Info(format!(
@ -587,15 +587,15 @@ impl Session {
return Command::Handled; return Command::Handled;
} }
let mut agent_guard = self.agent.lock().await; let mut agent_guard = self.agent.lock().await;
let msgs = agent_guard.messages_mut(); let entries = agent_guard.entries_mut();
let mut last_user_text = None; let mut last_user_text = None;
while let Some(msg) = msgs.last() { while let Some(entry) = entries.last() {
if msg.role == poc_memory::agent::types::Role::User { if entry.message().role == poc_memory::agent::types::Role::User {
last_user_text = last_user_text =
Some(msgs.pop().unwrap().content_text().to_string()); Some(entries.pop().unwrap().message().content_text().to_string());
break; break;
} }
msgs.pop(); entries.pop();
} }
drop(agent_guard); drop(agent_guard);
match last_user_text { match last_user_text {
@ -936,7 +936,7 @@ async fn run(cli: cli::CliArgs) -> Result<()> {
config.context_parts.clone(), config.context_parts.clone(),
); );
if restored { if restored {
replay_session_to_ui(agent_guard.messages(), &ui_tx); replay_session_to_ui(agent_guard.entries(), &ui_tx);
let _ = ui_tx.send(UiMessage::Info( let _ = ui_tx.send(UiMessage::Info(
"--- restored from conversation log ---".into(), "--- restored from conversation log ---".into(),
)); ));
@ -944,7 +944,7 @@ async fn run(cli: cli::CliArgs) -> Result<()> {
if let Ok(data) = std::fs::read_to_string(&session_file) { if let Ok(data) = std::fs::read_to_string(&session_file) {
if let Ok(messages) = serde_json::from_str(&data) { if let Ok(messages) = serde_json::from_str(&data) {
agent_guard.restore(messages); agent_guard.restore(messages);
replay_session_to_ui(agent_guard.messages(), &ui_tx); replay_session_to_ui(agent_guard.entries(), &ui_tx);
let _ = ui_tx.send(UiMessage::Info( let _ = ui_tx.send(UiMessage::Info(
"--- restored from session file ---".into(), "--- restored from session file ---".into(),
)); ));
@ -1104,7 +1104,7 @@ fn drain_ui_messages(rx: &mut ui_channel::UiReceiver, app: &mut tui::App) {
} }
fn save_session(agent: &Agent, path: &PathBuf) -> Result<()> { fn save_session(agent: &Agent, path: &PathBuf) -> Result<()> {
let data = serde_json::to_string_pretty(agent.messages())?; let data = serde_json::to_string_pretty(agent.entries())?;
std::fs::write(path, data)?; std::fs::write(path, data)?;
Ok(()) Ok(())
} }
@ -1186,21 +1186,23 @@ async fn run_tool_tests(ui_tx: &ui_channel::UiSender, tracker: &tools::ProcessTr
/// conversation history immediately on restart. Shows user input, /// conversation history immediately on restart. Shows user input,
/// assistant responses, and brief tool call summaries. Skips the system /// assistant responses, and brief tool call summaries. Skips the system
/// prompt, context message, DMN plumbing, and image injection messages. /// prompt, context message, DMN plumbing, and image injection messages.
fn replay_session_to_ui(messages: &[types::Message], ui_tx: &ui_channel::UiSender) { fn replay_session_to_ui(entries: &[types::ConversationEntry], ui_tx: &ui_channel::UiSender) {
use poc_memory::agent::ui_channel::StreamTarget; use poc_memory::agent::ui_channel::StreamTarget;
dbglog!("[replay] replaying {} messages to UI", messages.len()); dbglog!("[replay] replaying {} entries to UI", entries.len());
for (i, m) in messages.iter().enumerate() { for (i, e) in entries.iter().enumerate() {
let m = e.message();
let preview: String = m.content_text().chars().take(60).collect(); let preview: String = m.content_text().chars().take(60).collect();
dbglog!("[replay] [{}] {:?} tc={} tcid={:?} {:?}", dbglog!("[replay] [{}] {:?} mem={} tc={} tcid={:?} {:?}",
i, m.role, m.tool_calls.as_ref().map_or(0, |t| t.len()), i, m.role, e.is_memory(), m.tool_calls.as_ref().map_or(0, |t| t.len()),
m.tool_call_id.as_deref(), preview); m.tool_call_id.as_deref(), preview);
} }
let mut seen_first_user = false; let mut seen_first_user = false;
let mut target = StreamTarget::Conversation; let mut target = StreamTarget::Conversation;
for msg in messages { for entry in entries {
let msg = entry.message();
match msg.role { match msg.role {
types::Role::System => {} types::Role::System => {}
types::Role::User => { types::Role::User => {