From f1397b7783f625743facb7e42a19fadc20b775bf Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Wed, 8 Apr 2026 13:35:04 -0400 Subject: [PATCH] Redesign context AST: typed NodeBody, Role as grammar roles, tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Role is now just System/User/Assistant — maps 1:1 to the grammar. Leaf types are NodeBody variants: Content, Thinking, ToolCall, ToolResult, Memory, Dmn, Log. Each variant renders itself; no Role needed on leaves. AstNode is Leaf(NodeLeaf) | Branch{role, children}. ContextState holds four Vec sections directly. Moved tool call XML parsing from api/parsing.rs into context_new.rs so all grammar knowledge lives in one place. Tokenizer encode() now returns empty vec when uninitialized instead of panicking, so tests work without the tokenizer file. 26 tests: XML parsing, incremental streaming (char-by-char feeds found and fixed a lookahead bug), rendering for all node types, tokenizer round-trip verification. Co-Authored-By: Proof of Concept --- src/agent/context_new.rs | 1066 ++++++++++++++++++++++++++------------ src/agent/tokenizer.rs | 25 +- 2 files changed, 752 insertions(+), 339 deletions(-) diff --git a/src/agent/context_new.rs b/src/agent/context_new.rs index 0609596..35325f0 100644 --- a/src/agent/context_new.rs +++ b/src/agent/context_new.rs @@ -1,25 +1,33 @@ // context.rs — Context window as an AST // -// The context window is a tree of AstNodes. Each node has a role, a body -// (either leaf content or children), and cached token IDs. The full prompt -// for the model is a depth-first traversal. Streaming responses are parsed -// into new nodes by the ResponseParser. +// The context window is a tree of AstNodes. Each node is either a leaf +// (typed content with cached token IDs) or a branch (role + children). +// The full prompt is a depth-first traversal of the sections in ContextState. +// Streaming responses are parsed into new nodes by the ResponseParser. // // Grammar (EBNF): // // context = section* ; -// section = message* ; -// message = IM_START role "\n" body IM_END "\n" ; -// body = content | element* ; (* leaf or branch *) +// section = (message | leaf)* ; +// message = IM_START role "\n" element* IM_END "\n" ; +// role = "system" | "user" | "assistant" ; // element = thinking | tool_call | content ; -// thinking = "" content "" ; +// thinking = "" TEXT "" ; // tool_call = "\n" tool_xml "\n" ; // tool_xml = "\n" param* "" ; // param = "\n" VALUE "\n\n" ; -// content = TOKEN* ; +// content = TEXT ; // -// The AST is uniform: one AstNode type for everything. The grammar -// constraints are enforced by construction, not by the type system. +// Self-wrapping leaves (not inside a message branch): +// dmn = IM_START "dmn\n" TEXT IM_END "\n" ; +// memory = IM_START "memory\n" TEXT IM_END "\n" ; +// tool_result = IM_START "tool\n" TEXT IM_END "\n" ; +// +// Non-visible leaves (not in prompt): +// log = TEXT ; +// +// Role is only for branch (interior) nodes. Leaf type is determined by +// the NodeBody variant. Grammar constraints enforced by construction. use chrono::{DateTime, Utc}; use super::tokenizer; @@ -28,122 +36,132 @@ use super::tokenizer; // Types // --------------------------------------------------------------------------- +/// Branch roles — maps directly to the grammar's message roles. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Role { - // Sections (top-level groupings) - SystemSection, - IdentitySection, - JournalSection, - ConversationSection, - - // Messages (top-level, get im_start/im_end wrapping) System, User, Assistant, - Tool, - Dmn, - Memory, - Log, - - // Children of Assistant - Thinking, - ToolCall, - Content, } +/// Leaf content — each variant knows how to render itself. #[derive(Debug, Clone)] pub enum NodeBody { - Leaf { text: String, token_ids: Vec }, - Branch(Vec), + // Children of message branches — rendered without im_start/im_end + Content(String), + Thinking(String), + ToolCall { name: String, arguments: String }, + + // Self-wrapping leaves — render their own im_start/im_end + ToolResult(String), + Memory { key: String, text: String, score: Option }, + Dmn(String), + + // Non-visible (0 tokens in prompt) + Log(String), } +/// A leaf node: typed content with cached token IDs. #[derive(Debug, Clone)] -pub struct AstNode { - role: Role, +pub struct NodeLeaf { body: NodeBody, + token_ids: Vec, timestamp: Option>, - memory_key: Option, - memory_score: Option, - tool_name: Option, - tool_args: Option, - tool_call_id: Option, +} + +/// A node in the context AST. +#[derive(Debug, Clone)] +pub enum AstNode { + Leaf(NodeLeaf), + Branch { role: Role, children: Vec }, +} + +/// The context window: four sections as Vec. +/// All mutation goes through ContextState methods to maintain the invariant +/// that token_ids on every leaf matches its rendered text. +pub struct ContextState { + system: Vec, + identity: Vec, + journal: Vec, + conversation: Vec, +} + +/// Identifies a section for mutation methods. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Section { + System, + Identity, + Journal, + Conversation, +} + +/// State machine for parsing a streaming assistant response into an AstNode. +/// Feed text chunks as they arrive; completed tool calls are returned for +/// immediate dispatch. +pub struct ResponseParser { + buf: String, + content_parts: Vec, + children: Vec, + in_think: bool, + think_buf: String, + in_tool_call: bool, + tool_call_buf: String, } impl Role { pub fn as_str(&self) -> &'static str { match self { - Self::System | Self::SystemSection => "system", - Self::User | Self::IdentitySection | Self::JournalSection => "user", + Self::System => "system", + Self::User => "user", Self::Assistant => "assistant", - Self::Tool => "tool", - Self::Dmn => "dmn", - Self::Memory => "memory", - Self::Log => "log", - Self::Thinking => "thinking", - Self::ToolCall => "tool_call", - Self::Content => "content", - Self::ConversationSection => "conversation", } } - - /// Whether this node contributes tokens to the prompt. - pub fn is_prompt_visible(&self) -> bool { - !matches!(self, Self::Thinking | Self::Log) - } - - /// Whether this is a top-level message (gets im_start/im_end wrapping). - pub fn is_message(&self) -> bool { - matches!(self, Self::System | Self::User | Self::Assistant | - Self::Tool | Self::Dmn | Self::Memory | Self::Log) - } - - /// Whether this is a section (container for messages). - pub fn is_section(&self) -> bool { - matches!(self, Self::SystemSection | Self::IdentitySection | - Self::JournalSection | Self::ConversationSection) - } } -impl AstNode { - /// Create a leaf node (content string, no children). - pub fn leaf(role: Role, content: impl Into) -> Self { - let content = content.into(); - let token_ids = tokenize_leaf(role, &content); - Self { - role, body: NodeBody::Leaf { text: content, token_ids }, - timestamp: None, memory_key: None, memory_score: None, - tool_name: None, tool_args: None, tool_call_id: None, +impl NodeBody { + /// Render this leaf body to text for the prompt. + fn render(&self) -> String { + match self { + Self::Content(text) => text.clone(), + Self::Thinking(_) => String::new(), + Self::Log(_) => String::new(), + Self::ToolCall { name, arguments } => { + let xml = format_tool_call_xml(name, arguments); + format!("\n{}\n\n", xml) + } + Self::ToolResult(text) => + format!("<|im_start|>tool\n{}<|im_end|>\n", text), + Self::Memory { text, .. } => + format!("<|im_start|>memory\n{}<|im_end|>\n", text), + Self::Dmn(text) => + format!("<|im_start|>dmn\n{}<|im_end|>\n", text), } } - /// Create a branch node (children, no direct content). - pub fn branch(role: Role, children: Vec) -> Self { - Self { - role, body: NodeBody::Branch(children), - timestamp: None, memory_key: None, memory_score: None, - tool_name: None, tool_args: None, tool_call_id: None, + /// Whether this leaf contributes tokens to the prompt. + fn is_prompt_visible(&self) -> bool { + !matches!(self, Self::Thinking(_) | Self::Log(_)) + } + + /// The text content of this leaf (for display, not rendering). + pub fn text(&self) -> &str { + match self { + Self::Content(t) | Self::Thinking(t) | Self::Log(t) + | Self::ToolResult(t) | Self::Dmn(t) => t, + Self::ToolCall { name, .. } => name, + Self::Memory { text, .. } => text, } } +} - /// Create a memory node. - pub fn memory(key: impl Into, content: impl Into) -> Self { - let mut node = Self::leaf(Role::Memory, content); - node.memory_key = Some(key.into()); - node - } - - /// Create a tool call node. - pub fn tool_call(id: impl Into, name: impl Into, args: impl Into) -> Self { - let name = name.into(); - let args = args.into(); - let id = id.into(); - // Format the XML body for tokenization - let xml = format_tool_call_xml(&name, &args); - let mut node = Self::leaf(Role::ToolCall, xml); - node.tool_name = Some(name); - node.tool_args = Some(args); - node.tool_call_id = Some(id); - node +impl NodeLeaf { + fn new(body: NodeBody) -> Self { + let token_ids = if body.is_prompt_visible() { + tokenizer::encode(&body.render()) + } else { + vec![] + }; + Self { body, token_ids, timestamp: None } } pub fn with_timestamp(mut self, ts: DateTime) -> Self { @@ -151,80 +169,147 @@ impl AstNode { self } - /// Token count — leaf returns cached len, branch sums children recursively. - pub fn tokens(&self) -> usize { - match &self.body { - NodeBody::Leaf { token_ids, .. } => token_ids.len(), - NodeBody::Branch(children) => children.iter().map(|c| c.tokens()).sum(), - } - } - - /// Get token IDs — leaf returns cached, branch walks children. - pub fn token_ids(&self) -> Vec { - match &self.body { - NodeBody::Leaf { token_ids, .. } => token_ids.clone(), - NodeBody::Branch(children) => { - tokenize_branch(self.role, children) - } - } - } - - /// Render this node to text (same output as tokenization, but UTF-8). - pub fn render(&self) -> String { - match &self.body { - NodeBody::Leaf { text, .. } => render_leaf(self.role, text), - NodeBody::Branch(children) => render_branch(self.role, children), - } - } - - /// Get the content string (leaf nodes only). - pub fn content(&self) -> &str { - match &self.body { - NodeBody::Leaf { text, .. } => text, - NodeBody::Branch(_) => "", - } - } - - /// Get children (branch nodes only). - pub fn children(&self) -> &[AstNode] { - match &self.body { - NodeBody::Branch(c) => c, - NodeBody::Leaf { .. } => &[], - } - } - - pub fn role(&self) -> Role { self.role } + pub fn body(&self) -> &NodeBody { &self.body } + pub fn token_ids(&self) -> &[u32] { &self.token_ids } + pub fn tokens(&self) -> usize { self.token_ids.len() } pub fn timestamp(&self) -> Option> { self.timestamp } - pub fn memory_key(&self) -> Option<&str> { self.memory_key.as_deref() } - pub fn memory_score(&self) -> Option { self.memory_score } - pub fn tool_name(&self) -> Option<&str> { self.tool_name.as_deref() } - pub fn tool_args(&self) -> Option<&str> { self.tool_args.as_deref() } - pub fn tool_call_id(&self) -> Option<&str> { self.tool_call_id.as_deref() } +} + +impl AstNode { + // -- Leaf constructors ---------------------------------------------------- + + pub fn content(text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::Content(text.into()))) + } + + pub fn thinking(text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::Thinking(text.into()))) + } + + pub fn tool_call(name: impl Into, arguments: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::ToolCall { + name: name.into(), + arguments: arguments.into(), + })) + } + + pub fn tool_result(text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::ToolResult(text.into()))) + } + + pub fn memory(key: impl Into, text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::Memory { + key: key.into(), + text: text.into(), + score: None, + })) + } + + pub fn dmn(text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::Dmn(text.into()))) + } + + pub fn log(text: impl Into) -> Self { + Self::Leaf(NodeLeaf::new(NodeBody::Log(text.into()))) + } + + // -- Branch constructors -------------------------------------------------- + + pub fn branch(role: Role, children: Vec) -> Self { + Self::Branch { role, children } + } + + pub fn system_msg(text: impl Into) -> Self { + Self::Branch { + role: Role::System, + children: vec![Self::content(text)], + } + } + + pub fn user_msg(text: impl Into) -> Self { + Self::Branch { + role: Role::User, + children: vec![Self::content(text)], + } + } + + // -- Builder -------------------------------------------------------------- + + pub fn with_timestamp(mut self, ts: DateTime) -> Self { + match &mut self { + Self::Leaf(leaf) => leaf.timestamp = Some(ts), + Self::Branch { .. } => {} + } + self + } + + // -- Accessors ------------------------------------------------------------ + + pub fn tokens(&self) -> usize { + match self { + Self::Leaf(leaf) => leaf.tokens(), + Self::Branch { children, .. } => + children.iter().map(|c| c.tokens()).sum(), + } + } + + pub fn token_ids(&self) -> Vec { + match self { + Self::Leaf(leaf) => leaf.token_ids.clone(), + Self::Branch { role, children } => + tokenize_branch(*role, children), + } + } + + pub fn render(&self) -> String { + match self { + Self::Leaf(leaf) => leaf.body.render(), + Self::Branch { role, children } => + render_branch(*role, children), + } + } + + pub fn children(&self) -> &[AstNode] { + match self { + Self::Branch { children, .. } => children, + Self::Leaf(_) => &[], + } + } + + pub fn leaf(&self) -> Option<&NodeLeaf> { + match self { + Self::Leaf(l) => Some(l), + _ => None, + } + } /// Short label for the UI. pub fn label(&self) -> String { let cfg = crate::config::get(); - match self.role { - Role::System => "system".to_string(), - Role::User => format!("{}: {}", cfg.user_name, - truncate_preview(self.content(), 60)), - Role::Assistant => format!("{}: {}", cfg.assistant_name, - truncate_preview(self.content(), 60)), - Role::Tool => "tool_result".to_string(), - Role::Dmn => "dmn".to_string(), - Role::Memory => match (&self.memory_key, self.memory_score) { - (Some(key), Some(s)) => format!("mem: {} score:{:.1}", key, s), - (Some(key), None) => format!("mem: {}", key), - _ => "mem".to_string(), + match self { + Self::Branch { role, children } => { + let preview = children.first() + .and_then(|c| c.leaf()) + .map(|l| truncate_preview(l.body.text(), 60)) + .unwrap_or_default(); + match role { + Role::System => "system".into(), + Role::User => format!("{}: {}", cfg.user_name, preview), + Role::Assistant => format!("{}: {}", cfg.assistant_name, preview), + } + } + Self::Leaf(leaf) => match &leaf.body { + NodeBody::Content(t) => truncate_preview(t, 60), + NodeBody::Thinking(t) => format!("thinking: {}", truncate_preview(t, 60)), + NodeBody::ToolCall { name, .. } => format!("tool_call: {}", name), + NodeBody::ToolResult(_) => "tool_result".into(), + NodeBody::Memory { key, score, .. } => match score { + Some(s) => format!("mem: {} score:{:.1}", key, s), + None => format!("mem: {}", key), + }, + NodeBody::Dmn(_) => "dmn".into(), + NodeBody::Log(t) => format!("log: {}", truncate_preview(t, 60)), }, - Role::Thinking => format!("thinking: {}", - truncate_preview(self.content(), 60)), - Role::ToolCall => format!("tool_call: {}", - self.tool_name.as_deref().unwrap_or("?")), - Role::Content => truncate_preview(self.content(), 60), - Role::Log => format!("log: {}", - truncate_preview(self.content(), 60)), - _ => self.role.as_str().to_string(), } } } @@ -235,65 +320,19 @@ fn truncate_preview(s: &str, max: usize) -> String { if s.len() > max { format!("{}...", preview) } else { preview } } -// --------------------------------------------------------------------------- -// Serialization — two modes, same output -// --------------------------------------------------------------------------- - -/// Render a leaf node to text. -fn render_leaf(role: Role, content: &str) -> String { - if !role.is_prompt_visible() { - return String::new(); - } - - if role.is_message() { - format!("<|im_start|>{}\n{}<|im_end|>\n", role.as_str(), content) - } else { - match role { - Role::Thinking => format!("\n{}\n\n", content), - Role::ToolCall => format!("\n{}\n\n", content), - Role::Content => content.to_string(), - _ => content.to_string(), - } - } -} - -/// Render a branch node to text. fn render_branch(role: Role, children: &[AstNode]) -> String { - if !role.is_prompt_visible() { - return String::new(); - } - - if role.is_section() { - children.iter().map(|c| c.render()).collect() - } else if role == Role::Assistant { - let mut s = String::from("<|im_start|>assistant\n"); - for child in children { - s.push_str(&child.render()); - } - s.push_str("<|im_end|>\n"); - s - } else { - children.iter().map(|c| c.render()).collect() + let mut s = format!("<|im_start|>{}\n", role.as_str()); + for child in children { + s.push_str(&child.render()); } + s.push_str("<|im_end|>\n"); + s } -/// Tokenize a leaf node. -fn tokenize_leaf(role: Role, content: &str) -> Vec { - if !role.is_prompt_visible() { - return vec![]; - } - tokenizer::encode(&render_leaf(role, content)) -} - -/// Tokenize a branch node from its children. fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec { - if !role.is_prompt_visible() { - return vec![]; - } tokenizer::encode(&render_branch(role, children)) } -/// Format a tool call as the XML body expected by Qwen. fn format_tool_call_xml(name: &str, args_json: &str) -> String { let args: serde_json::Value = serde_json::from_str(args_json) .unwrap_or(serde_json::Value::Object(Default::default())); @@ -311,25 +350,69 @@ fn format_tool_call_xml(name: &str, args_json: &str) -> String { xml } -// --------------------------------------------------------------------------- -// Streaming response parser -// --------------------------------------------------------------------------- +fn normalize_xml_tags(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut chars = text.chars().peekable(); + while let Some(ch) = chars.next() { + if ch == '<' { + let mut tag = String::from('<'); + for inner in chars.by_ref() { + if inner == '>' { + tag.push('>'); + break; + } else if inner.is_whitespace() { + // Skip whitespace inside tags + } else { + tag.push(inner); + } + } + result.push_str(&tag); + } else { + result.push(ch); + } + } + result +} -/// State machine for parsing a streaming assistant response into an AstNode. -/// Feed text chunks as they arrive; completed tool calls are returned for -/// immediate dispatch. -pub struct ResponseParser { - /// Buffered text not yet committed to a node - buf: String, - /// Content fragments collected so far (between tags) - content_parts: Vec, - /// Completed child nodes - children: Vec, - /// Parse state - in_think: bool, - think_buf: String, - in_tool_call: bool, - tool_call_buf: String, +fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> { + let open = format!("<{}=", tag); + let close = format!("", tag); + + let start = s.find(&open)? + open.len(); + let name_end = start + s[start..].find('>')?; + let body_start = name_end + 1; + let body_end = body_start + s[body_start..].find(&close)?; + + Some(( + s[start..name_end].trim(), + s[body_start..body_end].trim(), + &s[body_end + close.len()..], + )) +} + +fn parse_tool_call_body(body: &str) -> Option<(String, String)> { + let normalized = normalize_xml_tags(body); + let body = normalized.trim(); + parse_xml_tool_call(body) + .or_else(|| parse_json_tool_call(body)) +} + +fn parse_xml_tool_call(body: &str) -> Option<(String, String)> { + let (func_name, func_body, _) = parse_qwen_tag(body, "function")?; + let mut args = serde_json::Map::new(); + let mut rest = func_body; + while let Some((key, val, remainder)) = parse_qwen_tag(rest, "parameter") { + args.insert(key.to_string(), serde_json::Value::String(val.to_string())); + rest = remainder; + } + Some((func_name.to_string(), serde_json::to_string(&args).unwrap_or_default())) +} + +fn parse_json_tool_call(body: &str) -> Option<(String, String)> { + let v: serde_json::Value = serde_json::from_str(body).ok()?; + let name = v["name"].as_str()?; + let arguments = &v["arguments"]; + Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default())) } impl ResponseParser { @@ -358,13 +441,18 @@ impl ResponseParser { self.think_buf.push_str(&self.buf[..end]); self.buf = self.buf[end + 8..].to_string(); self.in_think = false; - self.children.push(AstNode::leaf(Role::Thinking, &self.think_buf)); + self.children.push(AstNode::thinking(&self.think_buf)); self.think_buf.clear(); continue; } None => { - self.think_buf.push_str(&self.buf); - self.buf.clear(); + // Keep last 8 chars ("".len()) as lookahead + let safe = self.buf.len().saturating_sub(8); + if safe > 0 { + let safe = self.buf.floor_char_boundary(safe); + self.think_buf.push_str(&self.buf[..safe]); + self.buf = self.buf[safe..].to_string(); + } break; } } @@ -376,12 +464,8 @@ impl ResponseParser { self.tool_call_buf.push_str(&self.buf[..end]); self.buf = self.buf[end + 12..].to_string(); self.in_tool_call = false; - if let Some(call) = super::api::parsing::parse_tool_call_body(&self.tool_call_buf) { - let node = AstNode::tool_call( - call.id.clone(), - call.function.name.clone(), - call.function.arguments.clone(), - ); + if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { + let node = AstNode::tool_call(name, args); new_calls.push(node.clone()); self.flush_content(); self.children.push(node); @@ -390,14 +474,18 @@ impl ResponseParser { continue; } None => { - self.tool_call_buf.push_str(&self.buf); - self.buf.clear(); + // Keep last 12 chars ("".len()) as lookahead + let safe = self.buf.len().saturating_sub(12); + if safe > 0 { + let safe = self.buf.floor_char_boundary(safe); + self.tool_call_buf.push_str(&self.buf[..safe]); + self.buf = self.buf[safe..].to_string(); + } break; } } } - // Look for tag openings let think_pos = self.buf.find(""); let tool_pos = self.buf.find(""); let next_tag = match (think_pos, tool_pos) { @@ -409,7 +497,6 @@ impl ResponseParser { match next_tag { Some(pos) => { - // Content before the tag if pos > 0 { self.content_parts.push(self.buf[..pos].to_string()); } @@ -425,8 +512,6 @@ impl ResponseParser { continue; } None => { - // No complete tag. Keep last 11 chars as lookahead - // (length of "") to handle partial tags. let safe = self.buf.len().saturating_sub(11); if safe > 0 { let safe = self.buf.floor_char_boundary(safe); @@ -441,151 +526,134 @@ impl ResponseParser { new_calls } - /// Flush accumulated content into a Content child node. fn flush_content(&mut self) { if !self.content_parts.is_empty() { let text: String = self.content_parts.drain(..).collect(); if !text.is_empty() { - self.children.push(AstNode::leaf(Role::Content, text)); + self.children.push(AstNode::content(text)); } } } /// Finalize the parse. Returns the completed assistant AstNode. pub fn finish(mut self) -> AstNode { - // Remaining buffer is content if !self.buf.is_empty() { self.content_parts.push(std::mem::take(&mut self.buf)); } self.flush_content(); - AstNode::branch(Role::Assistant, self.children) } /// Get the current display text (for streaming to UI). - /// Returns content accumulated since the last call. pub fn display_content(&self) -> String { self.content_parts.join("") } } -// --------------------------------------------------------------------------- -// ContextState — the full context window -// --------------------------------------------------------------------------- - -/// The context window: four sections, each a branch AstNode. -/// All mutation goes through ContextState methods to maintain the invariant -/// that token_ids on every leaf matches its rendered text. -pub struct ContextState { - system: AstNode, - identity: AstNode, - journal: AstNode, - conversation: AstNode, -} - impl ContextState { pub fn new() -> Self { Self { - system: AstNode::branch(Role::SystemSection, vec![]), - identity: AstNode::branch(Role::IdentitySection, vec![]), - journal: AstNode::branch(Role::JournalSection, vec![]), - conversation: AstNode::branch(Role::ConversationSection, vec![]), + system: Vec::new(), + identity: Vec::new(), + journal: Vec::new(), + conversation: Vec::new(), } } // -- Read access ---------------------------------------------------------- - pub fn system(&self) -> &[AstNode] { self.system.children() } - pub fn identity(&self) -> &[AstNode] { self.identity.children() } - pub fn journal(&self) -> &[AstNode] { self.journal.children() } - pub fn conversation(&self) -> &[AstNode] { self.conversation.children() } + pub fn system(&self) -> &[AstNode] { &self.system } + pub fn identity(&self) -> &[AstNode] { &self.identity } + pub fn journal(&self) -> &[AstNode] { &self.journal } + pub fn conversation(&self) -> &[AstNode] { &self.conversation } pub fn tokens(&self) -> usize { - self.system.tokens() - + self.identity.tokens() - + self.journal.tokens() - + self.conversation.tokens() + self.system.iter().map(|n| n.tokens()).sum::() + + self.identity.iter().map(|n| n.tokens()).sum::() + + self.journal.iter().map(|n| n.tokens()).sum::() + + self.conversation.iter().map(|n| n.tokens()).sum::() } pub fn token_ids(&self) -> Vec { - let mut ids = self.system.token_ids(); - ids.extend(self.identity.token_ids()); - ids.extend(self.journal.token_ids()); - ids.extend(self.conversation.token_ids()); + let mut ids = Vec::new(); + for section in [&self.system, &self.identity, &self.journal, &self.conversation] { + for node in section { + ids.extend(node.token_ids()); + } + } ids } pub fn render(&self) -> String { - let mut s = self.system.render(); - s.push_str(&self.identity.render()); - s.push_str(&self.journal.render()); - s.push_str(&self.conversation.render()); + let mut s = String::new(); + for section in [&self.system, &self.identity, &self.journal, &self.conversation] { + for node in section { + s.push_str(&node.render()); + } + } s } // -- Mutation -------------------------------------------------------------- - fn section_mut(&mut self, role: Role) -> &mut AstNode { - match role { - Role::SystemSection => &mut self.system, - Role::IdentitySection => &mut self.identity, - Role::JournalSection => &mut self.journal, - Role::ConversationSection => &mut self.conversation, - _ => panic!("not a section role: {:?}", role), + fn section_mut(&mut self, section: Section) -> &mut Vec { + match section { + Section::System => &mut self.system, + Section::Identity => &mut self.identity, + Section::Journal => &mut self.journal, + Section::Conversation => &mut self.conversation, } } - fn children_mut(section: &mut AstNode) -> &mut Vec { - match &mut section.body { - NodeBody::Branch(c) => c, - _ => unreachable!("section is always a branch"), + pub fn push(&mut self, section: Section, node: AstNode) { + self.section_mut(section).push(node); + } + + /// Replace the body of a leaf at `index` in `section`. + /// Re-tokenizes to maintain the invariant. + pub fn set_message(&mut self, section: Section, index: usize, body: NodeBody) { + let nodes = self.section_mut(section); + let node = &mut nodes[index]; + match node { + AstNode::Leaf(leaf) => { + let token_ids = if body.is_prompt_visible() { + tokenizer::encode(&body.render()) + } else { + vec![] + }; + leaf.body = body; + leaf.token_ids = token_ids; + } + AstNode::Branch { .. } => panic!("set_message on branch node"), } } - /// Push a node into a section. - pub fn push(&mut self, section: Role, node: AstNode) { - let s = self.section_mut(section); - Self::children_mut(s).push(node); - } - - /// Replace the text content of a leaf at `index` in `section`. - /// Re-tokenizes the leaf to maintain the invariant. - pub fn set_message(&mut self, section: Role, index: usize, text: impl Into) { - let s = self.section_mut(section); - let node = &mut Self::children_mut(s)[index]; - let text = text.into(); - let token_ids = tokenize_leaf(node.role, &text); - node.body = NodeBody::Leaf { text, token_ids }; - } - - /// Set the memory score on a node at `index` in `section`. - pub fn set_score(&mut self, section: Role, index: usize, score: Option) { - let s = self.section_mut(section); - Self::children_mut(s)[index].memory_score = score; + /// Set the memory score on a Memory leaf at `index` in `section`. + pub fn set_score(&mut self, section: Section, index: usize, score: Option) { + let node = &mut self.section_mut(section)[index]; + match node { + AstNode::Leaf(leaf) => match &mut leaf.body { + NodeBody::Memory { score: s, .. } => *s = score, + _ => panic!("set_score on non-memory node"), + }, + _ => panic!("set_score on branch node"), + } } /// Remove a node at `index` from `section`. - pub fn del(&mut self, section: Role, index: usize) -> AstNode { - let s = self.section_mut(section); - Self::children_mut(s).remove(index) + pub fn del(&mut self, section: Section, index: usize) -> AstNode { + self.section_mut(section).remove(index) } } -// --------------------------------------------------------------------------- -// Context window size -// --------------------------------------------------------------------------- - -/// Context window size in tokens (from config). pub fn context_window() -> usize { crate::config::get().api_context_window } -/// Context budget in tokens: 80% of the model's context window. pub fn context_budget_tokens() -> usize { context_window() * 80 / 100 } -/// Detect context window overflow errors from the API. pub fn is_context_overflow(err: &anyhow::Error) -> bool { let msg = err.to_string().to_lowercase(); msg.contains("context length") @@ -599,7 +667,345 @@ pub fn is_context_overflow(err: &anyhow::Error) -> bool { || (msg.contains("400") && msg.contains("tokens")) } -/// Detect model/provider errors delivered inside the SSE stream. pub fn is_stream_error(err: &anyhow::Error) -> bool { err.to_string().contains("model stream error") } + +#[cfg(test)] +mod tests { + use super::*; + + // -- Helpers for inspecting parse results ---------------------------------- + + /// Extract child bodies from an Assistant branch node. + fn child_bodies(node: &AstNode) -> Vec<&NodeBody> { + match node { + AstNode::Branch { children, .. } => + children.iter().filter_map(|c| c.leaf()).map(|l| l.body()).collect(), + _ => panic!("expected branch"), + } + } + + fn assert_content(body: &NodeBody, expected: &str) { + match body { + NodeBody::Content(t) => assert_eq!(t, expected), + other => panic!("expected Content, got {:?}", other), + } + } + + fn assert_thinking(body: &NodeBody, expected: &str) { + match body { + NodeBody::Thinking(t) => assert_eq!(t, expected), + other => panic!("expected Thinking, got {:?}", other), + } + } + + fn assert_tool_call<'a>(body: &'a NodeBody, expected_name: &str) -> &'a str { + match body { + NodeBody::ToolCall { name, arguments } => { + assert_eq!(name, expected_name); + arguments + } + other => panic!("expected ToolCall, got {:?}", other), + } + } + + // -- XML parsing tests ---------------------------------------------------- + + #[test] + fn test_tool_call_xml_parse_clean() { + let body = "\npoc-memory used core-personality\n"; + let (name, args) = parse_tool_call_body(body).unwrap(); + assert_eq!(name, "bash"); + let args: serde_json::Value = serde_json::from_str(&args).unwrap(); + assert_eq!(args["command"], "poc-memory used core-personality"); + } + + #[test] + fn test_tool_call_xml_parse_streamed_whitespace() { + let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd\n"; + let (name, args) = parse_tool_call_body(body).unwrap(); + assert_eq!(name, "bash"); + let args: serde_json::Value = serde_json::from_str(&args).unwrap(); + assert_eq!(args["command"], "pwd"); + } + + #[test] + fn test_tool_call_json_parse() { + let body = r#"{"name": "bash", "arguments": {"command": "ls"}}"#; + let (name, args) = parse_tool_call_body(body).unwrap(); + assert_eq!(name, "bash"); + let args: serde_json::Value = serde_json::from_str(&args).unwrap(); + assert_eq!(args["command"], "ls"); + } + + #[test] + fn test_normalize_preserves_content() { + let text = "\necho hello world\n"; + let normalized = normalize_xml_tags(text); + assert_eq!(normalized, text); + } + + #[test] + fn test_normalize_strips_tag_internal_whitespace() { + assert_eq!(normalize_xml_tags("<\nfunction\n=\nbash\n>"), ""); + } + + // -- ResponseParser tests ------------------------------------------------- + + #[test] + fn test_parser_plain_text() { + let mut p = ResponseParser::new(); + p.feed("hello world"); + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 1); + assert_content(bodies[0], "hello world"); + } + + #[test] + fn test_parser_thinking_then_content() { + let mut p = ResponseParser::new(); + p.feed("reasoninganswer"); + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 2); + assert_thinking(bodies[0], "reasoning"); + assert_content(bodies[1], "answer"); + } + + #[test] + fn test_parser_tool_call() { + let mut p = ResponseParser::new(); + let calls = p.feed("\n\nls\n\n"); + assert_eq!(calls.len(), 1); // returned for immediate dispatch + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 1); + let args = assert_tool_call(bodies[0], "bash"); + let args: serde_json::Value = serde_json::from_str(args).unwrap(); + assert_eq!(args["command"], "ls"); + } + + #[test] + fn test_parser_content_then_tool_call_then_content() { + let mut p = ResponseParser::new(); + p.feed("before"); + p.feed("\n\npwd\n\n"); + p.feed("after"); + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 3); + assert_content(bodies[0], "before"); + assert_tool_call(bodies[1], "bash"); + assert_content(bodies[2], "after"); + } + + #[test] + fn test_parser_incremental_feed() { + // Feed the response one character at a time + let text = "thoughtresponse"; + let mut p = ResponseParser::new(); + for ch in text.chars() { + p.feed(&ch.to_string()); + } + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 2); + assert_thinking(bodies[0], "thought"); + assert_content(bodies[1], "response"); + } + + #[test] + fn test_parser_incremental_tool_call() { + let text = "text\n\nls\n\nmore"; + let mut p = ResponseParser::new(); + let mut total_calls = 0; + for ch in text.chars() { + total_calls += p.feed(&ch.to_string()).len(); + } + assert_eq!(total_calls, 1); // exactly one tool call dispatched + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 3); + assert_content(bodies[0], "text"); + assert_tool_call(bodies[1], "bash"); + assert_content(bodies[2], "more"); + } + + #[test] + fn test_parser_thinking_tool_call_content() { + let mut p = ResponseParser::new(); + p.feed("let me think"); + p.feed("\n\n/etc/hosts\n\n"); + p.feed("here's what I found"); + let node = p.finish(); + let bodies = child_bodies(&node); + assert_eq!(bodies.len(), 3); + assert_thinking(bodies[0], "let me think"); + assert_tool_call(bodies[1], "read"); + assert_content(bodies[2], "here's what I found"); + } + + #[test] + fn test_parser_finish_produces_assistant_branch() { + let mut p = ResponseParser::new(); + p.feed("hello"); + let node = p.finish(); + match &node { + AstNode::Branch { role, .. } => assert_eq!(*role, Role::Assistant), + _ => panic!("expected branch"), + } + } + + // -- Round-trip rendering tests ------------------------------------------- + + #[test] + fn test_render_system_msg() { + let node = AstNode::system_msg("you are helpful"); + assert_eq!(node.render(), "<|im_start|>system\nyou are helpful<|im_end|>\n"); + } + + #[test] + fn test_render_user_msg() { + let node = AstNode::user_msg("hello"); + assert_eq!(node.render(), "<|im_start|>user\nhello<|im_end|>\n"); + } + + #[test] + fn test_render_assistant_with_thinking_and_content() { + let node = AstNode::branch(Role::Assistant, vec![ + AstNode::thinking("hmm"), + AstNode::content("answer"), + ]); + // Thinking renders as empty, content renders as-is + assert_eq!(node.render(), "<|im_start|>assistant\nanswer<|im_end|>\n"); + } + + #[test] + fn test_render_tool_result() { + let node = AstNode::tool_result("output here"); + assert_eq!(node.render(), "<|im_start|>tool\noutput here<|im_end|>\n"); + } + + #[test] + fn test_render_memory() { + let node = AstNode::memory("identity", "I am Proof of Concept"); + assert_eq!(node.render(), "<|im_start|>memory\nI am Proof of Concept<|im_end|>\n"); + } + + #[test] + fn test_render_dmn() { + let node = AstNode::dmn("subconscious prompt"); + assert_eq!(node.render(), "<|im_start|>dmn\nsubconscious prompt<|im_end|>\n"); + } + + #[test] + fn test_render_tool_call() { + let node = AstNode::tool_call("bash", r#"{"command":"ls"}"#); + let rendered = node.render(); + assert!(rendered.contains("")); + assert!(rendered.contains("")); + assert!(rendered.contains("")); + assert!(rendered.contains("ls")); + assert!(rendered.contains("")); + } + + // -- Tokenizer round-trip tests ------------------------------------------- + // These require the tokenizer file; skipped if not present. + + fn init_tokenizer() -> bool { + let path = format!("{}/.consciousness/tokenizer-qwen35.json", + std::env::var("HOME").unwrap_or_default()); + if std::path::Path::new(&path).exists() { + tokenizer::init(&path); + true + } else { + false + } + } + + /// token_ids() must equal encode(render()) for all node types + fn assert_token_roundtrip(node: &AstNode) { + let rendered = node.render(); + let expected = tokenizer::encode(&rendered); + let actual = node.token_ids(); + assert_eq!(actual, expected, + "token_ids mismatch for rendered: {:?}", rendered); + } + + #[test] + fn test_tokenize_roundtrip_leaf_types() { + if !init_tokenizer() { return; } + + assert_token_roundtrip(&AstNode::system_msg("you are a helpful assistant")); + assert_token_roundtrip(&AstNode::user_msg("what is 2+2?")); + assert_token_roundtrip(&AstNode::tool_result("4")); + assert_token_roundtrip(&AstNode::memory("identity", "I am Proof of Concept")); + assert_token_roundtrip(&AstNode::dmn("check the memory store")); + assert_token_roundtrip(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#)); + } + + #[test] + fn test_tokenize_roundtrip_assistant_branch() { + if !init_tokenizer() { return; } + + let node = AstNode::branch(Role::Assistant, vec![ + AstNode::content("here's what I found:\n"), + AstNode::tool_call("bash", r#"{"command":"pwd"}"#), + AstNode::content("\nthat's the current directory"), + ]); + assert_token_roundtrip(&node); + } + + #[test] + fn test_tokenize_invisible_nodes_are_zero() { + if !init_tokenizer() { return; } + + assert_eq!(AstNode::thinking("deep thoughts").tokens(), 0); + assert_eq!(AstNode::log("debug info").tokens(), 0); + } + + #[test] + fn test_tokenize_decode_roundtrip() { + if !init_tokenizer() { return; } + + // Content without special tokens round-trips through decode + let text = "hello world, this is a test"; + let ids = tokenizer::encode(text); + let decoded = tokenizer::decode(&ids); + assert_eq!(decoded, text); + } + + #[test] + fn test_tokenize_context_state_matches_concatenation() { + if !init_tokenizer() { return; } + + let mut ctx = ContextState::new(); + ctx.push(Section::System, AstNode::system_msg("you are helpful")); + ctx.push(Section::Identity, AstNode::memory("name", "Proof of Concept")); + ctx.push(Section::Conversation, AstNode::user_msg("hi")); + + let rendered = ctx.render(); + let expected = tokenizer::encode(&rendered); + let actual = ctx.token_ids(); + assert_eq!(actual, expected); + } + + #[test] + fn test_parser_roundtrip_through_tokenizer() { + if !init_tokenizer() { return; } + + // Parse a response, render it, verify it matches the expected format + let mut p = ResponseParser::new(); + p.feed("I'll check that for you"); + p.feed("\n\nls\n\n"); + let node = p.finish(); + + // The assistant branch should tokenize to the same as encoding its render + assert_token_roundtrip(&node); + + // Token count should be nonzero (thinking is invisible but content + tool call are) + assert!(node.tokens() > 0); + } +} diff --git a/src/agent/tokenizer.rs b/src/agent/tokenizer.rs index 5c7108a..cefd492 100644 --- a/src/agent/tokenizer.rs +++ b/src/agent/tokenizer.rs @@ -25,17 +25,21 @@ pub fn init(path: &str) { TOKENIZER.set(t).ok(); } -/// Get the global tokenizer. Panics if not initialized. -fn get() -> &'static Tokenizer { - TOKENIZER.get().expect("tokenizer not initialized — call tokenizer::init() first") +/// Get the global tokenizer. Returns None if not initialized. +fn get() -> Option<&'static Tokenizer> { + TOKENIZER.get() } /// Tokenize a raw string, returning token IDs. +/// Returns empty vec if the tokenizer is not initialized. pub fn encode(text: &str) -> Vec { - get().encode(text, false) - .unwrap_or_else(|e| panic!("tokenization failed: {}", e)) - .get_ids() - .to_vec() + match get() { + Some(t) => t.encode(text, false) + .unwrap_or_else(|e| panic!("tokenization failed: {}", e)) + .get_ids() + .to_vec(), + None => vec![], + } } /// Tokenize a chat entry with template wrapping: @@ -59,8 +63,11 @@ pub fn count(text: &str) -> usize { /// Decode token IDs back to text. pub fn decode(ids: &[u32]) -> String { - get().decode(ids, true) - .unwrap_or_else(|e| panic!("detokenization failed: {}", e)) + match get() { + Some(t) => t.decode(ids, true) + .unwrap_or_else(|e| panic!("detokenization failed: {}", e)), + None => String::new(), + } } /// Check if the tokenizer is initialized.