diff --git a/src/agent/runner.rs b/src/agent/runner.rs index cba6296..de3f659 100644 --- a/src/agent/runner.rs +++ b/src/agent/runner.rs @@ -103,7 +103,7 @@ impl Agent { journal: Vec::new(), working_stack: Vec::new(), loaded_nodes: Vec::new(), - messages: Vec::new(), + entries: Vec::new(), }; let session_id = format!("poc-agent-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")); let agent_cycles = crate::subconscious::subconscious::AgentCycleState::new(&session_id); @@ -140,7 +140,7 @@ impl Agent { if !jnl.is_empty() { msgs.push(Message::user(jnl)); } - msgs.extend(self.context.messages.iter().cloned()); + msgs.extend(self.context.entries.iter().map(|e| e.api_message().clone())); msgs } @@ -173,7 +173,7 @@ impl Agent { eprintln!("warning: failed to log message: {:#}", e); } } - self.context.messages.push(msg); + self.context.entries.push(ConversationEntry::Message(msg)); } /// Push a context-only message (system prompt, identity context, @@ -673,32 +673,41 @@ impl Agent { } // Conversation — each message as a child - let conv_messages = &self.context.messages; + let conv_messages = &self.context.entries; let conv_children: Vec = conv_messages.iter().enumerate() - .map(|(i, msg)| { - let text = msg.content.as_ref() + .map(|(i, entry)| { + let m = entry.message(); + let text = m.content.as_ref() .map(|c| c.as_text().to_string()) .unwrap_or_default(); - let tool_info = msg.tool_calls.as_ref().map(|tc| { + let tool_info = m.tool_calls.as_ref().map(|tc| { tc.iter() .map(|c| c.function.name.clone()) .collect::>() .join(", ") }); - let label = match (&msg.role, &tool_info) { - (_, Some(tools)) => format!("[tool_call: {}]", tools), - _ => { - let preview: String = text.chars().take(60).collect(); - let preview = preview.replace('\n', " "); - if text.len() > 60 { format!("{}...", preview) } else { preview } + let label = if entry.is_memory() { + if let ConversationEntry::Memory { key, .. } = entry { + format!("[memory: {}]", key) + } else { unreachable!() } + } else { + match &tool_info { + Some(tools) => format!("[tool_call: {}]", tools), + None => { + let preview: String = text.chars().take(60).collect(); + let preview = preview.replace('\n', " "); + if text.len() > 60 { format!("{}...", preview) } else { preview } + } } }; let tokens = count(&text); - let role_name = match msg.role { - Role::Assistant => "PoC", - Role::User => "Kent", - Role::Tool => "tool", - Role::System => "system", + let role_name = if entry.is_memory() { "mem" } else { + match m.role { + Role::Assistant => "PoC", + Role::User => "Kent", + Role::Tool => "tool", + Role::System => "system", + } }; ContextSection { name: format!("[{}] {}: {}", i, role_name, label), @@ -846,7 +855,8 @@ impl Agent { /// all previous ones. The tool result message (right before each image /// message) already records what was loaded, so no info is lost. fn age_out_images(&mut self) { - for msg in &mut self.context.messages { + for entry in &mut self.context.entries { + let msg = entry.message_mut(); if let Some(MessageContent::Parts(parts)) = &msg.content { let has_images = parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })); if !has_images { @@ -891,7 +901,8 @@ impl Agent { let mut strip_ids: Vec = Vec::new(); let mut strip_msg_indices: Vec = Vec::new(); - for (i, msg) in self.context.messages.iter().enumerate() { + for (i, entry) in self.context.entries.iter().enumerate() { + let msg = entry.message(); if msg.role != Role::Assistant { continue; } @@ -917,8 +928,8 @@ impl Agent { } // Remove in reverse order to preserve indices - self.context.messages.retain(|msg| { - // Strip the assistant messages we identified + self.context.entries.retain(|entry| { + let msg = entry.message(); if msg.role == Role::Assistant { if let Some(calls) = &msg.tool_calls { if calls.iter().all(|c| strip_ids.contains(&c.id)) { @@ -926,7 +937,6 @@ impl Agent { } } } - // Strip matching tool results if msg.role == Role::Tool { if let Some(ref id) = msg.tool_call_id { if strip_ids.contains(id) { @@ -955,7 +965,8 @@ impl Agent { /// Internal compaction — rebuilds context window from current messages. fn do_compact(&mut self) { - let conversation: Vec = self.context.messages.clone(); + let conversation: Vec = self.context.entries.iter() + .map(|e| e.api_message().clone()).collect(); let (messages, journal) = crate::agent::context::build_context_window( &self.context, &conversation, @@ -963,7 +974,8 @@ impl Agent { &self.tokenizer, ); self.context.journal = journal::parse_journal_text(&journal); - self.context.messages = messages; + self.context.entries = messages.into_iter() + .map(ConversationEntry::Message).collect(); self.last_prompt_tokens = 0; self.publish_context_state(); @@ -1025,8 +1037,9 @@ impl Agent { dbglog!("[restore] journal text: {} chars, {} lines", journal.len(), journal.lines().count()); self.context.journal = journal::parse_journal_text(&journal); - self.context.messages = messages; - dbglog!("[restore] built context window: {} messages", self.context.messages.len()); + self.context.entries = messages.into_iter() + .map(ConversationEntry::Message).collect(); + dbglog!("[restore] built context window: {} entries", self.context.entries.len()); self.last_prompt_tokens = 0; self.publish_context_state(); @@ -1043,19 +1056,19 @@ impl Agent { &self.client.model } - /// Get the conversation history for persistence. - pub fn messages(&self) -> &[Message] { - &self.context.messages + /// Get the conversation entries for persistence. + pub fn entries(&self) -> &[ConversationEntry] { + &self.context.entries } - /// Mutable access to conversation history (for /retry). - pub fn messages_mut(&mut self) -> &mut Vec { - &mut self.context.messages + /// Mutable access to conversation entries (for /retry). + pub fn entries_mut(&mut self) -> &mut Vec { + &mut self.context.entries } - /// Restore from a saved conversation. - pub fn restore(&mut self, messages: Vec) { - self.context.messages = messages; + /// Restore from saved conversation entries. + pub fn restore(&mut self, entries: Vec) { + self.context.entries = entries; } } diff --git a/src/agent/types.rs b/src/agent/types.rs index dbd8f4f..9491a7e 100644 --- a/src/agent/types.rs +++ b/src/agent/types.rs @@ -322,19 +322,93 @@ impl ToolDef { } /// Mutable context state — the structured regions of the context window. +/// Conversation entry — either a regular message or memory content. +/// Memory entries preserve the original message for KV cache round-tripping. #[derive(Debug, Clone)] +pub enum ConversationEntry { + Message(Message), + Memory { key: String, message: Message }, +} + +// Custom serde: serialize Memory with a "memory_key" field added to the message, +// plain messages serialize as-is. This keeps the conversation log readable. +impl Serialize for ConversationEntry { + fn serialize(&self, s: S) -> Result { + use serde::ser::SerializeMap; + match self { + Self::Message(m) => m.serialize(s), + Self::Memory { key, message } => { + // Serialize message fields + memory_key + let json = serde_json::to_value(message).map_err(serde::ser::Error::custom)?; + let mut map = s.serialize_map(None)?; + if let serde_json::Value::Object(obj) = json { + for (k, v) in obj { + map.serialize_entry(&k, &v)?; + } + } + map.serialize_entry("memory_key", key)?; + map.end() + } + } + } +} + +impl<'de> Deserialize<'de> for ConversationEntry { + fn deserialize>(d: D) -> Result { + let mut json: serde_json::Value = serde_json::Value::deserialize(d)?; + if let Some(key) = json.as_object_mut().and_then(|o| o.remove("memory_key")) { + let key = key.as_str().unwrap_or("").to_string(); + let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?; + Ok(Self::Memory { key, message }) + } else { + let message: Message = serde_json::from_value(json).map_err(serde::de::Error::custom)?; + Ok(Self::Message(message)) + } + } +} + +impl ConversationEntry { + /// Get the API message for sending to the model. + pub fn api_message(&self) -> &Message { + match self { + Self::Message(m) => m, + Self::Memory { message, .. } => message, + } + } + + pub fn is_memory(&self) -> bool { + matches!(self, Self::Memory { .. }) + } + + /// Get a reference to the inner message. + pub fn message(&self) -> &Message { + match self { + Self::Message(m) => m, + Self::Memory { message, .. } => message, + } + } + + /// Get a mutable reference to the inner message. + pub fn message_mut(&mut self) -> &mut Message { + match self { + Self::Message(m) => m, + Self::Memory { message, .. } => message, + } + } +} + pub struct ContextState { pub system_prompt: String, pub personality: Vec<(String, String)>, pub journal: Vec, pub working_stack: Vec, - /// Memory nodes currently loaded in the context window. + /// Memory nodes currently loaded — for debug display and refresh. + /// Content is NOT duplicated here; the actual content is in entries + /// as ConversationEntry::Memory. pub loaded_nodes: Vec, - /// Conversation messages (user, assistant, tool turns). - /// Does NOT include system prompt, personality, or journal — - /// those are rendered from their typed sources when assembling - /// the API call. - pub messages: Vec, + /// Conversation entries — messages and memory, interleaved in order. + /// Does NOT include system prompt, personality, or journal. + pub entries: Vec, } // TODO: these should not be hardcoded absolute paths @@ -349,7 +423,12 @@ impl ContextState { let id = count_str(&self.system_prompt) + self.personality.iter().map(|(_, c)| count_str(c)).sum::(); let jnl: usize = self.journal.iter().map(|e| count_str(&e.content)).sum(); - let (mem, conv) = self.split_memory_conversation(count_msg); + let mut mem = 0; + let mut conv = 0; + for entry in &self.entries { + let tokens = count_msg(entry.api_message()); + if entry.is_memory() { mem += tokens } else { conv += tokens } + } ContextBudget { identity_tokens: id, memory_tokens: mem, @@ -359,40 +438,6 @@ impl ContextState { } } - /// Split conversation messages into memory tool interactions and - /// everything else. Returns (memory_tokens, conversation_tokens). - pub fn split_memory_conversation(&self, count: &dyn Fn(&Message) -> usize) -> (usize, usize) { - // Collect tool_call_ids that belong to memory tools - let mut memory_call_ids: std::collections::HashSet = std::collections::HashSet::new(); - for msg in &self.messages { - if let Some(ref calls) = msg.tool_calls { - for call in calls { - if call.function.name.starts_with("memory_") - || call.function.name.starts_with("journal_") { - memory_call_ids.insert(call.id.clone()); - } - } - } - } - - let mut mem_tokens = 0; - let mut conv_tokens = 0; - for msg in &self.messages { - let tokens = count(msg); - let is_memory = match &msg.tool_call_id { - Some(id) => memory_call_ids.contains(id), - None => msg.tool_calls.as_ref().map_or(false, |calls| - calls.iter().all(|c| memory_call_ids.contains(&c.id))), - }; - if is_memory { - mem_tokens += tokens; - } else { - conv_tokens += tokens; - } - } - (mem_tokens, conv_tokens) - } - pub fn render_context_message(&self) -> String { let mut parts: Vec = self.personality.iter() .map(|(name, content)| format!("## {}\n\n{}", name, content)) diff --git a/src/bin/poc-agent.rs b/src/bin/poc-agent.rs index 09e481f..eda6961 100644 --- a/src/bin/poc-agent.rs +++ b/src/bin/poc-agent.rs @@ -464,9 +464,9 @@ impl Session { } "/context" => { if let Ok(agent) = self.agent.try_lock() { - let msgs = agent.messages(); + let msgs = agent.entries(); let total_chars: usize = - msgs.iter().map(|m| m.content_text().len()).sum(); + msgs.iter().map(|e| e.message().content_text().len()).sum(); let prompt_tokens = agent.last_prompt_tokens(); let threshold = compaction_threshold(agent.model(), &self.config.app); let _ = self.ui_tx.send(UiMessage::Info(format!( @@ -587,15 +587,15 @@ impl Session { return Command::Handled; } let mut agent_guard = self.agent.lock().await; - let msgs = agent_guard.messages_mut(); + let entries = agent_guard.entries_mut(); let mut last_user_text = None; - while let Some(msg) = msgs.last() { - if msg.role == poc_memory::agent::types::Role::User { + while let Some(entry) = entries.last() { + if entry.message().role == poc_memory::agent::types::Role::User { last_user_text = - Some(msgs.pop().unwrap().content_text().to_string()); + Some(entries.pop().unwrap().message().content_text().to_string()); break; } - msgs.pop(); + entries.pop(); } drop(agent_guard); match last_user_text { @@ -936,7 +936,7 @@ async fn run(cli: cli::CliArgs) -> Result<()> { config.context_parts.clone(), ); if restored { - replay_session_to_ui(agent_guard.messages(), &ui_tx); + replay_session_to_ui(agent_guard.entries(), &ui_tx); let _ = ui_tx.send(UiMessage::Info( "--- restored from conversation log ---".into(), )); @@ -944,7 +944,7 @@ async fn run(cli: cli::CliArgs) -> Result<()> { if let Ok(data) = std::fs::read_to_string(&session_file) { if let Ok(messages) = serde_json::from_str(&data) { agent_guard.restore(messages); - replay_session_to_ui(agent_guard.messages(), &ui_tx); + replay_session_to_ui(agent_guard.entries(), &ui_tx); let _ = ui_tx.send(UiMessage::Info( "--- restored from session file ---".into(), )); @@ -1104,7 +1104,7 @@ fn drain_ui_messages(rx: &mut ui_channel::UiReceiver, app: &mut tui::App) { } fn save_session(agent: &Agent, path: &PathBuf) -> Result<()> { - let data = serde_json::to_string_pretty(agent.messages())?; + let data = serde_json::to_string_pretty(agent.entries())?; std::fs::write(path, data)?; Ok(()) } @@ -1186,21 +1186,23 @@ async fn run_tool_tests(ui_tx: &ui_channel::UiSender, tracker: &tools::ProcessTr /// conversation history immediately on restart. Shows user input, /// assistant responses, and brief tool call summaries. Skips the system /// prompt, context message, DMN plumbing, and image injection messages. -fn replay_session_to_ui(messages: &[types::Message], ui_tx: &ui_channel::UiSender) { +fn replay_session_to_ui(entries: &[types::ConversationEntry], ui_tx: &ui_channel::UiSender) { use poc_memory::agent::ui_channel::StreamTarget; - dbglog!("[replay] replaying {} messages to UI", messages.len()); - for (i, m) in messages.iter().enumerate() { + dbglog!("[replay] replaying {} entries to UI", entries.len()); + for (i, e) in entries.iter().enumerate() { + let m = e.message(); let preview: String = m.content_text().chars().take(60).collect(); - dbglog!("[replay] [{}] {:?} tc={} tcid={:?} {:?}", - i, m.role, m.tool_calls.as_ref().map_or(0, |t| t.len()), + dbglog!("[replay] [{}] {:?} mem={} tc={} tcid={:?} {:?}", + i, m.role, e.is_memory(), m.tool_calls.as_ref().map_or(0, |t| t.len()), m.tool_call_id.as_deref(), preview); } let mut seen_first_user = false; let mut target = StreamTarget::Conversation; - for msg in messages { + for entry in entries { + let msg = entry.message(); match msg.role { types::Role::System => {} types::Role::User => {